decorator: Add wrapper to directly pass remote_realm to view_func.

This commit is contained in:
Aman Agrawal
2023-11-26 17:34:33 +00:00
committed by Tim Abbott
parent 354330d81b
commit ede73fc2c6
2 changed files with 41 additions and 7 deletions

View File

@@ -1,16 +1,23 @@
from functools import wraps
from typing import Callable, TypeVar
from typing import Callable
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from django.shortcuts import render
from typing_extensions import Concatenate, ParamSpec
from corporate.lib.remote_billing_util import get_remote_realm_from_session
from zerver.lib.subdomains import get_subdomain
from zilencer.models import RemoteRealm
ParamT = ParamSpec("ParamT")
def is_self_hosting_management_subdomain(request: HttpRequest) -> bool: # nocoverage
subdomain = get_subdomain(request)
return settings.DEVELOPMENT and subdomain == settings.SELF_HOSTING_MANAGEMENT_SUBDOMAIN
def self_hosting_management_endpoint(
view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: # nocoverage
@@ -18,9 +25,36 @@ def self_hosting_management_endpoint(
def _wrapped_view_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
subdomain = get_subdomain(request)
if not settings.DEVELOPMENT or subdomain != settings.SELF_HOSTING_MANAGEMENT_SUBDOMAIN:
if not is_self_hosting_management_subdomain(request):
return render(request, "404.html", status=404)
return view_func(request, *args, **kwargs)
return _wrapped_view_func
def authenticated_remote_realm_management_endpoint(
view_func: Callable[Concatenate[HttpRequest, RemoteRealm, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, RemoteRealm, ParamT], HttpResponse]: # nocoverage
@wraps(view_func)
def _wrapped_view_func(
request: HttpRequest,
remote_realm: RemoteRealm,
/,
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> HttpResponse:
if not is_self_hosting_management_subdomain(request):
return render(request, "404.html", status=404)
realm_uuid = kwargs.get("realm_uuid")
server_uuid = kwargs.get("server_uuid")
if realm_uuid is not None and not isinstance(realm_uuid, str):
raise TypeError("realm_uuid must be a string or None")
if server_uuid is not None and not isinstance(server_uuid, str):
raise TypeError("server_uuid must be a string or None")
remote_realm = get_remote_realm_from_session(
request, realm_uuid=realm_uuid, server_uuid=server_uuid
)
return view_func(request, remote_realm, *args, **kwargs)
return _wrapped_view_func

View File

@@ -8,8 +8,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render
from pydantic import Json
from corporate.lib.decorator import self_hosting_management_endpoint
from corporate.lib.remote_billing_util import get_remote_realm_from_session
from corporate.lib.decorator import authenticated_remote_realm_management_endpoint
from corporate.lib.stripe import (
VALID_BILLING_MODALITY_VALUES,
VALID_BILLING_SCHEDULE_VALUES,
@@ -30,6 +29,7 @@ from zerver.lib.send_email import FromAddress, send_email
from zerver.lib.typed_endpoint import PathOnly, typed_endpoint
from zerver.lib.validator import check_bool, check_int, check_string_in
from zerver.models import UserProfile, get_org_type_display_name
from zilencer.models import RemoteRealm
billing_logger = logging.getLogger("corporate.stripe")
@@ -107,15 +107,15 @@ def upgrade_page(
return response
@self_hosting_management_endpoint
@authenticated_remote_realm_management_endpoint
@typed_endpoint
def remote_realm_upgrade_page(
request: HttpRequest,
remote_realm: RemoteRealm,
*,
realm_uuid: PathOnly[str],
manual_license_management: Json[bool] = False,
) -> HttpResponse: # nocoverage
remote_realm = get_remote_realm_from_session(request, realm_uuid)
initial_upgrade_request = InitialUpgradeRequest(
manual_license_management=manual_license_management,
tier=CustomerPlan.STANDARD,