diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 09e0eb0837..1cdf0cecfa 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -2141,6 +2141,10 @@ def do_add_submessage( msg_type: str, content: str, ) -> None: + """Should be called while holding a SELECT FOR UPDATE lock + (e.g. via access_message(..., lock_message=True)) on the + Message row, to prevent race conditions. + """ submessage = SubMessage( sender_id=sender_id, message_id=message_id, @@ -2160,7 +2164,7 @@ def do_add_submessage( ums = UserMessage.objects.filter(message_id=message_id) target_user_ids = [um.user_profile_id for um in ums] - send_event(realm, event, target_user_ids) + transaction.on_commit(lambda: send_event(realm, event, target_user_ids)) def notify_reaction_update( diff --git a/zerver/tests/test_submessage.py b/zerver/tests/test_submessage.py index cf1dfe46d7..d71c6ad6b5 100644 --- a/zerver/tests/test_submessage.py +++ b/zerver/tests/test_submessage.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping +from unittest import mock +from zerver.lib.actions import do_add_submessage from zerver.lib.message import MessageDict from zerver.lib.test_classes import ZulipTestCase from zerver.models import Message, SubMessage @@ -147,3 +149,19 @@ class TestBasics(ZulipTestCase): sender_id=cordelia.id, ) self.assertEqual(row, expected_data) + + def test_submessage_event_sent_after_transaction_commits(self) -> None: + """ + Tests that `send_event` is hooked to `transaction.on_commit`. This is important, because + we don't want to end up holding locks on message rows for too long if the event queue runs + into a problem. + """ + hamlet = self.example_user("hamlet") + message_id = self.send_stream_message(hamlet, "Scotland") + + with self.tornado_redirected_to_list([], expected_num_events=1): + with mock.patch("zerver.lib.actions.send_event") as m: + m.side_effect = AssertionError( + "Events should be sent only after the transaction commits." + ) + do_add_submessage(hamlet.realm, hamlet.id, message_id, "whatever", "whatever") diff --git a/zerver/views/submessage.py b/zerver/views/submessage.py index 8e5e7cab3c..5a1d51ca75 100644 --- a/zerver/views/submessage.py +++ b/zerver/views/submessage.py @@ -1,4 +1,5 @@ import orjson +from django.db import transaction from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ @@ -10,6 +11,8 @@ from zerver.lib.validator import check_int from zerver.models import UserProfile +# transaction.atomic is required since we use FOR UPDATE queries in access_message. +@transaction.atomic @has_request_variables def process_submessage( request: HttpRequest, @@ -18,7 +21,7 @@ def process_submessage( msg_type: str = REQ(), content: str = REQ(), ) -> HttpResponse: - message, user_message = access_message(user_profile, message_id) + message, user_message = access_message(user_profile, message_id, lock_message=True) try: orjson.loads(content)