diff --git a/zerver/lib/subscription_info.py b/zerver/lib/subscription_info.py index a62371b9e7..31520f166c 100644 --- a/zerver/lib/subscription_info.py +++ b/zerver/lib/subscription_info.py @@ -466,7 +466,6 @@ def bulk_get_subscriber_user_ids( stream_dicts: Collection[Mapping[str, Any]], user_profile: UserProfile, subscribed_stream_ids: set[int], - partial_subscribers: bool = False, ) -> dict[int, list[int]]: """sub_dict maps stream_id => whether the user is subscribed to that stream.""" target_stream_dicts = [] @@ -534,20 +533,6 @@ def bulk_get_subscriber_user_ids( stream_id = recip_to_stream_id[recip_id] result[stream_id] = list(user_profile_ids) - # Eventually this will return (at minimum): - # (1) if we’re in a channel view, which users are subscribed to the - # current channel - # (2) subscriptions for all bots - # - # For now, we're only doing (2). - if partial_subscribers: - bot_users = set( - UserProfile.objects.filter( - is_bot=True, realm=user_profile.realm, is_active=True - ).values_list("id", flat=True) - ) - for stream_id, users in result.items(): - result[stream_id] = [user_id for user_id in users if user_id in bot_users] return result @@ -824,19 +809,42 @@ def gather_subscriptions_helper( all_stream_dicts, user_profile, subscribed_stream_ids, - include_subscribers == "partial", ) + # Eventually "partial subscribers" will return (at minimum): + # (1) all subscriptions for recently active users (for the buddy list) + # (2) subscriptions for all bots + # + # For now, we're only doing (2). + send_partial_subscribers = include_subscribers == "partial" + partial_subscriber_map: dict[int, list[int]] = dict() + if send_partial_subscribers: + bot_users = set( + UserProfile.objects.filter( + is_bot=True, realm=user_profile.realm, is_active=True + ).values_list("id", flat=True) + ) + for stream_id, users in subscriber_map.items(): + partial_subscribers = [user_id for user_id in users if user_id in bot_users] + if len(partial_subscribers) != len(subscriber_map[stream_id]): + partial_subscriber_map[stream_id] = partial_subscribers + for lst in [subscribed, unsubscribed]: for stream_dict in lst: assert isinstance(stream_dict["stream_id"], int) stream_id = stream_dict["stream_id"] - stream_dict["subscribers"] = subscriber_map[stream_id] + if send_partial_subscribers and partial_subscriber_map.get(stream_id) is not None: + stream_dict["partial_subscribers"] = partial_subscriber_map[stream_id] + else: + stream_dict["subscribers"] = subscriber_map[stream_id] for slim_stream_dict in never_subscribed: assert isinstance(slim_stream_dict["stream_id"], int) stream_id = slim_stream_dict["stream_id"] - slim_stream_dict["subscribers"] = subscriber_map[stream_id] + if send_partial_subscribers and partial_subscriber_map.get(stream_id) is not None: + slim_stream_dict["partial_subscribers"] = partial_subscriber_map[stream_id] + else: + slim_stream_dict["subscribers"] = subscriber_map[stream_id] subscribed.sort(key=lambda x: x["name"]) unsubscribed.sort(key=lambda x: x["name"]) diff --git a/zerver/lib/types.py b/zerver/lib/types.py index 37bb7a9f7d..29e2f566bf 100644 --- a/zerver/lib/types.py +++ b/zerver/lib/types.py @@ -233,6 +233,7 @@ class SubscriptionStreamDict(TypedDict): stream_post_policy: int stream_weekly_traffic: int | None subscribers: NotRequired[list[int]] + partial_subscribers: NotRequired[list[int]] wildcard_mentions_notify: bool | None @@ -259,6 +260,7 @@ class NeverSubscribedStreamDict(TypedDict): stream_post_policy: int stream_weekly_traffic: int | None subscribers: NotRequired[list[int]] + partial_subscribers: NotRequired[list[int]] class DefaultStreamDict(TypedDict): diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 16612d1d06..8c74892d42 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -7667,6 +7667,80 @@ class GetSubscribersTest(ZulipTestCase): stream_name = gather_subscriptions(self.user_profile)[0][0]["name"] self.make_successful_subscriber_request(stream_name) + def test_gather_partial_subscriptions(self) -> None: + othello = self.example_user("othello") + bot = self.create_test_bot("bot", othello, "Foo Bot") + + stream_names = [ + "never_subscribed_only_bots", + "never_subscribed_more_than_bots", + "unsubscribed_only_bots", + "subscribed_more_than_bots", + ] + for stream_name in stream_names: + self.make_stream(stream_name) + + self.subscribe_via_post( + self.user_profile, + ["never_subscribed_only_bots"], + dict(principals=orjson.dumps([bot.id]).decode()), + ) + self.subscribe_via_post( + self.user_profile, + ["never_subscribed_more_than_bots"], + dict(principals=orjson.dumps([bot.id, othello.id]).decode()), + ) + self.subscribe_via_post( + self.user_profile, + ["unsubscribed_only_bots"], + dict(principals=orjson.dumps([bot.id, self.user_profile.id]).decode()), + ) + self.unsubscribe( + self.user_profile, + "unsubscribed_only_bots", + ) + self.subscribe_via_post( + self.user_profile, + ["subscribed_more_than_bots"], + dict(principals=orjson.dumps([bot.id, othello.id, self.user_profile.id]).decode()), + ) + + with self.assert_database_query_count(10): + sub_data = gather_subscriptions_helper(self.user_profile, include_subscribers="partial") + never_subscribed_streams = sub_data.never_subscribed + unsubscribed_streams = sub_data.unsubscribed + subscribed_streams = sub_data.subscriptions + self.assertGreaterEqual(len(never_subscribed_streams), 2) + self.assertGreaterEqual(len(unsubscribed_streams), 1) + self.assertGreaterEqual(len(subscribed_streams), 1) + + # Streams with only bots have sent all of their subscribers, + # since we always send bots. We tell the client it doesn't + # need to fetch more, by filling "subscribers" instead + # of "partial_subscribers". If there are non-bot subscribers, + # a partial fetch will return only partial subscribers. + + for sub in never_subscribed_streams: + if sub["name"] == "never_subscribed_only_bots": + self.assert_length(sub["subscribers"], 1) + self.assertIsNone(sub.get("partial_subscribers")) + continue + if sub["name"] == "never_subscribed_more_than_bots": + self.assert_length(sub["partial_subscribers"], 1) + self.assertIsNone(sub.get("subscribers")) + + for sub in unsubscribed_streams: + if sub["name"] == "unsubscribed_only_bots": + self.assert_length(sub["subscribers"], 1) + self.assertIsNone(sub.get("partial_subscribers")) + break + + for sub in subscribed_streams: + if sub["name"] == "subscribed_more_than_bots": + self.assert_length(sub["partial_subscribers"], 1) + self.assertIsNone(sub.get("subscribers")) + break + def test_gather_subscriptions(self) -> None: """ gather_subscriptions returns correct results with only 3 queries @@ -7678,7 +7752,6 @@ class GetSubscribersTest(ZulipTestCase): cordelia = self.example_user("cordelia") othello = self.example_user("othello") polonius = self.example_user("polonius") - bot = self.create_test_bot("bot", cordelia, "Foo Bot") realm = hamlet.realm stream_names = [f"stream_{i}" for i in range(10)] @@ -7689,7 +7762,6 @@ class GetSubscribersTest(ZulipTestCase): othello.id, cordelia.id, polonius.id, - bot.id, ] with self.assert_database_query_count(50): @@ -7774,17 +7846,6 @@ class GetSubscribersTest(ZulipTestCase): acting_user=hamlet, ) - # Test partial subscribers - with self.assert_database_query_count(10): - sub_data = gather_subscriptions_helper(self.user_profile, include_subscribers="partial") - subscribed_streams = sub_data.subscriptions - self.assertGreaterEqual(len(subscribed_streams), 11) - for sub in subscribed_streams: - if not sub["name"].startswith("stream_"): - continue - # Just the bot - self.assert_length(sub["subscribers"], 1) - with self.assert_database_query_count(9): subscribed_streams, _ = gather_subscriptions( self.user_profile, include_subscribers=True