streams: Add bulk_access_stream_metadata_user_ids.

This function will be useful in sending events for users gaining or
losing metadata access when the members of a user group change in any
way.
This commit is contained in:
Shubham Padia
2025-03-20 08:29:51 +00:00
committed by Tim Abbott
parent 55ea5be022
commit 208ee1b8d9
3 changed files with 258 additions and 10 deletions

View File

@@ -141,17 +141,15 @@ def num_subscribers_for_stream_id(stream_id: int) -> int:
).count() ).count()
def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]: def get_user_ids_for_stream_query(
all_subs = ( query: QuerySet[Subscription, Subscription],
get_active_subscriptions_for_stream_ids(stream_ids) ) -> dict[int, set[int]]:
.values( all_subs = query.values(
"recipient__type_id", "recipient__type_id",
"user_profile_id", "user_profile_id",
) ).order_by(
.order_by(
"recipient__type_id", "recipient__type_id",
) )
)
get_stream_id = itemgetter("recipient__type_id") get_stream_id = itemgetter("recipient__type_id")
@@ -163,6 +161,18 @@ def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]:
return result return result
def get_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]:
return get_user_ids_for_stream_query(get_active_subscriptions_for_stream_ids(stream_ids))
def get_guest_user_ids_for_streams(stream_ids: set[int]) -> dict[int, set[int]]:
return get_user_ids_for_stream_query(
get_active_subscriptions_for_stream_ids(stream_ids).filter(
user_profile__role=UserProfile.ROLE_GUEST
)
)
def get_users_for_streams(stream_ids: set[int]) -> dict[int, set[UserProfile]]: def get_users_for_streams(stream_ids: set[int]) -> dict[int, set[UserProfile]]:
all_subs = ( all_subs = (
get_active_subscriptions_for_stream_ids(stream_ids) get_active_subscriptions_for_stream_ids(stream_ids)

View File

@@ -18,7 +18,9 @@ from zerver.lib.exceptions import (
) )
from zerver.lib.stream_subscription import ( from zerver.lib.stream_subscription import (
get_active_subscriptions_for_stream_id, get_active_subscriptions_for_stream_id,
get_guest_user_ids_for_streams,
get_subscribed_stream_ids_for_user, get_subscribed_stream_ids_for_user,
get_user_ids_for_streams,
) )
from zerver.lib.stream_traffic import get_average_weekly_stream_traffic, get_streams_traffic from zerver.lib.stream_traffic import get_average_weekly_stream_traffic, get_streams_traffic
from zerver.lib.string_validation import check_stream_name from zerver.lib.string_validation import check_stream_name
@@ -993,6 +995,47 @@ def can_access_stream_metadata_user_ids(stream: Stream) -> set[int]:
) )
def bulk_can_access_stream_metadata_user_ids(streams: list[Stream]) -> dict[int, set[int]]:
# return user ids of users who can access the attributes of a
# stream, such as its name/description. Useful for sending events
# to all users with access to a stream's attributes.
result: dict[int, set[int]] = {}
public_streams = []
private_streams = []
for stream in streams:
if stream.is_public():
public_streams.append(stream)
else:
private_streams.append(stream)
if len(public_streams) > 0:
guest_subscriptions = get_guest_user_ids_for_streams(
{stream.id for stream in public_streams}
)
active_non_guest_user_id_set = set(active_non_guest_user_ids(public_streams[0].realm_id))
for stream in public_streams:
result[stream.id] = set(active_non_guest_user_id_set | guest_subscriptions[stream.id])
if len(private_streams) > 0:
private_stream_user_ids = get_user_ids_for_streams(
{stream.id for stream in private_streams}
)
admin_users_and_bots = {user.id for user in stream.realm.get_admin_users_and_bots()}
users_dict_with_metadata_access_to_streams_via_permission_groups = (
get_users_dict_with_metadata_access_to_streams_via_permission_groups(
private_streams, private_streams[0].realm_id
)
)
for stream in private_streams:
result[stream.id] = (
private_stream_user_ids[stream.id]
| admin_users_and_bots
| users_dict_with_metadata_access_to_streams_via_permission_groups[stream.id]
)
return result
def can_access_stream_history(user_profile: UserProfile, stream: Stream) -> bool: def can_access_stream_history(user_profile: UserProfile, stream: Stream) -> bool:
"""Determine whether the provided user is allowed to access the """Determine whether the provided user is allowed to access the
history of the target stream. history of the target stream.

View File

