diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 7f12a1724c..e75369a697 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -1432,9 +1432,17 @@ def do_send_messages(messages_maybe_none: Sequence[Optional[MutableMapping[str, do_widget_post_save_actions(message) for message in messages: + realm_id: Optional[int] = None + if message['message'].is_stream_message(): + if message['stream'] is None: + stream_id = message['message'].recipient.type_id + message['stream'] = Stream.objects.select_related().get(id=stream_id) + assert message['stream'] is not None # assert needed because stubs for django are missing + realm_id = message['stream'].realm_id + # Deliver events to the real-time push system, as well as # enqueuing any additional processing triggered by the message. - wide_message_dict = MessageDict.wide_dict(message['message']) + wide_message_dict = MessageDict.wide_dict(message['message'], realm_id) user_flags = user_message_flags.get(message['message'].id, {}) sender = message['message'].sender @@ -1487,9 +1495,6 @@ def do_send_messages(messages_maybe_none: Sequence[Optional[MutableMapping[str, # notify new_message request if it's a public stream, # ensuring that in the tornado server, non-public stream # messages are only associated to their subscribed users. - if message['stream'] is None: - stream_id = message['message'].recipient.type_id - message['stream'] = Stream.objects.select_related().get(id=stream_id) assert message['stream'] is not None # assert needed because stubs for django are missing if message['stream'].is_public(): event['realm_id'] = message['stream'].realm_id @@ -4266,12 +4271,12 @@ def update_user_message_flags(message: Message, ums: Iterable[UserMessage]) -> N for um in changed_ums: um.save(update_fields=['flags']) -def update_to_dict_cache(changed_messages: List[Message]) -> List[int]: +def update_to_dict_cache(changed_messages: List[Message], realm_id: Optional[int]=None) -> List[int]: """Updates the message as stored in the to_dict cache (for serving messages).""" items_for_remote_cache = {} message_ids = [] - changed_messages_to_dict = MessageDict.to_dict_uncached(changed_messages) + changed_messages_to_dict = MessageDict.to_dict_uncached(changed_messages, realm_id) for msg_id, msg in changed_messages_to_dict.items(): message_ids.append(msg_id) key = to_dict_cache_key_id(msg_id) @@ -4472,7 +4477,11 @@ def do_update_message(user_profile: UserProfile, message: Message, # This does message.save(update_fields=[...]) save_message_for_edit_use_case(message=message) - event['message_ids'] = update_to_dict_cache(changed_messages) + realm_id: Optional[int] = None + if stream_being_edited is not None: + realm_id = stream_being_edited.realm_id + + event['message_ids'] = update_to_dict_cache(changed_messages, realm_id) def user_info(um: UserMessage) -> Dict[str, Any]: return { diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index 2386530411..d634363454 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -606,7 +606,7 @@ def flush_used_upload_space_cache(sender: Any, **kwargs: Any) -> None: def to_dict_cache_key_id(message_id: int) -> str: return 'message_dict:%d' % (message_id,) -def to_dict_cache_key(message: 'Message') -> str: +def to_dict_cache_key(message: 'Message', realm_id: Optional[int]=None) -> str: return to_dict_cache_key_id(message.id) def open_graph_description_cache_key(content: Any, request: HttpRequest) -> str: diff --git a/zerver/lib/message.py b/zerver/lib/message.py index a419063028..e241dd90f1 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -168,8 +168,8 @@ def stringify_message_dict(message_dict: Dict[str, Any]) -> bytes: return zlib.compress(ujson.dumps(message_dict).encode()) @cache_with_key(to_dict_cache_key, timeout=3600*24) -def message_to_dict_json(message: Message) -> bytes: - return MessageDict.to_dict_uncached([message])[message.id] +def message_to_dict_json(message: Message, realm_id: Optional[int]=None) -> bytes: + return MessageDict.to_dict_uncached([message], realm_id)[message.id] def save_message_rendered_content(message: Message, content: str) -> str: rendered_content = render_markdown(message, content, realm=message.get_realm()) @@ -180,13 +180,13 @@ def save_message_rendered_content(message: Message, content: str) -> str: class MessageDict: @staticmethod - def wide_dict(message: Message) -> Dict[str, Any]: + def wide_dict(message: Message, realm_id: Optional[int]=None) -> Dict[str, Any]: ''' The next two lines get the cacheable field related to our message object, with the side effect of populating the cache. ''' - json = message_to_dict_json(message) + json = message_to_dict_json(message, realm_id) obj = extract_message_dict(json) ''' @@ -270,23 +270,23 @@ class MessageDict: return sew_messages_and_reactions(messages, reactions) @staticmethod - def to_dict_uncached(messages: List[Message]) -> Dict[int, bytes]: - messages_dict = MessageDict.to_dict_uncached_helper(messages) + def to_dict_uncached(messages: List[Message], realm_id: Optional[int]=None) -> Dict[int, bytes]: + messages_dict = MessageDict.to_dict_uncached_helper(messages, realm_id) encoded_messages = {msg['id']: stringify_message_dict(msg) for msg in messages_dict} return encoded_messages @staticmethod - def to_dict_uncached_helper(messages: List[Message]) -> List[Dict[str, Any]]: + def to_dict_uncached_helper(messages: List[Message], + realm_id: Optional[int]=None) -> List[Dict[str, Any]]: # Near duplicate of the build_message_dict + get_raw_db_rows # code path that accepts already fetched Message objects # rather than message IDs. - # TODO: We could potentially avoid this database query in - # common cases by optionally passing through the - # stream_realm_id through the code path from do_send_messages - # (where we've already fetched the data). It would involve - # somewhat messy plumbing, but would probably be worth it. def get_rendering_realm_id(message: Message) -> int: + # realm_id can differ among users, currently only possible + # with cross realm bots. + if realm_id is not None: + return realm_id if message.recipient.type == Recipient.STREAM: return Stream.objects.get(id=message.recipient.type_id).realm_id return message.sender.realm_id diff --git a/zerver/tests/test_messages.py b/zerver/tests/test_messages.py index 2760fae945..5ae5760a21 100644 --- a/zerver/tests/test_messages.py +++ b/zerver/tests/test_messages.py @@ -1031,7 +1031,7 @@ class StreamMessagesTest(ZulipTestCase): body=content, ) - self.assert_length(queries, 15) + self.assert_length(queries, 14) def test_stream_message_dict(self) -> None: user_profile = self.example_user('iago') @@ -2675,22 +2675,32 @@ class EditMessageTest(ZulipTestCase): self.login_user(user) stream_name = "public_stream" self.subscribe(user, stream_name) - message_one_id = self.send_stream_message(user, - stream_name, "Message one") - later_subscribed_user = self.example_user("cordelia") - self.subscribe(later_subscribed_user, stream_name) - message_two_id = self.send_stream_message(user, - stream_name, "Message two") - message_ids = [message_one_id, message_two_id] + message_ids = [] + message_ids.append(self.send_stream_message(user, + stream_name, "Message one")) + user_2 = self.example_user("cordelia") + self.subscribe(user_2, stream_name) + message_ids.append(self.send_stream_message(user_2, + stream_name, "Message two")) + self.subscribe(self.notification_bot(), stream_name) + message_ids.append(self.send_stream_message(self.notification_bot(), + stream_name, "Message three")) messages = [Message.objects.select_related().get(id=message_id) for message_id in message_ids] # Check number of queries performed with queries_captured() as queries: MessageDict.to_dict_uncached(messages) - # 1 query for realm_id per message = 2 + # 1 query for realm_id per message = 3 # 1 query each for reactions & submessage for all messages = 2 - self.assertEqual(len(queries), 4) + self.assertEqual(len(queries), 5) + + realm_id = 2 # Fetched from stream object + # Check number of queries performed with realm_id + with queries_captured() as queries: + MessageDict.to_dict_uncached(messages, realm_id) + # 1 query each for reactions & submessage for all messages = 2 + self.assertEqual(len(queries), 2) def test_save_message(self) -> None: """This is also tested by a client test, but here we can verify @@ -3519,7 +3529,7 @@ class EditMessageTest(ZulipTestCase): 'propagate_mode': 'change_all', 'topic': 'new topic' }) - self.assertEqual(len(queries), 54) + self.assertEqual(len(queries), 49) messages = get_topic_messages(user_profile, old_stream, "test") self.assertEqual(len(messages), 1) diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 8f666f906a..a036df03e1 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -2509,7 +2509,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([user1.id, user2.id])), ) - self.assert_length(queries, 40) + self.assert_length(queries, 39) self.assert_length(events, 7) for ev in [x for x in events if x['event']['type'] not in ('message', 'stream')]: @@ -3284,7 +3284,7 @@ class SubscriptionAPITest(ZulipTestCase): [new_streams[0]], dict(principals=ujson.dumps([user1.id, user2.id])), ) - self.assert_length(queries, 40) + self.assert_length(queries, 39) # Test creating private stream. with queries_captured() as queries: @@ -3294,7 +3294,7 @@ class SubscriptionAPITest(ZulipTestCase): dict(principals=ujson.dumps([user1.id, user2.id])), invite_only=True, ) - self.assert_length(queries, 40) + self.assert_length(queries, 39) # Test creating a public stream with announce when realm has a notification stream. notifications_stream = get_stream(self.streams[0], self.test_realm) @@ -3309,7 +3309,7 @@ class SubscriptionAPITest(ZulipTestCase): principals=ujson.dumps([user1.id, user2.id]) ) ) - self.assert_length(queries, 52) + self.assert_length(queries, 50) class GetStreamsTest(ZulipTestCase): def test_streams_api_for_bot_owners(self) -> None: