diff --git a/zerver/actions/message_flags.py b/zerver/actions/message_flags.py index 90e94bdb20..b18646774f 100644 --- a/zerver/actions/message_flags.py +++ b/zerver/actions/message_flags.py @@ -288,16 +288,31 @@ def do_update_message_flags( query = UserMessage.select_for_update_query().filter( user_profile=user_profile, message_id__in=messages ) + um_message_ids = {um.message_id for um in query} - historical_message_ids = list(set(messages) - um_message_ids) + if flag == "read" and operation == "add": + # When marking messages as read, creating "historical" + # UserMessage rows would be a waste of storage, because + # `flags.read | flags.historical` is exactly the flags we + # simulate when processing a message for which a user has + # access but no UserMessage row. + messages = [message_id for message_id in messages if message_id in um_message_ids] + else: + # Users can mutate flags for messages that don't have a + # UserMessage yet. Validate that the user is even allowed + # to access these message_ids; if so, we will create + # "historical" UserMessage rows for the messages in question. + # + # See create_historical_user_messages for a more detailed + # explanation. + historical_message_ids = list(set(messages) - um_message_ids) - # Users can mutate flags for messages that don't have a UserMessage yet. - # First, validate that the user is even allowed to access these message_ids. - for message_id in historical_message_ids: - access_message(user_profile, message_id) + for message_id in historical_message_ids: + access_message(user_profile, message_id) - # And then create historical UserMessage records. See the called function for more context. - create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids) + create_historical_user_messages( + user_id=user_profile.id, message_ids=historical_message_ids + ) if operation == "add": count = query.update(flags=F("flags").bitor(flagattr)) diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index 726cf7e3c5..f7a6b5129f 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -1791,10 +1791,19 @@ class MarkUnreadTest(ZulipTestCase): """ sender = self.example_user("cordelia") receiver = self.example_user("hamlet") - stream_name = "Denmark" - self.subscribe(receiver, stream_name) - self.subscribe(sender, stream_name) + stream_name = "Test stream" topic_name = "test" + self.subscribe(sender, stream_name) + before_subscribe_stream_message_ids = [ + self.send_stream_message( + sender=sender, + stream_name=stream_name, + topic_name=topic_name, + ) + for i in range(2) + ] + + self.subscribe(receiver, stream_name) subscribed_stream_message_ids = [ self.send_stream_message( sender=sender, @@ -1851,7 +1860,9 @@ class MarkUnreadTest(ZulipTestCase): ) self.assertEqual(um.flags.read, message_id in unsubscribed_stream_message_ids) for message_id in ( - never_subscribed_stream_message_ids + after_unsubscribe_stream_message_ids + before_subscribe_stream_message_ids + + never_subscribed_stream_message_ids + + after_unsubscribe_stream_message_ids ): self.assertFalse( UserMessage.objects.filter( @@ -1886,63 +1897,85 @@ class MarkUnreadTest(ZulipTestCase): ).exists() ) - # Now, explicitly mark them all as read. This will create new - # 'historical' UserMessage rows created for the ones that - # didn't have them previously. - # - # It's not clear that creating these `historical` UserMessage - # rows is useful, like it is for starring messages. But it - # should also be harmless. + # Now, explicitly mark them all as read. The messages which don't + # have UserMessage rows will be ignored. + message_ids = before_subscribe_stream_message_ids + message_ids self.login("hamlet") - result = self.client_post( - "/json/messages/flags", - {"messages": orjson.dumps(message_ids).decode(), "op": "add", "flag": "read"}, - ) + with self.tornado_redirected_to_list(events, expected_num_events=1): + result = self.client_post( + "/json/messages/flags", + {"messages": orjson.dumps(message_ids).decode(), "op": "add", "flag": "read"}, + ) self.assert_json_success(result) + event = events[0]["event"] + self.assertEqual( + event["messages"], subscribed_stream_message_ids + unsubscribed_stream_message_ids + ) - for message_id in message_ids: + for message_id in subscribed_stream_message_ids + unsubscribed_stream_message_ids: um = UserMessage.objects.get( user_profile_id=receiver.id, message_id=message_id, ) self.assertTrue(um.flags.read) - self.assertEqual( - message_id - in never_subscribed_stream_message_ids + after_unsubscribe_stream_message_ids, - bool(um.flags.historical), + for message_id in ( + before_subscribe_stream_message_ids + + never_subscribed_stream_message_ids + + after_unsubscribe_stream_message_ids + ): + self.assertFalse( + UserMessage.objects.filter( + user_profile_id=receiver.id, + message_id=message_id, + ).exists() ) # Now, request marking them all as unread. Since we haven't # resubscribed to any of the streams, we expect this to not # modify the messages in streams we're not subscribed to. + # + # This also create new 'historical' UserMessage rows for the + # messages in subscribed streams that didn't have them + # previously. with self.tornado_redirected_to_list(events, expected_num_events=1): result = self.client_post( "/json/messages/flags", {"messages": orjson.dumps(message_ids).decode(), "op": "remove", "flag": "read"}, ) event = events[0]["event"] - self.assertEqual(event["messages"], subscribed_stream_message_ids) - unread_message_ids = {str(message_id) for message_id in subscribed_stream_message_ids} + self.assertEqual( + event["messages"], before_subscribe_stream_message_ids + subscribed_stream_message_ids + ) + unread_message_ids = { + str(message_id) + for message_id in before_subscribe_stream_message_ids + subscribed_stream_message_ids + } self.assertSetEqual(set(event["message_details"].keys()), unread_message_ids) - for message_id in subscribed_stream_message_ids: + for message_id in before_subscribe_stream_message_ids + subscribed_stream_message_ids: um = UserMessage.objects.get( user_profile_id=receiver.id, message_id=message_id, ) self.assertFalse(um.flags.read) - for message_id in ( - unsubscribed_stream_message_ids - + after_unsubscribe_stream_message_ids - + never_subscribed_stream_message_ids - ): + for message_id in unsubscribed_stream_message_ids: um = UserMessage.objects.get( user_profile_id=receiver.id, message_id=message_id, ) self.assertTrue(um.flags.read) + for message_id in ( + after_unsubscribe_stream_message_ids + never_subscribed_stream_message_ids + ): + self.assertFalse( + UserMessage.objects.filter( + user_profile_id=receiver.id, + message_id=message_id, + ).exists() + ) + def test_pm_messages_unread(self) -> None: sender = self.example_user("cordelia") receiver = self.example_user("hamlet") diff --git a/zerver/tests/test_read_receipts.py b/zerver/tests/test_read_receipts.py index 94e6a68e96..073e1827ca 100644 --- a/zerver/tests/test_read_receipts.py +++ b/zerver/tests/test_read_receipts.py @@ -157,6 +157,7 @@ class TestReadReceipts(ZulipTestCase): hamlet = self.example_user("hamlet") sender = self.example_user("othello") bot = self.example_user("default_bot") + self.subscribe(bot, "Verona") message_id = self.send_stream_message(sender, "Verona", "read receipts")