diff --git a/zerver/tests/test_e2ee_push_notifications.py b/zerver/tests/test_e2ee_push_notifications.py index de53fc6bc7..a96652a024 100644 --- a/zerver/tests/test_e2ee_push_notifications.py +++ b/zerver/tests/test_e2ee_push_notifications.py @@ -28,6 +28,7 @@ from zerver.lib.timestamp import datetime_to_timestamp from zerver.models import PushDevice, UserMessage from zerver.models.realms import get_realm from zerver.models.scheduled_jobs import NotificationTriggers +from zilencer.lib.push_notifications import SentPushNotificationResult from zilencer.models import RemoteRealm, RemoteRealmCount @@ -260,7 +261,11 @@ class SendPushNotificationTest(E2EEPushNotificationTestCase): with ( self.mock_fcm() as mock_fcm_messaging, mock.patch( - "zilencer.lib.push_notifications.send_e2ee_push_notification_apple", return_value=1 + "zilencer.lib.push_notifications.send_e2ee_push_notification_apple", + return_value=SentPushNotificationResult( + successfully_sent_count=1, + delete_device_ids=[], + ), ), self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, self.assertLogs("zilencer.lib.push_notifications", level="WARNING") as zilencer_logger, diff --git a/zilencer/lib/push_notifications.py b/zilencer/lib/push_notifications.py index b88e4ffa61..c3adafbb0f 100644 --- a/zilencer/lib/push_notifications.py +++ b/zilencer/lib/push_notifications.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Iterable -from dataclasses import asdict +from dataclasses import asdict, dataclass from aioapns import NotificationRequest from django.utils.timezone import now as timezone_now @@ -23,14 +23,20 @@ from zilencer.models import RemotePushDevice, RemoteRealm logger = logging.getLogger(__name__) +@dataclass +class SentPushNotificationResult: + successfully_sent_count: int + delete_device_ids: list[int] + + def send_e2ee_push_notification_apple( apns_requests: list[NotificationRequest], apns_remote_push_devices: list[RemotePushDevice], - delete_device_ids: list[int], -) -> int: +) -> SentPushNotificationResult: import aioapns successfully_sent_count = 0 + delete_device_ids: list[int] = [] apns_context = get_apns_context() if apns_context is None: @@ -38,7 +44,10 @@ def send_e2ee_push_notification_apple( "APNs: Dropping a notification because nothing configured. " "Set ZULIP_SERVICES_URL (or APNS_CERT_FILE)." ) - return successfully_sent_count + return SentPushNotificationResult( + successfully_sent_count=successfully_sent_count, + delete_device_ids=delete_device_ids, + ) async def send_all_notifications() -> Iterable[ tuple[RemotePushDevice, aioapns.common.NotificationResult | BaseException] @@ -66,21 +75,28 @@ def send_e2ee_push_notification_apple( remote_push_device.save(update_fields=["expired_time"]) delete_device_ids.append(result_info.delete_device_id) - return successfully_sent_count + return SentPushNotificationResult( + successfully_sent_count=successfully_sent_count, + delete_device_ids=delete_device_ids, + ) def send_e2ee_push_notification_android( fcm_requests: list[firebase_messaging.Message], fcm_remote_push_devices: list[RemotePushDevice], - delete_device_ids: list[int], -) -> int: +) -> SentPushNotificationResult: + successfully_sent_count = 0 + delete_device_ids: list[int] = [] + try: batch_response = firebase_messaging.send_each(fcm_requests, app=fcm_app) except firebase_exceptions.FirebaseError: logger.warning("Error while pushing to FCM", exc_info=True) - return 0 + return SentPushNotificationResult( + successfully_sent_count=successfully_sent_count, + delete_device_ids=delete_device_ids, + ) - successfully_sent_count = 0 for idx, response in enumerate(batch_response.responses): # We enumerate to have idx to track which token the response # corresponds to. send_each() preserves the order of the messages, @@ -114,7 +130,10 @@ def send_e2ee_push_notification_android( error, ) - return successfully_sent_count + return SentPushNotificationResult( + successfully_sent_count=successfully_sent_count, + delete_device_ids=delete_device_ids, + ) def send_e2ee_push_notifications( @@ -181,15 +200,21 @@ def send_e2ee_push_notifications( apple_successfully_sent_count = 0 if len(apns_requests) > 0: - apple_successfully_sent_count = send_e2ee_push_notification_apple( - apns_requests, apns_remote_push_devices, delete_device_ids + sent_push_notification_result = send_e2ee_push_notification_apple( + apns_requests, + apns_remote_push_devices, ) + apple_successfully_sent_count = sent_push_notification_result.successfully_sent_count + delete_device_ids.extend(sent_push_notification_result.delete_device_ids) android_successfully_sent_count = 0 if len(fcm_requests) > 0: - android_successfully_sent_count = send_e2ee_push_notification_android( - fcm_requests, fcm_remote_push_devices, delete_device_ids + sent_push_notification_result = send_e2ee_push_notification_android( + fcm_requests, + fcm_remote_push_devices, ) + android_successfully_sent_count = sent_push_notification_result.successfully_sent_count + delete_device_ids.extend(sent_push_notification_result.delete_device_ids) return { "apple_successfully_sent_count": apple_successfully_sent_count,