diff --git a/docs/unread_messages.md b/docs/unread_messages.md new file mode 100644 index 0000000000..a79fda8515 --- /dev/null +++ b/docs/unread_messages.md @@ -0,0 +1,50 @@ +# Unread message synchronization + +In general displaying unread counts for all streams and topics may require +downloading an unbounded number of messages. Consider a user who has a muted +stream or topic and has not read the backlog in a month; to have an accurate +unread count we would need to load all messages this user has received in the +past month. This is inefficient for web clients and even more for mobile +devices. + +We work around this by including a list of unread message ids in the initial +state grouped by relevant conversation keys. This data is included in the +`unread_msgs` key if both `update_message_flags` and `message` are required +in the register call. + +``` +{ + "huddles": [ + { + "user_ids_string": "3,4,6", + "unread_message_ids": [ + 34 + ] + } + ], + "streams": [ + { + "stream_id": 1, + "topic": "test", + "unread_message_ids": [ + 33 + ] + } + ], + "pms": [ + { + "sender_id": 3, + "unread_message_ids": [ + 31, + 32 + ] + } + ] +} +``` + +Three event types are required to correctly maintain the `unread_msgs`. New +messages can be created without the unread flag by the `message` event type. +The unread flag can be added and removed by the `update_message_flags` event, +and the subject of unread messages can be updated by the `update_message` event +type. diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 369ff28986..e50c3a2f24 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -47,7 +47,8 @@ from zerver.models import Realm, RealmEmoji, Stream, UserProfile, UserActivity, get_old_unclaimed_attachments, get_cross_realm_emails, \ Reaction, EmailChangeStatus, CustomProfileField, \ custom_profile_fields_for_realm, \ - CustomProfileFieldValue, validate_attachment_request, get_system_bot + CustomProfileFieldValue, validate_attachment_request, get_system_bot, \ + get_display_recipient_by_id from zerver.lib.alert_words import alert_words_in_realm from zerver.lib.avatar import avatar_url diff --git a/zerver/lib/events.py b/zerver/lib/events.py index ac4c1b7203..288b0fa387 100644 --- a/zerver/lib/events.py +++ b/zerver/lib/events.py @@ -12,7 +12,7 @@ from django.conf import settings from importlib import import_module from six.moves import filter, map from typing import ( - Any, Dict, Iterable, List, Optional, Sequence, Set, Text, Tuple + cast, Any, Dict, Iterable, List, Optional, Sequence, Set, Text, Tuple, Union ) session_engine = import_module(settings.SESSION_ENGINE) @@ -21,13 +21,19 @@ from zerver.lib.alert_words import user_alert_words from zerver.lib.attachments import user_attachments from zerver.lib.avatar import avatar_url, avatar_url_from_dict from zerver.lib.hotspots import get_next_hotspots +from zerver.lib.message import ( + apply_unread_message_event, + get_unread_message_ids_per_recipient, +) from zerver.lib.narrow import check_supported_events_narrow_filter from zerver.lib.realm_icon import realm_icon_url from zerver.lib.request import JsonableError -from zerver.lib.actions import validate_user_access_to_subscribers_helper, \ - do_get_streams, get_default_streams_for_realm, \ - gather_subscriptions_helper, get_cross_realm_dicts, \ +from zerver.lib.actions import ( + validate_user_access_to_subscribers_helper, + do_get_streams, get_default_streams_for_realm, + gather_subscriptions_helper, get_cross_realm_dicts, get_status_dict, streams_to_dicts_sorted +) from zerver.tornado.event_queue import request_event_queue, get_user_events from zerver.models import Client, Message, Realm, UserPresence, UserProfile, \ get_user_profile_by_id, \ @@ -151,10 +157,12 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, state['unsubscribed'] = unsubscribed state['never_subscribed'] = never_subscribed - if want('update_message_flags'): - # There's no initial data for message flag updates, client will - # get any updates during a session from get_events() - pass + if want('update_message_flags') and want('message'): + # Keeping unread_msgs updated requires both message flag updates and + # message updates. This is due to the fact that new messages will not + # generate a flag update so we need to use the flags field in the + # message event. + state['unread_msgs'] = get_unread_message_ids_per_recipient(user_profile) if want('stream'): state['streams'] = do_get_streams(user_profile) @@ -177,6 +185,19 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, return state + +def remove_message_id_from_unread_mgs(state, remove_id): + # type: (Dict[str, Dict[str, List[Dict[str, Any]]]], int) -> None + for message_type, threads in state['unread_msgs'].items(): + for obj in threads: + msg_ids = obj['unread_message_ids'] + if remove_id in msg_ids: + msg_ids.remove(remove_id) + state['unread_msgs'][message_type] = [ + obj for obj in threads + if obj['unread_message_ids'] + ] + def apply_events(state, events, user_profile, include_subscribers=True, fetch_event_types=None): # type: (Dict[str, Any], Iterable[Dict[str, Any]], UserProfile, bool, Optional[Iterable[str]]) -> None @@ -197,6 +218,8 @@ def apply_event(state, event, user_profile, include_subscribers): # type: (Dict[str, Any], Dict[str, Any], UserProfile, bool) -> None if event['type'] == "message": state['max_message_id'] = max(state['max_message_id'], event['message']['id']) + apply_unread_message_event(state['unread_msgs'], event['message']) + elif event['type'] == "hotspots": state['hotspots'] = event['hotspots'] elif event['type'] == "custom_profile_fields": @@ -399,8 +422,13 @@ def apply_event(state, event, user_profile, include_subscribers): presence_user_profile = get_user(event['email'], user_profile.realm) state['presences'][event['email']] = UserPresence.get_status_dict_by_user(presence_user_profile)[event['email']] elif event['type'] == "update_message": - # The client will get the updated message directly - pass + # The client will get the updated message directly, but we need to + # update the subjects of our unread message ids + if 'subject' in event and 'unread_msgs' in state: + for obj in state['unread_msgs']['streams']: + if obj['stream_id'] == event['stream_id']: + if obj['topic'] == event['orig_subject']: + obj['topic'] = event['subject'] elif event['type'] == "delete_message": max_message = Message.objects.filter( usermessage__user_profile=user_profile).order_by('-id').first() @@ -408,6 +436,9 @@ def apply_event(state, event, user_profile, include_subscribers): state['max_message_id'] = max_message.id else: state['max_message_id'] = -1 + + remove_id = event['message_id'] + remove_message_id_from_unread_mgs(state, remove_id) elif event['type'] == "reaction": # The client will get the message with the reactions directly pass @@ -415,8 +446,11 @@ def apply_event(state, event, user_profile, include_subscribers): # Typing notification events are transient and thus ignored pass elif event['type'] == "update_message_flags": - # The client will get the message with the updated flags directly - pass + # The client will get the message with the updated flags directly but + # we need to keep the unread_msgs updated. + if event['flag'] == 'read' and event['operation'] == 'add': + for remove_id in event['messages']: + remove_message_id_from_unread_mgs(state, remove_id) elif event['type'] == "realm_domains": if event['op'] == 'add': state['realm_domains'].append(event['realm_domain']) @@ -477,6 +511,7 @@ def do_events_register(user_profile, user_client, apply_markdown=True, events = get_user_events(user_profile, queue_id, -1) apply_events(ret, events, user_profile, include_subscribers=include_subscribers, fetch_event_types=fetch_event_types) + if len(events) > 0: ret['last_event_id'] = events[-1]['id'] else: diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 639dcbb511..74646434c3 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -26,7 +26,7 @@ from zerver.models import ( Reaction ) -from typing import Any, Dict, List, Optional, Set, Tuple, Text +from typing import Any, Dict, List, Optional, Set, Tuple, Text, Union RealmAlertWords = Dict[int, List[Text]] @@ -344,3 +344,173 @@ def render_markdown(message, content, realm=None, realm_alert_words=None, messag message.is_me_message = Message.is_status_message(content, rendered_content) return rendered_content + +def huddle_users(recipient_id): + # type: (int) -> str + display_recipient = get_display_recipient_by_id(recipient_id, + Recipient.HUDDLE, + None) # type: Union[Text, List[Dict[str, Any]]] + + # Text is for streams. + assert not isinstance(display_recipient, Text) + + user_ids = [obj['id'] for obj in display_recipient] # type: List[int] + user_ids = sorted(user_ids) + return ','.join(str(uid) for uid in user_ids) + +def aggregate_dict(input_rows, lookup_fields, input_field, output_field): + # type: (List[Dict[str, Any]], List[str], str, str) -> List[Dict[str, Any]] + lookup_dict = dict() # type: Dict[Any, Dict] + + for input_row in input_rows: + lookup_key = tuple([input_row[f] for f in lookup_fields]) + if lookup_key not in lookup_dict: + obj = {} + for f in lookup_fields: + obj[f] = input_row[f] + obj[output_field] = [] + lookup_dict[lookup_key] = obj + + lookup_dict[lookup_key][output_field].append(input_row[input_field]) + + sorted_keys = sorted(lookup_dict.keys()) + + return [lookup_dict[k] for k in sorted_keys] + +def get_unread_message_ids_per_recipient(user_profile): + # type: (UserProfile) -> Dict[str, List[Dict[str, Any]]] + user_msgs = UserMessage.objects.filter( + user_profile=user_profile + ).extra( + where=[UserMessage.where_unread()] + ).values( + 'message_id', + 'message__sender_id', + 'message__subject', + 'message__recipient_id', + 'message__recipient__type', + 'message__recipient__type_id', + ) + + rows = list(user_msgs) + + pm_msgs = [ + dict( + sender_id=row['message__sender_id'], + message_id=row['message_id'], + ) for row in rows + if row['message__recipient__type'] == Recipient.PERSONAL] + + pm_objects = aggregate_dict( + input_rows=pm_msgs, + lookup_fields=[ + 'sender_id', + ], + input_field='message_id', + output_field='unread_message_ids', + ) + + stream_msgs = [ + dict( + stream_id=row['message__recipient__type_id'], + topic=row['message__subject'], + message_id=row['message_id'], + ) for row in rows + if row['message__recipient__type'] == Recipient.STREAM] + + stream_objects = aggregate_dict( + input_rows=stream_msgs, + lookup_fields=[ + 'stream_id', + 'topic', + ], + input_field='message_id', + output_field='unread_message_ids', + ) + + huddle_msgs = [ + dict( + recipient_id=row['message__recipient_id'], + message_id=row['message_id'], + ) for row in rows + if row['message__recipient__type'] == Recipient.HUDDLE] + + huddle_objects = aggregate_dict( + input_rows=huddle_msgs, + lookup_fields=[ + 'recipient_id', + ], + input_field='message_id', + output_field='unread_message_ids', + ) + + for huddle in huddle_objects: + huddle['user_ids_string'] = huddle_users(huddle['recipient_id']) + del huddle['recipient_id'] + + result = dict( + pms=pm_objects, + streams=stream_objects, + huddles=huddle_objects, + ) + + return result + +def apply_unread_message_event(state, message): + # type: (Dict[str, List[Dict[str, Any]]], Dict[str, Any]) -> None + message_id = message['id'] + if message['type'] == 'stream': + message_type = 'stream' + elif message['type'] == 'private': + others = [ + recip for recip in message['display_recipient'] + if recip['id'] != message['sender_id'] + ] + if len(others) <= 1: + message_type = 'private' + else: + message_type = 'huddle' + + if message_type == 'stream': + unread_key = 'streams' + stream_id = message['stream_id'] + topic = message['subject'] + + my_key = (stream_id, topic) # type: Any + + key_func = lambda obj: (obj['stream_id'], obj['topic']) + new_obj = dict( + stream_id=stream_id, + topic=topic, + unread_message_ids=[message_id], + ) + elif message_type == 'private': + unread_key = 'pms' + sender_id = message['sender_id'] + + my_key = sender_id + key_func = lambda obj: obj['sender_id'] + new_obj = dict( + sender_id=sender_id, + unread_message_ids=[message_id], + ) + else: + unread_key = 'huddles' + display_recipient = message['display_recipient'] + user_ids = [obj['id'] for obj in display_recipient] + user_ids = sorted(user_ids) + my_key = ','.join(str(uid) for uid in user_ids) + key_func = lambda obj: obj['user_ids_string'] + new_obj = dict( + user_ids_string=my_key, + unread_message_ids=[message_id], + ) + + for obj in state[unread_key]: + if key_func(obj) == my_key: + obj['unread_message_ids'].append(message_id) + obj['unread_message_ids'].sort() + return + + state[unread_key].append(new_obj) + state[unread_key].sort(key=key_func) diff --git a/zerver/models.py b/zerver/models.py index 20727fb84b..e50dcba9d6 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -1272,6 +1272,14 @@ class AbstractUserMessage(ModelReprMixin, models.Model): abstract = True unique_together = ("user_profile", "message") + @staticmethod + def where_unread(): + # type: () -> str + # Use this for Django ORM queries where we are getting lots + # of rows. This customer SQL plays nice with our partial indexes. + # Grep the code for example usage. + return 'flags & 1 = 0' + def flags_list(self): # type: () -> List[str] return [flag for flag in self.flags.keys() if getattr(self.flags, flag).is_set] diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index f5bedcfc1f..28b02338bb 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -456,7 +456,31 @@ class EventsRegisterTest(ZulipTestCase): required_keys.append(('id', check_int)) return check_dict_only(required_keys) - def test_send_message_events(self): + def test_pm_send_message_events(self): + # type: () -> None + self.do_test( + lambda: self.send_message(self.example_email('cordelia'), + self.example_email('hamlet'), + Recipient.PERSONAL, + 'hola') + + ) + + def test_huddle_send_message_events(self): + # type: () -> None + huddle = [ + self.example_email('hamlet'), + self.example_email('othello'), + ] + self.do_test( + lambda: self.send_message(self.example_email('cordelia'), + huddle, + Recipient.HUDDLE, + 'hola') + + ) + + def test_stream_send_message_events(self): # type: () -> None schema_checker = self.check_events_dict([ ('type', equals('message')), @@ -519,7 +543,7 @@ class EventsRegisterTest(ZulipTestCase): events = self.do_test( lambda: do_update_message(self.user_profile, message, topic, propagate_mode, content, rendered_content), - state_change_expected=False, + state_change_expected=True, ) error = schema_checker('events[0]', events[0]) self.assert_on_error(error) @@ -579,6 +603,35 @@ class EventsRegisterTest(ZulipTestCase): error = schema_checker('events[0]', events[0]) self.assert_on_error(error) + def test_update_read_flag_removes_unread_msg_ids(self): + # type: () -> None + message = self.send_message( + self.example_email('cordelia'), + "Verona", + Recipient.STREAM, + "hello" + ) + + user_profile = self.example_user('hamlet') + self.do_test( + lambda: do_update_message_flags(user_profile, 'add', 'read', + [message], False, None, None), + state_change_expected=True, + ) + + def test_send_message_to_existing_recipient(self): + # type: () -> None + self.send_message( + self.example_email('cordelia'), + "Verona", + Recipient.STREAM, + "hello 1" + ) + self.do_test( + lambda: self.send_message("cordelia@zulip.com", "Verona", Recipient.STREAM, "hello 2"), + state_change_expected=True, + ) + def test_send_reaction(self): # type: () -> None schema_checker = self.check_events_dict([ @@ -1568,6 +1621,46 @@ class FetchInitialStateDataTest(ZulipTestCase): result = fetch_initial_state_data(user_profile, None, "") self.assertEqual(result['max_message_id'], -1) + def test_unread_msgs(self): + # type: () -> None + cordelia = self.example_user('cordelia') + sender_id = cordelia.id + sender_email = cordelia.email + user_profile = self.example_user('hamlet') + othello = self.example_user('othello') + + # our tests rely on order + assert(sender_email < user_profile.email) + assert(user_profile.email < othello.email) + + pm1_message_id = self.send_message(sender_email, user_profile.email, Recipient.PERSONAL, "hello1") + pm2_message_id = self.send_message(sender_email, user_profile.email, Recipient.PERSONAL, "hello2") + + stream_message_id = self.send_message(sender_email, "Denmark", Recipient.STREAM, "hello") + + huddle_message_id = self.send_message(sender_email, + [user_profile.email, othello.email], + Recipient.HUDDLE, + 'hello3') + + result = fetch_initial_state_data(user_profile, None, "")['unread_msgs'] + + unread_pm = result['pms'][0] + self.assertEqual(unread_pm['sender_id'], sender_id) + self.assertEqual(unread_pm['unread_message_ids'], [pm1_message_id, pm2_message_id]) + + unread_stream = result['streams'][0] + self.assertEqual(unread_stream['stream_id'], get_stream('Denmark', user_profile.realm).id) + self.assertEqual(unread_stream['topic'], 'test') + self.assertEqual(unread_stream['unread_message_ids'], [stream_message_id]) + + huddle_string = ','.join(str(uid) for uid in sorted([sender_id, user_profile.id, othello.id])) + + unread_huddle = result['huddles'][0] + self.assertEqual(unread_huddle['user_ids_string'], huddle_string) + self.assertEqual(unread_huddle['unread_message_ids'], [huddle_message_id]) + + class EventQueueTest(TestCase): def test_one_event(self): # type: () -> None diff --git a/zerver/tests/test_home.py b/zerver/tests/test_home.py index 64ea0d2c3c..12e99e3789 100644 --- a/zerver/tests/test_home.py +++ b/zerver/tests/test_home.py @@ -141,6 +141,7 @@ class HomeTest(ZulipTestCase): "timezone", "twenty_four_hour_time", "unread_count", + "unread_msgs", "unsubscribed", "use_websockets", "user_id",