import logging from collections import defaultdict from datetime import datetime from typing import Callable, TypedDict from django.db import connection, transaction from django.db.models import QuerySet from django.utils.timezone import now as timezone_now from psycopg2.sql import SQL, Literal from sqlalchemy.sql import ClauseElement, and_, column, not_, or_ from sqlalchemy.types import Integer from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.topic_sqlalchemy import topic_match_sa from zerver.lib.types import UserTopicDict from zerver.models import UserProfile, UserTopic from zerver.models.streams import get_stream def get_user_topics( user_profile: UserProfile, include_deactivated: bool = False, include_stream_name: bool = False, visibility_policy: int | None = None, ) -> list[UserTopicDict]: """ Fetches UserTopic objects associated with the target user. * include_deactivated: Whether to include those associated with deactivated streams. * include_stream_name: Whether to include stream names in the returned dictionaries. * visibility_policy: If specified, returns only UserTopic objects with the specified visibility_policy value. """ query = UserTopic.objects.filter(user_profile=user_profile) if visibility_policy is not None: query = query.filter(visibility_policy=visibility_policy) # Exclude user topics that are part of deactivated streams unless # explicitly requested. if not include_deactivated: query = query.filter(stream__deactivated=False) rows = query.values( "stream_id", "stream__name", "topic_name", "last_updated", "visibility_policy" ) result = [] for row in rows: user_topic_dict: UserTopicDict = { "stream_id": row["stream_id"], "topic_name": row["topic_name"], "visibility_policy": row["visibility_policy"], "last_updated": datetime_to_timestamp(row["last_updated"]), } if include_stream_name: user_topic_dict["stream__name"] = row["stream__name"] result.append(user_topic_dict) return result def get_topic_mutes( user_profile: UserProfile, include_deactivated: bool = False ) -> list[tuple[str, str, int]]: user_topics = get_user_topics( user_profile=user_profile, include_deactivated=include_deactivated, include_stream_name=True, visibility_policy=UserTopic.VisibilityPolicy.MUTED, ) return [ (user_topic["stream__name"], user_topic["topic_name"], user_topic["last_updated"]) for user_topic in user_topics ] @transaction.atomic(savepoint=False) def set_topic_visibility_policy( user_profile: UserProfile, topics: list[list[str]], visibility_policy: int, last_updated: datetime | None = None, ) -> None: """ This is only used in tests. """ UserTopic.objects.filter( user_profile=user_profile, visibility_policy=visibility_policy, ).delete() if last_updated is None: last_updated = timezone_now() for stream_name, topic_name in topics: stream = get_stream(stream_name, user_profile.realm) recipient_id = stream.recipient_id assert recipient_id is not None bulk_set_user_topic_visibility_policy_in_database( user_profiles=[user_profile], stream_id=stream.id, recipient_id=recipient_id, topic_name=topic_name, visibility_policy=visibility_policy, last_updated=last_updated, ) def get_topic_visibility_policy( user_profile: UserProfile, stream_id: int, topic_name: str, ) -> int: try: user_topic = UserTopic.objects.get( user_profile=user_profile, stream_id=stream_id, topic_name__iexact=topic_name ) visibility_policy = user_topic.visibility_policy except UserTopic.DoesNotExist: visibility_policy = UserTopic.VisibilityPolicy.INHERIT return visibility_policy @transaction.atomic(savepoint=False) def bulk_set_user_topic_visibility_policy_in_database( user_profiles: list[UserProfile], stream_id: int, topic_name: str, *, visibility_policy: int, recipient_id: int | None = None, last_updated: datetime | None = None, ) -> list[UserProfile]: # returns the list of user_profiles whose user_topic row # is either deleted, updated, or created. rows = UserTopic.objects.filter( user_profile__in=user_profiles, stream_id=stream_id, topic_name__iexact=topic_name, ).select_related("user_profile", "user_profile__realm") user_profiles_with_visibility_policy = [row.user_profile for row in rows] user_profiles_without_visibility_policy = list( set(user_profiles) - set(user_profiles_with_visibility_policy) ) if visibility_policy == UserTopic.VisibilityPolicy.INHERIT: for user_profile in user_profiles_without_visibility_policy: # The user doesn't already have a visibility_policy for this topic. logging.info( "User %s tried to remove visibility_policy, which actually doesn't exist", user_profile.id, ) rows.delete() return user_profiles_with_visibility_policy assert last_updated is not None assert recipient_id is not None user_profiles_seeking_user_topic_update_or_create: list[UserProfile] = ( user_profiles_without_visibility_policy ) for row in rows: if row.visibility_policy == visibility_policy: logging.info( "User %s tried to set visibility_policy to its current value of %s", row.user_profile_id, visibility_policy, ) continue # The request is to just 'update' the visibility policy of a topic user_profiles_seeking_user_topic_update_or_create.append(row.user_profile) if user_profiles_seeking_user_topic_update_or_create: user_profile_ids_array = SQL("ARRAY[{}]").format( SQL(", ").join( [ Literal(user_profile.id) for user_profile in user_profiles_seeking_user_topic_update_or_create ] ) ) query = SQL(""" INSERT INTO zerver_usertopic(user_profile_id, stream_id, recipient_id, topic_name, last_updated, visibility_policy) SELECT * FROM UNNEST({user_profile_ids_array}) AS user_profile(user_profile_id) CROSS JOIN (VALUES ({stream_id}, {recipient_id}, {topic_name}, {last_updated}, {visibility_policy})) AS other_values(stream_id, recipient_id, topic_name, last_updated, visibility_policy) ON CONFLICT (user_profile_id, stream_id, lower(topic_name)) DO UPDATE SET last_updated = EXCLUDED.last_updated, visibility_policy = EXCLUDED.visibility_policy; """).format( user_profile_ids_array=user_profile_ids_array, stream_id=Literal(stream_id), recipient_id=Literal(recipient_id), topic_name=Literal(topic_name), last_updated=Literal(last_updated), visibility_policy=Literal(visibility_policy), ) with connection.cursor() as cursor: cursor.execute(query) return user_profiles_seeking_user_topic_update_or_create def topic_has_visibility_policy( user_profile: UserProfile, stream_id: int, topic_name: str, visibility_policy: int ) -> bool: if visibility_policy == UserTopic.VisibilityPolicy.INHERIT: has_user_topic_row = UserTopic.objects.filter( user_profile=user_profile, stream_id=stream_id, topic_name__iexact=topic_name ).exists() return not has_user_topic_row has_visibility_policy = UserTopic.objects.filter( user_profile=user_profile, stream_id=stream_id, topic_name__iexact=topic_name, visibility_policy=visibility_policy, ).exists() return has_visibility_policy def exclude_topic_mutes( conditions: list[ClauseElement], user_profile: UserProfile, stream_id: int | None ) -> list[ClauseElement]: # Note: Unlike get_topic_mutes, here we always want to # consider topics in deactivated streams, so they are # never filtered from the query in this method. query = UserTopic.objects.filter( user_profile=user_profile, visibility_policy=UserTopic.VisibilityPolicy.MUTED, ) if stream_id is not None: # If we are narrowed to a stream, we can optimize the query # by not considering topic mutes outside the stream. query = query.filter(stream_id=stream_id) rows = query.values( "recipient_id", "topic_name", ) if not rows: return conditions class RecipientTopicDict(TypedDict): recipient_id: int topic_name: str def mute_cond(row: RecipientTopicDict) -> ClauseElement: recipient_id = row["recipient_id"] topic_name = row["topic_name"] stream_cond = column("recipient_id", Integer) == recipient_id topic_cond = topic_match_sa(topic_name) return and_(stream_cond, topic_cond) condition = not_(or_(*map(mute_cond, rows))) return [*conditions, condition] def build_get_topic_visibility_policy( user_profile: UserProfile, ) -> Callable[[int, str], int]: """Prefetch the visibility policies the user has configured for various topics. The prefetching helps to avoid the db queries later in the loop to determine the user's visibility policy for a topic. """ rows = UserTopic.objects.filter(user_profile=user_profile).values( "recipient_id", "topic_name", "visibility_policy", ) topic_to_visibility_policy: dict[tuple[int, str], int] = defaultdict(int) for row in rows: recipient_id = row["recipient_id"] topic_name = row["topic_name"] visibility_policy = row["visibility_policy"] topic_to_visibility_policy[(recipient_id, topic_name)] = visibility_policy def get_topic_visibility_policy(recipient_id: int, topic_name: str) -> int: return topic_to_visibility_policy[(recipient_id, topic_name.lower())] return get_topic_visibility_policy def get_users_with_user_topic_visibility_policy( stream_id: int, topic_name: str ) -> QuerySet[UserTopic]: return UserTopic.objects.filter( stream_id=stream_id, topic_name__iexact=topic_name ).select_related("user_profile", "user_profile__realm")