mention: Use filter_stream_authorization.

In preparation for accessing the messages in channels to link topics
in them, we need to check channel access.
This commit is contained in:
Tim Abbott
2025-01-30 17:25:29 -08:00
parent bb8d0684e7
commit bd8b845a4d
5 changed files with 39 additions and 14 deletions

View File

@@ -2740,15 +2740,18 @@ def do_convert(
# the fetches are somewhat expensive and these types of syntax
# are uncommon enough that it's a useful optimization.
if mention_data is None:
mention_backend = MentionBackend(message_realm.id)
message_sender = None
if message is not None:
message_sender = message.sender
if mention_data is None:
mention_backend = MentionBackend(message_realm.id)
mention_data = MentionData(mention_backend, content, message_sender)
stream_names = possible_linked_stream_names(content)
stream_name_info = mention_data.get_stream_name_map(stream_names)
stream_name_info = mention_data.get_stream_name_map(
stream_names, acting_user=message_sender
)
if content_has_emoji_syntax(content):
active_realm_emoji = get_name_keyed_dict_for_active_realm_emoji(message_realm.id)

View File

@@ -8,6 +8,7 @@ from django.conf import settings
from django.db.models import Q
from django_stubs_ext import StrPromise
from zerver.lib.streams import filter_stream_authorization
from zerver.lib.user_groups import get_root_id_annotated_recursive_subgroups_for_groups
from zerver.lib.users import get_inaccessible_user_ids
from zerver.models import NamedUserGroup, UserProfile
@@ -140,7 +141,9 @@ class MentionBackend:
return result
def get_stream_name_map(self, stream_names: set[str]) -> dict[str, int]:
def get_stream_name_map(
self, stream_names: set[str], acting_user: UserProfile | None
) -> dict[str, int]:
if not stream_names:
return {}
@@ -153,9 +156,11 @@ class MentionBackend:
else:
unseen_stream_names.append(stream_name)
if unseen_stream_names:
q_list = {Q(name=name) for name in unseen_stream_names}
if not unseen_stream_names:
return result
q_list = {Q(name=name) for name in unseen_stream_names}
if acting_user is None:
rows = (
get_linkable_streams(
realm_id=self.realm_id,
@@ -168,10 +173,24 @@ class MentionBackend:
"name",
)
)
for row in rows:
self.stream_cache[row["name"]] = row["id"]
result[row["name"]] = row["id"]
else:
authorization = filter_stream_authorization(
acting_user,
list(
get_linkable_streams(
realm_id=self.realm_id,
).filter(
functools.reduce(lambda a, b: a | b, q_list),
)
),
is_subscribing_other_users=False,
)
for stream in authorization.authorized_streams:
self.stream_cache[stream.name] = stream.id
result[stream.name] = stream.id
return result
@@ -334,8 +353,10 @@ class MentionData:
def get_group_members(self, user_group_id: int) -> set[int]:
return self.user_group_members.get(user_group_id, set())
def get_stream_name_map(self, stream_names: set[str]) -> dict[str, int]:
return self.mention_backend.get_stream_name_map(stream_names)
def get_stream_name_map(
self, stream_names: set[str], acting_user: UserProfile | None
) -> dict[str, int]:
return self.mention_backend.get_stream_name_map(stream_names, acting_user=acting_user)
def silent_mention_syntax_for_user(user_profile: UserProfile) -> str:

View File

@@ -15,7 +15,6 @@ from zerver.lib.exceptions import (
JsonableError,
OrganizationOwnerRequiredError,
)
from zerver.lib.markdown import markdown_convert
from zerver.lib.stream_subscription import (
get_active_subscriptions_for_stream_id,
get_subscribed_stream_ids_for_user,
@@ -128,6 +127,8 @@ def get_default_value_for_history_public_to_subscribers(
def render_stream_description(text: str, realm: Realm) -> str:
from zerver.lib.markdown import markdown_convert
return markdown_convert(text, message_realm=realm, no_previews=True).rendered_content

View File

@@ -1201,7 +1201,7 @@ class MessageMoveStreamTest(ZulipTestCase):
"iago", "test move stream", "new stream", "test"
)
with self.assert_database_query_count(55), self.assert_memcached_count(14):
with self.assert_database_query_count(57), self.assert_memcached_count(14):
result = self.client_patch(
f"/json/messages/{msg_id}",
{

View File

@@ -6202,7 +6202,7 @@ class SubscriptionAPITest(ZulipTestCase):
new_stream_announcements_stream = get_stream(self.streams[0], self.test_realm)
self.test_realm.new_stream_announcements_stream_id = new_stream_announcements_stream.id
self.test_realm.save()
with self.assert_database_query_count(53):
with self.assert_database_query_count(54):
self.subscribe_via_post(
self.test_user,
[new_streams[2]],
@@ -6698,7 +6698,7 @@ class GetSubscribersTest(ZulipTestCase):
polonius.id,
]
with self.assert_database_query_count(49):
with self.assert_database_query_count(50):
self.subscribe_via_post(
self.user_profile,
streams,