remote_billing: Add redirect flow for users with expired session.

Implements a nice redirect flow to give a good UX for users attempting
to access a remote billing page with an expired RemoteRealm session e.g.
/realm/some-uuid/sponsorship - perhaps through their browser
history or just their session expired while they were doing things in
this billing system.

The logic has a few pieces:
1. get_remote_realm_from_session, if the user doesn't have a
   identity_dict will raise RemoteBillingAuthenticationError.
2. If the user has an identity_dict, but it's expired, then
   get_identity_dict_from_session inside of get_remote_realm_from_session
   will raise RemoteBillingIdentityExpiredError.
3. The decorator authenticated_remote_realm_management_endpoint
   catches that exception and uses some general logic, described in more
   detail in the comments in the code, to figure out the right URL to
   redirect them to. Something like:
   https://theirserver.example.com/self-hosted-billing/?next_page=...
   where the next_page param is determined based on parsing request.path
   to see what kind of endpoint they're trying to access.
4. The remote_server_billing_entry endpoint is tweaked to also send
   its uri scheme to the bouncer, so that the bouncer can know whether
   to do the redirect on http or https.
This commit is contained in:
Mateusz Mandera
2023-12-02 22:37:54 +01:00
committed by Tim Abbott
parent 4987600edc
commit ec7245d4e1
6 changed files with 194 additions and 12 deletions

View File

