stream: Pass group id to get recursive group members.

Previously, we needed to pass the group to the function, which sometimes
meant having 1 extra query to fetch the user group when we just needed
the group id for this function.
This commit is contained in:
Shubham Padia
2025-02-06 20:11:58 +00:00
committed by Tim Abbott
parent 35f9305acb
commit 121af1c815
3 changed files with 21 additions and 17 deletions

View File

@@ -182,12 +182,14 @@ def get_default_values_for_stream_permission_group_settings(
def get_user_ids_with_metadata_access_via_permission_groups(stream: Stream) -> set[int]: def get_user_ids_with_metadata_access_via_permission_groups(stream: Stream) -> set[int]:
stream_admin_user_ids = set( stream_admin_user_ids = set(
get_recursive_group_members(stream.can_administer_channel_group).values_list( get_recursive_group_members(stream.can_administer_channel_group_id).values_list(
"id", flat=True "id", flat=True
) )
) )
stream_add_subscribers_group_user_ids = set( stream_add_subscribers_group_user_ids = set(
get_recursive_group_members(stream.can_add_subscribers_group).values_list("id", flat=True) get_recursive_group_members(stream.can_add_subscribers_group_id).values_list(
"id", flat=True
)
) )
return stream_admin_user_ids | stream_add_subscribers_group_user_ids return stream_admin_user_ids | stream_add_subscribers_group_user_ids

View File

@@ -669,9 +669,9 @@ def get_direct_memberships_of_users(user_group: UserGroup, members: list[UserPro
# https://code.djangoproject.com/ticket/28919 # https://code.djangoproject.com/ticket/28919
def get_recursive_subgroups(user_group: UserGroup) -> QuerySet[UserGroup]: def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]:
cte = With.recursive( cte = With.recursive(
lambda cte: UserGroup.objects.filter(id=user_group.id) lambda cte: UserGroup.objects.filter(id=user_group_id)
.values(group_id=F("id")) .values(group_id=F("id"))
.union( .union(
cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")) cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))
@@ -694,9 +694,9 @@ def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserG
return cte.join(NamedUserGroup, id=cte.col.group_id).with_cte(cte) return cte.join(NamedUserGroup, id=cte.col.group_id).with_cte(cte)
def get_recursive_group_members(user_group: UserGroup) -> QuerySet[UserProfile]: def get_recursive_group_members(user_group_id: int) -> QuerySet[UserProfile]:
return UserProfile.objects.filter( return UserProfile.objects.filter(
is_active=True, direct_groups__in=get_recursive_subgroups(user_group) is_active=True, direct_groups__in=get_recursive_subgroups(user_group_id)
) )
@@ -728,7 +728,7 @@ def is_user_in_group(
if direct_member_only: if direct_member_only:
return get_user_group_direct_members(user_group=user_group).filter(id=user.id).exists() return get_user_group_direct_members(user_group=user_group).filter(id=user.id).exists()
return get_recursive_group_members(user_group=user_group).filter(id=user.id).exists() return get_recursive_group_members(user_group_id=user_group.id).filter(id=user.id).exists()
def is_any_user_in_group( def is_any_user_in_group(
@@ -737,7 +737,7 @@ def is_any_user_in_group(
if direct_member_only: if direct_member_only:
return get_user_group_direct_members(user_group=user_group).filter(id__in=user_ids).exists() return get_user_group_direct_members(user_group=user_group).filter(id__in=user_ids).exists()
return get_recursive_group_members(user_group=user_group).filter(id__in=user_ids).exists() return get_recursive_group_members(user_group_id=user_group.id).filter(id__in=user_ids).exists()
def get_user_group_member_ids( def get_user_group_member_ids(
@@ -746,7 +746,7 @@ def get_user_group_member_ids(
if direct_member_only: if direct_member_only:
member_ids: Iterable[int] = get_user_group_direct_member_ids(user_group) member_ids: Iterable[int] = get_user_group_direct_member_ids(user_group)
else: else:
member_ids = get_recursive_group_members(user_group).values_list("id", flat=True) member_ids = get_recursive_group_members(user_group.id).values_list("id", flat=True)
return list(member_ids) return list(member_ids)

View File

@@ -261,14 +261,14 @@ class UserGroupTestCase(ZulipTestCase):
GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group) GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group)
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(leadership_group)), [leadership_group.usergroup_ptr] list(get_recursive_subgroups(leadership_group.id)), [leadership_group.usergroup_ptr]
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(staff_group)), list(get_recursive_subgroups(staff_group.id)),
[leadership_group.usergroup_ptr, staff_group.usergroup_ptr], [leadership_group.usergroup_ptr, staff_group.usergroup_ptr],
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(everyone_group)), list(get_recursive_subgroups(everyone_group.id)),
[ [
leadership_group.usergroup_ptr, leadership_group.usergroup_ptr,
staff_group.usergroup_ptr, staff_group.usergroup_ptr,
@@ -283,10 +283,10 @@ class UserGroupTestCase(ZulipTestCase):
[leadership_group, staff_group], [leadership_group, staff_group],
) )
self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona]) self.assertCountEqual(list(get_recursive_group_members(leadership_group.id)), [desdemona])
self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona, iago]) self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona, iago])
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_group_members(everyone_group)), [desdemona, iago, shiva] list(get_recursive_group_members(everyone_group.id)), [desdemona, iago, shiva]
) )
self.assertIn(leadership_group.usergroup_ptr, get_recursive_membership_groups(desdemona)) self.assertIn(leadership_group.usergroup_ptr, get_recursive_membership_groups(desdemona))
@@ -299,8 +299,10 @@ class UserGroupTestCase(ZulipTestCase):
self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(shiva)) self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(shiva))
do_deactivate_user(iago, acting_user=None) do_deactivate_user(iago, acting_user=None)
self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona]) self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona])
self.assertCountEqual(list(get_recursive_group_members(everyone_group)), [desdemona, shiva]) self.assertCountEqual(
list(get_recursive_group_members(everyone_group.id)), [desdemona, shiva]
)
def test_subgroups_of_role_based_system_groups(self) -> None: def test_subgroups_of_role_based_system_groups(self) -> None:
realm = get_realm("zulip") realm = get_realm("zulip")