@@ -74,6 +74,7 @@ from zerver.lib.streams import (
StreamsCategorizedByPermissionsForAddingSubscribers, StreamsCategorizedByPermissionsForAddingSubscribers,
access_stream_by_id, access_stream_by_id,
access_stream_by_name, access_stream_by_name,
bulk_can_access_stream_metadata_user_ids,
can_access_stream_history, can_access_stream_history,
can_access_stream_metadata_user_ids, can_access_stream_metadata_user_ids,
create_stream_if_needed, create_stream_if_needed,
@@ -8777,6 +8778,200 @@ class AccessStreamTest(ZulipTestCase):
True, True,
) )
def test_can_access_stream_metadata_user_ids(self) -> None:
aaron = self.example_user("aaron")
cordelia = self.example_user("cordelia")
guest_user = self.example_user("polonius")
iago = self.example_user("iago")
desdemona = self.example_user("desdemona")
realm = aaron.realm
public_stream = self.make_stream("public_stream", realm, invite_only=False)
nobody_system_group = NamedUserGroup.objects.get(
name="role:nobody", realm=realm, is_system_group=True
)
# Public stream with no subscribers.
expected_public_user_ids = set(active_non_guest_user_ids(realm.id))
self.assertCountEqual(
can_access_stream_metadata_user_ids(public_stream), expected_public_user_ids
)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
# Public stream with 1 guest as a subscriber.
self.subscribe(guest_user, "public_stream")
expected_public_user_ids.add(guest_user.id)
self.assertCountEqual(
can_access_stream_metadata_user_ids(public_stream), expected_public_user_ids
)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
test_bot = self.create_test_bot("foo", desdemona)
expected_public_user_ids.add(test_bot.id)
private_stream = self.make_stream("private_stream", realm, invite_only=True)
# Nobody is subscribed yet for the private stream, only admin
# users will turn up for that stream. We will continue testing
# the existing public stream for the bulk function here on.
expected_private_user_ids = {iago.id, desdemona.id}
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
# Bot with admin privileges should also be part of the result.
do_change_user_role(test_bot, UserProfile.ROLE_REALM_ADMINISTRATOR, acting_user=desdemona)
expected_private_user_ids.add(test_bot.id)
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
# Subscriber should also be part of the result.
self.subscribe(aaron, "private_stream")
expected_private_user_ids.add(aaron.id)
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
stream_permission_group_settings = set(Stream.stream_permission_group_settings.keys())
stream_permission_group_settings_not_granting_metadata_access = (
stream_permission_group_settings
- set(Stream.stream_permission_group_settings_granting_metadata_access)
)
for setting_name in stream_permission_group_settings_not_granting_metadata_access:
do_change_stream_group_based_setting(
private_stream,
setting_name,
UserGroupMembersData(direct_members=[cordelia.id], direct_subgroups=[]),
acting_user=cordelia,
)
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
with self.assert_database_query_count(6):
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
for setting_name in Stream.stream_permission_group_settings_granting_metadata_access:
do_change_stream_group_based_setting(
private_stream,
setting_name,
UserGroupMembersData(direct_members=[cordelia.id], direct_subgroups=[]),
acting_user=cordelia,
)
expected_private_user_ids.add(cordelia.id)
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
with self.assert_database_query_count(6):
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
do_change_stream_group_based_setting(
private_stream, setting_name, nobody_system_group, acting_user=cordelia
)
expected_private_user_ids.remove(cordelia.id)
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, private_stream]
)
self.assertCountEqual(
can_access_stream_metadata_user_ids(private_stream), expected_private_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
# Query count should not increase on fetching user ids for an
# additional public stream.
public_stream_2 = self.make_stream("public_stream_2", realm, invite_only=False)
with self.assert_database_query_count(6):
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, public_stream_2, private_stream]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream_2.id],
active_non_guest_user_ids(realm.id),
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
# Query count should not increase on fetching user ids for an
# additional private stream.
private_stream_2 = self.make_stream("private_stream_2", realm, invite_only=True)
self.subscribe(aaron, "private_stream_2")
with self.assert_database_query_count(6):
bulk_access_stream_metadata_user_ids = bulk_can_access_stream_metadata_user_ids(
[public_stream, public_stream_2, private_stream, private_stream_2]
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream.id], expected_public_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[public_stream_2.id],
active_non_guest_user_ids(realm.id),
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream.id], expected_private_user_ids
)
self.assertCountEqual(
bulk_access_stream_metadata_user_ids[private_stream_2.id], expected_private_user_ids
)
class StreamTrafficTest(ZulipTestCase): class StreamTrafficTest(ZulipTestCase):
def test_average_weekly_stream_traffic_calculation(self) -> None: def test_average_weekly_stream_traffic_calculation(self) -> None: