From 99b97ea883aadeda365d146014b58d3fc6737b0b Mon Sep 17 00:00:00 2001 From: Mateusz Mandera Date: Fri, 11 Jul 2025 16:00:05 +0800 Subject: [PATCH] saml: Don't put group_memberships_sync_map in the session. In 40956ae4c524416e07399ceffa742e2671bed77f we implemented group sync via SAML during sign in and sign up. The sign up implementation used a session variable group_memberships_sync_map to plumb through the sync information to the registration codepath, to execute group sync after user creation. We can use a more robust approach instead, and just amend groups on the `PreregistrationUser` object that's going to be used for registration. --- zerver/tests/test_auth_backends.py | 14 ++++++++---- zerver/views/auth.py | 19 +++++----------- zerver/views/registration.py | 31 ------------------------- zproject/backends.py | 36 ++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 49 deletions(-) diff --git a/zerver/tests/test_auth_backends.py b/zerver/tests/test_auth_backends.py index 8dc561b49c..6ad8beaf48 100644 --- a/zerver/tests/test_auth_backends.py +++ b/zerver/tests/test_auth_backends.py @@ -3656,8 +3656,11 @@ class SAMLAuthBackendTest(SocialAuthBase): self.logger_output("Returning role owner for user creation", type="info"), ) - self.assertIn( - f"INFO:root:Syncing groups post-registration for new user {user_profile.id} in realm {realm.id}", + prereg_user = PreregistrationUser.objects.last() + assert prereg_user is not None + self.assertEqual( + f"INFO:root:Synced user groups for PreregistrationUser {prereg_user.id} in {realm.id}: " + '{"testgroup1": true, "testgroup2": false}. Final groups set: {\'testgroup1\'}', mock_root_logger.output[0], ) @@ -3735,8 +3738,11 @@ class SAMLAuthBackendTest(SocialAuthBase): ) ) - self.assertIn( - f"INFO:root:Syncing groups post-registration for new user {user_profile.id} in realm {realm.id}", + prereg_user = PreregistrationUser.objects.last() + assert prereg_user is not None + self.assertEqual( + f"INFO:root:Synced user groups for PreregistrationUser {prereg_user.id} in {realm.id}: " + '{"testgroup1": true, "testgroup2": false}. Final groups set: {\'testgroup1\'}', mock_root_logger.output[0], ) diff --git a/zerver/views/auth.py b/zerver/views/auth.py index 2d11c1c0a5..deaaf4cc9f 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -103,6 +103,7 @@ from zproject.backends import ( ldap_auth_enabled, password_auth_enabled, saml_auth_enabled, + sync_groups_for_prereg_user, validate_otp_params, ) @@ -232,18 +233,6 @@ def maybe_send_to_registration( expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS, ) - if group_memberships_sync_map: - set_expirable_session_var( - request.session, - "registration_group_memberships_sync_map", - orjson.dumps(group_memberships_sync_map).decode(), - expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS, - ) - elif "registration_group_memberships_sync_map" in request.session: # nocoverage - # Ensure it isn't possible to leak this state across - # registration attempts. - del request.session["registration_group_memberships_sync_map"] - try: # TODO: This should use get_realm_from_request, but a bunch of tests # rely on mocking get_subdomain here, so they'll need to be tweaked first. @@ -327,8 +316,10 @@ def maybe_send_to_registration( if streams_to_subscribe: prereg_user.streams.set(streams_to_subscribe) - if user_groups: - prereg_user.groups.set(user_groups) + if user_groups or group_memberships_sync_map: + prereg_user.groups.set(user_groups or []) + if group_memberships_sync_map: + sync_groups_for_prereg_user(prereg_user, group_memberships_sync_map) if include_realm_default_subscriptions is not None: prereg_user.include_realm_default_subscriptions = include_realm_default_subscriptions diff --git a/zerver/views/registration.py b/zerver/views/registration.py index 56071344da..f804305859 100644 --- a/zerver/views/registration.py +++ b/zerver/views/registration.py @@ -135,7 +135,6 @@ from zproject.backends import ( get_external_method_dicts, ldap_auth_enabled, password_auth_enabled, - sync_groups, ) logger = logging.getLogger("") @@ -843,8 +842,6 @@ def registration_helper( # duplicate email address. Redirect them to the login # form. return redirect_to_email_login_url(email) - else: - sync_groups_post_registration(request=request, user_profile=user_profile) if realm_creation: # Because for realm creation, registration happens on the @@ -917,34 +914,6 @@ def registration_helper( return TemplateResponse(request, "zerver/register.html", context=context) -def sync_groups_post_registration(request: HttpRequest, user_profile: UserProfile) -> None: - group_memberships_sync_data = get_expirable_session_var( - request.session, "registration_group_memberships_sync_map", delete=True - ) - if not group_memberships_sync_data: - return - - group_memberships_sync_map = orjson.loads(group_memberships_sync_data) - if group_memberships_sync_map: - logger.info( - "Syncing groups post-registration for new user %s in realm %s: %s", - user_profile.id, - user_profile.realm_id, - group_memberships_sync_map, - ) - assert isinstance(group_memberships_sync_map, dict) - sync_groups( - all_group_names=set(group_memberships_sync_map.keys()), - intended_group_names={ - group_name - for group_name, is_member in group_memberships_sync_map.items() - if is_member - }, - user_profile=user_profile, - logger=logger, - ) - - def login_and_go_to_home(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: mobile_flow_otp = get_expirable_session_var( request.session, "registration_mobile_flow_otp", delete=True diff --git a/zproject/backends.py b/zproject/backends.py index 5ecbc0c8dd..b2faab40e1 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -1803,6 +1803,42 @@ def redirect_deactivated_user_to_login(realm: Realm, email: str) -> HttpResponse return HttpResponseRedirect(redirect_url) +@transaction.atomic(savepoint=False) +def sync_groups_for_prereg_user( + prereg_user: PreregistrationUser, group_memberships_sync_map: dict[str, bool] +) -> None: + assert prereg_user.realm is not None + realm = prereg_user.realm + + group_names_to_ensure_member = [ + group_name for group_name, is_member in group_memberships_sync_map.items() if is_member + ] + group_names_to_ensure_not_member = [ + group_name for group_name, is_member in group_memberships_sync_map.items() if not is_member + ] + + groups_to_ensure_member = list( + NamedUserGroup.objects.filter(realm=realm, name__in=group_names_to_ensure_member) + ) + groups_to_ensure_not_member = list( + NamedUserGroup.objects.filter(realm=realm, name__in=group_names_to_ensure_not_member) + ) + + prereg_user.groups.add(*groups_to_ensure_member) + prereg_user.groups.remove(*groups_to_ensure_not_member) + + final_group_names = set(prereg_user.groups.all().values_list("name", flat=True)) + + stringified_dict = json.dumps(group_memberships_sync_map, sort_keys=True) + logging.info( + "Synced user groups for PreregistrationUser %s in %s: %s. Final groups set: %s", + prereg_user.id, + realm.id, + stringified_dict, + final_group_names, + ) + + def sync_groups( all_group_names: set[str], intended_group_names: set[str],