user_groups: Add function to get union of members of two groups.

This helps us important database queries when we want to perform a union
on the members of multiple user groups.
This commit is contained in:
Shubham Padia
2025-02-06 20:50:14 +00:00
committed by Tim Abbott
parent 121af1c815
commit 33ea2b366e
5 changed files with 62 additions and 22 deletions

View File

@@ -24,7 +24,7 @@ from zerver.lib.string_validation import check_stream_name
from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import AnonymousSettingGroupDict, APIStreamDict
from zerver.lib.user_groups import (
get_recursive_group_members,
get_recursive_group_members_union_for_groups,
get_recursive_membership_groups,
get_role_based_system_groups_dict,
user_has_permission_for_group_setting,
@@ -181,17 +181,11 @@ def get_default_values_for_stream_permission_group_settings(
def get_user_ids_with_metadata_access_via_permission_groups(stream: Stream) -> set[int]:
stream_admin_user_ids = set(
get_recursive_group_members(stream.can_administer_channel_group_id).values_list(
"id", flat=True
)
return set(
get_recursive_group_members_union_for_groups(
[stream.can_add_subscribers_group_id, stream.can_administer_channel_group_id]
).values_list("id", flat=True)
)
stream_add_subscribers_group_user_ids = set(
get_recursive_group_members(stream.can_add_subscribers_group_id).values_list(
"id", flat=True
)
)
return stream_admin_user_ids | stream_add_subscribers_group_user_ids
@transaction.atomic(savepoint=False)

View File

@@ -669,9 +669,9 @@ def get_direct_memberships_of_users(user_group: UserGroup, members: list[UserPro
# https://code.djangoproject.com/ticket/28919
def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]:
def get_recursive_subgroups_union_for_groups(user_group_ids: list[int]) -> QuerySet[UserGroup]:
cte = With.recursive(
lambda cte: UserGroup.objects.filter(id=user_group_id)
lambda cte: UserGroup.objects.filter(id__in=user_group_ids)
.values(group_id=F("id"))
.union(
cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))
@@ -680,6 +680,10 @@ def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]:
return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte)
def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]:
return get_recursive_subgroups_union_for_groups([user_group_id])
def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserGroup]:
# Same as get_recursive_subgroups but does not include the
# user_group passed.
@@ -695,8 +699,15 @@ def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserG
def get_recursive_group_members(user_group_id: int) -> QuerySet[UserProfile]:
return get_recursive_group_members_union_for_groups([user_group_id])
def get_recursive_group_members_union_for_groups(
user_group_ids: list[int],
) -> QuerySet[UserProfile]:
return UserProfile.objects.filter(
is_active=True, direct_groups__in=get_recursive_subgroups(user_group_id)
is_active=True,
direct_groups__in=get_recursive_subgroups_union_for_groups(user_group_ids),
)

View File

