diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index ca2e361443..deda26916a 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -97,7 +97,7 @@ from zerver.models import Realm, RealmEmoji, Stream, UserProfile, UserActivity, ScheduledEmail, MAX_TOPIC_NAME_LENGTH, \ MAX_MESSAGE_LENGTH, get_client, get_stream, get_personal_recipient, \ get_user_profile_by_id, PreregistrationUser, \ - bulk_get_recipients, get_stream_recipient, get_stream_recipients, \ + get_stream_recipient, get_stream_recipients, \ email_allowed_for_realm, email_to_username, \ get_user_by_delivery_email, get_stream_cache_key, active_non_guest_user_ids, \ UserActivityInterval, active_user_ids, get_active_streams, \ @@ -2869,12 +2869,12 @@ def bulk_add_subscriptions(streams: Iterable[Stream], acting_user: Optional[UserProfile]=None) -> SubT: users = list(users) - recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams]) # type: Mapping[int, Recipient] - recipients = [recipient.id for recipient in recipients_map.values()] # type: List[int] + recipients_map = dict((stream.id, stream.recipient_id) for stream in streams) # type: Dict[int, int] + recipient_ids = [recipient_id for recipient_id in recipients_map.values()] # type: List[int] stream_map = {} # type: Dict[int, Stream] for stream in streams: - stream_map[recipients_map[stream.id].id] = stream + stream_map[recipients_map[stream.id]] = stream subs_by_user = defaultdict(list) # type: Dict[int, List[Subscription]] all_subs_query = get_stream_subscriptions_for_users(users).select_related('user_profile') @@ -2887,7 +2887,7 @@ def bulk_add_subscriptions(streams: Iterable[Stream], subs_to_activate = [] # type: List[Tuple[Subscription, Stream]] new_subs = [] # type: List[Tuple[UserProfile, int, Stream]] for user_profile in users: - needs_new_sub = set(recipients) # type: Set[int] + needs_new_sub = set(recipient_ids) # type: Set[int] for sub in subs_by_user[user_profile.id]: if sub.recipient_id in needs_new_sub: needs_new_sub.remove(sub.recipient_id) diff --git a/zerver/lib/streams.py b/zerver/lib/streams.py index 95acbccd39..dbd5ce1e5f 100644 --- a/zerver/lib/streams.py +++ b/zerver/lib/streams.py @@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _ from zerver.lib.actions import check_stream_name, create_streams_if_needed from zerver.lib.request import JsonableError from zerver.models import UserProfile, Stream, Subscription, \ - Realm, Recipient, bulk_get_recipients, get_stream, \ + Realm, Recipient, get_stream, \ bulk_get_streams, get_realm_stream, DefaultStreamGroup, get_stream_by_id_in_realm from django.db.models.query import QuerySet @@ -202,9 +202,9 @@ def can_access_stream_history_by_id(user_profile: UserProfile, stream_id: int) - def filter_stream_authorization(user_profile: UserProfile, streams: Iterable[Stream]) -> Tuple[List[Stream], List[Stream]]: streams_subscribed = set() # type: Set[int] - recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams]) + recipient_ids = [stream.recipient_id for stream in streams] subs = Subscription.objects.filter(user_profile=user_profile, - recipient__in=list(recipients_map.values()), + recipient_id__in=recipient_ids, active=True) for sub in subs: diff --git a/zerver/models.py b/zerver/models.py index 2f410db91d..0ab860ff2f 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -1565,26 +1565,7 @@ def bulk_get_huddle_user_ids(recipients: List[Recipient]) -> Dict[int, List[int] return result_dict -def bulk_get_recipients(type: int, type_ids: List[int]) -> Dict[int, Any]: - def cache_key_function(type_id: int) -> str: - return get_recipient_cache_key(type, type_id) - - def query_function(type_ids: List[int]) -> Sequence[Recipient]: - # TODO: Change return type to QuerySet[Recipient] - return Recipient.objects.filter(type=type, type_id__in=type_ids) - - def recipient_to_type_id(recipient: Recipient) -> int: - return recipient.type_id - - return generic_bulk_cached_fetch(cache_key_function, query_function, type_ids, - id_fetcher=recipient_to_type_id) - def get_stream_recipients(stream_ids: List[int]) -> List[Recipient]: - - ''' - We could call bulk_get_recipients(...).values() here, but it actually - leads to an extra query in test mode. - ''' return Recipient.objects.filter( type=Recipient.STREAM, type_id__in=stream_ids, diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index 55208bad97..0da5a248b7 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -517,7 +517,7 @@ class LoginTest(ZulipTestCase): with queries_captured() as queries: self.register(self.nonreg_email('test'), "test") # Ensure the number of queries we make is not O(streams) - self.assertEqual(len(queries), 80) + self.assertEqual(len(queries), 79) user_profile = self.nonreg_user('test') self.assert_logged_in_user_id(user_profile.id) self.assertFalse(user_profile.enable_stream_desktop_notifications) diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 5832353a9f..29c051e34d 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -2434,7 +2434,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([user1.email, user2.email])), ) - self.assert_length(queries, 44) + self.assert_length(queries, 43) self.assert_length(events, 7) for ev in [x for x in events if x['event']['type'] not in ('message', 'stream')]: @@ -2462,7 +2462,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([self.test_email])), ) - self.assert_length(queries, 16) + self.assert_length(queries, 14) self.assert_length(events, 2) add_event, add_peer_event = events @@ -2759,7 +2759,7 @@ class SubscriptionAPITest(ZulipTestCase): # Make sure Zephyr mirroring realms such as MIT do not get # any tornado subscription events self.assert_length(events, 0) - self.assert_length(queries, 9) + self.assert_length(queries, 8) events = [] with tornado_redirected_to_list(events): @@ -2785,7 +2785,7 @@ class SubscriptionAPITest(ZulipTestCase): dict(principals=ujson.dumps([self.test_email])), ) # Make sure we don't make O(streams) queries - self.assert_length(queries, 21) + self.assert_length(queries, 19) def test_subscriptions_add_for_principal(self) -> None: """ @@ -3174,7 +3174,7 @@ class SubscriptionAPITest(ZulipTestCase): [new_streams[0]], dict(principals=ujson.dumps([user1.email, user2.email])), ) - self.assert_length(queries, 44) + self.assert_length(queries, 43) # Test creating private stream. with queries_captured() as queries: @@ -3184,7 +3184,7 @@ class SubscriptionAPITest(ZulipTestCase): dict(principals=ujson.dumps([user1.email, user2.email])), invite_only=True, ) - self.assert_length(queries, 39) + self.assert_length(queries, 38) # Test creating a public stream with announce when realm has a notification stream. notifications_stream = get_stream(self.streams[0], self.test_realm) @@ -3199,7 +3199,7 @@ class SubscriptionAPITest(ZulipTestCase): principals=ujson.dumps([user1.email, user2.email]) ) ) - self.assert_length(queries, 53) + self.assert_length(queries, 52) class GetBotOwnerStreamsTest(ZulipTestCase): def test_streams_api_for_bot_owners(self) -> None: diff --git a/zerver/views/messages.py b/zerver/views/messages.py index 83f696bdf5..8aea250570 100644 --- a/zerver/views/messages.py +++ b/zerver/views/messages.py @@ -51,7 +51,7 @@ from zerver.lib.validator import \ check_string_or_int_list, check_string_or_int from zerver.lib.zephyr import compute_mit_user_fullname from zerver.models import Message, UserProfile, Stream, Subscription, Client,\ - Realm, RealmDomain, Recipient, UserMessage, bulk_get_recipients, \ + Realm, RealmDomain, Recipient, UserMessage, \ email_to_domain, get_realm, get_active_streams, get_user_including_cross_realm, \ get_user_by_id_in_realm_including_cross_realm @@ -255,9 +255,8 @@ class NarrowBuilder: matching_streams = get_active_streams(self.user_profile.realm).filter( name__iregex=r'^(un)*%s(\.d)*$' % (self._pg_re_escape(base_stream_name),)) - matching_stream_ids = [matching_stream.id for matching_stream in matching_streams] - recipients_map = bulk_get_recipients(Recipient.STREAM, matching_stream_ids) - cond = column("recipient_id").in_([recipient.id for recipient in recipients_map.values()]) + recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams] + cond = column("recipient_id").in_(recipient_ids) return query.where(maybe_negate(cond)) recipient = stream.recipient