mirror of
https://github.com/zulip/zulip.git
synced 2025-10-23 04:52:12 +00:00
235 lines
9.7 KiB
Python
235 lines
9.7 KiB
Python
import inspect
|
|
from collections.abc import Callable
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Concatenate
|
|
from urllib.parse import urlencode, urljoin
|
|
|
|
from django.conf import settings
|
|
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
|
|
from django.shortcuts import render
|
|
from django.urls import reverse
|
|
from typing_extensions import ParamSpec
|
|
|
|
from corporate.lib.remote_billing_util import (
|
|
RemoteBillingIdentityExpiredError,
|
|
get_remote_realm_and_user_from_session,
|
|
get_remote_server_and_user_from_session,
|
|
)
|
|
from zerver.lib.exceptions import RemoteBillingAuthenticationError
|
|
from zerver.lib.subdomains import get_subdomain
|
|
from zerver.lib.url_encoding import append_url_query_string
|
|
from zilencer.models import RemoteRealm
|
|
|
|
if TYPE_CHECKING:
|
|
from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession
|
|
|
|
ParamT = ParamSpec("ParamT")
|
|
|
|
|
|
def session_expired_ajax_response(login_url: str) -> JsonResponse: # nocoverage
|
|
return JsonResponse(
|
|
{
|
|
"error_message": "Remote billing authentication expired",
|
|
"login_url": login_url,
|
|
},
|
|
status=401,
|
|
)
|
|
|
|
|
|
def is_self_hosting_management_subdomain(request: HttpRequest) -> bool:
|
|
subdomain = get_subdomain(request)
|
|
return subdomain == settings.SELF_HOSTING_MANAGEMENT_SUBDOMAIN
|
|
|
|
|
|
def self_hosting_management_endpoint(
|
|
view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse],
|
|
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
|
|
@wraps(view_func)
|
|
def _wrapped_view_func(
|
|
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
|
) -> HttpResponse:
|
|
if not is_self_hosting_management_subdomain(request): # nocoverage
|
|
return render(request, "404.html", status=404)
|
|
return view_func(request, *args, **kwargs)
|
|
|
|
return _wrapped_view_func
|
|
|
|
|
|
def authenticated_remote_realm_management_endpoint(
|
|
view_func: Callable[
|
|
Concatenate[HttpRequest, "RemoteRealmBillingSession", ParamT], HttpResponse
|
|
],
|
|
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
|
|
@wraps(view_func)
|
|
def _wrapped_view_func(
|
|
request: HttpRequest,
|
|
/,
|
|
*args: ParamT.args,
|
|
**kwargs: ParamT.kwargs,
|
|
) -> HttpResponse:
|
|
from corporate.lib.stripe import RemoteRealmBillingSession
|
|
|
|
if not is_self_hosting_management_subdomain(request): # nocoverage
|
|
return render(request, "404.html", status=404)
|
|
|
|
realm_uuid = kwargs.pop("realm_uuid")
|
|
if realm_uuid is not None and not isinstance(realm_uuid, str): # nocoverage
|
|
raise TypeError("realm_uuid must be a string or None")
|
|
|
|
try:
|
|
remote_realm, remote_billing_user = get_remote_realm_and_user_from_session(
|
|
request, realm_uuid
|
|
)
|
|
except RemoteBillingIdentityExpiredError as e:
|
|
# The user had an authenticated session with an identity_dict,
|
|
# but it expired.
|
|
# We want to redirect back to the start of their login flow
|
|
# at their {realm.host}/self-hosted-billing/ with a proper
|
|
# next parameter to take them back to what they're trying
|
|
# to access after re-authing.
|
|
# Note: Theoretically we could take the realm_uuid from the request
|
|
# path or params to figure out the remote_realm.host for the redirect,
|
|
# but that would mean leaking that .host value to anyone who knows
|
|
# the uuid. Therefore we limit ourselves to taking the realm_uuid
|
|
# from the identity_dict - since that proves that the user at least
|
|
# previously was successfully authenticated as a billing admin of that
|
|
# realm.
|
|
realm_uuid = e.realm_uuid
|
|
server_uuid = e.server_uuid
|
|
uri_scheme = e.uri_scheme
|
|
if realm_uuid is None:
|
|
# This doesn't make sense - if get_remote_realm_and_user_from_session
|
|
# found an expired identity dict, it should have had a realm_uuid.
|
|
raise AssertionError
|
|
|
|
assert server_uuid is not None, "identity_dict with realm_uuid must have server_uuid"
|
|
assert uri_scheme is not None, "identity_dict with realm_uuid must have uri_scheme"
|
|
|
|
try:
|
|
remote_realm = RemoteRealm.objects.get(uuid=realm_uuid, server__uuid=server_uuid)
|
|
except RemoteRealm.DoesNotExist:
|
|
# This should be impossible - unless the RemoteRealm existed and somehow the row
|
|
# was deleted.
|
|
raise AssertionError
|
|
|
|
# Using EXTERNAL_URI_SCHEME means we'll do https:// in production, which is
|
|
# the sane default - while having http:// in development, which will allow
|
|
# these redirects to work there for testing.
|
|
url = urljoin(uri_scheme + remote_realm.host, "/self-hosted-billing/")
|
|
|
|
page_type = get_next_page_param_from_request_path(request)
|
|
if page_type is not None:
|
|
query = urlencode({"next_page": page_type})
|
|
url = append_url_query_string(url, query)
|
|
|
|
# Return error for AJAX requests with url.
|
|
if (
|
|
request.get_preferred_type(["application/json", "text/html"]) != "text/html"
|
|
): # nocoverage
|
|
return session_expired_ajax_response(url)
|
|
|
|
return HttpResponseRedirect(url)
|
|
|
|
billing_session = RemoteRealmBillingSession(
|
|
remote_realm, remote_billing_user=remote_billing_user
|
|
)
|
|
return view_func(request, billing_session, *args, **kwargs)
|
|
|
|
signature = inspect.signature(view_func)
|
|
request_parameter, billing_session_parameter, *other_parameters = signature.parameters.values()
|
|
_wrapped_view_func.__signature__ = signature.replace( # type: ignore[attr-defined] # too magic
|
|
parameters=[request_parameter, *other_parameters]
|
|
)
|
|
_wrapped_view_func.__annotations__ = {
|
|
k: v for k, v in view_func.__annotations__.items() if k != billing_session_parameter.name
|
|
}
|
|
|
|
return _wrapped_view_func
|
|
|
|
|
|
def get_next_page_param_from_request_path(request: HttpRequest) -> str | None:
|
|
# Our endpoint URLs in this subsystem end with something like
|
|
# /sponsorship or /plans etc.
|
|
# Therefore we can use this nice property to figure out easily what
|
|
# kind of page the user is trying to access and find the right value
|
|
# for the `next` query parameter.
|
|
path = request.path.removesuffix("/")
|
|
page_type = path.split("/")[-1]
|
|
|
|
from corporate.views.remote_billing_page import (
|
|
VALID_NEXT_PAGES as REMOTE_BILLING_VALID_NEXT_PAGES,
|
|
)
|
|
|
|
if page_type in REMOTE_BILLING_VALID_NEXT_PAGES:
|
|
return page_type
|
|
|
|
# page_type is not where we want user to go after a login, so just render the default page.
|
|
return None # nocoverage
|
|
|
|
|
|
def authenticated_remote_server_management_endpoint(
|
|
view_func: Callable[
|
|
Concatenate[HttpRequest, "RemoteServerBillingSession", ParamT], HttpResponse
|
|
],
|
|
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
|
|
@wraps(view_func)
|
|
def _wrapped_view_func(
|
|
request: HttpRequest,
|
|
/,
|
|
*args: ParamT.args,
|
|
**kwargs: ParamT.kwargs,
|
|
) -> HttpResponse:
|
|
from corporate.lib.stripe import RemoteServerBillingSession
|
|
|
|
if not is_self_hosting_management_subdomain(request): # nocoverage
|
|
return render(request, "404.html", status=404)
|
|
|
|
server_uuid = kwargs.pop("server_uuid")
|
|
if not isinstance(server_uuid, str):
|
|
raise TypeError("server_uuid must be a string") # nocoverage
|
|
|
|
try:
|
|
remote_server, remote_billing_user = get_remote_server_and_user_from_session(
|
|
request, server_uuid=server_uuid
|
|
)
|
|
if remote_billing_user is None:
|
|
# This should only be possible if the user hasn't finished the confirmation flow
|
|
# and doesn't have a fully authenticated session yet. They should not be attempting
|
|
# to access this endpoint yet.
|
|
raise RemoteBillingAuthenticationError
|
|
except (RemoteBillingIdentityExpiredError, RemoteBillingAuthenticationError):
|
|
# In this flow, we can only redirect to our local "legacy server flow login" page.
|
|
# That means that we can do it universally whether the user has an expired
|
|
# identity_dict, or just lacks any form of authentication info at all - there
|
|
# are no security concerns since this is just a local redirect.
|
|
page_type = get_next_page_param_from_request_path(request)
|
|
url = reverse(
|
|
"remote_billing_legacy_server_login",
|
|
query=None if page_type is None else {"next_page": page_type},
|
|
)
|
|
|
|
# Return error for AJAX requests with url.
|
|
if (
|
|
request.get_preferred_type(["application/json", "text/html"]) != "text/html"
|
|
): # nocoverage
|
|
return session_expired_ajax_response(url)
|
|
|
|
return HttpResponseRedirect(url)
|
|
|
|
assert remote_billing_user is not None
|
|
billing_session = RemoteServerBillingSession(
|
|
remote_server, remote_billing_user=remote_billing_user
|
|
)
|
|
return view_func(request, billing_session, *args, **kwargs)
|
|
|
|
signature = inspect.signature(view_func)
|
|
request_parameter, billing_session_parameter, *other_parameters = signature.parameters.values()
|
|
_wrapped_view_func.__signature__ = signature.replace( # type: ignore[attr-defined] # too magic
|
|
parameters=[request_parameter, *other_parameters]
|
|
)
|
|
_wrapped_view_func.__annotations__ = {
|
|
k: v for k, v in view_func.__annotations__.items() if k != billing_session_parameter.name
|
|
}
|
|
|
|
return _wrapped_view_func
|