diff --git a/zerver/lib/events.py b/zerver/lib/events.py index 5509860612..0f135de38b 100644 --- a/zerver/lib/events.py +++ b/zerver/lib/events.py @@ -19,8 +19,9 @@ from zerver.lib.attachments import user_attachments from zerver.lib.avatar import avatar_url, avatar_url_from_dict from zerver.lib.hotspots import get_next_hotspots from zerver.lib.message import ( + aggregate_unread_data, apply_unread_message_event, - get_unread_message_ids_per_recipient, + get_raw_unread_data, ) from zerver.lib.narrow import check_supported_events_narrow_filter from zerver.lib.soft_deactivation import maybe_catch_up_soft_deactivated_user @@ -168,7 +169,7 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, # message updates. This is due to the fact that new messages will not # generate a flag update so we need to use the flags field in the # message event. - state['unread_msgs'] = get_unread_message_ids_per_recipient(user_profile) + state['raw_unread_msgs'] = get_raw_unread_data(user_profile) if want('stream'): state['streams'] = do_get_streams(user_profile) @@ -192,22 +193,15 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, return state -def remove_message_id_from_unread_mgs(state, remove_id): +def remove_message_id_from_unread_mgs(state, message_id): # type: (Dict[str, Dict[str, Any]], int) -> None - for message_type in ['pms', 'streams', 'huddles']: - threads = state['unread_msgs'][message_type] - for obj in threads: - msg_ids = obj['unread_message_ids'] - if remove_id in msg_ids: - state['unread_msgs']['count'] -= 1 - msg_ids.remove(remove_id) - state['unread_msgs'][message_type] = [ - obj for obj in threads - if obj['unread_message_ids'] - ] + raw_unread = state['raw_unread_msgs'] - if remove_id in state['unread_msgs']['mentions']: - state['unread_msgs']['mentions'].remove(remove_id) + for key in ['pm_dict', 'stream_dict', 'huddle_dict']: + raw_unread[key].pop(message_id, None) + + raw_unread['unmuted_stream_msgs'].discard(message_id) + raw_unread['mentions'].discard(message_id) def apply_events(state, events, user_profile, include_subscribers=True, fetch_event_types=None): @@ -229,8 +223,10 @@ def apply_event(state, event, user_profile, include_subscribers): # type: (Dict[str, Any], Dict[str, Any], UserProfile, bool) -> None if event['type'] == "message": state['max_message_id'] = max(state['max_message_id'], event['message']['id']) - if 'unread_msgs' in state: - apply_unread_message_event(state['unread_msgs'], event['message']) + if 'raw_unread_msgs' in state: + apply_unread_message_event(user_profile, + state['raw_unread_msgs'], + event['message']) elif event['type'] == "hotspots": state['hotspots'] = event['hotspots'] @@ -434,13 +430,15 @@ def apply_event(state, event, user_profile, include_subscribers): presence_user_profile = get_user(event['email'], user_profile.realm) state['presences'][event['email']] = UserPresence.get_status_dict_by_user(presence_user_profile)[event['email']] elif event['type'] == "update_message": - # The client will get the updated message directly, but we need to - # update the subjects of our unread message ids - if 'subject' in event and 'unread_msgs' in state: - for obj in state['unread_msgs']['streams']: - if obj['stream_id'] == event['stream_id']: - if obj['topic'] == event['orig_subject']: - obj['topic'] = event['subject'] + # We don't return messages in /register, so we don't need to + # do anything for content updates, but we may need to update + # the unread_msgs data if the topic of an unread message changed. + if 'subject' in event: + stream_dict = state['raw_unread_msgs']['stream_dict'] + topic = event['subject'] + for message_id in event['message_ids']: + if message_id in stream_dict: + stream_dict[message_id]['topic'] = topic elif event['type'] == "delete_message": max_message = Message.objects.filter( usermessage__user_profile=user_profile).order_by('-id').first() @@ -458,8 +456,9 @@ def apply_event(state, event, user_profile, include_subscribers): # Typing notification events are transient and thus ignored pass elif event['type'] == "update_message_flags": - # The client will get the message with the updated flags directly but - # we need to keep the unread_msgs updated. + # We don't return messages in `/register`, so most flags we + # can ignore, but we do need to update the unread_msgs data if + # unread state is changed. if event['flag'] == 'read' and event['operation'] == 'add': for remove_id in event['messages']: remove_message_id_from_unread_mgs(state, remove_id) @@ -527,6 +526,21 @@ def do_events_register(user_profile, user_client, apply_markdown=True, apply_events(ret, events, user_profile, include_subscribers=include_subscribers, fetch_event_types=fetch_event_types) + ''' + NOTE: + + Below is an example of post-processing initial state data AFTER we + apply events. For large payloads like `unread_msgs`, it's helpful + to have an intermediate data structure that is easy to manipulate + with O(1)-type operations as we apply events. + + Then, only at the end, we put it in the form that's more appropriate + for client. + ''' + if 'raw_unread_msgs' in ret: + ret['unread_msgs'] = aggregate_unread_data(ret['raw_unread_msgs']) + del ret['raw_unread_msgs'] + if len(events) > 0: ret['last_event_id'] = events[-1]['id'] else: diff --git a/zerver/lib/message.py b/zerver/lib/message.py index a12d94da7d..3d9b42a389 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -13,7 +13,10 @@ from zerver.lib.cache import cache_with_key, to_dict_cache_key from zerver.lib.request import JsonableError from zerver.lib.str_utils import force_bytes, dict_with_str_keys from zerver.lib.timestamp import datetime_to_timestamp -from zerver.lib.topic_mutes import build_topic_mute_checker +from zerver.lib.topic_mutes import ( + build_topic_mute_checker, + topic_is_muted, +) from zerver.models import ( get_display_recipient_by_id, @@ -552,6 +555,7 @@ def get_raw_unread_data(user_profile): return dict( pm_dict=pm_dict, stream_dict=stream_dict, + muted_stream_ids=muted_stream_ids, unmuted_stream_msgs=unmuted_stream_msgs, huddle_dict=huddle_dict, mentions=mentions, @@ -602,10 +606,8 @@ def aggregate_unread_data(raw_data): return result -def apply_unread_message_event(state, message): - # type: (Dict[str, Any], Dict[str, Any]) -> None - state['count'] += 1 - +def apply_unread_message_event(user_profile, state, message): + # type: (UserProfile, Dict[str, Any], Dict[str, Any]) -> None message_id = message['id'] if message['type'] == 'stream': message_type = 'stream' @@ -622,49 +624,36 @@ def apply_unread_message_event(state, message): raise AssertionError("Invalid message type %s" % (message['type'],)) if message_type == 'stream': - unread_key = 'streams' stream_id = message['stream_id'] topic = message['subject'] - - my_key = (stream_id, topic) # type: Any - - key_func = lambda obj: (obj['stream_id'], obj['topic']) - new_obj = dict( + new_row = dict( stream_id=stream_id, topic=topic, - unread_message_ids=[message_id], ) - elif message_type == 'private': - unread_key = 'pms' - sender_id = message['sender_id'] + state['stream_dict'][message_id] = new_row - my_key = sender_id - key_func = lambda obj: obj['sender_id'] - new_obj = dict( + if stream_id not in state['muted_stream_ids']: + # This next check hits the database. + if not topic_is_muted(user_profile, stream_id, topic): + state['unmuted_stream_msgs'].add(message_id) + + elif message_type == 'private': + sender_id = message['sender_id'] + new_row = dict( sender_id=sender_id, - unread_message_ids=[message_id], ) + state['pm_dict'][message_id] = new_row + else: - unread_key = 'huddles' display_recipient = message['display_recipient'] user_ids = [obj['id'] for obj in display_recipient] user_ids = sorted(user_ids) - my_key = ','.join(str(uid) for uid in user_ids) - key_func = lambda obj: obj['user_ids_string'] - new_obj = dict( - user_ids_string=my_key, - unread_message_ids=[message_id], + user_ids_string = ','.join(str(uid) for uid in user_ids) + new_row = dict( + user_ids_string=user_ids_string, ) + state['huddle_dict'][message_id] = new_row - if message.get('is_mentioned'): - if message_id not in state['mentions']: - state['mentions'].append(message_id) - - for obj in state[unread_key]: - if key_func(obj) == my_key: - obj['unread_message_ids'].append(message_id) - obj['unread_message_ids'].sort() - return - - state[unread_key].append(new_obj) - state[unread_key].sort(key=key_func) + mentioned = message.get('is_mentioned', False) + if mentioned: + state['mentions'].add(message_id) diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index b4bfcfc524..65260bab67 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -71,7 +71,11 @@ from zerver.lib.events import ( apply_events, fetch_initial_state_data, ) -from zerver.lib.message import render_markdown +from zerver.lib.message import ( + get_unread_message_ids_per_recipient, + render_markdown, + UnreadMessagesResult, +) from zerver.lib.test_helpers import POSTRequestMock, get_subscription, \ stub_event_queue_user_events from zerver.lib.test_classes import ( @@ -1715,9 +1719,8 @@ class FetchInitialStateDataTest(ZulipTestCase): 'hello3') def get_unread_data(): - # type: () -> Dict[str, Any] - result = fetch_initial_state_data(user_profile, None, "")['unread_msgs'] - return result + # type: () -> UnreadMessagesResult + return get_unread_message_ids_per_recipient(user_profile) result = get_unread_data()