@@ -1,17 +1,21 @@
from functools import wraps
from typing import Callable
from urllib.parse import urlencode, urljoin
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render
from typing_extensions import Concatenate, ParamSpec
from corporate.lib.remote_billing_util import (
RemoteBillingIdentityExpiredError,
get_remote_realm_from_session,
get_remote_server_from_session,
)
from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession
from zerver.lib.subdomains import get_subdomain
from zerver.lib.url_encoding import append_url_query_string
from zilencer.models import RemoteRealm
ParamT = ParamSpec("ParamT")
@@ -52,7 +56,65 @@ def authenticated_remote_realm_management_endpoint(
if realm_uuid is not None and not isinstance(realm_uuid, str):
raise TypeError("realm_uuid must be a string or None")
try:
remote_realm = get_remote_realm_from_session(request, realm_uuid)
except RemoteBillingIdentityExpiredError as e:
# The user had an authenticated session with an identity_dict,
# but it expired.
# We want to redirect back to the start of their login flow
# at their {realm.host}/self-hosted-billing/ with a proper
# next parameter to take them back to what they're trying
# to access after re-authing.
# Note: Theoretically we could take the realm_uuid from the request
# path or params to figure out the remote_realm.host for the redirect,
# but that would mean leaking that .host value to anyone who knows
# the uuid. Therefore we limit ourselves to taking the realm_uuid
# from the identity_dict - since that proves that the user at least
# previously was successfully authenticated as a billing admin of that
# realm.
realm_uuid = e.realm_uuid
server_uuid = e.server_uuid
uri_scheme = e.uri_scheme
if realm_uuid is None:
# This doesn't make sense - if get_remote_realm_from_session
# found an expired identity dict, it should have had a realm_uuid.
raise AssertionError
assert server_uuid is not None, "identity_dict with realm_uuid must have server_uuid"
assert uri_scheme is not None, "identity_dict with realm_uuid must have uri_scheme"
try:
remote_realm = RemoteRealm.objects.get(uuid=realm_uuid, server__uuid=server_uuid)
except RemoteRealm.DoesNotExist:
# This should be impossible - unless the RemoteRealm existed and somehow the row
# was deleted.
raise AssertionError
# Using EXTERNAL_URI_SCHEME means we'll do https:// in production, which is
# the sane default - while having http:// in development, which will allow
# these redirects to work there for testing.
url = urljoin(uri_scheme + remote_realm.host, "/self-hosted-billing/")
# Our endpoint URLs in this subsystem end with something like
# /sponsorship or /plans etc.
# Therefore we can use this nice property to figure out easily what
# kind of page the user is trying to access and find the right value
# for the `next` query parameter.
path = request.path
if path.endswith("/"): # nocoverage
path = path[:-1]
page_type = path.split("/")[-1]
from corporate.views.remote_billing_page import (
VALID_NEXT_PAGES as REMOTE_BILLING_VALID_NEXT_PAGES,
)
if page_type in REMOTE_BILLING_VALID_NEXT_PAGES:
query = urlencode({"next_page": page_type})
url = append_url_query_string(url, query)
return HttpResponseRedirect(url)
billing_session = RemoteRealmBillingSession(remote_realm)
return view_func(request, billing_session)

View File

@@ -1,11 +1,11 @@
import logging
from typing import Optional, TypedDict, Union, cast
from typing import Literal, Optional, TypedDict, Union, cast
from django.http import HttpRequest
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from zerver.lib.exceptions import JsonableError
from zerver.lib.exceptions import JsonableError, RemoteBillingAuthenticationError
from zerver.lib.timestamp import datetime_to_timestamp
from zilencer.models import RemoteRealm, RemoteZulipServer
@@ -29,6 +29,7 @@ class RemoteBillingIdentityDict(TypedDict):
remote_realm_uuid: str
authenticated_at: int
uri_scheme: Literal["http://", "https://"]
next_page: Optional[str]
@@ -41,6 +42,19 @@ class LegacyServerIdentityDict(TypedDict):
authenticated_at: int
class RemoteBillingIdentityExpiredError(Exception):
def __init__(
self,
*,
realm_uuid: Optional[str] = None,
server_uuid: Optional[str] = None,
uri_scheme: Optional[Literal["http://", "https://"]] = None,
) -> None:
self.realm_uuid = realm_uuid
self.server_uuid = server_uuid
self.uri_scheme = uri_scheme
def get_identity_dict_from_session(
request: HttpRequest,
*,
@@ -66,7 +80,14 @@ def get_identity_dict_from_session(
datetime_to_timestamp(timezone_now()) - result["authenticated_at"]
> REMOTE_BILLING_SESSION_VALIDITY_SECONDS
):
return None
# In this case we raise, because callers want to catch this as an explicitly
# different scenario from the user not being authenticated, to handle it nicely
# by redirecting them to their login page.
raise RemoteBillingIdentityExpiredError(
realm_uuid=result.get("remote_realm_uuid"),
server_uuid=result.get("remote_server_uuid"),
uri_scheme=result.get("uri_scheme"),
)
return result
@@ -83,7 +104,7 @@ def get_remote_realm_from_session(
)
if identity_dict is None:
raise JsonableError(_("User not authenticated"))
raise RemoteBillingAuthenticationError
remote_server_uuid = identity_dict["remote_server_uuid"]
remote_realm_uuid = identity_dict["remote_realm_uuid"]

View File

@@ -55,6 +55,7 @@ class RemoteBillingAuthenticationTest(BouncerTestCase):
remote_server_uuid=str(self.server.uuid),
remote_realm_uuid=str(user.realm.uuid),
authenticated_at=datetime_to_timestamp(now),
uri_scheme="http://",
next_page=next_page,
)
self.assertEqual(
@@ -154,13 +155,81 @@ class RemoteBillingAuthenticationTest(BouncerTestCase):
self.assert_in_success_response([desdemona.delivery_email], result)
# Now go there again, simulating doing this after the session has expired.
# We should be denied access.
# We should be denied access and redirected to re-auth.
with time_machine.travel(
now + datetime.timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 1),
tick=False,
):
result = self.client_get(final_url, subdomain="selfhosting")
self.assert_json_error(result, "User not authenticated")
self.assertEqual(result.status_code, 302)
self.assertEqual(
result["Location"],
f"http://{desdemona.realm.host}/self-hosted-billing/?next_page=plans",
)
# Opening this re-auth URL in result["Location"] is same as re-doing the auth
# flow via execute_remote_billing_authentication_flow with next_page="plans".
# So let's test that and assert that we end up successfully re-authed on the /plans
# page.
result = self.execute_remote_billing_authentication_flow(desdemona, next_page="plans")
self.assertEqual(result["Location"], f"/realm/{realm.uuid!s}/plans")
result = self.client_get(result["Location"], subdomain="selfhosting")
self.assert_in_success_response(["Your remote user info:"], result)
self.assert_in_success_response([desdemona.delivery_email], result)
@responses.activate
def test_remote_billing_unauthed_access(self) -> None:
now = timezone_now()
self.login("desdemona")
desdemona = self.example_user("desdemona")
realm = desdemona.realm
self.add_mock_response()
send_realms_only_to_push_bouncer()
# Straight-up access without authing at all:
result = self.client_get(f"/realm/{realm.uuid!s}/plans", subdomain="selfhosting")
self.assert_json_error(result, "User not authenticated", 401)
result = self.execute_remote_billing_authentication_flow(desdemona)
self.assertEqual(result["Location"], f"/realm/{realm.uuid!s}/plans")
final_url = result["Location"]
# Sanity check - access is granted after authing:
result = self.client_get(final_url, subdomain="selfhosting")
self.assertEqual(result.status_code, 200)
# Now mess with the identity dict in the session in unlikely ways so that it should
# not grant access.
# First delete the RemoteRealm entry for this session.
RemoteRealm.objects.filter(uuid=realm.uuid).delete()
with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError):
self.client_get(final_url, subdomain="selfhosting")
self.assertIn(
"The remote realm is missing despite being in the RemoteBillingIdentityDict",
m.output[0],
)
# Try the case where the identity dict is simultaneously expired.
with time_machine.travel(
now + datetime.timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30),
tick=False,
):
with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError):
self.client_get(final_url, subdomain="selfhosting")
# The django.request log should be a traceback, mentioning the relevant
# exceptions that occurred.
self.assertIn(
"RemoteBillingIdentityExpiredError",
m.output[0],
)
self.assertIn(
"AssertionError",
m.output[0],
)
@responses.activate
def test_remote_billing_authentication_flow_to_sponsorship_page(self) -> None:

