message_flags: Remove inappropriate use of zerver.lib.timeout.

zerver.lib.timeout abuses asynchronous exceptions, so it’s only safe
to use on CPU computations with no side effects.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 95a1481f99)
This commit is contained in:
Anders Kaseorg
2024-04-18 10:25:46 -07:00
committed by Alex Vandiver
parent 82e3b33f1e
commit 7384f61556
3 changed files with 17 additions and 16 deletions

View File

@@ -1,3 +1,4 @@
import time
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Set
@@ -32,7 +33,11 @@ class ReadMessagesEvent:
flag: str = field(default="read", init=False)
def do_mark_all_as_read(user_profile: UserProfile) -> int:
def do_mark_all_as_read(
user_profile: UserProfile, *, timeout: Optional[float] = None
) -> Optional[int]:
start_time = time.monotonic()
# First, we clear mobile push notifications. This is safer in the
# event that the below logic times out and we're killed.
all_push_message_ids = (
@@ -49,6 +54,9 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int:
batch_size = 2000
count = 0
while True:
if timeout is not None and time.monotonic() >= start_time + timeout:
return None
with transaction.atomic(savepoint=False):
query = (
UserMessage.select_for_update_query()

View File

@@ -24,8 +24,7 @@ from zerver.lib.message import (
)
from zerver.lib.message_cache import MessageDict
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import get_subscription, timeout_mock
from zerver.lib.timeout import TimeoutExpiredError
from zerver.lib.test_helpers import get_subscription
from zerver.models import (
Message,
Recipient,
@@ -66,8 +65,7 @@ class FirstUnreadAnchorTests(ZulipTestCase):
self.login("hamlet")
# Mark all existing messages as read
with timeout_mock("zerver.views.message_flags"):
result = self.client_post("/json/mark_all_as_read")
result = self.client_post("/json/mark_all_as_read")
result_dict = self.assert_json_success(result)
self.assertTrue(result_dict["complete"])
@@ -127,8 +125,7 @@ class FirstUnreadAnchorTests(ZulipTestCase):
def test_visible_messages_use_first_unread_anchor(self) -> None:
self.login("hamlet")
with timeout_mock("zerver.views.message_flags"):
result = self.client_post("/json/mark_all_as_read")
result = self.client_post("/json/mark_all_as_read")
result_dict = self.assert_json_success(result)
self.assertTrue(result_dict["complete"])
@@ -659,8 +656,7 @@ class PushNotificationMarkReadFlowsTest(ZulipTestCase):
[third_message_id, fourth_message_id],
)
with timeout_mock("zerver.views.message_flags"):
result = self.client_post("/json/mark_all_as_read", {})
result = self.client_post("/json/mark_all_as_read", {})
self.assertEqual(self.get_mobile_push_notification_ids(user_profile), [])
mock_push_notifications.assert_called()
@@ -682,8 +678,7 @@ class MarkAllAsReadEndpointTest(ZulipTestCase):
.count()
)
self.assertNotEqual(unread_count, 0)
with timeout_mock("zerver.views.message_flags"):
result = self.client_post("/json/mark_all_as_read", {})
result = self.client_post("/json/mark_all_as_read", {})
result_dict = self.assert_json_success(result)
self.assertTrue(result_dict["complete"])
@@ -696,7 +691,7 @@ class MarkAllAsReadEndpointTest(ZulipTestCase):
def test_mark_all_as_read_timeout_response(self) -> None:
self.login("hamlet")
with mock.patch("zerver.views.message_flags.timeout", side_effect=TimeoutExpiredError):
with mock.patch("time.monotonic", side_effect=[10000, 10051]):
result = self.client_post("/json/mark_all_as_read", {})
result_dict = self.assert_json_success(result)
self.assertFalse(result_dict["complete"])

View File

@@ -18,7 +18,6 @@ from zerver.lib.narrow import (
from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.streams import access_stream_by_id
from zerver.lib.timeout import TimeoutExpiredError, timeout
from zerver.lib.topic import user_message_exists_for_topic
from zerver.lib.validator import check_bool, check_int, check_list, to_non_negative_int
from zerver.models import UserActivity, UserProfile
@@ -120,9 +119,8 @@ def update_message_flags_for_narrow(
@has_request_variables
def mark_all_as_read(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
request_notes = RequestNotes.get_notes(request)
try:
count = timeout(50, lambda: do_mark_all_as_read(user_profile))
except TimeoutExpiredError:
count = do_mark_all_as_read(user_profile, timeout=50)
if count is None:
return json_success(request, data={"complete": False})
log_data_str = f"[{count} updated]"