diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 3358aaad65..05295e444a 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional from unittest.mock import MagicMock, patch import orjson +import time_machine from django.conf import settings from django.db.utils import IntegrityError from django.test import override_settings @@ -159,10 +160,6 @@ class WorkerTest(ZulipTestCase): events = [hamlet_event1, hamlet_event2, othello_event] - fake_client = FakeClient() - for event in events: - fake_client.enqueue("missedmessage_emails", event) - mmw = MissedMessageWorker() batch_duration = datetime.timedelta( seconds=hamlet.email_notifications_batching_period_seconds @@ -172,21 +169,6 @@ class WorkerTest(ZulipTestCase): == othello.email_notifications_batching_period_seconds ) - class MockTimer: - is_running = False - - def is_alive(self) -> bool: - return self.is_running - - def start(self) -> None: - self.is_running = True - - timer = MockTimer() - timer_mock = patch( - "zerver.worker.queue_processors.Timer", - return_value=timer, - ) - send_mock = patch( "zerver.lib.email_notifications.do_send_missedmessage_events_reply_in_zulip", ) @@ -206,106 +188,110 @@ class WorkerTest(ZulipTestCase): self.assertEqual(row.scheduled_timestamp, scheduled_timestamp) self.assertEqual(row.mentioned_user_group_id, mentioned_user_group_id) - with send_mock as sm, timer_mock as tm: - with simulated_queue_client(fake_client): - self.assertFalse(timer.is_alive()) + def advance() -> Optional[float]: + mmw.stopping = False - time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc) - expected_scheduled_timestamp = time_zero + batch_duration - with patch("zerver.worker.queue_processors.timezone_now", return_value=time_zero): - mmw.setup() - mmw.start() + def inner(check: Callable[[], bool], timeout: Optional[float]) -> bool: + # The check should never pass, since we've just (with + # the lock) ascertained above the cv.wait that its + # conditions are not met. + self.assertFalse(check()) - # The events should be saved in the database - hamlet_row1 = ScheduledMessageNotificationEmail.objects.get( - user_profile_id=hamlet.id, message_id=hamlet1_msg_id - ) - check_row(hamlet_row1, expected_scheduled_timestamp, None) + # Set ourself to stop at the top of the next loop, but + # pretend we didn't get an event + mmw.stopping = True + return False - hamlet_row2 = ScheduledMessageNotificationEmail.objects.get( - user_profile_id=hamlet.id, message_id=hamlet2_msg_id - ) - check_row(hamlet_row2, expected_scheduled_timestamp, 4) + with patch.object(mmw.cv, "wait_for", side_effect=inner): + mmw.work() + return mmw.has_timeout - othello_row1 = ScheduledMessageNotificationEmail.objects.get( - user_profile_id=othello.id, message_id=othello_msg_id - ) - check_row(othello_row1, expected_scheduled_timestamp, None) + # With nothing enqueued, the condition variable is pending + # forever. We double-check that the condition is false in + # steady-state. + has_timeout = advance() + self.assertFalse(has_timeout) - # Additionally, the timer should have be started - self.assertTrue(timer.is_alive()) + # Enqueues the events to the internal queue, as if from RabbitMQ + time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc) + with time_machine.travel(time_zero, tick=False), patch.object( + mmw.cv, "notify" + ) as notify_mock: + for event in events: + mmw.consume_single_event(event) + # All of these notify, because has_timeout is still false in + # each case. This represents multiple consume() calls getting + # the lock before the worker escapes the wait_for, and is + # unlikely in real life but does not lead to incorrect + # behaviour. + self.assertEqual(notify_mock.call_count, 3) - # If another event is received, test that it gets saved with the same - # `expected_scheduled_timestamp` as the earlier events. - fake_client.enqueue("missedmessage_emails", bonus_event_hamlet) - self.assertTrue(timer.is_alive()) - few_moments_later = time_zero + datetime.timedelta(seconds=3) - with patch( - "zerver.worker.queue_processors.timezone_now", return_value=few_moments_later - ): - # Double-calling start is our way to get it to run again - mmw.start() - hamlet_row3 = ScheduledMessageNotificationEmail.objects.get( - user_profile_id=hamlet.id, message_id=hamlet3_msg_id - ) - check_row(hamlet_row3, expected_scheduled_timestamp, None) + # This leaves a timeout set, since there are objects pending + with time_machine.travel(time_zero, tick=False): + has_timeout = advance() + self.assertTrue(has_timeout) - # Now let us test `maybe_send_batched_emails` - # If called too early, it shouldn't process the emails. - one_minute_premature = expected_scheduled_timestamp - datetime.timedelta(seconds=60) - with patch( - "zerver.worker.queue_processors.timezone_now", return_value=one_minute_premature - ): - mmw.maybe_send_batched_emails() - self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 4) + expected_scheduled_timestamp = time_zero + batch_duration - # If called after `expected_scheduled_timestamp`, it should process all emails. - one_minute_overdue = expected_scheduled_timestamp + datetime.timedelta(seconds=60) - with self.assertLogs(level="INFO") as info_logs, patch( - "zerver.worker.queue_processors.timezone_now", return_value=one_minute_overdue - ): - mmw.maybe_send_batched_emails() - self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + # The events should be saved in the database + hamlet_row1 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet1_msg_id + ) + check_row(hamlet_row1, expected_scheduled_timestamp, None) - self.assert_length(info_logs.output, 2) - self.assertIn( - f"INFO:root:Batch-processing 3 missedmessage_emails events for user {hamlet.id}", - info_logs.output, - ) - self.assertIn( - f"INFO:root:Batch-processing 1 missedmessage_emails events for user {othello.id}", - info_logs.output, - ) + hamlet_row2 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet2_msg_id + ) + check_row(hamlet_row2, expected_scheduled_timestamp, 4) - # All batches got processed. Verify that the timer isn't running. - self.assertEqual(mmw.timer_event, None) + othello_row1 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=othello.id, message_id=othello_msg_id + ) + check_row(othello_row1, expected_scheduled_timestamp, None) - # Hacky test coming up! We want to test the try-except block in the consumer which handles - # IntegrityErrors raised when the message was deleted before it processed the notification - # event. - # However, Postgres defers checking ForeignKey constraints to when the current transaction - # commits. This poses some difficulties in testing because of Django running tests inside a - # transaction which never commits. See https://code.djangoproject.com/ticket/22431 for more - # details, but the summary is that IntegrityErrors due to database constraints are raised at - # the end of the test, not inside the `try` block. So, we have the code inside the `try` block - # raise `IntegrityError` by mocking. - def raise_error(**kwargs: Any) -> None: - raise IntegrityError + # If another event is received, test that it gets saved with the same + # `expected_scheduled_timestamp` as the earlier events. - fake_client.enqueue("missedmessage_emails", hamlet_event1) + few_moments_later = time_zero + datetime.timedelta(seconds=3) + with time_machine.travel(few_moments_later, tick=False), patch.object( + mmw.cv, "notify" + ) as notify_mock: + mmw.consume_single_event(bonus_event_hamlet) + self.assertEqual(notify_mock.call_count, 0) - with patch( - "zerver.models.ScheduledMessageNotificationEmail.objects.create", - side_effect=raise_error, - ), self.assertLogs(level="DEBUG") as debug_logs: - mmw.start() - self.assertIn( - "DEBUG:root:ScheduledMessageNotificationEmail row could not be created. The message may have been deleted. Skipping event.", - debug_logs.output, - ) + with time_machine.travel(few_moments_later, tick=False): + has_timeout = advance() + self.assertTrue(has_timeout) + hamlet_row3 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet3_msg_id + ) + check_row(hamlet_row3, expected_scheduled_timestamp, None) - # Check that the frequency of calling maybe_send_batched_emails is correct (5 seconds) - self.assertEqual(tm.call_args[0][0], 5) + # Now let us test `maybe_send_batched_emails` + # If called too early, it shouldn't process the emails. + one_minute_premature = expected_scheduled_timestamp - datetime.timedelta(seconds=60) + with time_machine.travel(one_minute_premature, tick=False): + has_timeout = advance() + self.assertTrue(has_timeout) + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 4) + + # If called after `expected_scheduled_timestamp`, it should process all emails. + one_minute_overdue = expected_scheduled_timestamp + datetime.timedelta(seconds=60) + with time_machine.travel(one_minute_overdue, tick=True): + with send_mock as sm, self.assertLogs(level="INFO") as info_logs: + has_timeout = advance() + self.assertTrue(has_timeout) + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + has_timeout = advance() + self.assertFalse(has_timeout) + + self.assertEqual( + [ + f"INFO:root:Batch-processing 3 missedmessage_emails events for user {hamlet.id}", + f"INFO:root:Batch-processing 1 missedmessage_emails events for user {othello.id}", + ], + info_logs.output, + ) # Verify the payloads now args = [c[0] for c in sm.call_args_list] @@ -331,32 +317,62 @@ class WorkerTest(ZulipTestCase): {"where art thou, othello?"}, ) - with send_mock as sm, timer_mock as tm: - with simulated_queue_client(fake_client): - time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc) - # Verify that we make forward progress if one of the messages throws an exception - fake_client.enqueue("missedmessage_emails", hamlet_event1) - fake_client.enqueue("missedmessage_emails", hamlet_event2) - fake_client.enqueue("missedmessage_emails", othello_event) - with patch("zerver.worker.queue_processors.timezone_now", return_value=time_zero): - mmw.setup() - mmw.start() + # Hacky test coming up! We want to test the try-except block in the consumer which handles + # IntegrityErrors raised when the message was deleted before it processed the notification + # event. + # However, Postgres defers checking ForeignKey constraints to when the current transaction + # commits. This poses some difficulties in testing because of Django running tests inside a + # transaction which never commits. See https://code.djangoproject.com/ticket/22431 for more + # details, but the summary is that IntegrityErrors due to database constraints are raised at + # the end of the test, not inside the `try` block. So, we have the code inside the `try` block + # raise `IntegrityError` by mocking. + with patch( + "zerver.models.ScheduledMessageNotificationEmail.objects.create", + side_effect=IntegrityError, + ), self.assertLogs(level="DEBUG") as debug_logs, patch.object( + mmw.cv, "notify" + ) as notify_mock: + mmw.consume_single_event(hamlet_event1) + self.assertEqual(notify_mock.call_count, 0) + self.assertIn( + "DEBUG:root:ScheduledMessageNotificationEmail row could not be created. The message may have been deleted. Skipping event.", + debug_logs.output, + ) - def fail_some(user: UserProfile, *args: Any) -> None: - if user.id == hamlet.id: - raise RuntimeError + # Verify that we make forward progress if one of the messages + # throws an exception. First, enqueue the messages, and get + # them to create database rows: + time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc) + with time_machine.travel(time_zero, tick=False), patch.object( + mmw.cv, "notify" + ) as notify_mock: + mmw.consume_single_event(hamlet_event1) + mmw.consume_single_event(hamlet_event2) + mmw.consume_single_event(othello_event) + # See above note about multiple notifies + self.assertEqual(notify_mock.call_count, 3) + has_timeout = advance() + self.assertTrue(has_timeout) - sm.side_effect = fail_some - one_minute_overdue = expected_scheduled_timestamp + datetime.timedelta(seconds=60) - with patch( - "zerver.worker.queue_processors.timezone_now", return_value=one_minute_overdue - ), self.assertLogs(level="ERROR") as error_logs: - mmw.maybe_send_batched_emails() - self.assertIn( - "ERROR:root:Failed to process 2 missedmessage_emails for user 10", - error_logs.output[0], - ) - self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + # Next, set up a fail-y consumer: + def fail_some(user: UserProfile, *args: Any) -> None: + if user.id == hamlet.id: + raise RuntimeError + + one_minute_overdue = expected_scheduled_timestamp + datetime.timedelta(seconds=60) + with time_machine.travel(one_minute_overdue, tick=False), self.assertLogs( + level="ERROR" + ) as error_logs, send_mock as sm: + sm.side_effect = fail_some + has_timeout = advance() + self.assertTrue(has_timeout) + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + has_timeout = advance() + self.assertFalse(has_timeout) + self.assertIn( + "ERROR:root:Failed to process 2 missedmessage_emails for user 10", + error_logs.output[0], + ) def test_push_notifications_worker(self) -> None: """ diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 47e8f85c5f..b04eafaf56 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -10,13 +10,13 @@ import os import signal import socket import tempfile +import threading import time import urllib from abc import ABC, abstractmethod from collections import deque from email.message import EmailMessage from functools import wraps -from threading import RLock, Timer from types import FrameType from typing import ( Any, @@ -566,48 +566,72 @@ class UserPresenceWorker(QueueProcessingWorker): @assign_queue("missedmessage_emails") class MissedMessageWorker(QueueProcessingWorker): - # Aggregate all messages received over the last BATCH_DURATION - # seconds to let someone finish sending a batch of messages and/or - # editing them before they are sent out as emails to recipients. + # Aggregate all messages received over the last several seconds + # (configurable by each recipient) to let someone finish sending a + # batch of messages and/or editing them before they are sent out + # as emails to recipients. # - # The timer is running whenever; we poll at most every TIMER_FREQUENCY - # seconds, to avoid excessive activity. - TIMER_FREQUENCY = 5 - timer_event: Optional[Timer] = None + # The batch interval is best-effort -- we poll at most every + # CHECK_FREQUENCY_SECONDS, to avoid excessive activity. + CHECK_FREQUENCY_SECONDS = 5 - # This lock protects access to all of the data structures declared - # above. A lock is required because maybe_send_batched_emails, as - # the argument to Timer, runs in a separate thread from the rest - # of the consumer. This is a _re-entrant_ lock because we may - # need to take the lock when we already have it during shutdown - # (see the stop method). - lock = RLock() + worker_thread: Optional[threading.Thread] = None - # Because the background `maybe_send_batched_email` thread can - # hold the lock for an indeterminate amount of time, the `consume` - # can block on that for longer than 30s, the default worker - # timeout. Allow arbitrarily-long worker `consume` calls. - MAX_CONSUME_SECONDS = None + # This condition variable mediates the stopping and has_timeout + # pieces of state, below it. + cv = threading.Condition() + stopping = False + has_timeout = False + # The main thread, which handles the RabbitMQ connection and creates + # database rows from them. def consume(self, event: Dict[str, Any]) -> None: - with self.lock: - logging.debug("Received missedmessage_emails event: %s", event) + logging.debug("Processing missedmessage_emails event: %s", event) + # When we consume an event, check if there are existing pending emails + # for that user, and if so use the same scheduled timestamp. + user_profile_id: int = event["user_profile_id"] + user_profile = get_user_profile_by_id(user_profile_id) + batch_duration_seconds = user_profile.email_notifications_batching_period_seconds + batch_duration = datetime.timedelta(seconds=batch_duration_seconds) - # When we consume an event, check if there are existing pending emails - # for that user, and if so use the same scheduled timestamp. - user_profile_id: int = event["user_profile_id"] - user_profile = get_user_profile_by_id(user_profile_id) - batch_duration_seconds = user_profile.email_notifications_batching_period_seconds - batch_duration = datetime.timedelta(seconds=batch_duration_seconds) - - try: - pending_email = ScheduledMessageNotificationEmail.objects.filter( - user_profile_id=user_profile_id - )[0] - scheduled_timestamp = pending_email.scheduled_timestamp - except IndexError: - scheduled_timestamp = timezone_now() + batch_duration + try: + pending_email = ScheduledMessageNotificationEmail.objects.filter( + user_profile_id=user_profile_id + )[0] + scheduled_timestamp = pending_email.scheduled_timestamp + except IndexError: + scheduled_timestamp = timezone_now() + batch_duration + with self.cv: + # We now hold the lock, so there are three places the + # worker thread can be: + # + # 1. In maybe_send_batched_emails, and will have to take + # the lock (and thus block insertions of new rows + # here) to decide if there are any rows and if it thus + # needs a timeout. + # + # 2. In the cv.wait_for with a timeout because there were + # rows already. There's nothing for us to do, since + # the newly-inserted row will get checked upon that + # timeout. + # + # 3. In the cv.wait_for without a timeout, because there + # weren't any rows (which we're about to change). + # + # Notifying in (1) is irrelevant, since the thread is not + # waiting. If we over-notify by doing so for both (2) and + # (3), the behaviour is correct but slightly inefficient, + # as the thread will be needlessly awoken and will just + # re-wait. However, if we fail to awake case (3), the + # worker thread will never wake up, and the + # ScheduledMessageNotificationEmail internal queue will + # back up. + # + # Use the self.has_timeout property (which is protected by + # the lock) to determine which of cases (2) or (3) we are + # in, and as such if we need to notify after making the + # row. try: ScheduledMessageNotificationEmail.objects.create( user_profile_id=user_profile_id, @@ -616,108 +640,126 @@ class MissedMessageWorker(QueueProcessingWorker): scheduled_timestamp=scheduled_timestamp, mentioned_user_group_id=event.get("mentioned_user_group_id"), ) - - self.ensure_timer() + if not self.has_timeout: + self.cv.notify() except IntegrityError: logging.debug( "ScheduledMessageNotificationEmail row could not be created. The message may have been deleted. Skipping event." ) - def ensure_timer(self) -> None: - # The caller is responsible for ensuring self.lock is held when it calls this. - if self.timer_event is not None: - return + def start(self) -> None: + with self.cv: + self.stopping = False + self.worker_thread = threading.Thread(target=lambda: self.work()) + self.worker_thread.start() + super().start() - self.timer_event = Timer( - self.TIMER_FREQUENCY, MissedMessageWorker.maybe_send_batched_emails, [self] - ) - self.timer_event.start() + def work(self) -> None: + while True: + with self.cv: + if self.stopping: + return + # There are three conditions which we wait for: + # + # 1. We are being explicitly asked to stop; see the + # notify() call in stop() + # + # 2. We have no ScheduledMessageNotificationEmail + # objects currently (has_timeout = False) and the + # first one was just enqueued; see the notify() + # call in consume(). We break out so that we can + # come back around the loop and re-wait with a + # timeout (see next condition). + # + # 3. One or more ScheduledMessageNotificationEmail + # exist in the database, so we need to re-check + # them regularly; this happens by hitting the + # timeout and calling maybe_send_batched_emails(). + # There is no explicit notify() for this. + timeout: Optional[int] = None + if ScheduledMessageNotificationEmail.objects.exists(): + timeout = self.CHECK_FREQUENCY_SECONDS + self.has_timeout = timeout is not None + + def wait_condition() -> bool: + if self.stopping: + # Condition (1) + return True + if timeout is None: + # Condition (2). We went to sleep with no + # ScheduledMessageNotificationEmail existing, + # and one has just been made. We re-check + # that is still true now that we have the + # lock, and if we see it, we stop waiting. + return ScheduledMessageNotificationEmail.objects.exists() + # This should only happen at the start or end of + # the wait, when we haven't been notified, but are + # re-checking the condition. + return False + + was_notified = self.cv.wait_for(wait_condition, timeout=timeout) + + # Being notified means that we are in conditions (1) or + # (2), above. In neither case do we need to look at if + # there are batches to send -- (2) means that the + # ScheduledMessageNotificationEmail was _just_ created, so + # there is no need to check it now. + if not was_notified: + self.maybe_send_batched_emails() def maybe_send_batched_emails(self) -> None: - with self.lock: - # self.timer_event just triggered execution of this - # function in a thread, so now that we hold the lock, we - # clear the timer_event attribute to record that no Timer - # is active. If it is already None, stop() is shutting us - # down. - if self.timer_event is None: - return - self.timer_event = None + current_time = timezone_now() - current_time = timezone_now() + with transaction.atomic(): + events_to_process = ScheduledMessageNotificationEmail.objects.filter( + scheduled_timestamp__lte=current_time + ).select_related() - with transaction.atomic(): - events_to_process = ScheduledMessageNotificationEmail.objects.filter( - scheduled_timestamp__lte=current_time - ).select_related() + # Batch the entries by user + events_by_recipient: Dict[int, List[Dict[str, Any]]] = {} + for event in events_to_process: + entry = dict( + user_profile_id=event.user_profile_id, + message_id=event.message_id, + trigger=event.trigger, + mentioned_user_group_id=event.mentioned_user_group_id, + ) + if event.user_profile_id in events_by_recipient: + events_by_recipient[event.user_profile_id].append(entry) + else: + events_by_recipient[event.user_profile_id] = [entry] - # Batch the entries by user - events_by_recipient: Dict[int, List[Dict[str, Any]]] = {} - for event in events_to_process: - entry = dict( - user_profile_id=event.user_profile_id, - message_id=event.message_id, - trigger=event.trigger, - mentioned_user_group_id=event.mentioned_user_group_id, - ) - if event.user_profile_id in events_by_recipient: - events_by_recipient[event.user_profile_id].append(entry) - else: - events_by_recipient[event.user_profile_id] = [entry] + for user_profile_id in events_by_recipient: + events: List[Dict[str, Any]] = events_by_recipient[user_profile_id] - for user_profile_id in events_by_recipient: - events: List[Dict[str, Any]] = events_by_recipient[user_profile_id] - - logging.info( - "Batch-processing %s missedmessage_emails events for user %s", + logging.info( + "Batch-processing %s missedmessage_emails events for user %s", + len(events), + user_profile_id, + ) + try: + # Because we process events in batches, an + # escaped exception here would lead to + # duplicate messages being sent for other + # users in the same events_to_process batch, + # and no guarantee of forward progress. + handle_missedmessage_emails(user_profile_id, events) + except Exception: + logging.exception( + "Failed to process %d missedmessage_emails for user %s", len(events), user_profile_id, + stack_info=True, ) - try: - # Because we process events in batches, an - # escaped exception here would lead to - # duplicate messages being sent for other - # users in the same events_to_process batch, - # and no guarantee of forward progress. - handle_missedmessage_emails(user_profile_id, events) - except Exception: - logging.exception( - "Failed to process %d missedmessage_emails for user %s", - len(events), - user_profile_id, - stack_info=True, - ) - events_to_process.delete() - - # By only restarting the timer if there are actually events in - # the queue, we ensure this queue processor is idle when there - # are no missed-message emails to process. This avoids - # constant CPU usage when there is no work to do. - if ScheduledMessageNotificationEmail.objects.exists(): - self.ensure_timer() + events_to_process.delete() def stop(self) -> None: - # This may be called from a signal handler when we _already_ - # have the lock. Python doesn't give us a way to check if our - # thread has the lock, so we instead use a re-entrant lock to - # always take it. - with self.lock: - # With the lock,we can safely inspect the timer_event and - # cancel it if it is still pending. - if self.timer_event is not None: - # We cancel and then join the timer with a timeout to - # prevent deadlock, where we took the lock, the timer - # then ran out and started maybe_send_batched_emails, - # and then it started waiting for the lock. The timer - # isn't running anymore so can't be canceled, and the - # thread is blocked on the lock, so will never join(). - self.timer_event.cancel() - self.timer_event.join(timeout=1) - # In case we did hit this deadlock, we signal to - # maybe_send_batched_emails that it should abort by, - # before releasing the lock, unsetting the timer. - self.timer_event = None + with self.cv: + self.stopping = True + self.cv.notify() + if self.worker_thread is not None: + self.worker_thread.join() super().stop()