Introduce StreamRecipient class.

This class encapsulates the mapping of stream ids to
recipient ids, and it is optimized for bulk use and
repeated use (i.e. it remembers values it already fetched).

This particular commit barely improves the performance
of gather_subscriptions_helper, but it sets us up for
further optimizations.

Long term, we may try to denormalize stream_id on to the
Subscriber table or otherwise modify the database so we
don't have to jump through hoops to do this kind of mapping.
This commit will help enable those changes, because we
isolate the mapping to this one new class.
This commit is contained in:
Steve Howell
2017-09-13 11:00:36 -07:00
committed by Tim Abbott
parent fc2e485ca7
commit 1553dc00e0
4 changed files with 161 additions and 21 deletions

View File

@@ -64,6 +64,7 @@ from zerver.models import Realm, RealmEmoji, Stream, UserProfile, UserActivity,
from zerver.lib.alert_words import alert_words_in_realm from zerver.lib.alert_words import alert_words_in_realm
from zerver.lib.avatar import avatar_url from zerver.lib.avatar import avatar_url
from zerver.lib.stream_recipient import StreamRecipientMap
from django.db import transaction, IntegrityError, connection from django.db import transaction, IntegrityError, connection
from django.db.models import F, Q, Max from django.db.models import F, Q, Max
@@ -1736,8 +1737,8 @@ def validate_user_access_to_subscribers_helper(user_profile, stream_dict, check_
raise JsonableError(_("Unable to retrieve subscribers for invite-only stream")) raise JsonableError(_("Unable to retrieve subscribers for invite-only stream"))
# sub_dict is a dictionary mapping stream_id => whether the user is subscribed to that stream # sub_dict is a dictionary mapping stream_id => whether the user is subscribed to that stream
def bulk_get_subscriber_user_ids(stream_dicts, user_profile, sub_dict): def bulk_get_subscriber_user_ids(stream_dicts, user_profile, sub_dict, stream_recipient):
# type: (Iterable[Mapping[str, Any]], UserProfile, Mapping[int, bool]) -> Dict[int, List[int]] # type: (Iterable[Mapping[str, Any]], UserProfile, Mapping[int, bool], StreamRecipientMap) -> Dict[int, List[int]]
target_stream_dicts = [] target_stream_dicts = []
for stream_dict in stream_dicts: for stream_dict in stream_dicts:
try: try:
@@ -1747,15 +1748,30 @@ def bulk_get_subscriber_user_ids(stream_dicts, user_profile, sub_dict):
continue continue
target_stream_dicts.append(stream_dict) target_stream_dicts.append(stream_dict)
subscriptions = Subscription.objects.select_related("recipient").filter( stream_ids = [stream['id'] for stream in target_stream_dicts]
recipient__type=Recipient.STREAM, stream_recipient.populate_for_stream_ids(stream_ids)
recipient__type_id__in=[stream["id"] for stream in target_stream_dicts], recipient_ids = sorted([
stream_recipient.recipient_id_for(stream_id)
for stream_id in stream_ids
])
subscriptions = Subscription.objects.filter(
recipient_id__in=recipient_ids,
user_profile__is_active=True, user_profile__is_active=True,
active=True).values("user_profile_id", "recipient__type_id") active=True
).values(
'recipient_id',
'user_profile_id',
).order_by('recipient_id')
subscriptions = list(subscriptions)
result = dict((stream["id"], []) for stream in stream_dicts) # type: Dict[int, List[int]] result = dict((stream["id"], []) for stream in stream_dicts) # type: Dict[int, List[int]]
recip_to_stream_id = stream_recipient.recipient_to_stream_id_dict()
for sub in subscriptions: for sub in subscriptions:
result[sub["recipient__type_id"]].append(sub["user_profile_id"]) recip_id = sub['recipient_id']
stream_id = recip_to_stream_id[recip_id]
user_profile_id = sub['user_profile_id']
result[stream_id].append(user_profile_id)
return result return result
@@ -3200,13 +3216,27 @@ def decode_email_address(email):
# subscriptions, so it's worth optimizing. # subscriptions, so it's worth optimizing.
def gather_subscriptions_helper(user_profile, include_subscribers=True): def gather_subscriptions_helper(user_profile, include_subscribers=True):
# type: (UserProfile, bool) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]] # type: (UserProfile, bool) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]
sub_dicts = Subscription.objects.select_related("recipient").filter( sub_dicts = Subscription.objects.filter(
user_profile = user_profile, user_profile = user_profile,
recipient__type = Recipient.STREAM).values( recipient__type = Recipient.STREAM
"recipient__type_id", "in_home_view", "color", "desktop_notifications", ).values(
"audible_notifications", "push_notifications", "active", "pin_to_top") "recipient_id", "in_home_view", "color", "desktop_notifications",
"audible_notifications", "push_notifications", "active", "pin_to_top"
).order_by("recipient_id")
sub_dicts = list(sub_dicts)
sub_recipient_ids = [
sub['recipient_id']
for sub in sub_dicts
]
stream_recipient = StreamRecipientMap()
stream_recipient.populate_for_recipient_ids(sub_recipient_ids)
stream_ids = set() # type: Set[int]
for sub in sub_dicts:
sub['stream_id'] = stream_recipient.stream_id_for(sub['recipient_id'])
stream_ids.add(sub['stream_id'])
stream_ids = set([sub["recipient__type_id"] for sub in sub_dicts])
all_streams = get_active_streams(user_profile.realm).select_related( all_streams = get_active_streams(user_profile.realm).select_related(
"realm").values("id", "name", "invite_only", "realm_id", "realm").values("id", "name", "invite_only", "realm_id",
"email_token", "description") "email_token", "description")
@@ -3223,15 +3253,20 @@ def gather_subscriptions_helper(user_profile, include_subscribers=True):
never_subscribed = [] never_subscribed = []
# Deactivated streams aren't in stream_hash. # Deactivated streams aren't in stream_hash.
streams = [stream_hash[sub["recipient__type_id"]] for sub in sub_dicts streams = [stream_hash[sub["stream_id"]] for sub in sub_dicts
if sub["recipient__type_id"] in stream_hash] if sub["stream_id"] in stream_hash]
streams_subscribed_map = dict((sub["recipient__type_id"], sub["active"]) for sub in sub_dicts) streams_subscribed_map = dict((sub["stream_id"], sub["active"]) for sub in sub_dicts)
# Add never subscribed streams to streams_subscribed_map # Add never subscribed streams to streams_subscribed_map
streams_subscribed_map.update({stream['id']: False for stream in all_streams if stream not in streams}) streams_subscribed_map.update({stream['id']: False for stream in all_streams if stream not in streams})
if include_subscribers: if include_subscribers:
subscriber_map = bulk_get_subscriber_user_ids(all_streams, user_profile, streams_subscribed_map) # type: Mapping[int, Optional[List[int]]] subscriber_map = bulk_get_subscriber_user_ids(
all_streams,
user_profile,
streams_subscribed_map,
stream_recipient
) # type: Mapping[int, Optional[List[int]]]
else: else:
# If we're not including subscribers, always return None, # If we're not including subscribers, always return None,
# which the below code needs to check for anyway. # which the below code needs to check for anyway.
@@ -3239,8 +3274,8 @@ def gather_subscriptions_helper(user_profile, include_subscribers=True):
sub_unsub_stream_ids = set() sub_unsub_stream_ids = set()
for sub in sub_dicts: for sub in sub_dicts:
sub_unsub_stream_ids.add(sub["recipient__type_id"]) sub_unsub_stream_ids.add(sub["stream_id"])
stream = stream_hash.get(sub["recipient__type_id"]) stream = stream_hash.get(sub["stream_id"])
if not stream: if not stream:
# This stream has been deactivated, don't include it. # This stream has been deactivated, don't include it.
continue continue

