bulk_access_messages: Bulk fetch Subscription details.

This completes the effort to make it possible to use
bulk_access_message in contexts where there are more than a handful of
messages without creating performance issues.
This commit is contained in:
Tim Abbott
2021-05-12 14:40:58 -07:00
committed by Tim Abbott
parent c6e1702335
commit 0bfef96543
3 changed files with 32 additions and 5 deletions

View File

@@ -31,6 +31,7 @@ from zerver.lib.markdown import version as markdown_version
from zerver.lib.request import JsonableError
from zerver.lib.stream_subscription import (
get_stream_subscriptions_for_user,
get_subscribed_stream_recipient_ids_for_user,
num_subscribers_for_stream_id,
)
from zerver.lib.timestamp import datetime_to_timestamp
@@ -729,6 +730,14 @@ def has_message_access(
def bulk_access_messages(
user_profile: UserProfile, messages: Sequence[Message], *, stream: Optional[Stream] = None
) -> List[Message]:
"""This function does the full has_message_access check for each
message. If stream is provided, it is used to avoid unnecessary
database queries, and will use exactly 2 bulk queries instead.
Throws AssertionError if stream is passed and any of the messages
were not sent to that stream.
"""
filtered_messages = []
user_message_set = set(
@@ -737,10 +746,20 @@ def bulk_access_messages(
)
)
# TODO: Ideally, we'd do a similar bulk-stream-fetch if stream is
# None, so that this function is fast with
subscribed_recipient_ids = set(get_subscribed_stream_recipient_ids_for_user(user_profile))
for message in messages:
has_user_message = message.id in user_message_set
is_subscribed = message.recipient_id in subscribed_recipient_ids
if has_message_access(
user_profile, message, has_user_message=has_user_message, stream=stream
user_profile,
message,
has_user_message=has_user_message,
stream=stream,
is_subscribed=is_subscribed,
):
filtered_messages.append(message)
return filtered_messages

View File

@@ -58,6 +58,14 @@ def get_subscribed_stream_ids_for_user(user_profile: UserProfile) -> QuerySet:
).values_list("recipient__type_id", flat=True)
def get_subscribed_stream_recipient_ids_for_user(user_profile: UserProfile) -> QuerySet:
return Subscription.objects.filter(
user_profile_id=user_profile,
recipient__type=Recipient.STREAM,
active=True,
).values_list("recipient_id", flat=True)
def get_stream_subscriptions_for_user(user_profile: UserProfile) -> QuerySet:
# TODO: Change return type to QuerySet[Subscription]
return Subscription.objects.filter(

View File

@@ -1307,7 +1307,7 @@ class MessageAccessTests(ZulipTestCase):
with queries_captured() as queries:
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
self.assert_length(queries, 1)
self.assert_length(queries, 2)
# Message sent before subscribing wouldn't be accessible by later
# subscribed user as stream has protected history
@@ -1329,7 +1329,7 @@ class MessageAccessTests(ZulipTestCase):
with queries_captured() as queries:
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
self.assert_length(queries, 3)
self.assert_length(queries, 2)
self.assertEqual(len(filtered_messages), 0)
@@ -1364,14 +1364,14 @@ class MessageAccessTests(ZulipTestCase):
with queries_captured() as queries:
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
self.assertEqual(len(filtered_messages), 2)
self.assert_length(queries, 1)
self.assert_length(queries, 2)
unsubscribed_user = self.example_user("ZOE")
with queries_captured() as queries:
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
self.assertEqual(len(filtered_messages), 2)
self.assert_length(queries, 1)
self.assert_length(queries, 2)
class PersonalMessagesFlagTest(ZulipTestCase):