@@ -3248,7 +3248,7 @@ class StreamAdminTest(ZulipTestCase):
are on.
"""
result = self.attempt_unsubscribe_of_principal(
query_count=19,
query_count=16,
target_users=[self.example_user("cordelia")],
is_realm_admin=True,
is_subbed=True,
@@ -3265,7 +3265,7 @@ class StreamAdminTest(ZulipTestCase):
streams you aren't on.
"""
result = self.attempt_unsubscribe_of_principal(
query_count=19,
query_count=16,
target_users=[self.example_user("cordelia")],
is_realm_admin=True,
is_subbed=False,
@@ -5992,7 +5992,7 @@ class SubscriptionAPITest(ZulipTestCase):
# Sends 3 peer-remove events, 2 unsubscribe events
# and 2 stream delete events for private streams.
with (
self.assert_database_query_count(20),
self.assert_database_query_count(19),
self.assert_memcached_count(3),
self.capture_send_event_calls(expected_num_events=7) as events,
):
@@ -6548,7 +6548,7 @@ class SubscriptionAPITest(ZulipTestCase):
)
# Test creating private stream.
with self.assert_database_query_count(50):
with self.assert_database_query_count(48):
self.subscribe_via_post(
self.test_user,
[new_streams[1]],

View File

@@ -40,9 +40,11 @@ from zerver.lib.types import AnonymousSettingGroupDict
from zerver.lib.user_groups import (
get_direct_user_groups,
get_recursive_group_members,
get_recursive_group_members_union_for_groups,
get_recursive_membership_groups,
get_recursive_strict_subgroups,
get_recursive_subgroups,
get_recursive_subgroups_union_for_groups,
get_role_based_system_groups_dict,
get_subgroup_ids,
get_user_group_member_ids,
@@ -249,6 +251,8 @@ class UserGroupTestCase(ZulipTestCase):
iago = self.example_user("iago")
desdemona = self.example_user("desdemona")
shiva = self.example_user("shiva")
aaron = self.example_user("aaron")
prospero = self.example_user("prospero")
leadership_group = check_add_user_group(
realm, "Leadership", [desdemona], acting_user=desdemona
@@ -257,8 +261,14 @@ class UserGroupTestCase(ZulipTestCase):
staff_group = check_add_user_group(realm, "Staff", [iago], acting_user=iago)
GroupGroupMembership.objects.create(supergroup=staff_group, subgroup=leadership_group)
manager_group = check_add_user_group(
realm, "Managers", [aaron, prospero], acting_user=aaron
)
GroupGroupMembership.objects.create(supergroup=manager_group, subgroup=leadership_group)
everyone_group = check_add_user_group(realm, "Everyone", [shiva], acting_user=shiva)
GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group)
GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=manager_group)
self.assertCountEqual(
list(get_recursive_subgroups(leadership_group.id)), [leadership_group.usergroup_ptr]
@@ -273,6 +283,16 @@ class UserGroupTestCase(ZulipTestCase):
leadership_group.usergroup_ptr,
staff_group.usergroup_ptr,
everyone_group.usergroup_ptr,
manager_group.usergroup_ptr,
],
)
self.assertCountEqual(
list(get_recursive_subgroups_union_for_groups([staff_group.id, manager_group.id])),
[
leadership_group.usergroup_ptr,
staff_group.usergroup_ptr,
manager_group.usergroup_ptr,
],
)
@@ -280,28 +300,43 @@ class UserGroupTestCase(ZulipTestCase):
self.assertCountEqual(list(get_recursive_strict_subgroups(staff_group)), [leadership_group])
self.assertCountEqual(
list(get_recursive_strict_subgroups(everyone_group)),
[leadership_group, staff_group],
[leadership_group, staff_group, manager_group],
)
self.assertCountEqual(list(get_recursive_group_members(leadership_group.id)), [desdemona])
self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona, iago])
self.assertCountEqual(
list(get_recursive_group_members(everyone_group.id)), [desdemona, iago, shiva]
list(get_recursive_group_members(everyone_group.id)),
[desdemona, iago, shiva, aaron, prospero],
)
self.assertCountEqual(
list(get_recursive_group_members_union_for_groups([staff_group.id, manager_group.id])),
[iago, desdemona, aaron, prospero],
)
self.assertCountEqual(
list(
get_recursive_group_members_union_for_groups([leadership_group.id, staff_group.id])
),
[desdemona, iago],
)
self.assertIn(leadership_group.usergroup_ptr, get_recursive_membership_groups(desdemona))
self.assertIn(staff_group.usergroup_ptr, get_recursive_membership_groups(desdemona))
self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(desdemona))
self.assertIn(manager_group.usergroup_ptr, get_recursive_membership_groups(desdemona))
self.assertIn(staff_group.usergroup_ptr, get_recursive_membership_groups(iago))
self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(iago))
self.assertNotIn(manager_group.usergroup_ptr, get_recursive_membership_groups(iago))
self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(shiva))
do_deactivate_user(iago, acting_user=None)
self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona])
self.assertCountEqual(
list(get_recursive_group_members(everyone_group.id)), [desdemona, shiva]
list(get_recursive_group_members(everyone_group.id)),
[desdemona, shiva, aaron, prospero],
)
def test_subgroups_of_role_based_system_groups(self) -> None:

View File

@@ -1021,7 +1021,7 @@ class QueryCountTest(ZulipTestCase):
prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com")
with (
self.assert_database_query_count(93),
self.assert_database_query_count(87),
self.assert_memcached_count(19),
self.capture_send_event_calls(expected_num_events=10) as events,
):