View File

@@ -0,0 +1,102 @@
from __future__ import absolute_import
from __future__ import print_function
from typing import (Dict, List)
from django.db import connection
from zerver.models import Recipient
class StreamRecipientMap(object):
'''
This class maps stream_id -> recipient_id and vice versa.
It is useful for bulk operations. Call the populate_* methods
to initialize the data structures. You should try to avoid
excessive queries by finding ids up front, but you can call
this repeatedly, and it will only look up new ids.
You should ONLY use this class for READ operations.
Note that this class uses raw SQL, because we want to highly
optimize page loads.
'''
def __init__(self):
# type: () -> None
self.recip_to_stream = dict() # type: Dict[int, int]
self.stream_to_recip = dict() # type: Dict[int, int]
def populate_for_stream_ids(self, stream_ids):
# type: (List[int]) -> None
stream_ids = sorted([
stream_id for stream_id in stream_ids
if stream_id not in self.stream_to_recip
])
if not stream_ids:
return
# see comment at the top of the class
id_list = ', '.join(str(stream_id) for stream_id in stream_ids)
query = '''
SELECT
zerver_recipient.id as recipient_id,
zerver_stream.id as stream_id
FROM
zerver_stream
INNER JOIN zerver_recipient ON
zerver_stream.id = zerver_recipient.type_id
WHERE
zerver_recipient.type = %d
AND
zerver_stream.id in (%s)
''' % (Recipient.STREAM, id_list)
self._process_query(query)
def populate_for_recipient_ids(self, recipient_ids):
# type: (List[int]) -> None
recipient_ids = sorted([
recip_id for recip_id in recipient_ids
if recip_id not in self.recip_to_stream
])
if not recipient_ids:
return
# see comment at the top of the class
id_list = ', '.join(str(recip_id) for recip_id in recipient_ids)
query = '''
SELECT
zerver_recipient.id as recipient_id,
zerver_stream.id as stream_id
FROM
zerver_recipient
INNER JOIN zerver_stream ON
zerver_stream.id = zerver_recipient.type_id
WHERE
zerver_recipient.type = %d
AND
zerver_recipient.id in (%s)
''' % (Recipient.STREAM, id_list)
self._process_query(query)
def _process_query(self, query):
# type: (str) -> None
cursor = connection.cursor()
cursor.execute(query)
rows = cursor.fetchall()
cursor.close()
for recip_id, stream_id in rows:
self.recip_to_stream[recip_id] = stream_id
self.stream_to_recip[stream_id] = recip_id
def recipient_id_for(self, stream_id):
# type: (int) -> int
return self.stream_to_recip[stream_id]
def stream_id_for(self, recip_id):
# type: (int) -> int
return self.recip_to_stream[recip_id]
def recipient_to_stream_id_dict(self):
# type: () -> Dict[int, int]
return self.recip_to_stream