View File

@@ -12,7 +12,10 @@ from django.utils.translation import gettext as _
from django.views.decorators.csrf import csrf_exempt
from pydantic import Json
from corporate.lib.decorator import self_hosting_management_endpoint
from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint,
self_hosting_management_endpoint,
)
from corporate.lib.remote_billing_util import (
REMOTE_BILLING_SESSION_VALIDITY_SECONDS,
LegacyServerIdentityDict,
@@ -20,6 +23,7 @@ from corporate.lib.remote_billing_util import (
RemoteBillingUserDict,
get_identity_dict_from_session,
)
from corporate.lib.stripe import RemoteRealmBillingSession
from zerver.lib.exceptions import JsonableError, MissingRemoteRealmError
from zerver.lib.remote_server import RealmDataForAnalytics, UserDataForRemoteBilling
from zerver.lib.response import json_success
@@ -42,6 +46,7 @@ def remote_server_billing_entry(
*,
user: Json[UserDataForRemoteBilling],
realm: Json[RealmDataForAnalytics],
uri_scheme: Literal["http://", "https://"] = "https://",
next_page: VALID_NEXT_PAGES_TYPE = None,
) -> HttpResponse:
if not settings.DEVELOPMENT:
@@ -61,6 +66,7 @@ def remote_server_billing_entry(
remote_server_uuid=str(remote_server.uuid),
remote_realm_uuid=str(remote_realm.uuid),
authenticated_at=datetime_to_timestamp(timezone_now()),
uri_scheme=uri_scheme,
next_page=next_page,
)
@@ -194,9 +200,11 @@ def remote_billing_plans_common(
return render_tmp_remote_billing_page(request, realm_uuid=realm_uuid, server_uuid=server_uuid)
@self_hosting_management_endpoint
@typed_endpoint
def remote_realm_plans_page(request: HttpRequest, *, realm_uuid: PathOnly[str]) -> HttpResponse:
@authenticated_remote_realm_management_endpoint
def remote_realm_plans_page(
request: HttpRequest, billing_session: RemoteRealmBillingSession
) -> HttpResponse:
realm_uuid = str(billing_session.remote_realm.uuid)
return remote_billing_plans_common(request, realm_uuid=realm_uuid, server_uuid=None)

View File

@@ -49,6 +49,7 @@ class ErrorCode(Enum):
MISSING_REMOTE_REALM = auto()
TOPIC_WILDCARD_MENTION_NOT_ALLOWED = auto()
STREAM_WILDCARD_MENTION_NOT_ALLOWED = auto()
REMOTE_BILLING_UNAUTHENTICATED_USER = auto()
class JsonableError(Exception):
@@ -445,6 +446,22 @@ class MissingAuthenticationError(JsonableError):
# converted into json_unauthorized in Zulip's middleware.
class RemoteBillingAuthenticationError(JsonableError):
# We want this as a distinct class from MissingAuthenticationError,
# as we don't want the json_unauthorized conversion mechanism to apply
# to this.
code = ErrorCode.REMOTE_BILLING_UNAUTHENTICATED_USER
http_status_code = 401
def __init__(self) -> None:
pass
@staticmethod
@override
def msg_format() -> str:
return _("User not authenticated")
class InvalidSubdomainError(JsonableError):
code = ErrorCode.NONEXISTENT_SUBDOMAIN
http_status_code = 404

View File

@@ -148,6 +148,11 @@ def self_hosting_auth_redirect(
post_data = {
"user": user_info.model_dump_json(),
"realm": realm_info.model_dump_json(),
# The uri_scheme is necessary for the bouncer to know the correct URL
# to redirect the user to for re-authing in case the session expires.
# Otherwise, the bouncer would know only the realm.host but be missing
# the knowledge of whether to use http or https.
"uri_scheme": settings.EXTERNAL_URI_SCHEME,
}
if next_page is not None:
post_data["next_page"] = next_page