saml: Don't put group_memberships_sync_map in the session.

In 40956ae4c5 we implemented group sync
via SAML during sign in and sign up. The sign up implementation used a
session variable group_memberships_sync_map to plumb through the sync
information to the registration codepath, to execute group sync after
user creation.

We can use a more robust approach instead, and just amend groups on the
`PreregistrationUser` object that's going to be used for registration.
This commit is contained in:
Mateusz Mandera
2025-07-11 16:00:05 +08:00
committed by Tim Abbott
parent 08e9853850
commit 99b97ea883
4 changed files with 51 additions and 49 deletions

View File

@@ -3656,8 +3656,11 @@ class SAMLAuthBackendTest(SocialAuthBase):
self.logger_output("Returning role owner for user creation", type="info"),
)
self.assertIn(
f"INFO:root:Syncing groups post-registration for new user {user_profile.id} in realm {realm.id}",
prereg_user = PreregistrationUser.objects.last()
assert prereg_user is not None
self.assertEqual(
f"INFO:root:Synced user groups for PreregistrationUser {prereg_user.id} in {realm.id}: "
'{"testgroup1": true, "testgroup2": false}. Final groups set: {\'testgroup1\'}',
mock_root_logger.output[0],
)
@@ -3735,8 +3738,11 @@ class SAMLAuthBackendTest(SocialAuthBase):
)
)
self.assertIn(
f"INFO:root:Syncing groups post-registration for new user {user_profile.id} in realm {realm.id}",
prereg_user = PreregistrationUser.objects.last()
assert prereg_user is not None
self.assertEqual(
f"INFO:root:Synced user groups for PreregistrationUser {prereg_user.id} in {realm.id}: "
'{"testgroup1": true, "testgroup2": false}. Final groups set: {\'testgroup1\'}',
mock_root_logger.output[0],
)

View File

@@ -103,6 +103,7 @@ from zproject.backends import (
ldap_auth_enabled,
password_auth_enabled,
saml_auth_enabled,
sync_groups_for_prereg_user,
validate_otp_params,
)
@@ -232,18 +233,6 @@ def maybe_send_to_registration(
expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS,
)
if group_memberships_sync_map:
set_expirable_session_var(
request.session,
"registration_group_memberships_sync_map",
orjson.dumps(group_memberships_sync_map).decode(),
expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS,
)
elif "registration_group_memberships_sync_map" in request.session: # nocoverage
# Ensure it isn't possible to leak this state across
# registration attempts.
del request.session["registration_group_memberships_sync_map"]
try:
# TODO: This should use get_realm_from_request, but a bunch of tests
# rely on mocking get_subdomain here, so they'll need to be tweaked first.
@@ -327,8 +316,10 @@ def maybe_send_to_registration(
if streams_to_subscribe:
prereg_user.streams.set(streams_to_subscribe)
if user_groups:
prereg_user.groups.set(user_groups)
if user_groups or group_memberships_sync_map:
prereg_user.groups.set(user_groups or [])
if group_memberships_sync_map:
sync_groups_for_prereg_user(prereg_user, group_memberships_sync_map)
if include_realm_default_subscriptions is not None:
prereg_user.include_realm_default_subscriptions = include_realm_default_subscriptions

View File

@@ -135,7 +135,6 @@ from zproject.backends import (
get_external_method_dicts,
ldap_auth_enabled,
password_auth_enabled,
sync_groups,
)
logger = logging.getLogger("")
@@ -843,8 +842,6 @@ def registration_helper(
# duplicate email address. Redirect them to the login
# form.
return redirect_to_email_login_url(email)
else:
sync_groups_post_registration(request=request, user_profile=user_profile)
if realm_creation:
# Because for realm creation, registration happens on the
@@ -917,34 +914,6 @@ def registration_helper(
return TemplateResponse(request, "zerver/register.html", context=context)
def sync_groups_post_registration(request: HttpRequest, user_profile: UserProfile) -> None:
group_memberships_sync_data = get_expirable_session_var(
request.session, "registration_group_memberships_sync_map", delete=True
)
if not group_memberships_sync_data:
return
group_memberships_sync_map = orjson.loads(group_memberships_sync_data)
if group_memberships_sync_map:
logger.info(
"Syncing groups post-registration for new user %s in realm %s: %s",
user_profile.id,
user_profile.realm_id,
group_memberships_sync_map,
)
assert isinstance(group_memberships_sync_map, dict)
sync_groups(
all_group_names=set(group_memberships_sync_map.keys()),
intended_group_names={
group_name
for group_name, is_member in group_memberships_sync_map.items()
if is_member
},
user_profile=user_profile,
logger=logger,
)
def login_and_go_to_home(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
mobile_flow_otp = get_expirable_session_var(
request.session, "registration_mobile_flow_otp", delete=True

View File

@@ -1803,6 +1803,42 @@ def redirect_deactivated_user_to_login(realm: Realm, email: str) -> HttpResponse
return HttpResponseRedirect(redirect_url)
@transaction.atomic(savepoint=False)
def sync_groups_for_prereg_user(
prereg_user: PreregistrationUser, group_memberships_sync_map: dict[str, bool]
) -> None:
assert prereg_user.realm is not None
realm = prereg_user.realm
group_names_to_ensure_member = [
group_name for group_name, is_member in group_memberships_sync_map.items() if is_member
]
group_names_to_ensure_not_member = [
group_name for group_name, is_member in group_memberships_sync_map.items() if not is_member
]
groups_to_ensure_member = list(
NamedUserGroup.objects.filter(realm=realm, name__in=group_names_to_ensure_member)
)
groups_to_ensure_not_member = list(
NamedUserGroup.objects.filter(realm=realm, name__in=group_names_to_ensure_not_member)
)
prereg_user.groups.add(*groups_to_ensure_member)
prereg_user.groups.remove(*groups_to_ensure_not_member)
final_group_names = set(prereg_user.groups.all().values_list("name", flat=True))
stringified_dict = json.dumps(group_memberships_sync_map, sort_keys=True)
logging.info(
"Synced user groups for PreregistrationUser %s in %s: %s. Final groups set: %s",
prereg_user.id,
realm.id,
stringified_dict,
final_group_names,
)
def sync_groups(
all_group_names: set[str],
intended_group_names: set[str],