View File

@@ -446,6 +446,9 @@ class EventsRegisterTest(ZulipTestCase):
def normalize(state): def normalize(state):
# type: (Dict[str, Any]) -> None # type: (Dict[str, Any]) -> None
state['realm_users'] = {u['email']: u for u in state['realm_users']} state['realm_users'] = {u['email']: u for u in state['realm_users']}
for u in state['never_subscribed']:
if 'subscribers' in u:
u['subscribers'].sort()
for u in state['subscriptions']: for u in state['subscriptions']:
if 'subscribers' in u: if 'subscribers' in u:
u['subscribers'].sort() u['subscribers'].sort()

View File

@@ -2469,7 +2469,7 @@ class GetSubscribersTest(ZulipTestCase):
if not sub["name"].startswith("stream_"): if not sub["name"].startswith("stream_"):
continue continue
self.assertTrue(len(sub["subscribers"]) == len(users_to_subscribe)) self.assertTrue(len(sub["subscribers"]) == len(users_to_subscribe))
self.assert_length(queries, 4) self.assert_length(queries, 6)
@slow("common_subscribe_to_streams is slow") @slow("common_subscribe_to_streams is slow")
def test_never_subscribed_streams(self): def test_never_subscribed_streams(self):
@@ -2527,7 +2527,7 @@ class GetSubscribersTest(ZulipTestCase):
with queries_captured() as queries: with queries_captured() as queries:
sub_data = gather_subscriptions_helper(self.user_profile) sub_data = gather_subscriptions_helper(self.user_profile)
never_subscribed = sub_data[2] never_subscribed = sub_data[2]
self.assert_length(queries, 3) self.assert_length(queries, 5)
# Ignore old streams. # Ignore old streams.
never_subscribed = [ never_subscribed = [
@@ -2595,7 +2595,7 @@ class GetSubscribersTest(ZulipTestCase):
self.assertTrue(len(sub["subscribers"]) == len(users_to_subscribe)) self.assertTrue(len(sub["subscribers"]) == len(users_to_subscribe))
else: else:
self.assertTrue(len(sub["subscribers"]) == 0) self.assertTrue(len(sub["subscribers"]) == 0)
self.assert_length(queries, 4) self.assert_length(queries, 5)
def test_nonsubscriber(self): def test_nonsubscriber(self):
# type: () -> None # type: () -> None