message: Add a bulk_access_stream_messages_query method.

This applies access restrictions in SQL, so that individual messages
do not need to be walked one-by-one.  It only functions for stream
messages.

Use of this method significantly speeds up checks if we moved "all
visible messages" in a topic, since we no longer need to walk every
remaining message in the old topic to determine that at least one was
visible to the user.  Similarly, it significantly speeds up merging
into existing topics, since it no longer must walk every message in
the new topic to determine if the user could see at least one.

Finally, it unlocks the ability to bulk-update only messages the user
has access to, in a single query (see subsequent commit).

(cherry picked from commit 7dcc7540f9)
This commit is contained in:
Alex Vandiver
2023-09-26 15:34:55 +00:00
committed by Tim Abbott
parent 9ac6ca1545
commit 9a2a5b5910
6 changed files with 119 additions and 79 deletions

View File

@@ -33,7 +33,7 @@ from zerver.lib.markdown import version as markdown_version
from zerver.lib.mention import MentionBackend, MentionData, silent_mention_syntax_for_user from zerver.lib.mention import MentionBackend, MentionData, silent_mention_syntax_for_user
from zerver.lib.message import ( from zerver.lib.message import (
access_message, access_message,
bulk_access_messages, bulk_access_stream_messages_query,
check_user_group_mention_allowed, check_user_group_mention_allowed,
normalize_body, normalize_body,
stream_wildcard_mention_allowed, stream_wildcard_mention_allowed,
@@ -808,27 +808,23 @@ def do_update_message(
# full-topic move. # full-topic move.
# #
# For security model reasons, we don't want to allow a # For security model reasons, we don't want to allow a
# user to take any action that would leak information # user to take any action (e.g. post a message about
# about older messages they cannot access (E.g. the only # having not moved the whole topic) that would leak
# remaining messages are in a stream without shared # information about older messages they cannot access
# history). The bulk_access_messages call below addresses # (e.g. there were earlier inaccessible messages in the
# topic, in a stream without shared history). The
# bulk_access_stream_messages_query call below addresses
# that concern. # that concern.
#
# bulk_access_messages is inefficient for this task, since
# we just want to do the exists() version of this
# query. But it's nice to reuse code, and this bulk
# operation is likely cheaper than a `GET /messages`
# unless the topic has thousands of messages of history.
assert stream_being_edited.recipient_id is not None assert stream_being_edited.recipient_id is not None
unmoved_messages = messages_for_topic( unmoved_messages = messages_for_topic(
realm.id, realm.id,
stream_being_edited.recipient_id, stream_being_edited.recipient_id,
orig_topic_name, orig_topic_name,
) )
visible_unmoved_messages = bulk_access_messages( visible_unmoved_messages = bulk_access_stream_messages_query(
user_profile, unmoved_messages, stream=stream_being_edited user_profile, unmoved_messages, stream_being_edited
) )
moved_all_visible_messages = len(visible_unmoved_messages) == 0 moved_all_visible_messages = not visible_unmoved_messages.exists()
# Migrate 'topic with visibility_policy' configuration in the following # Migrate 'topic with visibility_policy' configuration in the following
# circumstances: # circumstances:
@@ -1045,24 +1041,15 @@ def do_update_message(
# avoid leaking information about whether there are # avoid leaking information about whether there are
# messages in the destination topic's deeper history that # messages in the destination topic's deeper history that
# the acting user does not have permission to access. # the acting user does not have permission to access.
#
# TODO: These queries are quite inefficient, in that we're
# fetching full copies of all the messages in the
# destination topic to answer the question of whether the
# current user has access to at least one such message.
#
# The main strength of the current implementation is that
# it reuses existing logic, which is good for keeping it
# correct as we maintain the codebase.
preexisting_topic_messages = messages_for_topic( preexisting_topic_messages = messages_for_topic(
realm.id, stream_for_new_topic.recipient_id, new_topic realm.id, stream_for_new_topic.recipient_id, new_topic
).exclude(id__in=[*changed_message_ids, resolved_topic_message_id]) ).exclude(id__in=[*changed_message_ids, resolved_topic_message_id])
visible_preexisting_messages = bulk_access_messages( visible_preexisting_messages = bulk_access_stream_messages_query(
user_profile, preexisting_topic_messages, stream=stream_for_new_topic user_profile, preexisting_topic_messages, stream_for_new_topic
) )
no_visible_preexisting_messages = len(visible_preexisting_messages) == 0 no_visible_preexisting_messages = not visible_preexisting_messages.exists()
if no_visible_preexisting_messages and moved_all_visible_messages: if no_visible_preexisting_messages and moved_all_visible_messages:
new_thread_notification_string = gettext_lazy( new_thread_notification_string = gettext_lazy(

View File

@@ -21,7 +21,7 @@ import ahocorasick
import orjson import orjson
from django.conf import settings from django.conf import settings
from django.db import connection from django.db import connection
from django.db.models import Max, QuerySet, Sum from django.db.models import Exists, Max, OuterRef, QuerySet, Sum
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_stubs_ext import ValuesQuerySet from django_stubs_ext import ValuesQuerySet
@@ -996,6 +996,38 @@ def bulk_access_messages(
return filtered_messages return filtered_messages
def bulk_access_stream_messages_query(
user_profile: UserProfile, messages: QuerySet[Message], stream: Stream
) -> QuerySet[Message]:
"""This function mirrors bulk_access_messages, above, but applies the
limits to a QuerySet and returns a new QuerySet which only
contains messages in the given stream which the user can access.
Note that this only works with streams. It may return an empty
QuerySet if the user has access to no messages (for instance, for
a private stream which the user is not subscribed to).
"""
messages = messages.filter(realm_id=user_profile.realm_id, recipient_id=stream.recipient_id)
if stream.is_public() and user_profile.can_access_public_streams():
return messages
if not Subscription.objects.filter(
user_profile=user_profile, active=True, recipient=stream.recipient
).exists():
return Message.objects.none()
if not stream.is_history_public_to_subscribers():
messages = messages.annotate(
has_usermessage=Exists(
UserMessage.objects.filter(
user_profile_id=user_profile.id, message_id=OuterRef("id")
)
)
).filter(has_usermessage=1)
return messages
def get_messages_with_usermessage_rows_for_user( def get_messages_with_usermessage_rows_for_user(
user_profile_id: int, message_ids: Sequence[int] user_profile_id: int, message_ids: Sequence[int]
) -> ValuesQuerySet[UserMessage, int]: ) -> ValuesQuerySet[UserMessage, int]:

View File

@@ -173,9 +173,9 @@ def update_messages_for_topic_edit(
# If we're moving the messages between streams, only move # If we're moving the messages between streams, only move
# messages that the acting user can access, so that one cannot # messages that the acting user can access, so that one cannot
# gain access to messages through moving them. # gain access to messages through moving them.
from zerver.lib.message import bulk_access_messages from zerver.lib.message import bulk_access_stream_messages_query
messages_list = bulk_access_messages(acting_user, messages, stream=old_stream) messages_list = list(bulk_access_stream_messages_query(acting_user, messages, old_stream))
else: else:
# For single-message edits or topic moves within a stream, we # For single-message edits or topic moves within a stream, we
# allow moving history the user may not have access in order # allow moving history the user may not have access in order

View File

@@ -1530,7 +1530,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(29): with self.assert_database_query_count(28):
check_update_message( check_update_message(
user_profile=desdemona, user_profile=desdemona,
message_id=message_id, message_id=message_id,
@@ -1561,7 +1561,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(34): with self.assert_database_query_count(33):
check_update_message( check_update_message(
user_profile=desdemona, user_profile=desdemona,
message_id=message_id, message_id=message_id,
@@ -1594,7 +1594,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(29): with self.assert_database_query_count(28):
check_update_message( check_update_message(
user_profile=desdemona, user_profile=desdemona,
message_id=message_id, message_id=message_id,
@@ -1617,7 +1617,7 @@ class EditMessageTest(EditMessageTestCase):
second_message_id = self.send_stream_message( second_message_id = self.send_stream_message(
hamlet, stream_name, topic_name="changed topic name", content="Second message" hamlet, stream_name, topic_name="changed topic name", content="Second message"
) )
with self.assert_database_query_count(25): with self.assert_database_query_count(23):
check_update_message( check_update_message(
user_profile=desdemona, user_profile=desdemona,
message_id=second_message_id, message_id=second_message_id,
@@ -3785,7 +3785,7 @@ class EditMessageTest(EditMessageTestCase):
"iago", "test move stream", "new stream", "test" "iago", "test move stream", "new stream", "test"
) )
with self.assert_database_query_count(55), self.assert_memcached_count(14): with self.assert_database_query_count(52), self.assert_memcached_count(14):
result = self.client_patch( result = self.client_patch(
f"/json/messages/{msg_id}", f"/json/messages/{msg_id}",
{ {

View File

@@ -19,6 +19,7 @@ from zerver.lib.message import (
aggregate_unread_data, aggregate_unread_data,
apply_unread_message_event, apply_unread_message_event,
bulk_access_messages, bulk_access_messages,
bulk_access_stream_messages_query,
format_unread_message_details, format_unread_message_details,
get_raw_unread_data, get_raw_unread_data,
) )
@@ -1505,6 +1506,30 @@ class MessageAccessTests(ZulipTestCase):
result = self.change_star(message_id) result = self.change_star(message_id)
self.assert_json_success(result) self.assert_json_success(result)
def assert_bulk_access(
self,
user: UserProfile,
message_ids: List[int],
stream: Stream,
bulk_access_messages_count: int,
bulk_access_stream_messages_query_count: int,
) -> List[Message]:
with self.assert_database_query_count(bulk_access_messages_count):
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in sorted(message_ids)
]
list_result = bulk_access_messages(user, messages, stream=stream)
with self.assert_database_query_count(bulk_access_stream_messages_query_count):
message_query = (
Message.objects.select_related("recipient")
.filter(id__in=message_ids)
.order_by("id")
)
query_result = list(bulk_access_stream_messages_query(user, message_query, stream))
self.assertEqual(query_result, list_result)
return list_result
def test_bulk_access_messages_private_stream(self) -> None: def test_bulk_access_messages_private_stream(self) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
self.login_user(user) self.login_user(user)
@@ -1526,16 +1551,12 @@ class MessageAccessTests(ZulipTestCase):
message_two_id = self.send_stream_message(user, stream_name, "Message two") message_two_id = self.send_stream_message(user, stream_name, "Message two")
message_ids = [message_one_id, message_two_id] message_ids = [message_one_id, message_two_id]
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in message_ids
]
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
# Message sent before subscribing wouldn't be accessible by later # Message sent before subscribing wouldn't be accessible by later
# subscribed user as stream has protected history # subscribed user as stream has protected history
filtered_messages = self.assert_bulk_access(
later_subscribed_user, message_ids, stream, 4, 2
)
self.assert_length(filtered_messages, 1) self.assert_length(filtered_messages, 1)
self.assertEqual(filtered_messages[0].id, message_two_id) self.assertEqual(filtered_messages[0].id, message_two_id)
@@ -1547,27 +1568,44 @@ class MessageAccessTests(ZulipTestCase):
acting_user=self.example_user("cordelia"), acting_user=self.example_user("cordelia"),
) )
with self.assert_database_query_count(2): # Message sent before subscribing are accessible by user as stream
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream) # now don't have protected history
filtered_messages = self.assert_bulk_access(
# Message sent before subscribing are accessible by 8user as stream later_subscribed_user, message_ids, stream, 4, 2
# don't have protected history )
self.assert_length(filtered_messages, 2) self.assert_length(filtered_messages, 2)
# Testing messages accessibility for an unsubscribed user # Testing messages accessibility for an unsubscribed user
unsubscribed_user = self.example_user("ZOE") unsubscribed_user = self.example_user("ZOE")
filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
self.assert_length(filtered_messages, 0) self.assert_length(filtered_messages, 0)
# Adding more message ids to the list increases the query size
# for bulk_access_messages but not
# bulk_access_stream_messages_query
more_message_ids = [
*message_ids,
self.send_stream_message(user, stream_name, "Message three"),
self.send_stream_message(user, stream_name, "Message four"),
]
filtered_messages = self.assert_bulk_access(
later_subscribed_user, more_message_ids, stream, 6, 2
)
self.assert_length(filtered_messages, 4)
# Verify an exception is thrown if called where the passed # Verify an exception is thrown if called where the passed
# stream not matching the messages. # stream not matching the messages.
other_stream = get_stream("Denmark", unsubscribed_user.realm)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
bulk_access_messages( messages = [Message.objects.get(id=id) for id in message_ids]
unsubscribed_user, messages, stream=get_stream("Denmark", unsubscribed_user.realm) bulk_access_messages(unsubscribed_user, messages, stream=other_stream)
)
# Verify that bulk_access_stream_messages_query is empty with a stream mismatch
message_query = Message.objects.select_related("recipient").filter(id__in=message_ids)
filtered_query = bulk_access_stream_messages_query(
later_subscribed_user, message_query, other_stream
)
self.assert_length(filtered_query, 0)
def test_bulk_access_messages_public_stream(self) -> None: def test_bulk_access_messages_public_stream(self) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
@@ -1585,20 +1623,15 @@ class MessageAccessTests(ZulipTestCase):
message_two_id = self.send_stream_message(user, stream_name, "Message two") message_two_id = self.send_stream_message(user, stream_name, "Message two")
message_ids = [message_one_id, message_two_id] message_ids = [message_one_id, message_two_id]
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in message_ids
]
# All public stream messages are always accessible # All public stream messages are always accessible
with self.assert_database_query_count(2): filtered_messages = self.assert_bulk_access(
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream) later_subscribed_user, message_ids, stream, 4, 1
)
self.assert_length(filtered_messages, 2) self.assert_length(filtered_messages, 2)
unsubscribed_user = self.example_user("ZOE") unsubscribed_user = self.example_user("ZOE")
with self.assert_database_query_count(2): filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
self.assert_length(filtered_messages, 2) self.assert_length(filtered_messages, 2)

View File

@@ -54,6 +54,7 @@ from zerver.lib.exceptions import (
ResourceNotFoundError, ResourceNotFoundError,
) )
from zerver.lib.mention import MentionBackend, silent_mention_syntax_for_user from zerver.lib.mention import MentionBackend, silent_mention_syntax_for_user
from zerver.lib.message import bulk_access_stream_messages_query
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.retention import STREAM_MESSAGE_BATCH_SIZE as RETENTION_STREAM_MESSAGE_BATCH_SIZE from zerver.lib.retention import STREAM_MESSAGE_BATCH_SIZE as RETENTION_STREAM_MESSAGE_BATCH_SIZE
@@ -99,7 +100,7 @@ from zerver.lib.validator import (
check_union, check_union,
to_non_negative_int, to_non_negative_int,
) )
from zerver.models import Realm, Stream, UserGroup, UserMessage, UserProfile from zerver.models import Realm, Stream, UserGroup, UserProfile
from zerver.models.users import get_system_bot from zerver.models.users import get_system_bot
@@ -925,22 +926,9 @@ def delete_in_topic(
messages = messages_for_topic( messages = messages_for_topic(
user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name
) )
# Note: It would be better to use bulk_access_messages here, which is our core function # This handles applying access control, such that only messages
# for obtaining the accessible messages - and it's good to use it wherever we can, # the user can see are returned in the query.
# so that we have a central place to keep up to date with our security model for messages = bulk_access_stream_messages_query(user_profile, messages, stream)
# message access.
# However, it fetches the full Message objects, which would be bad here for very large
# topics.
# The access_stream_by_id call above ensures that the acting user currently has access to the
# stream (which entails having an active Subscription in case of private streams), meaning
# that combined with the UserMessage check below, this is a sufficient replacement for
# bulk_access_messages.
if not stream.is_history_public_to_subscribers():
# Don't allow the user to delete messages that they don't have access to.
deletable_message_ids = UserMessage.objects.filter(
user_profile=user_profile, message_id__in=messages
).values_list("message_id", flat=True)
messages = messages.filter(id__in=deletable_message_ids)
def delete_in_batches() -> Literal[True]: def delete_in_batches() -> Literal[True]:
# Topics can be large enough that this request will inevitably time out. # Topics can be large enough that this request will inevitably time out.