diff --git a/zerver/actions/user_topics.py b/zerver/actions/user_topics.py index a6e17768b2..a8cadbea03 100644 --- a/zerver/actions/user_topics.py +++ b/zerver/actions/user_topics.py @@ -6,7 +6,11 @@ from django.utils.translation import gettext as _ from zerver.lib.exceptions import JsonableError from zerver.lib.timestamp import datetime_to_timestamp -from zerver.lib.user_topics import add_topic_mute, get_topic_mutes, remove_topic_mute +from zerver.lib.user_topics import ( + get_topic_mutes, + remove_topic_mute, + set_user_topic_visibility_policy_in_database, +) from zerver.models import Stream, UserProfile, UserTopic from zerver.tornado.django_api import send_event @@ -31,12 +35,13 @@ def do_set_user_topic_visibility_policy( raise JsonableError(_("Topic is not muted")) else: assert stream.recipient_id is not None - add_topic_mute( + set_user_topic_visibility_policy_in_database( user_profile, stream.id, - stream.recipient_id, topic, - last_updated, + visibility_policy=visibility_policy, + recipient_id=stream.recipient_id, + last_updated=last_updated, ignore_duplicate=ignore_duplicate, ) diff --git a/zerver/lib/user_topics.py b/zerver/lib/user_topics.py index be8b086c95..b683f6492e 100644 --- a/zerver/lib/user_topics.py +++ b/zerver/lib/user_topics.py @@ -1,11 +1,14 @@ import datetime -from typing import Callable, List, Optional, Tuple, TypedDict +from typing import Callable, Dict, List, Optional, Tuple, TypedDict +from django.db import transaction from django.db.models import QuerySet from django.utils.timezone import now as timezone_now +from django.utils.translation import gettext as _ from sqlalchemy.sql import ClauseElement, and_, column, not_, or_ from sqlalchemy.types import Integer +from zerver.lib.exceptions import JsonableError from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.topic import topic_match_sa from zerver.lib.types import UserTopicDict @@ -95,39 +98,62 @@ def set_topic_mutes( recipient_id = stream.recipient_id assert recipient_id is not None - add_topic_mute( + set_user_topic_visibility_policy_in_database( user_profile=user_profile, stream_id=stream.id, recipient_id=recipient_id, topic_name=topic_name, - date_muted=date_muted, + visibility_policy=UserTopic.MUTED, + last_updated=date_muted, ) -def add_topic_mute( +@transaction.atomic(savepoint=False) +def set_user_topic_visibility_policy_in_database( user_profile: UserProfile, stream_id: int, - recipient_id: int, topic_name: str, - date_muted: Optional[datetime.datetime] = None, + *, + visibility_policy: int, + recipient_id: int, + last_updated: Optional[datetime.datetime] = None, ignore_duplicate: bool = False, ) -> None: - if date_muted is None: - date_muted = timezone_now() - UserTopic.objects.bulk_create( - [ - UserTopic( - user_profile=user_profile, - stream_id=stream_id, - recipient_id=recipient_id, - topic_name=topic_name, - last_updated=date_muted, - visibility_policy=UserTopic.MUTED, - ), - ], - ignore_conflicts=ignore_duplicate, + assert last_updated is not None + (row, created) = UserTopic.objects.get_or_create( + user_profile=user_profile, + stream_id=stream_id, + topic_name__iexact=topic_name, + recipient_id=recipient_id, + defaults={ + "topic_name": topic_name, + "last_updated": last_updated, + "visibility_policy": visibility_policy, + }, ) + if created: + return + + duplicate_request: bool = row.visibility_policy == visibility_policy + + if duplicate_request and ignore_duplicate: + return + + if duplicate_request and not ignore_duplicate: + visibility_policy_string: Dict[int, str] = { + 1: "muted", + 2: "unmuted", + 3: "followed", + } + raise JsonableError( + _("Topic already {}").format(visibility_policy_string[visibility_policy]) + ) + # The request is to just 'update' the visibility policy of a topic + row.visibility_policy = visibility_policy + row.last_updated = last_updated + row.save(update_fields=["visibility_policy", "last_updated"]) + def remove_topic_mute(user_profile: UserProfile, stream_id: int, topic_name: str) -> None: row = UserTopic.objects.get( diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 22e3b3356a..909e21a3d4 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -1456,6 +1456,17 @@ class NormalActionsTest(BaseAction): ) check_user_topic("events[0]", events[0]) + def test_unmuted_topics_events(self) -> None: + stream = get_stream("Denmark", self.user_profile.realm) + events = self.verify_action( + lambda: do_set_user_topic_visibility_policy( + self.user_profile, stream, "topic", visibility_policy=UserTopic.UNMUTED + ), + num_events=2, + ) + check_muted_topics("events[0]", events[0]) + check_user_topic("events[1]", events[1]) + def test_muted_users_events(self) -> None: muted_user = self.example_user("othello") events = self.verify_action( diff --git a/zerver/tests/test_message_edit.py b/zerver/tests/test_message_edit.py index 17dcf268fb..e3b606fd46 100644 --- a/zerver/tests/test_message_edit.py +++ b/zerver/tests/test_message_edit.py @@ -1338,7 +1338,7 @@ class EditMessageTest(EditMessageTestCase): # This code path adds 9 (1 + 4/user with muted topics) + 1 to # the number of database queries for moving a topic. - with self.assert_database_query_count(19): + with self.assert_database_query_count(21): check_update_message( user_profile=hamlet, message_id=message_id, @@ -1422,7 +1422,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_mutes(desdemona, muted_topics) set_topic_mutes(cordelia, muted_topics) - with self.assert_database_query_count(30): + with self.assert_database_query_count(32): check_update_message( user_profile=desdemona, message_id=message_id, @@ -1453,7 +1453,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_mutes(desdemona, muted_topics) set_topic_mutes(cordelia, muted_topics) - with self.assert_database_query_count(32): + with self.assert_database_query_count(33): check_update_message( user_profile=desdemona, message_id=message_id, @@ -1486,7 +1486,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_mutes(desdemona, muted_topics) set_topic_mutes(cordelia, muted_topics) - with self.assert_database_query_count(30): + with self.assert_database_query_count(32): check_update_message( user_profile=desdemona, message_id=message_id, diff --git a/zerver/tests/test_user_topics.py b/zerver/tests/test_user_topics.py index 6f2f90165c..38e05e13a1 100644 --- a/zerver/tests/test_user_topics.py +++ b/zerver/tests/test_user_topics.py @@ -199,3 +199,47 @@ class MutedTopicsTests(ZulipTestCase): data = {"stream": stream.name, "stream_id": stream.id, "topic": "Verona3", "op": "remove"} result = self.api_patch(user, url, data) self.assert_json_error(result, "Please choose one: 'stream' or 'stream_id'.") + + +class UnmutedTopicsTests(ZulipTestCase): + def test_user_ids_unmuting_topic(self) -> None: + hamlet = self.example_user("hamlet") + cordelia = self.example_user("cordelia") + realm = hamlet.realm + stream = get_stream("Verona", realm) + topic_name = "teST topic" + date_unmuted = datetime(2020, 1, 1, tzinfo=timezone.utc) + + stream_topic_target = StreamTopicTarget( + stream_id=stream.id, + topic_name=topic_name, + ) + + user_ids = stream_topic_target.user_ids_with_visibility_policy(UserTopic.UNMUTED) + self.assertEqual(user_ids, set()) + + def set_topic_visibility_for_user(user: UserProfile, visibility_policy: int) -> None: + do_set_user_topic_visibility_policy( + user, + stream, + "test TOPIC", + visibility_policy=visibility_policy, + last_updated=date_unmuted, + ) + + set_topic_visibility_for_user(hamlet, UserTopic.UNMUTED) + set_topic_visibility_for_user(cordelia, UserTopic.MUTED) + user_ids = stream_topic_target.user_ids_with_visibility_policy(UserTopic.UNMUTED) + self.assertEqual(user_ids, {hamlet.id}) + hamlet_date_unmuted = UserTopic.objects.filter( + user_profile=hamlet, visibility_policy=UserTopic.UNMUTED + )[0].last_updated + self.assertEqual(hamlet_date_unmuted, date_unmuted) + + set_topic_visibility_for_user(cordelia, UserTopic.UNMUTED) + user_ids = stream_topic_target.user_ids_with_visibility_policy(UserTopic.UNMUTED) + self.assertEqual(user_ids, {hamlet.id, cordelia.id}) + cordelia_date_unmuted = UserTopic.objects.filter( + user_profile=cordelia, visibility_policy=UserTopic.UNMUTED + )[0].last_updated + self.assertEqual(cordelia_date_unmuted, date_unmuted) diff --git a/zerver/views/user_topics.py b/zerver/views/user_topics.py index 34b5f9fabe..b7079737cc 100644 --- a/zerver/views/user_topics.py +++ b/zerver/views/user_topics.py @@ -1,13 +1,11 @@ import datetime from typing import Optional -from django.db import IntegrityError from django.http import HttpRequest, HttpResponse from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from zerver.actions.user_topics import do_set_user_topic_visibility_policy -from zerver.lib.exceptions import JsonableError from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.streams import ( @@ -34,16 +32,13 @@ def mute_topic( assert stream_id is not None (stream, sub) = access_stream_by_id(user_profile, stream_id) - try: - do_set_user_topic_visibility_policy( - user_profile, - stream, - topic_name, - visibility_policy=UserTopic.MUTED, - last_updated=date_muted, - ) - except IntegrityError: - raise JsonableError(_("Topic already muted")) + do_set_user_topic_visibility_policy( + user_profile, + stream, + topic_name, + visibility_policy=UserTopic.MUTED, + last_updated=date_muted, + ) def unmute_topic(