From c04558fe310ccc56bb83a9874ebb44cafebe0777 Mon Sep 17 00:00:00 2001 From: bedo Date: Sat, 10 May 2025 01:02:03 +0300 Subject: [PATCH] stream: Add subscriber_count field. Fixes #34246. Add subscriber_count field to Stream model to track number of non-deactivated users subscribed to the channel. --- .../commands/populate_analytics_db.py | 9 +- zerver/actions/streams.py | 12 ++ zerver/actions/users.py | 6 +- zerver/lib/stream_subscription.py | 103 ++++++++- zerver/lib/test_classes.py | 41 +++- .../0704_stream_subscriber_count.py | 17 ++ ..._stream_subscriber_count_data_migration.py | 56 +++++ zerver/models/streams.py | 5 + zerver/tests/test_invite.py | 4 +- zerver/tests/test_message_send.py | 7 +- zerver/tests/test_populate_db.py | 32 +++ zerver/tests/test_signup.py | 60 +++++- zerver/tests/test_subs.py | 195 ++++++++++++++++-- zerver/tests/test_users.py | 105 +++++++++- zilencer/management/commands/populate_db.py | 13 +- 15 files changed, 628 insertions(+), 37 deletions(-) create mode 100644 zerver/migrations/0704_stream_subscriber_count.py create mode 100644 zerver/migrations/0705_stream_subscriber_count_data_migration.py diff --git a/analytics/management/commands/populate_analytics_db.py b/analytics/management/commands/populate_analytics_db.py index cd36f599e4..a77f269bf2 100644 --- a/analytics/management/commands/populate_analytics_db.py +++ b/analytics/management/commands/populate_analytics_db.py @@ -22,10 +22,11 @@ from zerver.lib.create_user import create_user from zerver.lib.management import ZulipBaseCommand from zerver.lib.storage import static_path from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS +from zerver.lib.stream_subscription import create_stream_subscription from zerver.lib.streams import get_default_values_for_stream_permission_group_settings from zerver.lib.timestamp import floor_to_day from zerver.lib.upload import upload_message_attachment_from_request -from zerver.models import Client, Realm, RealmAuditLog, Recipient, Stream, Subscription, UserProfile +from zerver.models import Client, Realm, RealmAuditLog, Recipient, Stream, UserProfile from zerver.models.groups import NamedUserGroup, SystemGroups, UserGroupMembership from zerver.models.realm_audit_logs import AuditLogEventType @@ -125,10 +126,10 @@ class Command(ZulipBaseCommand): stream.save(update_fields=["recipient"]) # Subscribe shylock to the stream to avoid invariant failures. - Subscription.objects.create( - recipient=recipient, + create_stream_subscription( user_profile=shylock, - is_user_active=shylock.is_active, + recipient=recipient, + stream=stream, color=STREAM_ASSIGNMENT_COLORS[0], ) RealmAuditLog.objects.create( diff --git a/zerver/actions/streams.py b/zerver/actions/streams.py index 2d3bede91c..e1360d765f 100644 --- a/zerver/actions/streams.py +++ b/zerver/actions/streams.py @@ -28,6 +28,7 @@ from zerver.lib.stream_color import pick_colors from zerver.lib.stream_subscription import ( SubInfo, SubscriberPeerInfo, + bulk_update_subscriber_counts, get_active_subscriptions_for_stream_id, get_bulk_stream_subscriber_info, get_used_colors_for_user_ids, @@ -825,9 +826,12 @@ def bulk_add_subscriptions( altered_user_dict: dict[int, set[int]] = defaultdict(set) altered_guests: set[int] = set() altered_streams_dict: dict[UserProfile, set[int]] = defaultdict(set) + subscriber_count_changes: dict[int, set[int]] = defaultdict(set) for sub_info in subs_to_add + subs_to_activate: altered_user_dict[sub_info.stream.id].add(sub_info.user.id) altered_streams_dict[sub_info.user].add(sub_info.stream.id) + if sub_info.user.is_active: + subscriber_count_changes[sub_info.stream.id].add(sub_info.user.id) if sub_info.user.is_guest: altered_guests.add(sub_info.user.id) @@ -843,6 +847,7 @@ def bulk_add_subscriptions( subs_to_add=subs_to_add, subs_to_activate=subs_to_activate, ) + bulk_update_subscriber_counts(direction=1, streams=subscriber_count_changes) stream_dict = {stream.id: stream for stream in streams} @@ -1092,12 +1097,19 @@ def bulk_remove_subscriptions( return ([], not_subscribed) sub_ids_to_deactivate = [sub_info.sub.id for sub_info in subs_to_deactivate] + + subscriber_count_changes: dict[int, set[int]] = defaultdict(set) + for sub_info in subs_to_deactivate: + if sub_info.user.is_active: + subscriber_count_changes[sub_info.stream.id].add(sub_info.user.id) + # We do all the database changes in a transaction to ensure # RealmAuditLog entries are atomically created when making changes. with transaction.atomic(savepoint=False): Subscription.objects.filter( id__in=sub_ids_to_deactivate, ).update(active=False) + bulk_update_subscriber_counts(direction=-1, streams=subscriber_count_changes) # Log subscription activities in RealmAuditLog event_time = timezone_now() diff --git a/zerver/actions/users.py b/zerver/actions/users.py index c6474fcb86..df91ecd691 100644 --- a/zerver/actions/users.py +++ b/zerver/actions/users.py @@ -32,6 +32,7 @@ from zerver.lib.send_email import ( ) from zerver.lib.sessions import delete_user_sessions from zerver.lib.soft_deactivation import queue_soft_reactivation +from zerver.lib.stream_subscription import update_all_subscriber_counts_for_user from zerver.lib.stream_traffic import get_streams_traffic from zerver.lib.streams import ( get_anonymous_group_membership_dict_for_streams, @@ -273,12 +274,15 @@ def change_user_is_active(user_profile: UserProfile, value: bool) -> None: Helper function for changing the .is_active field. Not meant as a standalone function in production code as properly activating/deactivating users requires more steps. This changes the is_active value and saves it, while ensuring - Subscription.is_user_active values are updated in the same db transaction. + Subscription.is_user_active and Stream.subscriber_count values are updated in the same db transaction. """ with transaction.atomic(savepoint=False): user_profile.is_active = value user_profile.save(update_fields=["is_active"]) Subscription.objects.filter(user_profile=user_profile).update(is_user_active=value) + update_all_subscriber_counts_for_user( + user_profile=user_profile, direction=1 if value else -1 + ) def send_group_update_event_for_anonymous_group_setting( diff --git a/zerver/lib/stream_subscription.py b/zerver/lib/stream_subscription.py index 7cdb83c66b..728aceabe1 100644 --- a/zerver/lib/stream_subscription.py +++ b/zerver/lib/stream_subscription.py @@ -3,9 +3,12 @@ from collections import defaultdict from collections.abc import Set as AbstractSet from dataclasses import dataclass from operator import itemgetter -from typing import Any +from typing import Any, Literal -from django.db.models import Q, QuerySet +from django.db import connection, transaction +from django.db.models import F, Q, QuerySet +from psycopg2 import sql +from psycopg2.extras import execute_values from zerver.models import AlertWord, Recipient, Stream, Subscription, UserProfile, UserTopic @@ -76,6 +79,12 @@ def get_stream_subscriptions_for_user(user_profile: UserProfile) -> QuerySet[Sub ) +def get_user_subscribed_streams(user_profile: UserProfile) -> QuerySet[Stream]: + return Stream.objects.filter( + recipient_id__in=get_subscribed_stream_recipient_ids_for_user(user_profile) + ) + + def get_used_colors_for_user_ids(user_ids: list[int]) -> dict[int, set[str]]: """Fetch which stream colors have already been used for each user in user_ids. Uses an optimized query designed to support picking @@ -313,3 +322,93 @@ def get_subscriptions_for_send_message( ) ) return query + + +def update_all_subscriber_counts_for_user( + user_profile: UserProfile, direction: Literal[1, -1] +) -> None: + """ + Increment/Decrement number of stream subscribers by 1, when reactivating/deactivating user. + + direction -> 1=increment, -1=decrement + """ + get_user_subscribed_streams(user_profile).update( + subscriber_count=F("subscriber_count") + direction + ) + + +def bulk_update_subscriber_counts( + direction: Literal[1, -1], + streams: dict[int, set[int]], +) -> None: + """Increment/Decrement number of stream subscribers for multiple users. + + direction -> 1=increment, -1=decrement + """ + if len(streams) == 0: + return + + # list of tuples (stream_id, delta_subscribers) used as the + # columns of the temporary table delta_table. + stream_delta_values = [ + (stream_id, len(subscribers) * direction) for stream_id, subscribers in streams.items() + ] + + # The goal here is to update subscriber_count in a bulk efficient way, + # letting the database handle the deltas to avoid some race conditions. + # + # But unlike update_all_subscriber_counts_for_user which uses F() + # for a single delta value, we can't use F() to apply different + # deltas per row in a single update using ORM, so we use a raw + # SQL query. + query = sql.SQL( + """UPDATE {stream_table} + SET subscriber_count = {stream_table}.subscriber_count + delta_table.delta + FROM (VALUES %s) AS delta_table(id, delta) + WHERE {stream_table}.id = delta_table.id; + """ + ).format(stream_table=sql.Identifier(Stream._meta.db_table)) + + cursor = connection.cursor() + execute_values(cursor.cursor, query, stream_delta_values) + + +@transaction.atomic(savepoint=False) +def create_stream_subscription( + user_profile: UserProfile, + recipient: Recipient, + stream: Stream, + color: str = Subscription.DEFAULT_STREAM_COLOR, +) -> None: + """ + Creates a single stream Subscription object, incrementing + stream.subscriber_count by 1 if user is active, in the same + transaction. + """ + + # We only create a stream subscription in this function + assert recipient.type == Recipient.STREAM + + Subscription.objects.create( + recipient=recipient, + user_profile=user_profile, + is_user_active=user_profile.is_active, + color=color, + ) + + if user_profile.is_active: + Stream.objects.filter(id=stream.id).update(subscriber_count=F("subscriber_count") + 1) + + +@transaction.atomic(savepoint=False) +def bulk_create_stream_subscriptions( # nocoverage + subs: list[Subscription], streams: dict[int, set[int]] +) -> None: + """ + Bulk create subscripions for streams, incrementing + stream.subscriber_count in the same transaction. + + Currently only used in populate_db. + """ + Subscription.objects.bulk_create(subs) + bulk_update_subscriber_counts(direction=1, streams=streams) diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index 5c08393b3f..80e5bd9e8e 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -5,7 +5,7 @@ import re import shutil import subprocess import tempfile -from collections.abc import Callable, Collection, Iterator, Mapping, Sequence +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Union, cast from unittest import TestResult, mock, skipUnless @@ -1453,6 +1453,31 @@ Output: self.assertEqual(stream.recipient_id, message.recipient_id) self.assertEqual(stream.name, stream_name) + def assert_stream_subscriber_count( + self, + counts_before: dict[int, int], + counts_after: dict[int, int], + expected_difference: int, + ) -> None: + # Normally they should always be equal, + # but just in case this was called in some test where user/s streams have changed + # and we forgot to update streams, + # so this assertion catches that. + self.assertEqual( + set(counts_before), + set(counts_after), + msg="Different streams! You should compare subscriber_count for the same streams.", + ) + + for stream_id, count_before in counts_before.items(): + self.assertEqual( + count_before + expected_difference, + counts_after[stream_id], + msg=f""" + stream of ID ({stream_id}) should have a subscriber_count of {count_before + expected_difference}. + """, + ) + def webhook_fixture_data(self, type: str, action: str, file_type: str = "json") -> str: fn = os.path.join( os.path.dirname(__file__), @@ -2242,6 +2267,20 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase): with self.captureOnCommitCallbacks(execute=True): handle_missedmessage_emails(user_profile_id, message_ids) + def build_streams_subscriber_count(self, streams: Iterable[Stream]) -> dict[int, int]: + """ + Callers MUST pass a new db-fetched version of streams each time. + """ + return {stream.id: stream.subscriber_count for stream in streams} + + def fetch_streams_subscriber_count(self, stream_ids: set[int]) -> dict[int, int]: + return self.build_streams_subscriber_count(streams=Stream.objects.filter(id__in=stream_ids)) + + def fetch_other_streams_subscriber_count(self, stream_ids: set[int]) -> dict[int, int]: + return self.build_streams_subscriber_count( + streams=Stream.objects.exclude(id__in=stream_ids) + ) + def get_row_ids_in_all_tables() -> Iterator[tuple[str, set[int]]]: all_models = apps.get_models(include_auto_created=True) diff --git a/zerver/migrations/0704_stream_subscriber_count.py b/zerver/migrations/0704_stream_subscriber_count.py new file mode 100644 index 0000000000..a6eecbebcc --- /dev/null +++ b/zerver/migrations/0704_stream_subscriber_count.py @@ -0,0 +1,17 @@ +# Generated by Django 5.1.7 on 2025-04-08 20:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("zerver", "0703_realmuserdefault_resolved_topic_notice_auto_read_policy_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="stream", + name="subscriber_count", + field=models.PositiveIntegerField(db_default=0, default=0), + ), + ] diff --git a/zerver/migrations/0705_stream_subscriber_count_data_migration.py b/zerver/migrations/0705_stream_subscriber_count_data_migration.py new file mode 100644 index 0000000000..0938bc2324 --- /dev/null +++ b/zerver/migrations/0705_stream_subscriber_count_data_migration.py @@ -0,0 +1,56 @@ +from django.db import connection, migrations, transaction +from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from django.db.migrations.state import StateApps +from psycopg2 import sql + + +def set_stream_subscribe_count(apps: StateApps, schema_editor: BaseDatabaseSchemaEditor) -> None: + Realm = apps.get_model("zerver", "realm") + Stream = apps.get_model("zerver", "Stream") + Subscription = apps.get_model("zerver", "subscription") + Recipient = apps.get_model("zerver", "recipient") + + # Here we compute per-stream subscriber_count in the sub-query + # which returns a stream_subscribers_table, then Update it + # in-place using that computed value. The whole query is then + # executed by one batch per realm. + query = sql.SQL( + """UPDATE {stream_table} AS stream + SET subscriber_count = stream_subscribers_table.count + FROM ( + SELECT recipient.type_id AS stream_id, + COUNT(subscription.user_profile_id) AS count + FROM {subscription_table} AS subscription + JOIN {recipient_table} AS recipient + ON subscription.recipient_id = recipient.id + WHERE + recipient.type = 2 + AND subscription.active = True + AND subscription.is_user_active = True + GROUP BY stream_id + ) AS stream_subscribers_table + WHERE + stream.realm_id = %(realm_id)s + AND stream.id = stream_subscribers_table.stream_id; + """ + ).format( + stream_table=sql.Identifier(Stream._meta.db_table), + subscription_table=sql.Identifier(Subscription._meta.db_table), + recipient_table=sql.Identifier(Recipient._meta.db_table), + ) + + for realm in Realm.objects.all(): + with connection.cursor() as cursor, transaction.atomic(durable=True): + cursor.execute(query, {"realm_id": realm.id}) + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ("zerver", "0704_stream_subscriber_count"), + ] + + operations = [ + migrations.RunPython(set_stream_subscribe_count), + ] diff --git a/zerver/models/streams.py b/zerver/models/streams.py index 1a47db16fa..038a61be8e 100644 --- a/zerver/models/streams.py +++ b/zerver/models/streams.py @@ -33,6 +33,11 @@ class Stream(models.Model): description = models.CharField(max_length=MAX_DESCRIPTION_LENGTH, default="") rendered_description = models.TextField(default="") + # Total number of non-deactivated users who are subscribed to the channel. + # It's obvious to be a positive field but also in case it becomes negative + # we know immediately that something is wrong as it raises IntegrityError. + subscriber_count = models.PositiveIntegerField(default=0, db_default=0) + # Foreign key to the Recipient object for STREAM type messages to this stream. recipient = models.ForeignKey(Recipient, null=True, on_delete=models.SET_NULL) diff --git a/zerver/tests/test_invite.py b/zerver/tests/test_invite.py index 00a81d9498..4b6b6e6548 100644 --- a/zerver/tests/test_invite.py +++ b/zerver/tests/test_invite.py @@ -109,7 +109,7 @@ class StreamSetupTest(ZulipTestCase): new_user = self.create_simple_new_user(realm, "alice@zulip.com") - with self.assert_database_query_count(13): + with self.assert_database_query_count(14): set_up_streams_and_groups_for_new_human_user( user_profile=new_user, prereg_user=None, @@ -145,7 +145,7 @@ class StreamSetupTest(ZulipTestCase): new_user = self.create_simple_new_user(realm, new_user_email) - with self.assert_database_query_count(16): + with self.assert_database_query_count(17): set_up_streams_and_groups_for_new_human_user( user_profile=new_user, prereg_user=prereg_user, diff --git a/zerver/tests/test_message_send.py b/zerver/tests/test_message_send.py index a11566168b..a938bba4a5 100644 --- a/zerver/tests/test_message_send.py +++ b/zerver/tests/test_message_send.py @@ -47,6 +47,7 @@ from zerver.lib.exceptions import ( from zerver.lib.message import get_raw_unread_data, get_recent_private_conversations from zerver.lib.message_cache import MessageDict from zerver.lib.per_request_cache import flush_per_request_caches +from zerver.lib.stream_subscription import create_stream_subscription from zerver.lib.streams import create_stream_if_needed from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import ( @@ -1559,11 +1560,7 @@ class StreamMessagesTest(ZulipTestCase): delivery_email=email, long_term_idle=long_term_idle, ) - Subscription.objects.create( - user_profile=user, - is_user_active=user.is_active, - recipient=recipient, - ) + create_stream_subscription(user_profile=user, recipient=recipient, stream=stream) def send_test_message() -> None: message = Message( diff --git a/zerver/tests/test_populate_db.py b/zerver/tests/test_populate_db.py index 605b8338e4..d50c1cd4f8 100644 --- a/zerver/tests/test_populate_db.py +++ b/zerver/tests/test_populate_db.py @@ -1,6 +1,10 @@ +from collections import defaultdict from datetime import timedelta +from zerver.lib.stream_subscription import get_active_subscriptions_for_stream_ids from zerver.lib.test_classes import ZulipTestCase +from zerver.models import Stream +from zerver.models.realms import get_realm from zilencer.management.commands.populate_db import choose_date_sent @@ -30,3 +34,31 @@ class TestUserTimeZones(ZulipTestCase): self.assertEqual(shiva.timezone, "Asia/Kolkata") cordelia = self.example_user("cordelia") self.assertEqual(cordelia.timezone, "UTC") + + +class TestSubscribeUsers(ZulipTestCase): + def test_bulk_create_stream_subscriptions(self) -> None: + """ + This insures bulk_create_stream_subscriptions() ran successfully when test data is loaded via populate_db.py + """ + + realm = get_realm("zulip") + streams = Stream.objects.filter(realm=realm) + active_subscriptions = get_active_subscriptions_for_stream_ids( + {stream.id for stream in streams} + ).select_related("recipient") + + # Map stream_id to its No. active subscriptions. + expected_subscriber_count: dict[int, int] = defaultdict(int) + + for sub in active_subscriptions: + expected_subscriber_count[sub.recipient.type_id] += 1 + + for stream in streams: + self.assertEqual( + stream.subscriber_count, + expected_subscriber_count[stream.id], + msg=f""" + stream of ID ({stream.id}) should have a subscriber_count of {expected_subscriber_count[stream.id]}. + """, + ) diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index 08a1a46799..ef48ebca0a 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -47,7 +47,10 @@ from zerver.lib.mobile_auth_otp import ( ) from zerver.lib.name_restrictions import is_disposable_domain from zerver.lib.send_email import EmailNotDeliveredError, FromAddress, send_future_email -from zerver.lib.stream_subscription import get_stream_subscriptions_for_user +from zerver.lib.stream_subscription import ( + get_stream_subscriptions_for_user, + get_user_subscribed_streams, +) from zerver.lib.streams import create_stream_if_needed from zerver.lib.subdomains import is_root_domain_available from zerver.lib.test_classes import ZulipTestCase @@ -1029,7 +1032,7 @@ class LoginTest(ZulipTestCase): # to sending messages, such as getting the welcome bot, looking up # the alert words for a realm, etc. with ( - self.assert_database_query_count(95), + self.assert_database_query_count(96), self.assert_memcached_count(18), self.captureOnCommitCallbacks(execute=True), ): @@ -2807,6 +2810,59 @@ class UserSignUpTest(ZulipTestCase): result = self.submit_reg_form_for_user(email, password, default_stream_groups=["group 1"]) self.check_user_subscribed_only_to_streams("newguy", default_streams | set(group1_streams)) + def test_signup_stream_subscriber_count(self) -> None: + """ + Verify that signing up successfully increments subscriber_count by 1 + for that new user subscribed streams. + """ + email = "newguy@zulip.com" + password = "newpassword" + realm = get_realm("zulip") + + all_streams_subscriber_count = self.build_streams_subscriber_count( + streams=Stream.objects.all() + ) + + result = self.verify_signup(email=email, password=password, realm=realm) + assert isinstance(result, UserProfile) + + user_profile = result + user_stream_ids = {stream.id for stream in get_user_subscribed_streams(user_profile)} + + streams_subscriber_counts_before = { + stream_id: count + for stream_id, count in all_streams_subscriber_count.items() + if stream_id in user_stream_ids + } + + other_streams_subscriber_counts_before = { + stream_id: count + for stream_id, count in all_streams_subscriber_count.items() + if stream_id not in user_stream_ids + } + + # DB-refresh streams. + streams_subscriber_counts_after = self.fetch_streams_subscriber_count(user_stream_ids) + + # DB-refresh other_streams. + other_streams_subscriber_counts_after = self.fetch_other_streams_subscriber_count( + user_stream_ids + ) + + # Signing up a user should result in subscriber_count + 1 + self.assert_stream_subscriber_count( + streams_subscriber_counts_before, + streams_subscriber_counts_after, + expected_difference=1, + ) + + # Make sure other streams are not affected upon signup. + self.assert_stream_subscriber_count( + other_streams_subscriber_counts_before, + other_streams_subscriber_counts_after, + expected_difference=0, + ) + def test_signup_two_confirmation_links(self) -> None: email = self.nonreg_email("newguy") password = "newpassword" diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 4c0dbfd3b7..1ac4a4e405 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -3626,7 +3626,7 @@ class StreamAdminTest(ZulipTestCase): those you aren't on. """ result = self.attempt_unsubscribe_of_principal( - query_count=14, + query_count=15, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=True, @@ -3653,7 +3653,7 @@ class StreamAdminTest(ZulipTestCase): for name in ["cordelia", "prospero", "iago", "hamlet", "outgoing_webhook_bot"] ] result = self.attempt_unsubscribe_of_principal( - query_count=21, + query_count=22, cache_count=13, target_users=target_users, is_realm_admin=True, @@ -3671,7 +3671,7 @@ class StreamAdminTest(ZulipTestCase): are on. """ result = self.attempt_unsubscribe_of_principal( - query_count=17, + query_count=18, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=True, @@ -3688,7 +3688,7 @@ class StreamAdminTest(ZulipTestCase): streams you aren't on. """ result = self.attempt_unsubscribe_of_principal( - query_count=17, + query_count=18, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=False, @@ -3714,7 +3714,7 @@ class StreamAdminTest(ZulipTestCase): def test_admin_remove_others_from_stream_legacy_emails(self) -> None: result = self.attempt_unsubscribe_of_principal( - query_count=14, + query_count=15, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=True, @@ -3728,7 +3728,7 @@ class StreamAdminTest(ZulipTestCase): def test_admin_remove_multiple_users_from_stream_legacy_emails(self) -> None: result = self.attempt_unsubscribe_of_principal( - query_count=16, + query_count=17, target_users=[self.example_user("cordelia"), self.example_user("prospero")], is_realm_admin=True, is_subbed=True, @@ -3742,7 +3742,7 @@ class StreamAdminTest(ZulipTestCase): def test_remove_unsubbed_user_along_with_subbed(self) -> None: result = self.attempt_unsubscribe_of_principal( - query_count=13, + query_count=14, target_users=[self.example_user("cordelia"), self.example_user("iago")], is_realm_admin=True, is_subbed=True, @@ -3775,7 +3775,7 @@ class StreamAdminTest(ZulipTestCase): webhook_bot = self.example_user("webhook_bot") do_change_bot_owner(webhook_bot, bot_owner=user_profile, acting_user=user_profile) result = self.attempt_unsubscribe_of_principal( - query_count=14, + query_count=15, target_users=[webhook_bot], is_realm_admin=False, is_subbed=True, @@ -6261,7 +6261,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub = ["multi_user_stream"] with ( self.capture_send_event_calls(expected_num_events=5) as events, - self.assert_database_query_count(43), + self.assert_database_query_count(44), ): self.subscribe_via_post( self.test_user, @@ -6287,7 +6287,7 @@ class SubscriptionAPITest(ZulipTestCase): # Now add ourselves with ( self.capture_send_event_calls(expected_num_events=2) as events, - self.assert_database_query_count(19), + self.assert_database_query_count(20), ): self.subscribe_via_post( self.test_user, @@ -6633,7 +6633,7 @@ class SubscriptionAPITest(ZulipTestCase): # Sends 5 peer-remove events, 2 unsubscribe events # and 2 stream delete events for private streams. with ( - self.assert_database_query_count(26), + self.assert_database_query_count(27), self.assert_memcached_count(5), self.capture_send_event_calls(expected_num_events=9) as events, ): @@ -6775,7 +6775,7 @@ class SubscriptionAPITest(ZulipTestCase): # Verify that peer_event events are never sent in Zephyr # realm. This does generate stream creation events from # send_stream_creation_events_for_previously_inaccessible_streams. - with self.assert_database_query_count(num_streams + 18): + with self.assert_database_query_count(num_streams + 19): with self.capture_send_event_calls(expected_num_events=num_streams + 1) as events: self.subscribe_via_post( mit_user, @@ -6856,7 +6856,7 @@ class SubscriptionAPITest(ZulipTestCase): test_user_ids = [user.id for user in test_users] with ( - self.assert_database_query_count(22), + self.assert_database_query_count(23), self.assert_memcached_count(11), mock.patch("zerver.views.streams.send_messages_for_new_subscribers"), ): @@ -7234,7 +7234,7 @@ class SubscriptionAPITest(ZulipTestCase): ] # Test creating a public stream when realm does not have a notification stream. - with self.assert_database_query_count(43): + with self.assert_database_query_count(44): self.subscribe_via_post( self.test_user, [new_streams[0]], @@ -7242,7 +7242,7 @@ class SubscriptionAPITest(ZulipTestCase): ) # Test creating private stream. - with self.assert_database_query_count(51): + with self.assert_database_query_count(52): self.subscribe_via_post( self.test_user, [new_streams[1]], @@ -7254,7 +7254,7 @@ class SubscriptionAPITest(ZulipTestCase): new_stream_announcements_stream = get_stream(self.streams[0], self.test_realm) self.test_realm.new_stream_announcements_stream_id = new_stream_announcements_stream.id self.test_realm.save() - with self.assert_database_query_count(55): + with self.assert_database_query_count(56): self.subscribe_via_post( self.test_user, [new_streams[2]], @@ -7264,6 +7264,167 @@ class SubscriptionAPITest(ZulipTestCase): ), ) + def test_stream_subscriber_count_upon_bulk_subscription(self) -> None: + """ + Test subscriber_count increases for the correct streams + upon bulk subscription. + + We use the api here as we want this to be end-to-end. + """ + + stream_names = [f"stream_{i}" for i in range(10)] + stream_ids = {self.make_stream(stream_name).id for stream_name in stream_names} + + desdemona = self.example_user("desdemona") + self.login_user(desdemona) + + user_ids = [ + desdemona.id, + self.example_user("cordelia").id, + self.example_user("hamlet").id, + self.example_user("othello").id, + self.example_user("iago").id, + self.example_user("prospero").id, + ] + + streams_subscriber_counts_before_subscribe = self.fetch_streams_subscriber_count(stream_ids) + other_streams_subscriber_counts_before_subscribe = ( + self.fetch_other_streams_subscriber_count(stream_ids) + ) + + # Subscribe users to the streams. + self.subscribe_via_post( + desdemona, + stream_names, + dict(principals=orjson.dumps(user_ids).decode()), + ) + + # DB-refresh streams. + streams_subscriber_counts_after_subscribe = self.fetch_streams_subscriber_count(stream_ids) + # DB-refresh other streams. + other_streams_subscriber_counts_after_subscribe = self.fetch_other_streams_subscriber_count( + stream_ids + ) + + # Ensure an increase in subscriber_count + self.assert_stream_subscriber_count( + streams_subscriber_counts_before_subscribe, + streams_subscriber_counts_after_subscribe, + expected_difference=len(user_ids), + ) + + # Make sure other streams are not affected. + self.assert_stream_subscriber_count( + other_streams_subscriber_counts_before_subscribe, + other_streams_subscriber_counts_after_subscribe, + expected_difference=0, + ) + + # Re-subscribe same users to the same streams. + self.subscribe_via_post( + desdemona, + stream_names, + dict(principals=orjson.dumps(user_ids).decode()), + ) + # DB-refresh streams. + streams_subscriber_counts_after_resubscribe = self.fetch_streams_subscriber_count( + stream_ids + ) + # Ensure Idempotency; subscribing "already" subscribed users shouldn't change subscriber_count. + self.assert_stream_subscriber_count( + streams_subscriber_counts_after_subscribe, + streams_subscriber_counts_after_resubscribe, + expected_difference=0, + ) + + def test_stream_subscriber_count_upon_bulk_unsubscription(self) -> None: + """ + Test subscriber_count decreases for the correct streams + upon bulk un-subscription. + + We use the api here as we want this to be end-to-end. + """ + + stream_names = [f"stream_{i}" for i in range(10)] + stream_ids = {self.make_stream(stream_name).id for stream_name in stream_names} + + desdemona = self.example_user("desdemona") + self.login_user(desdemona) + + user_ids = [ + desdemona.id, + self.example_user("cordelia").id, + self.example_user("hamlet").id, + self.example_user("othello").id, + self.example_user("iago").id, + self.example_user("prospero").id, + ] + + # Subscribe users to the streams. + self.subscribe_via_post( + desdemona, + stream_names, + dict(principals=orjson.dumps(user_ids).decode()), + ) + + streams_subscriber_counts_before_unsubscribe = self.fetch_streams_subscriber_count( + stream_ids + ) + other_streams_subscriber_counts_before_unsubscribe = ( + self.fetch_other_streams_subscriber_count(stream_ids) + ) + + # Unsubscribe users from the same streams. + self.client_delete( + "/json/users/me/subscriptions", + { + "subscriptions": orjson.dumps(stream_names).decode(), + "principals": orjson.dumps(user_ids).decode(), + }, + ) + + # DB-refresh streams. + streams_subscriber_counts_after_unsubscribe = self.fetch_streams_subscriber_count( + stream_ids + ) + # DB-refresh other streams. + other_streams_subscriber_counts_after_unsubscribe = ( + self.fetch_other_streams_subscriber_count(stream_ids) + ) + + # Ensure a decrease in subscriber_count + self.assert_stream_subscriber_count( + streams_subscriber_counts_before_unsubscribe, + streams_subscriber_counts_after_unsubscribe, + expected_difference=-len(user_ids), + ) + + # Make sure other streams are not affected. + self.assert_stream_subscriber_count( + other_streams_subscriber_counts_before_unsubscribe, + other_streams_subscriber_counts_after_unsubscribe, + expected_difference=0, + ) + + # Re-Unsubscribe users from the same streams. + self.client_delete( + "/json/users/me/subscriptions", + { + "subscriptions": orjson.dumps(stream_names).decode(), + "principals": orjson.dumps(user_ids).decode(), + }, + ) + # DB-refresh streams. + streams_subscriber_counts_after_reunsubscribe = self.fetch_streams_subscriber_count( + stream_ids + ) + # Ensure Idempotency; unsubscribing "already" non-subscribed users shouldn't change subscriber_count. + self.assert_stream_subscriber_count( + streams_subscriber_counts_after_unsubscribe, + streams_subscriber_counts_after_reunsubscribe, + expected_difference=0, + ) + class GetStreamsTest(ZulipTestCase): def test_streams_api_for_bot_owners(self) -> None: @@ -8039,7 +8200,7 @@ class GetSubscribersTest(ZulipTestCase): polonius.id, ] - with self.assert_database_query_count(51): + with self.assert_database_query_count(52): self.subscribe_via_post( self.user_profile, stream_names, diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index f9e476057b..63b7afc252 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -37,6 +37,7 @@ from zerver.lib.create_user import copy_default_settings from zerver.lib.events import do_events_register from zerver.lib.exceptions import JsonableError from zerver.lib.send_email import clear_scheduled_emails, queue_scheduled_emails, send_future_email +from zerver.lib.stream_subscription import get_user_subscribed_streams from zerver.lib.stream_topic import StreamTopicTarget from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import ( @@ -982,6 +983,8 @@ class QueryCountTest(ZulipTestCase): ] streams = [get_stream(stream_name, realm) for stream_name in stream_names] + subscriber_count_before = self.build_streams_subscriber_count(streams) + invite_expires_in_minutes = 4 * 24 * 60 with self.captureOnCommitCallbacks(execute=True): do_invite_users( @@ -995,7 +998,7 @@ class QueryCountTest(ZulipTestCase): prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com") with ( - self.assert_database_query_count(85), + self.assert_database_query_count(86), self.assert_memcached_count(23), self.capture_send_event_calls(expected_num_events=11) as events, ): @@ -1021,6 +1024,17 @@ class QueryCountTest(ZulipTestCase): notifications, {"private_stream1", "private_stream2", "Verona", "Denmark,Scotland"} ) + # DB-refresh streams + subscriber_count_after = self.fetch_streams_subscriber_count( + stream_ids=set(subscriber_count_before) + ) + + self.assert_stream_subscriber_count( + subscriber_count_before, + subscriber_count_after, + expected_difference=1, + ) + class BulkCreateUserTest(ZulipTestCase): def test_create_users(self) -> None: @@ -1811,6 +1825,95 @@ class ActivateTest(ZulipTestCase): user = self.example_user("hamlet") self.assertTrue(user.is_active) + def test_stream_subscriber_count_upon_deactivate(self) -> None: + # Test subscriber_count decrements upon deactivating a user. + # We use the api here as we want this to be end-to-end. + + admin = self.example_user("othello") + do_change_user_role(admin, UserProfile.ROLE_REALM_ADMINISTRATOR, acting_user=None) + self.login("othello") + user = self.example_user("hamlet") + + streams_subscriber_counts_before = self.build_streams_subscriber_count( + streams=get_user_subscribed_streams(user) + ) + stream_ids = set(streams_subscriber_counts_before) + other_streams_subscriber_counts_before = self.fetch_other_streams_subscriber_count( + stream_ids + ) + + result = self.client_delete(f"/json/users/{user.id}") + self.assert_json_success(result) + + # DB-refresh streams. + streams_subscriber_counts_after = self.fetch_streams_subscriber_count(stream_ids) + + # DB-refresh other_streams. + other_streams_subscriber_counts_after = self.fetch_other_streams_subscriber_count( + stream_ids + ) + + # Deactivating a user should result in subscriber_count - 1 + self.assert_stream_subscriber_count( + streams_subscriber_counts_before, + streams_subscriber_counts_after, + expected_difference=-1, + ) + + # Make sure other streams are not affected upon deactivation. + self.assert_stream_subscriber_count( + other_streams_subscriber_counts_before, + other_streams_subscriber_counts_after, + expected_difference=0, + ) + + def test_stream_subscriber_count_upon_reactivate(self) -> None: + # Test subscriber_count increments upon reactivating a user. + # We use the api here as we want this to be end-to-end. + + admin = self.example_user("othello") + do_change_user_role(admin, UserProfile.ROLE_REALM_ADMINISTRATOR, acting_user=None) + self.login("othello") + user = self.example_user("hamlet") + + # First, deactivate that user + result = self.client_delete(f"/json/users/{user.id}") + self.assert_json_success(result) + + streams_subscriber_counts_before = self.build_streams_subscriber_count( + streams=get_user_subscribed_streams(user) + ) + stream_ids = set(streams_subscriber_counts_before) + other_streams_subscriber_counts_before = self.fetch_other_streams_subscriber_count( + stream_ids + ) + + # Reactivate user + result = self.client_post(f"/json/users/{user.id}/reactivate") + self.assert_json_success(result) + + # DB-refresh streams. + streams_subscriber_counts_after = self.fetch_streams_subscriber_count(stream_ids) + + # DB-refresh other_streams. + other_streams_subscriber_counts_after = self.fetch_other_streams_subscriber_count( + stream_ids + ) + + # Reactivating a user should result in subscriber_count + 1 + self.assert_stream_subscriber_count( + streams_subscriber_counts_before, + streams_subscriber_counts_after, + expected_difference=1, + ) + + # Make sure other streams are not affected upon reactivation. + self.assert_stream_subscriber_count( + other_streams_subscriber_counts_before, + other_streams_subscriber_counts_after, + expected_difference=0, + ) + def test_email_sent(self) -> None: self.login("iago") user = self.example_user("hamlet") diff --git a/zilencer/management/commands/populate_db.py b/zilencer/management/commands/populate_db.py index 4696756c07..b228b5fb9d 100644 --- a/zilencer/management/commands/populate_db.py +++ b/zilencer/management/commands/populate_db.py @@ -46,6 +46,7 @@ from zerver.lib.remote_server import get_realms_info_for_push_bouncer from zerver.lib.server_initialization import create_internal_realm, create_users from zerver.lib.storage import static_path from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS +from zerver.lib.stream_subscription import bulk_create_stream_subscriptions from zerver.lib.types import AnalyticsDataUploadLevel, ProfileFieldData from zerver.lib.users import add_service from zerver.lib.utils import generate_api_key @@ -156,6 +157,7 @@ def subscribe_users_to_streams(realm: Realm, stream_dict: dict[str, dict[str, An subscriptions_to_add = [] event_time = timezone_now() all_subscription_logs = [] + subscriber_count_changes: dict[int, set[int]] = defaultdict(set) profiles = UserProfile.objects.select_related("realm").filter(realm=realm) for i, stream_name in enumerate(stream_dict): stream = Stream.objects.get(name=stream_name, realm=realm) @@ -169,6 +171,8 @@ def subscribe_users_to_streams(realm: Realm, stream_dict: dict[str, dict[str, An color=STREAM_ASSIGNMENT_COLORS[i % len(STREAM_ASSIGNMENT_COLORS)], ) subscriptions_to_add.append(s) + if profile.is_active: + subscriber_count_changes[stream.id].add(profile.id) log = RealmAuditLog( realm=profile.realm, @@ -179,7 +183,7 @@ def subscribe_users_to_streams(realm: Realm, stream_dict: dict[str, dict[str, An event_time=event_time, ) all_subscription_logs.append(log) - Subscription.objects.bulk_create(subscriptions_to_add) + bulk_create_stream_subscriptions(subs=subscriptions_to_add, streams=subscriber_count_changes) RealmAuditLog.objects.bulk_create(all_subscription_logs) @@ -727,6 +731,7 @@ class Command(ZulipBaseCommand): subscriptions_list.append((profile, r)) subscriptions_to_add: list[Subscription] = [] + subscriber_count_changes: dict[int, set[int]] = defaultdict(set) event_time = timezone_now() all_subscription_logs: list[RealmAuditLog] = [] @@ -742,6 +747,8 @@ class Command(ZulipBaseCommand): ) subscriptions_to_add.append(s) + if profile.is_active: + subscriber_count_changes[recipient.type_id].add(profile.id) log = RealmAuditLog( realm=profile.realm, @@ -753,7 +760,9 @@ class Command(ZulipBaseCommand): ) all_subscription_logs.append(log) - Subscription.objects.bulk_create(subscriptions_to_add) + bulk_create_stream_subscriptions( + subs=subscriptions_to_add, streams=subscriber_count_changes + ) RealmAuditLog.objects.bulk_create(all_subscription_logs) # Create custom profile field data