From 208ee1b8d98f0630adf60d927fd8fadb0921cedf Mon Sep 17 00:00:00 2001 From: Shubham Padia Date: Thu, 20 Mar 2025 08:29:51 +0000 Subject: [PATCH] streams: Add bulk_access_stream_metadata_user_ids. This function will be useful in sending events for users gaining or losing metadata access when the members of a user group change in any way. --- zerver/lib/stream_subscription.py | 30 +++-- zerver/lib/streams.py | 43 +++++++ zerver/tests/test_subs.py | 195 ++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 10 deletions(-) diff --git a/zerver/lib/stream_subscription.py b/zerver/lib/stream_subscription.py index c46cb837df..7cdb83c66b 100644 --- a/zerver/lib/stream_subscription.py +++ b/zerver/lib/stream_subscription.py @@ -141,16 +141,14 @@ def num_subscribers_for_stream_id(stream_id: int) -> int: ).count() -def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]: - all_subs = ( - get_active_subscriptions_for_stream_ids(stream_ids) - .values( - "recipient__type_id", - "user_profile_id", - ) - .order_by( - "recipient__type_id", - ) +def get_user_ids_for_stream_query( + query: QuerySet[Subscription, Subscription], +) -> dict[int, set[int]]: + all_subs = query.values( + "recipient__type_id", + "user_profile_id", + ).order_by( + "recipient__type_id", ) get_stream_id = itemgetter("recipient__type_id") @@ -163,6 +161,18 @@ def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]: return result +def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]: + return get_user_ids_for_stream_query(get_active_subscriptions_for_stream_ids(stream_ids)) + + +def get_guest_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]: + return get_user_ids_for_stream_query( + get_active_subscriptions_for_stream_ids(stream_ids).filter( + user_profile__role=UserProfile.ROLE_GUEST + ) + ) + + def get_users_for_streams(stream_ids: set[int]) -> dict[int, set[UserProfile]]: all_subs = ( get_active_subscriptions_for_stream_ids(stream_ids) diff --git a/zerver/lib/streams.py b/zerver/lib/streams.py index 99281a4f42..d618bd88f5 100644 --- a/zerver/lib/streams.py +++ b/zerver/lib/streams.py @@ -18,7 +18,9 @@ from zerver.lib.exceptions import ( ) from zerver.lib.stream_subscription import ( get_active_subscriptions_for_stream_id, + get_guest_user_ids_for_streams, get_subscribed_stream_ids_for_user, + get_user_ids_for_streams, ) from zerver.lib.stream_traffic import get_average_weekly_stream_traffic, get_streams_traffic from zerver.lib.string_validation import check_stream_name @@ -993,6 +995,47 @@ def can_access_stream_metadata_user_ids(stream: Stream) -> set[int]: ) +def bulk_can_access_stream_metadata_user_ids(streams: list[Stream]) -> dict[int, set[int]]: + # return user ids of users who can access the attributes of a + # stream, such as its name/description. Useful for sending events + # to all users with access to a stream's attributes. + result: dict[int, set[int]] = {} + public_streams = [] + private_streams = [] + for stream in streams: + if stream.is_public(): + public_streams.append(stream) + else: + private_streams.append(stream) + + if len(public_streams) > 0: + guest_subscriptions = get_guest_user_ids_for_streams( + {stream.id for stream in public_streams} + ) + active_non_guest_user_id_set = set(active_non_guest_user_ids(public_streams[0].realm_id)) + for stream in public_streams: + result[stream.id] = set(active_non_guest_user_id_set | guest_subscriptions[stream.id]) + + if len(private_streams) > 0: + private_stream_user_ids = get_user_ids_for_streams( + {stream.id for stream in private_streams} + ) + admin_users_and_bots = {user.id for user in stream.realm.get_admin_users_and_bots()} + users_dict_with_metadata_access_to_streams_via_permission_groups = ( + get_users_dict_with_metadata_access_to_streams_via_permission_groups( + private_streams, private_streams[0].realm_id + ) + ) + for stream in private_streams: + result[stream.id] = ( + private_stream_user_ids[stream.id] + | admin_users_and_bots + | users_dict_with_metadata_access_to_streams_via_permission_groups[stream.id] + ) + + return result + + def can_access_stream_history(user_profile: UserProfile, stream: Stream) -> bool: """Determine whether the provided user is allowed to access the history of the target stream. diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 230e28ee89..694113cc41 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -74,6 +74,7 @@ from zerver.lib.streams import ( StreamsCategorizedByPermissionsForAddingSubscribers, access_stream_by_id, access_stream_by_name, + bulk_can_access_stream_metadata_user_ids, can_access_stream_history, can_access_stream_metadata_user_ids, create_stream_if_needed, @@ -8777,6 +8778,200 @@ class AccessStreamTest(ZulipTestCase): True, ) + def test_can_access_stream_metadata_user_ids(self) -> None: + aaron = self.example_user("aaron") + cordelia = self.example_user("cordelia") + guest_user = self.example_user("polonius") + iago = self.example_user("iago") + desdemona = self.example_user("desdemona") + realm = aaron.realm + public_stream = self.make_stream("public_stream", realm, invite_only=False) + nobody_system_group = NamedUserGroup.objects.get( + name="role:nobody", realm=realm, is_system_group=True + ) + + # Public stream with no subscribers. + expected_public_user_ids = set(active_non_guest_user_ids(realm.id)) + self.assertCountEqual( + can_access_stream_metadata_user_ids(public_stream), expected_public_user_ids + ) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + + # Public stream with 1 guest as a subscriber. + self.subscribe(guest_user, "public_stream") + expected_public_user_ids.add(guest_user.id) + self.assertCountEqual( + can_access_stream_metadata_user_ids(public_stream), expected_public_user_ids + ) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + + test_bot = self.create_test_bot("foo", desdemona) + expected_public_user_ids.add(test_bot.id) + private_stream = self.make_stream("private_stream", realm, invite_only=True) + # Nobody is subscribed yet for the private stream, only admin + # users will turn up for that stream. We will continue testing + # the existing public stream for the bulk function here on. + expected_private_user_ids = {iago.id, desdemona.id} + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + # Bot with admin privileges should also be part of the result. + do_change_user_role(test_bot, UserProfile.ROLE_REALM_ADMINISTRATOR, acting_user=desdemona) + expected_private_user_ids.add(test_bot.id) + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + # Subscriber should also be part of the result. + self.subscribe(aaron, "private_stream") + expected_private_user_ids.add(aaron.id) + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + stream_permission_group_settings = set(Stream.stream_permission_group_settings.keys()) + stream_permission_group_settings_not_granting_metadata_access = ( + stream_permission_group_settings + - set(Stream.stream_permission_group_settings_granting_metadata_access) + ) + for setting_name in stream_permission_group_settings_not_granting_metadata_access: + do_change_stream_group_based_setting( + private_stream, + setting_name, + UserGroupMembersData(direct_members=[cordelia.id], direct_subgroups=[]), + acting_user=cordelia, + ) + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + with self.assert_database_query_count(6): + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + for setting_name in Stream.stream_permission_group_settings_granting_metadata_access: + do_change_stream_group_based_setting( + private_stream, + setting_name, + UserGroupMembersData(direct_members=[cordelia.id], direct_subgroups=[]), + acting_user=cordelia, + ) + expected_private_user_ids.add(cordelia.id) + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + with self.assert_database_query_count(6): + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + do_change_stream_group_based_setting( + private_stream, setting_name, nobody_system_group, acting_user=cordelia + ) + expected_private_user_ids.remove(cordelia.id) + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, private_stream] + ) + self.assertCountEqual( + can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + # Query count should not increase on fetching user ids for an + # additional public stream. + public_stream_2 = self.make_stream("public_stream_2", realm, invite_only=False) + with self.assert_database_query_count(6): + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, public_stream_2, private_stream] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream_2.id], + active_non_guest_user_ids(realm.id), + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + + # Query count should not increase on fetching user ids for an + # additional private stream. + private_stream_2 = self.make_stream("private_stream_2", realm, invite_only=True) + self.subscribe(aaron, "private_stream_2") + with self.assert_database_query_count(6): + bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids( + [public_stream, public_stream_2, private_stream, private_stream_2] + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[public_stream_2.id], + active_non_guest_user_ids(realm.id), + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids + ) + self.assertCountEqual( + bulk_access_stream_metadata_user_ids[private_stream_2.id], expected_private_user_ids + ) + class StreamTrafficTest(ZulipTestCase): def test_average_weekly_stream_traffic_calculation(self) -> None: