stream: Add subscriber_count field.

Fixes #34246.

Add subscriber_count field to Stream model to track number of
non-deactivated users subscribed to the channel.
This commit is contained in:
bedo
2025-05-10 01:02:03 +03:00
committed by Tim Abbott
parent 54702ba2a0
commit c04558fe31
15 changed files with 628 additions and 37 deletions

View File

@@ -5,7 +5,7 @@ import re
import shutil
import subprocess
import tempfile
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union, cast
from unittest import TestResult, mock, skipUnless
@@ -1453,6 +1453,31 @@ Output:
self.assertEqual(stream.recipient_id, message.recipient_id)
self.assertEqual(stream.name, stream_name)
def assert_stream_subscriber_count(
self,
counts_before: dict[int, int],
counts_after: dict[int, int],
expected_difference: int,
) -> None:
# Normally they should always be equal,
# but just in case this was called in some test where user/s streams have changed
# and we forgot to update streams,
# so this assertion catches that.
self.assertEqual(
set(counts_before),
set(counts_after),
msg="Different streams! You should compare subscriber_count for the same streams.",
)
for stream_id, count_before in counts_before.items():
self.assertEqual(
count_before + expected_difference,
counts_after[stream_id],
msg=f"""
stream of ID ({stream_id}) should have a subscriber_count of {count_before + expected_difference}.
""",
)
def webhook_fixture_data(self, type: str, action: str, file_type: str = "json") -> str:
fn = os.path.join(
os.path.dirname(__file__),
@@ -2242,6 +2267,20 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase):
with self.captureOnCommitCallbacks(execute=True):
handle_missedmessage_emails(user_profile_id, message_ids)
def build_streams_subscriber_count(self, streams: Iterable[Stream]) -> dict[int, int]:
"""
Callers MUST pass a new db-fetched version of streams each time.
"""
return {stream.id: stream.subscriber_count for stream in streams}
def fetch_streams_subscriber_count(self, stream_ids: set[int]) -> dict[int, int]:
return self.build_streams_subscriber_count(streams=Stream.objects.filter(id__in=stream_ids))
def fetch_other_streams_subscriber_count(self, stream_ids: set[int]) -> dict[int, int]:
return self.build_streams_subscriber_count(
streams=Stream.objects.exclude(id__in=stream_ids)
)
def get_row_ids_in_all_tables() -> Iterator[tuple[str, set[int]]]:
all_models = apps.get_models(include_auto_created=True)