diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index 97bdc10aa3..1547e2b4d5 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -5,7 +5,7 @@ import copy import logging import re from collections.abc import Iterable, Mapping, Sequence -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field from email.headerregistry import Address from functools import cache from typing import TYPE_CHECKING, Any, Final, Literal, Optional, TypeAlias, Union @@ -1366,6 +1366,51 @@ class SendNotificationResponseData(TypedDict): realm_push_status: NotRequired[RealmPushStatusDict] +FCMPriority: TypeAlias = Literal["high", "normal"] +APNsPriority: TypeAlias = Literal[10, 5, 1] + + +@dataclass +class PushRequestBasePayload: + push_account_id: int + encrypted_data: str + + +@dataclass +class FCMPushRequest: + device_id: int + fcm_priority: FCMPriority + payload: PushRequestBasePayload + + +@dataclass +class APNsHTTPHeaders: + apns_priority: APNsPriority + apns_push_type: PushType + + +@dataclass +class APNsPayload(PushRequestBasePayload): + aps: dict[str, int | dict[str, str]] = field( + default_factory=lambda: {"mutable-content": 1, "alert": {"title": "New notification"}} + ) + + +@dataclass +class APNsPushRequest: + device_id: int + http_headers: APNsHTTPHeaders + payload: APNsPayload + + +def get_encrypted_data(payload_data_to_encrypt: dict[str, Any], public_key_str: str) -> str: + public_key = PublicKey(public_key_str.encode("utf-8"), Base64Encoder) + sealed_box = SealedBox(public_key) + encrypted_data_bytes = sealed_box.encrypt(orjson.dumps(payload_data_to_encrypt), Base64Encoder) + encrypted_data = encrypted_data_bytes.decode("utf-8") + return encrypted_data + + def send_push_notifications( user_profile: UserProfile, apns_payload_data_to_encrypt: dict[str, Any], @@ -1382,51 +1427,64 @@ def send_push_notifications( ) return - # Prepare payload with encrypted data to send. - device_id_to_encrypted_data: dict[str, str] = {} - for push_device in push_devices: - public_key_str: str = push_device.push_public_key - public_key = PublicKey(public_key_str.encode("utf-8"), Base64Encoder) - sealed_box = SealedBox(public_key) - - if push_device.token_kind == PushDevice.TokenKind.APNS: - encrypted_data_bytes = sealed_box.encrypt( - orjson.dumps(apns_payload_data_to_encrypt), Base64Encoder - ) - else: - encrypted_data_bytes = sealed_box.encrypt( - orjson.dumps(fcm_payload_data_to_encrypt), Base64Encoder - ) - - encrypted_data = encrypted_data_bytes.decode("utf-8") - assert push_device.bouncer_device_id is not None # for mypy - device_id_to_encrypted_data[str(push_device.bouncer_device_id)] = encrypted_data - # Note: The "Final" qualifier serves as a shorthand # for declaring that a variable is effectively Literal. fcm_priority: Final = "normal" if is_removal else "high" apns_priority: Final = 5 if is_removal else 10 apns_push_type = PushType.BACKGROUND if is_removal else PushType.ALERT + # Prepare payload to send. + push_requests: list[FCMPushRequest | APNsPushRequest] = [] + for push_device in push_devices: + assert push_device.bouncer_device_id is not None # for mypy + if push_device.token_kind == PushDevice.TokenKind.APNS: + apns_http_headers = APNsHTTPHeaders( + apns_priority=apns_priority, + apns_push_type=apns_push_type, + ) + encrypted_data = get_encrypted_data( + apns_payload_data_to_encrypt, + push_device.push_public_key, + ) + apns_payload = APNsPayload( + push_account_id=push_device.push_account_id, + encrypted_data=encrypted_data, + ) + apns_push_request = APNsPushRequest( + device_id=push_device.bouncer_device_id, + http_headers=apns_http_headers, + payload=apns_payload, + ) + push_requests.append(apns_push_request) + else: + encrypted_data = get_encrypted_data( + fcm_payload_data_to_encrypt, + push_device.push_public_key, + ) + fcm_payload = PushRequestBasePayload( + push_account_id=push_device.push_account_id, + encrypted_data=encrypted_data, + ) + fcm_push_request = FCMPushRequest( + device_id=push_device.bouncer_device_id, + fcm_priority=fcm_priority, + payload=fcm_payload, + ) + push_requests.append(fcm_push_request) + # Send push notification try: if settings.ZILENCER_ENABLED: from zilencer.lib.push_notifications import send_e2ee_push_notifications response_data: SendNotificationResponseData = send_e2ee_push_notifications( - device_id_to_encrypted_data, - fcm_priority=fcm_priority, - apns_priority=apns_priority, - apns_push_type=apns_push_type, + push_requests, realm=user_profile.realm, ) else: post_data = { "realm_uuid": str(user_profile.realm.uuid), - "device_id_to_encrypted_data": device_id_to_encrypted_data, - "fcm_priority": fcm_priority, - "apns_priority": apns_priority, - "apns_push_type": apns_push_type, + "push_requests": [asdict(push_request) for push_request in push_requests], } result = send_json_to_push_bouncer("POST", "push/e2ee/notify", post_data) assert isinstance(result["android_successfully_sent_count"], int) # for mypy diff --git a/zerver/tests/test_e2ee_push_notifications.py b/zerver/tests/test_e2ee_push_notifications.py index 98e4ccb625..6cf08d1d52 100644 --- a/zerver/tests/test_e2ee_push_notifications.py +++ b/zerver/tests/test_e2ee_push_notifications.py @@ -258,6 +258,45 @@ class SendPushNotificationTest(E2EEPushNotificationTestCase): zerver_logger.output[1], ) + def test_early_return_if_expired_time_set(self, unused_mock: mock.MagicMock) -> None: + aaron = self.example_user("aaron") + hamlet = self.example_user("hamlet") + + registered_device_apple, registered_device_android = ( + self.register_push_devices_for_notification() + ) + registered_device_apple.expired_time = datetime(2099, 4, 24, tzinfo=timezone.utc) + registered_device_android.expired_time = datetime(2099, 4, 24, tzinfo=timezone.utc) + registered_device_apple.save(update_fields=["expired_time"]) + registered_device_android.save(update_fields=["expired_time"]) + + self.assertEqual(PushDevice.objects.count(), 2) + + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + # Since 'expired_time' is set for concerned 'RemotePushDevice' rows, + # the bouncer will not attempt to send notification and instead returns + # a list of device IDs which server should erase on their own end. + with ( + mock.patch( + "zilencer.lib.push_notifications.send_e2ee_push_notification_apple" + ) as send_apple, + mock.patch( + "zilencer.lib.push_notifications.send_e2ee_push_notification_android" + ) as send_android, + ): + handle_push_notification(hamlet.id, missed_message) + + send_apple.assert_not_called() + send_android.assert_not_called() + self.assertEqual(PushDevice.objects.count(), 0) + @responses.activate @override_settings(ZILENCER_ENABLED=False) def test_success_self_hosted(self, unused_mock: mock.MagicMock) -> None: diff --git a/zilencer/lib/push_notifications.py b/zilencer/lib/push_notifications.py index 40bf0a6233..b88e4ffa61 100644 --- a/zilencer/lib/push_notifications.py +++ b/zilencer/lib/push_notifications.py @@ -1,15 +1,17 @@ import asyncio import logging from collections.abc import Iterable -from typing import Literal, TypeAlias +from dataclasses import asdict -from aioapns import NotificationRequest, PushType +from aioapns import NotificationRequest from django.utils.timezone import now as timezone_now from firebase_admin import exceptions as firebase_exceptions from firebase_admin import messaging as firebase_messaging from firebase_admin.messaging import UnregisteredError as FCMUnregisteredError from zerver.lib.push_notifications import ( + APNsPushRequest, + FCMPushRequest, SendNotificationResponseData, fcm_app, get_apns_context, @@ -20,9 +22,6 @@ from zilencer.models import RemotePushDevice, RemoteRealm logger = logging.getLogger(__name__) -FCMPriority: TypeAlias = Literal["high", "normal"] -APNsPriority: TypeAlias = Literal[10, 5, 1] - def send_e2ee_push_notification_apple( apns_requests: list[NotificationRequest], @@ -119,11 +118,8 @@ def send_e2ee_push_notification_android( def send_e2ee_push_notifications( - device_id_to_encrypted_data: dict[str, str], + push_requests: list[APNsPushRequest | FCMPushRequest], *, - fcm_priority: FCMPriority, - apns_priority: APNsPriority, - apns_push_type: PushType, realm: Realm | None = None, remote_realm: RemoteRealm | None = None, ) -> SendNotificationResponseData: @@ -131,13 +127,15 @@ def send_e2ee_push_notifications( import aioapns - device_ids = [int(device_id_str) for device_id_str in device_id_to_encrypted_data] + device_ids = {push_request.device_id for push_request in push_requests} remote_push_devices = RemotePushDevice.objects.filter( device_id__in=device_ids, expired_time__isnull=True, realm=realm, remote_realm=remote_realm ) - unexpired_remote_push_device_ids = { - remote_push_device.device_id for remote_push_device in remote_push_devices + device_id_to_remote_push_device = { + remote_push_device.device_id: remote_push_device + for remote_push_device in remote_push_devices } + unexpired_remote_push_device_ids = set(device_id_to_remote_push_device.keys()) # Device IDs which should be deleted on server. # Either the device ID is invalid or the token @@ -148,44 +146,35 @@ def send_e2ee_push_notifications( apns_requests = [] apns_remote_push_devices: list[RemotePushDevice] = [] - apns_base_message_payload = { - "aps": { - "mutable-content": 1, - "alert": { - "title": "New notification", - }, - }, - } fcm_requests = [] fcm_remote_push_devices: list[RemotePushDevice] = [] - for remote_push_device in remote_push_devices: - message_payload = { - "encrypted_data": device_id_to_encrypted_data[str(remote_push_device.device_id)], - "push_account_id": remote_push_device.push_account_id, - } + for push_request in push_requests: + device_id = push_request.device_id + if device_id not in unexpired_remote_push_device_ids: + continue + + remote_push_device = device_id_to_remote_push_device[device_id] if remote_push_device.token_kind == RemotePushDevice.TokenKind.APNS: - apns_message_payload = { - **apns_base_message_payload, - **message_payload, - } + assert isinstance(push_request, APNsPushRequest) apns_requests.append( aioapns.NotificationRequest( apns_topic=remote_push_device.ios_app_id, device_token=remote_push_device.token, - message=apns_message_payload, - priority=apns_priority, - push_type=apns_push_type, + message=asdict(push_request.payload), + priority=push_request.http_headers.apns_priority, + push_type=push_request.http_headers.apns_push_type, ) ) apns_remote_push_devices.append(remote_push_device) else: + assert isinstance(push_request, FCMPushRequest) fcm_requests.append( firebase_messaging.Message( - data=message_payload, + data=asdict(push_request.payload), token=remote_push_device.token, - android=firebase_messaging.AndroidConfig(priority=fcm_priority), + android=firebase_messaging.AndroidConfig(priority=push_request.fcm_priority), ) ) fcm_remote_push_devices.append(remote_push_device) diff --git a/zilencer/views.py b/zilencer/views.py index ffcfa0be38..6b7fcc381d 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -8,7 +8,6 @@ from uuid import UUID import orjson import requests.exceptions -from aioapns import PushType from django.conf import settings from django.core.exceptions import ValidationError from django.core.validators import URLValidator, validate_email @@ -56,6 +55,8 @@ from zerver.lib.exceptions import ( from zerver.lib.outgoing_http import OutgoingSession from zerver.lib.push_notifications import ( PUSH_REGISTRATION_LIVENESS_TIMEOUT, + APNsPushRequest, + FCMPushRequest, HostnameAlreadyInUseBouncerError, InvalidRemotePushDeviceTokenError, RealmPushStatusDict, @@ -93,7 +94,7 @@ from zilencer.auth import ( generate_registration_transfer_verification_secret, validate_registration_transfer_verification_secret, ) -from zilencer.lib.push_notifications import APNsPriority, FCMPriority, send_e2ee_push_notifications +from zilencer.lib.push_notifications import send_e2ee_push_notifications from zilencer.lib.remote_counts import MissingDataError from zilencer.models import ( RemoteInstallationCount, @@ -1807,10 +1808,7 @@ def remote_server_check_analytics(request: HttpRequest, server: RemoteZulipServe class SendE2EEPushNotificationPayload(BaseModel): realm_uuid: str - device_id_to_encrypted_data: dict[str, str] - fcm_priority: FCMPriority - apns_priority: APNsPriority - apns_push_type: PushType + push_requests: list[APNsPushRequest | FCMPushRequest] @typed_endpoint @@ -1837,21 +1835,18 @@ def remote_server_send_e2ee_push_notification( reason = push_status.message raise PushNotificationsDisallowedError(reason=reason) - device_id_to_encrypted_data = payload.device_id_to_encrypted_data + push_requests = payload.push_requests do_increment_logging_stat( remote_realm, COUNT_STATS["mobile_pushes_received::day"], None, timezone_now(), - increment=len(device_id_to_encrypted_data), + increment=len(push_requests), ) response_data = send_e2ee_push_notifications( - device_id_to_encrypted_data, - fcm_priority=payload.fcm_priority, - apns_priority=payload.apns_priority, - apns_push_type=payload.apns_push_type, + push_requests, remote_realm=remote_realm, )