remote_billing: Add redirects to login for unauthed user in legacy flow.

Analogical to the more complex mechanism implemented for the RemoteRealm
flow in a previous commit in
authenticated_remote_realm_management_endpoint.

As explained in the code comment, this is much easier because:

In this flow, we can only redirect to our local "legacy server flow
login" page. That means that we can do it universally whether the user
has an expired
identity_dict, or just lacks any form of authentication info at all -
there are no security concerns since this is just a local redirect.
This commit is contained in:
Mateusz Mandera
2023-12-03 02:58:02 +01:00
committed by Tim Abbott
parent 44ac99b8fc
commit 134e3bfa68
3 changed files with 54 additions and 24 deletions

View File

@@ -1,10 +1,11 @@
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable, Optional
from urllib.parse import urlencode, urljoin from urllib.parse import urlencode, urljoin
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
from django.urls import reverse
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from corporate.lib.remote_billing_util import ( from corporate.lib.remote_billing_util import (
@@ -13,6 +14,7 @@ from corporate.lib.remote_billing_util import (
get_remote_server_from_session, get_remote_server_from_session,
) )
from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession
from zerver.lib.exceptions import RemoteBillingAuthenticationError
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
from zerver.lib.url_encoding import append_url_query_string from zerver.lib.url_encoding import append_url_query_string
from zilencer.models import RemoteRealm from zilencer.models import RemoteRealm
@@ -95,22 +97,8 @@ def authenticated_remote_realm_management_endpoint(
# these redirects to work there for testing. # these redirects to work there for testing.
url = urljoin(uri_scheme + remote_realm.host, "/self-hosted-billing/") url = urljoin(uri_scheme + remote_realm.host, "/self-hosted-billing/")
# Our endpoint URLs in this subsystem end with something like page_type = get_next_page_param_from_request_path(request)
# /sponsorship or /plans etc. if page_type is not None:
# Therefore we can use this nice property to figure out easily what
# kind of page the user is trying to access and find the right value
# for the `next` query parameter.
path = request.path
if path.endswith("/"): # nocoverage
path = path[:-1]
page_type = path.split("/")[-1]
from corporate.views.remote_billing_page import (
VALID_NEXT_PAGES as REMOTE_BILLING_VALID_NEXT_PAGES,
)
if page_type in REMOTE_BILLING_VALID_NEXT_PAGES:
query = urlencode({"next_page": page_type}) query = urlencode({"next_page": page_type})
url = append_url_query_string(url, query) url = append_url_query_string(url, query)
@@ -122,6 +110,31 @@ def authenticated_remote_realm_management_endpoint(
return _wrapped_view_func return _wrapped_view_func
def get_next_page_param_from_request_path(request: HttpRequest) -> Optional[str]: # nocoverage
# Our endpoint URLs in this subsystem end with something like
# /sponsorship or /plans etc.
# Therefore we can use this nice property to figure out easily what
# kind of page the user is trying to access and find the right value
# for the `next` query parameter.
path = request.path
if path.endswith("/"):
path = path[:-1]
page_type = path.split("/")[-1]
from corporate.views.remote_billing_page import (
VALID_NEXT_PAGES as REMOTE_BILLING_VALID_NEXT_PAGES,
)
if page_type in REMOTE_BILLING_VALID_NEXT_PAGES:
return page_type
# Should be impossible to reach here. If this is reached, it must mean
# we have a registered endpoint that doesn't have a VALID_NEXT_PAGES entry
# or the parsing logic above is failing.
raise AssertionError(f"Unknown page type: {page_type}")
def authenticated_remote_server_management_endpoint( def authenticated_remote_server_management_endpoint(
view_func: Callable[Concatenate[HttpRequest, RemoteServerBillingSession, ParamT], HttpResponse] view_func: Callable[Concatenate[HttpRequest, RemoteServerBillingSession, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: # nocoverage ) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: # nocoverage
@@ -139,7 +152,21 @@ def authenticated_remote_server_management_endpoint(
if not isinstance(server_uuid, str): if not isinstance(server_uuid, str):
raise TypeError("server_uuid must be a string") raise TypeError("server_uuid must be a string")
remote_server = get_remote_server_from_session(request, server_uuid=server_uuid) try:
remote_server = get_remote_server_from_session(request, server_uuid=server_uuid)
except (RemoteBillingIdentityExpiredError, RemoteBillingAuthenticationError):
# In this flow, we can only redirect to our local "legacy server flow login" page.
# That means that we can do it universally whether the user has an expired
# identity_dict, or just lacks any form of authentication info at all - there
# are no security concerns since this is just a local redirect.
url = reverse("remote_billing_legacy_server_login")
page_type = get_next_page_param_from_request_path(request)
if page_type is not None:
query = urlencode({"next_page": page_type})
url = append_url_query_string(url, query)
return HttpResponseRedirect(url)
billing_session = RemoteServerBillingSession(remote_server) billing_session = RemoteServerBillingSession(remote_server)
return view_func(request, billing_session) return view_func(request, billing_session)

View File

@@ -137,7 +137,7 @@ def get_remote_server_from_session(
) )
if identity_dict is None: if identity_dict is None:
raise JsonableError(_("User not authenticated")) raise RemoteBillingAuthenticationError
remote_server_uuid = identity_dict["remote_server_uuid"] remote_server_uuid = identity_dict["remote_server_uuid"]
try: try:

View File

@@ -14,6 +14,7 @@ from pydantic import Json
from corporate.lib.decorator import ( from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint, authenticated_remote_realm_management_endpoint,
authenticated_remote_server_management_endpoint,
self_hosting_management_endpoint, self_hosting_management_endpoint,
) )
from corporate.lib.remote_billing_util import ( from corporate.lib.remote_billing_util import (
@@ -23,12 +24,12 @@ from corporate.lib.remote_billing_util import (
RemoteBillingUserDict, RemoteBillingUserDict,
get_identity_dict_from_session, get_identity_dict_from_session,
) )
from corporate.lib.stripe import RemoteRealmBillingSession from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession
from zerver.lib.exceptions import JsonableError, MissingRemoteRealmError from zerver.lib.exceptions import JsonableError, MissingRemoteRealmError
from zerver.lib.remote_server import RealmDataForAnalytics, UserDataForRemoteBilling from zerver.lib.remote_server import RealmDataForAnalytics, UserDataForRemoteBilling
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.typed_endpoint import PathOnly, typed_endpoint from zerver.lib.typed_endpoint import typed_endpoint
from zilencer.models import RemoteRealm, RemoteZulipServer, get_remote_server_by_uuid from zilencer.models import RemoteRealm, RemoteZulipServer, get_remote_server_by_uuid
billing_logger = logging.getLogger("corporate.stripe") billing_logger = logging.getLogger("corporate.stripe")
@@ -208,9 +209,11 @@ def remote_realm_plans_page(
return remote_billing_plans_common(request, realm_uuid=realm_uuid, server_uuid=None) return remote_billing_plans_common(request, realm_uuid=realm_uuid, server_uuid=None)
@self_hosting_management_endpoint @authenticated_remote_server_management_endpoint
@typed_endpoint def remote_server_plans_page(
def remote_server_plans_page(request: HttpRequest, *, server_uuid: PathOnly[str]) -> HttpResponse: request: HttpRequest, billing_session: RemoteServerBillingSession
) -> HttpResponse:
server_uuid = str(billing_session.remote_server.uuid)
return remote_billing_plans_common(request, server_uuid=server_uuid, realm_uuid=None) return remote_billing_plans_common(request, server_uuid=server_uuid, realm_uuid=None)