From d6e21b5ca9d6eb7037861092f47b2950bcfc6b4c Mon Sep 17 00:00:00 2001 From: Steve Howell Date: Thu, 5 Oct 2017 09:35:34 -0700 Subject: [PATCH] Collect sender_ids (by topic) in `unread_msgs`. This will allow the mobile app to say "A, B, and C are talking" in the topic views. --- zerver/lib/message.py | 58 +++++++++++++++++++++++-------------- zerver/tests/test_events.py | 5 ++++ 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 3d9b42a389..28b385c4b5 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -380,50 +380,62 @@ def huddle_users(recipient_id): user_ids = sorted(user_ids) return ','.join(str(uid) for uid in user_ids) -def aggregate_dict(input_dict, lookup_fields, output_field): - # type: (Dict[int, Dict[str, Any]], List[str], str) -> List[Dict[str, Any]] +def aggregate_message_dict(input_dict, lookup_fields, collect_senders): + # type: (Dict[int, Dict[str, Any]], List[str], bool) -> List[Dict[str, Any]] lookup_dict = dict() # type: Dict[Any, Dict] ''' A concrete example might help explain the inputs here: input_dict = { - 1002: dict(stream_id=5, topic='foo'), - 1003: dict(stream_id=5, topic='foo'), - 1004: dict(stream_id=6, topic='baz'), + 1002: dict(stream_id=5, topic='foo', sender_id=40), + 1003: dict(stream_id=5, topic='foo', sender_id=41), + 1004: dict(stream_id=6, topic='baz', sender_id=99), } lookup_fields = ['stream_id', 'topic'] - output_field = 'unread_message_ids' The first time through the loop: - key_to_aggregate = 1002 - attribute_dict = dict(stream_id=5, topic='foo') + attribute_dict = dict(stream_id=5, topic='foo', sender_id=40) + lookup_dict = (5, 'foo') lookup_dict = { - (5, foo): dict(stream_id=5, topic='foo', unread_message_ids=[1002, 1003]), + (5, 'foo'): dict(stream_id=5, topic='foo', + unread_message_ids=[1002, 1003], + sender_ids=[40, 41], + ), ... } result = [ - dict(stream_id=5, topic='foo', unread_message_ids=[1002, 1003]), + dict(stream_id=5, topic='foo', + unread_message_ids=[1002, 1003], + sender_ids=[40, 41], + ), ... ] ''' - for key_to_aggregate, attribute_dict in input_dict.items(): + for message_id, attribute_dict in input_dict.items(): lookup_key = tuple([attribute_dict[f] for f in lookup_fields]) if lookup_key not in lookup_dict: obj = {} for f in lookup_fields: obj[f] = attribute_dict[f] - obj[output_field] = [] + obj['unread_message_ids'] = [] + if collect_senders: + obj['sender_ids'] = set() lookup_dict[lookup_key] = obj - lookup_dict[lookup_key][output_field].append(key_to_aggregate) + bucket = lookup_dict[lookup_key] + bucket['unread_message_ids'].append(message_id) + if collect_senders: + bucket['sender_ids'].add(attribute_dict['sender_id']) for dct in lookup_dict.values(): - dct[output_field].sort() + dct['unread_message_ids'].sort() + if collect_senders: + dct['sender_ids'] = sorted(list(dct['sender_ids'])) sorted_keys = sorted(lookup_dict.keys()) @@ -525,6 +537,7 @@ def get_raw_unread_data(user_profile): message_id = row['message_id'] msg_type = row['message__recipient__type'] recipient_id = row['message__recipient_id'] + sender_id = row['message__sender_id'] if msg_type == Recipient.STREAM: stream_id = row['message__recipient__type_id'] @@ -532,12 +545,12 @@ def get_raw_unread_data(user_profile): stream_dict[message_id] = dict( stream_id=stream_id, topic=topic, + sender_id=sender_id, ) if not is_row_muted(stream_id, recipient_id, topic): unmuted_stream_msgs.add(message_id) elif msg_type == Recipient.PERSONAL: - sender_id = row['message__sender_id'] pm_dict[message_id] = dict( sender_id=sender_id, ) @@ -572,29 +585,29 @@ def aggregate_unread_data(raw_data): count = len(pm_dict) + len(unmuted_stream_msgs) + len(huddle_dict) - pm_objects = aggregate_dict( + pm_objects = aggregate_message_dict( input_dict=pm_dict, lookup_fields=[ 'sender_id', ], - output_field='unread_message_ids', + collect_senders=False, ) - stream_objects = aggregate_dict( + stream_objects = aggregate_message_dict( input_dict=stream_dict, lookup_fields=[ 'stream_id', 'topic', ], - output_field='unread_message_ids', + collect_senders=True, ) - huddle_objects = aggregate_dict( + huddle_objects = aggregate_message_dict( input_dict=huddle_dict, lookup_fields=[ 'user_ids_string', ], - output_field='unread_message_ids', + collect_senders=False, ) result = dict( @@ -623,12 +636,15 @@ def apply_unread_message_event(user_profile, state, message): else: raise AssertionError("Invalid message type %s" % (message['type'],)) + sender_id = message['sender_id'] + if message_type == 'stream': stream_id = message['stream_id'] topic = message['subject'] new_row = dict( stream_id=stream_id, topic=topic, + sender_id=sender_id, ) state['stream_dict'][message_id] = new_row diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 65260bab67..9623b6b51d 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -1732,27 +1732,32 @@ class FetchInitialStateDataTest(ZulipTestCase): unread_pm = result['pms'][0] self.assertEqual(unread_pm['sender_id'], sender_id) self.assertEqual(unread_pm['unread_message_ids'], [pm1_message_id, pm2_message_id]) + self.assertTrue('sender_ids' not in unread_pm) unread_stream = result['streams'][0] self.assertEqual(unread_stream['stream_id'], get_stream('Denmark', user_profile.realm).id) self.assertEqual(unread_stream['topic'], 'muted-topic') self.assertEqual(unread_stream['unread_message_ids'], [muted_topic_message_id]) + self.assertEqual(unread_stream['sender_ids'], [sender_id]) unread_stream = result['streams'][1] self.assertEqual(unread_stream['stream_id'], get_stream('Denmark', user_profile.realm).id) self.assertEqual(unread_stream['topic'], 'test') self.assertEqual(unread_stream['unread_message_ids'], [stream_message_id]) + self.assertEqual(unread_stream['sender_ids'], [sender_id]) unread_stream = result['streams'][2] self.assertEqual(unread_stream['stream_id'], get_stream('Muted Stream', user_profile.realm).id) self.assertEqual(unread_stream['topic'], 'test') self.assertEqual(unread_stream['unread_message_ids'], [muted_stream_message_id]) + self.assertEqual(unread_stream['sender_ids'], [sender_id]) huddle_string = ','.join(str(uid) for uid in sorted([sender_id, user_profile.id, othello.id])) unread_huddle = result['huddles'][0] self.assertEqual(unread_huddle['user_ids_string'], huddle_string) self.assertEqual(unread_huddle['unread_message_ids'], [huddle_message_id]) + self.assertTrue('sender_ids' not in unread_huddle) self.assertEqual(result['mentions'], [])