From 7e1afa0e8a45aa9925de5799316c94d84256dd00 Mon Sep 17 00:00:00 2001 From: Prakhar Pratyush Date: Fri, 4 Jul 2025 12:59:36 +0530 Subject: [PATCH] push_notification: Send end-to-end encrypted push notifications. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds support to send encrypted push notifications to devices registered to receive encrypted notifications. URL: `POST /api/v1/remotes/push/e2ee/notify` payload: `realm_uuid` and `device_id_to_encrypted_data` The POST request needs to be authenticated with the server’s API key. Note: For Zulip Cloud, a background fact about the push bouncer is that it runs on the same server and database as the main application; it’s not a separate service. So, as an optimization we directly call 'send_e2ee_push_notifications' function and skip the HTTP request. --- zerver/lib/push_notifications.py | 337 ++++++++++--- zerver/lib/remote_server.py | 2 + zerver/lib/test_classes.py | 95 +++- ...r_pushdevice_user_bouncer_device_id_idx.py | 23 + zerver/models/push_notifications.py | 9 + zerver/tests/test_e2ee_push_notifications.py | 460 ++++++++++++++++++ zerver/tests/test_handle_push_notification.py | 6 +- zilencer/lib/push_notifications.py | 208 ++++++++ zilencer/urls.py | 2 + zilencer/views.py | 63 +++ 10 files changed, 1134 insertions(+), 71 deletions(-) create mode 100644 zerver/migrations/0741_pushdevice_zerver_pushdevice_user_bouncer_device_id_idx.py create mode 100644 zerver/tests/test_e2ee_push_notifications.py create mode 100644 zilencer/lib/push_notifications.py diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index 27d7033ba2..0782a8a542 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, Union import lxml.html import orjson +from aioapns.common import NotificationResult from django.conf import settings from django.db import transaction from django.db.models import F, Q @@ -25,7 +26,9 @@ from firebase_admin import exceptions as firebase_exceptions from firebase_admin import initialize_app as firebase_initialize_app from firebase_admin import messaging as firebase_messaging from firebase_admin.messaging import UnregisteredError as FCMUnregisteredError -from typing_extensions import TypedDict, override +from nacl.encoding import Base64Encoder +from nacl.public import PublicKey, SealedBox +from typing_extensions import NotRequired, TypedDict, override from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat from zerver.actions.realm_settings import ( @@ -35,7 +38,7 @@ from zerver.actions.realm_settings import ( from zerver.lib.avatar import absolute_avatar_url, get_avatar_for_inaccessible_user from zerver.lib.display_recipient import get_display_recipient from zerver.lib.emoji_utils import hex_codepoint_to_emoji -from zerver.lib.exceptions import ErrorCode, JsonableError +from zerver.lib.exceptions import ErrorCode, JsonableError, MissingRemoteRealmError from zerver.lib.message import access_message_and_usermessage, direct_message_group_users from zerver.lib.notification_data import get_mentioned_user_group from zerver.lib.remote_server import ( @@ -72,7 +75,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) if settings.ZILENCER_ENABLED: - from zilencer.models import RemotePushDeviceToken, RemoteZulipServer + from zilencer.models import RemotePushDevice, RemotePushDeviceToken, RemoteZulipServer # Time (in seconds) for which the server should retry registering # a push device to the bouncer. 24 hrs is a good time limit because @@ -252,6 +255,43 @@ def dedupe_device_tokens( return result +@dataclass +class APNsResultInfo: + successfully_sent: bool + delete_device_id: int | None = None + delete_device_token: str | None = None + + +def get_info_from_apns_result( + result: NotificationResult | BaseException, + device: "DeviceToken | RemotePushDevice", + log_context: str, +) -> APNsResultInfo: + import aioapns.exceptions + + result_info = APNsResultInfo(successfully_sent=False) + + if isinstance(result, aioapns.exceptions.ConnectionError): + logger.error("APNs: ConnectionError sending %s; check certificate expiration", log_context) + elif isinstance(result, BaseException): + logger.error("APNs: Error sending %s", log_context, exc_info=result) + elif result.is_successful: + result_info.successfully_sent = True + logger.info("APNs: Success sending %s", log_context) + elif result.description in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]: + logger.info( + "APNs: Removing invalid/expired token %s (%s)", device.token, result.description + ) + if isinstance(device, RemotePushDevice): + result_info.delete_device_id = device.device_id + else: + result_info.delete_device_token = device.token + else: + logger.warning("APNs: Failed to send %s: %s", log_context, result.description) + + return result_info + + def send_apple_push_notification( user_identity: UserPushIdentityCompat, devices: Sequence[DeviceToken], @@ -265,7 +305,6 @@ def send_apple_push_notification( # notification queue worker, it's best to only import them in the # code that needs them. import aioapns - import aioapns.exceptions apns_context = get_apns_context() if apns_context is None: @@ -337,40 +376,18 @@ def send_apple_push_notification( successfully_sent_count = 0 for device, result in results: - if isinstance(result, aioapns.exceptions.ConnectionError): - logger.error( - "APNs: ConnectionError sending for user %s to device %s; check certificate expiration", - user_identity, - device.token, - ) - elif isinstance(result, BaseException): - logger.error( - "APNs: Error sending for user %s to device %s", - user_identity, - device.token, - exc_info=result, - ) - elif result.is_successful: + log_context = f"for user {user_identity} to device {device.token}" + result_info = get_info_from_apns_result(result, device, log_context) + + if result_info.successfully_sent: successfully_sent_count += 1 - logger.info( - "APNs: Success sending for user %s to device %s", user_identity, device.token - ) - elif result.description in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]: - logger.info( - "APNs: Removing invalid/expired token %s (%s)", device.token, result.description - ) + elif result_info.delete_device_token is not None: # We remove all entries for this token (There # could be multiple for different Zulip servers). DeviceTokenClass._default_manager.alias(lower_token=Lower("token")).filter( - lower_token=device.token.lower(), kind=DeviceTokenClass.APNS + lower_token=result_info.delete_device_token.lower(), + kind=DeviceTokenClass.APNS, ).delete() - else: - logger.warning( - "APNs: Failed to send for user %s to device %s: %s", - user_identity, - device.token, - result.description, - ) return successfully_sent_count @@ -1131,6 +1148,35 @@ def get_apns_badge_count_future( ) +def get_apns_payload_data_to_encrypt( + user_profile: UserProfile, + message: Message, + trigger: str, + mentioned_user_group_id: int | None = None, + mentioned_user_group_name: str | None = None, + can_access_sender: bool = True, +) -> dict[str, Any]: + zulip_data = get_message_payload( + user_profile, message, mentioned_user_group_id, mentioned_user_group_name, can_access_sender + ) + zulip_data.update( + message_ids=[message.id], + ) + + assert message.rendered_content is not None + with override_language(user_profile.default_language): + content, _ = truncate_content(get_mobile_push_content(message.rendered_content)) + + zulip_data["alert_title"] = get_apns_alert_title(message, user_profile.default_language) + zulip_data["alert_subtitle"] = get_apns_alert_subtitle( + message, trigger, user_profile, mentioned_user_group_name, can_access_sender + ) + zulip_data["alert_body"] = content + zulip_data["badge"] = get_apns_badge_count(user_profile) + + return zulip_data + + def get_message_payload_apns( user_profile: UserProfile, message: Message, @@ -1315,6 +1361,179 @@ def handle_remove_push_notification(user_profile_id: int, message_ids: list[int] ).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification)) +def send_push_notifications_legacy( + user_profile: UserProfile, + apns_payload: dict[str, Any], + gcm_payload: dict[str, Any], + gcm_options: dict[str, Any], +) -> None: + android_devices = list( + PushDeviceToken.objects.filter(user=user_profile, kind=PushDeviceToken.FCM).order_by("id") + ) + apple_devices = list( + PushDeviceToken.objects.filter(user=user_profile, kind=PushDeviceToken.APNS).order_by("id") + ) + + if uses_notification_bouncer(): + send_notifications_to_bouncer( + user_profile, apns_payload, gcm_payload, gcm_options, android_devices, apple_devices + ) + return + + logger.info( + "Sending mobile push notifications for local user %s: %s via FCM devices, %s via APNs devices", + user_profile.id, + len(android_devices), + len(apple_devices), + ) + user_identity = UserPushIdentityCompat(user_id=user_profile.id) + + apple_successfully_sent_count = send_apple_push_notification( + user_identity, apple_devices, apns_payload + ) + android_successfully_sent_count = send_android_push_notification( + user_identity, android_devices, gcm_payload, gcm_options + ) + + do_increment_logging_stat( + user_profile.realm, + COUNT_STATS["mobile_pushes_sent::day"], + None, + timezone_now(), + increment=apple_successfully_sent_count + android_successfully_sent_count, + ) + + +class RealmPushStatusDict(TypedDict): + can_push: bool + expected_end_timestamp: int | None + + +class SendNotificationResponseData(TypedDict): + android_successfully_sent_count: int + apple_successfully_sent_count: int + delete_device_ids: list[int] + realm_push_status: NotRequired[RealmPushStatusDict] + + +def send_push_notifications( + user_profile: UserProfile, + apns_payload_data_to_encrypt: dict[str, Any], + fcm_payload_data_to_encrypt: dict[str, Any], +) -> None: + # Uses 'zerver_pushdevice_user_bouncer_device_id_idx' index. + push_devices = PushDevice.objects.filter(user=user_profile, bouncer_device_id__isnull=False) + + if len(push_devices) == 0: + logger.info( + "Skipping E2EE push notifications for user %s because there are no registered devices", + user_profile.id, + ) + return + + # Prepare payload with encrypted data to send. + device_id_to_encrypted_data: dict[str, str] = {} + for push_device in push_devices: + public_key_str: str = push_device.push_public_key + public_key = PublicKey(public_key_str.encode("utf-8"), Base64Encoder) + sealed_box = SealedBox(public_key) + + if push_device.token_kind == PushDevice.TokenKind.APNS: + encrypted_data_bytes = sealed_box.encrypt( + orjson.dumps(apns_payload_data_to_encrypt), Base64Encoder + ) + else: + encrypted_data_bytes = sealed_box.encrypt( + orjson.dumps(fcm_payload_data_to_encrypt), Base64Encoder + ) + + encrypted_data = encrypted_data_bytes.decode("utf-8") + assert push_device.bouncer_device_id is not None # for mypy + device_id_to_encrypted_data[str(push_device.bouncer_device_id)] = encrypted_data + + # Send push notification + try: + if settings.ZILENCER_ENABLED: + from zilencer.lib.push_notifications import send_e2ee_push_notifications + + response_data: SendNotificationResponseData = send_e2ee_push_notifications( + device_id_to_encrypted_data, realm=user_profile.realm + ) + else: + post_data = { + "realm_uuid": str(user_profile.realm.uuid), + "device_id_to_encrypted_data": device_id_to_encrypted_data, + } + result = send_json_to_push_bouncer("POST", "push/e2ee/notify", post_data) + assert isinstance(result["android_successfully_sent_count"], int) # for mypy + assert isinstance(result["apple_successfully_sent_count"], int) # for mypy + assert isinstance(result["delete_device_ids"], list) # for mypy + assert isinstance(result["realm_push_status"], dict) # for mypy + response_data = { + "android_successfully_sent_count": result["android_successfully_sent_count"], + "apple_successfully_sent_count": result["apple_successfully_sent_count"], + "delete_device_ids": result["delete_device_ids"], + "realm_push_status": result["realm_push_status"], # type: ignore[typeddict-item] # TODO: Can't use isinstance() with TypedDict type + } + except (MissingRemoteRealmError, PushNotificationsDisallowedByBouncerError) as e: + reason = e.reason if isinstance(e, PushNotificationsDisallowedByBouncerError) else e.msg + logger.warning("Bouncer refused to send E2EE push notification: %s", reason) + do_set_realm_property( + user_profile.realm, + "push_notifications_enabled", + False, + acting_user=None, + ) + do_set_push_notifications_enabled_end_timestamp(user_profile.realm, None, acting_user=None) + return + + # Handle success response data + delete_device_ids = response_data["delete_device_ids"] + apple_successfully_sent_count = response_data["apple_successfully_sent_count"] + android_successfully_sent_count = response_data["android_successfully_sent_count"] + + if len(delete_device_ids) > 0: + logger.info( + "Deleting PushDevice rows with the following device IDs based on response from bouncer: %s", + sorted(delete_device_ids), + ) + # Filtering on `user_profile` is not necessary here, we do it to take + # advantage of 'zerver_pushdevice_user_bouncer_device_id_idx' index. + PushDevice.objects.filter( + user=user_profile, bouncer_device_id__in=delete_device_ids + ).delete() + + do_increment_logging_stat( + user_profile.realm, + COUNT_STATS["mobile_pushes_sent::day"], + None, + timezone_now(), + increment=apple_successfully_sent_count + android_successfully_sent_count, + ) + + logger.info( + "Sent E2EE mobile push notifications for user %s: %s via FCM, %s via APNs", + user_profile.id, + android_successfully_sent_count, + apple_successfully_sent_count, + ) + + realm_push_status_dict = response_data.get("realm_push_status") + if realm_push_status_dict is not None: + can_push = realm_push_status_dict["can_push"] + do_set_realm_property( + user_profile.realm, + "push_notifications_enabled", + can_push, + acting_user=None, + ) + do_set_push_notifications_enabled_end_timestamp( + user_profile.realm, realm_push_status_dict["expected_end_timestamp"], acting_user=None + ) + if can_push: + record_push_notifications_recently_working() + + def handle_push_notification(user_profile_id: int, missed_message: dict[str, Any]) -> None: """ missed_message is the event received by the @@ -1437,43 +1656,25 @@ def handle_push_notification(user_profile_id: int, missed_message: dict[str, Any gcm_payload, gcm_options = get_message_payload_gcm( user_profile, message, mentioned_user_group_id, mentioned_user_group_name, can_access_sender ) + apns_payload_data_to_encrypt = get_apns_payload_data_to_encrypt( + user_profile, + message, + trigger, + mentioned_user_group_id, + mentioned_user_group_name, + can_access_sender, + ) logger.info("Sending push notifications to mobile clients for user %s", user_profile_id) - android_devices = list( - PushDeviceToken.objects.filter(user=user_profile, kind=PushDeviceToken.FCM).order_by("id") - ) - - apple_devices = list( - PushDeviceToken.objects.filter(user=user_profile, kind=PushDeviceToken.APNS).order_by("id") - ) - if uses_notification_bouncer(): - send_notifications_to_bouncer( - user_profile, apns_payload, gcm_payload, gcm_options, android_devices, apple_devices - ) - return - - logger.info( - "Sending mobile push notifications for local user %s: %s via FCM devices, %s via APNs devices", - user_profile_id, - len(android_devices), - len(apple_devices), - ) - user_identity = UserPushIdentityCompat(user_id=user_profile.id) - - apple_successfully_sent_count = send_apple_push_notification( - user_identity, apple_devices, apns_payload - ) - android_successfully_sent_count = send_android_push_notification( - user_identity, android_devices, gcm_payload, gcm_options - ) - - do_increment_logging_stat( - user_profile.realm, - COUNT_STATS["mobile_pushes_sent::day"], - None, - timezone_now(), - increment=apple_successfully_sent_count + android_successfully_sent_count, - ) + # TODO: We plan to offer a personal, realm-level, and server-level setting + # to require all notifications to be end-to-end encrypted. When either setting + # is enabled, we skip calling 'send_push_notifications_legacy'. + send_push_notifications_legacy(user_profile, apns_payload, gcm_payload, gcm_options) + if settings.DEVELOPMENT: + # TODO: Remove the 'settings.DEVELOPMENT' check when mobile clients start + # to offer a way to register for E2EE push notifications; otherwise it'll + # do needless DB query and logging. + send_push_notifications(user_profile, apns_payload_data_to_encrypt, gcm_payload) def send_test_push_notification_directly_to_devices( diff --git a/zerver/lib/remote_server.py b/zerver/lib/remote_server.py index dcffbcd3b1..c27078c5f3 100644 --- a/zerver/lib/remote_server.py +++ b/zerver/lib/remote_server.py @@ -225,6 +225,8 @@ def send_to_push_bouncer( raise RequestExpiredError elif endpoint == "push/e2ee/register" and code == "MISSING_REMOTE_REALM": raise MissingRemoteRealmError + elif endpoint == "push/e2ee/notify" and code == "MISSING_REMOTE_REALM": + raise MissingRemoteRealmError else: # But most other errors coming from the push bouncer # server are client errors (e.g. never-registered token) diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index ca0348b221..bbd67c2a4a 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -95,6 +95,7 @@ from zerver.models import ( Client, Message, NamedUserGroup, + PushDevice, PushDeviceToken, Reaction, Realm, @@ -116,7 +117,13 @@ from zerver.openapi.openapi import validate_test_request, validate_test_response from zerver.tornado.event_queue import clear_client_event_queues_for_testing if settings.ZILENCER_ENABLED: - from zilencer.models import RemotePushDeviceToken, RemoteZulipServer, get_remote_server_by_uuid + from zilencer.models import ( + RemotePushDevice, + RemotePushDeviceToken, + RemoteRealm, + RemoteZulipServer, + get_remote_server_by_uuid, + ) if TYPE_CHECKING: from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse @@ -2835,3 +2842,89 @@ class PushNotificationTestCase(BouncerTestCase): ) -> firebase_messaging.BatchResponse: error_response = firebase_messaging.SendResponse(exception=exception, resp=None) return firebase_messaging.BatchResponse([error_response]) + + +class E2EEPushNotificationTestCase(BouncerTestCase): + def register_push_devices_for_notification( + self, is_server_self_hosted: bool = False + ) -> tuple[RemotePushDevice, RemotePushDevice]: + hamlet = self.example_user("hamlet") + realm = hamlet.realm + + # Hamlet registers both an Android and an Apple device for push notification. + PushDevice.objects.create( + user=hamlet, + push_account_id=10, + bouncer_device_id=1, + token_kind=PushDevice.TokenKind.APNS, + push_public_key="9VvW7k59AET0v3+VFCkKTrNm5DJQ7JTKdvUjZInZZ0Y=", + ) + PushDevice.objects.create( + user=hamlet, + push_account_id=20, + bouncer_device_id=2, + token_kind=PushDevice.TokenKind.FCM, + push_public_key="n4WTVqj8KH6u0vScRycR4TqRaHhFeJ0POvMb8LCu8iI=", + ) + + realm_and_remote_realm_fields: dict[str, Realm | RemoteRealm | None] = { + "realm": realm, + "remote_realm": None, + } + if is_server_self_hosted: + remote_realm = RemoteRealm.objects.get(uuid=realm.uuid) + realm_and_remote_realm_fields = {"realm": None, "remote_realm": remote_realm} + + registered_device_apple = RemotePushDevice.objects.create( + push_account_id=10, + device_id=1, + token_kind=RemotePushDevice.TokenKind.APNS, + token="push-device-token-1", + ios_app_id="abc", + **realm_and_remote_realm_fields, + ) + registered_device_android = RemotePushDevice.objects.create( + push_account_id=20, + device_id=2, + token_kind=RemotePushDevice.TokenKind.FCM, + token="push-device-token-3", + **realm_and_remote_realm_fields, + ) + + return registered_device_apple, registered_device_android + + @contextmanager + def mock_fcm(self) -> Iterator[mock.MagicMock]: + with mock.patch("zilencer.lib.push_notifications.firebase_messaging") as mock_fcm_messaging: + yield mock_fcm_messaging + + @contextmanager + def mock_apns(self) -> Iterator[mock.AsyncMock]: + apns = mock.Mock(spec=aioapns.APNs) + apns.send_notification = mock.AsyncMock() + apns_context = APNsContext( + apns=apns, + loop=asyncio.new_event_loop(), + ) + try: + with mock.patch("zilencer.lib.push_notifications.get_apns_context") as mock_get: + mock_get.return_value = apns_context + yield apns.send_notification + finally: + apns_context.loop.close() + + def make_fcm_success_response(self) -> firebase_messaging.BatchResponse: + device_ids_count = RemotePushDevice.objects.filter( + token_kind=RemotePushDevice.TokenKind.FCM + ).count() + responses = [ + firebase_messaging.SendResponse(exception=None, resp=dict(name=str(idx))) + for idx in range(device_ids_count) + ] + return firebase_messaging.BatchResponse(responses) + + def make_fcm_error_response( + self, exception: firebase_exceptions.FirebaseError + ) -> firebase_messaging.BatchResponse: + error_response = firebase_messaging.SendResponse(exception=exception, resp=None) + return firebase_messaging.BatchResponse([error_response]) diff --git a/zerver/migrations/0741_pushdevice_zerver_pushdevice_user_bouncer_device_id_idx.py b/zerver/migrations/0741_pushdevice_zerver_pushdevice_user_bouncer_device_id_idx.py new file mode 100644 index 0000000000..35301052c5 --- /dev/null +++ b/zerver/migrations/0741_pushdevice_zerver_pushdevice_user_bouncer_device_id_idx.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.4 on 2025-07-22 11:57 + +from django.contrib.postgres.operations import AddIndexConcurrently +from django.db import migrations, models + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ("zerver", "0740_pushdevicetoken_apns_case_insensitive"), + ] + + operations = [ + AddIndexConcurrently( + model_name="pushdevice", + index=models.Index( + condition=models.Q(("bouncer_device_id__isnull", False)), + fields=["user", "bouncer_device_id"], + name="zerver_pushdevice_user_bouncer_device_id_idx", + ), + ), + ] diff --git a/zerver/models/push_notifications.py b/zerver/models/push_notifications.py index ca504ee136..1f5c3a8e0d 100644 --- a/zerver/models/push_notifications.py +++ b/zerver/models/push_notifications.py @@ -124,6 +124,15 @@ class PushDevice(AbstractPushDevice): name="unique_push_device_user_push_account_id", ), ] + indexes = [ + models.Index( + # Used in 'send_push_notifications' function, + # in 'zerver/lib/push_notifications'. + fields=["user", "bouncer_device_id"], + condition=Q(bouncer_device_id__isnull=False), + name="zerver_pushdevice_user_bouncer_device_id_idx", + ), + ] @property def status(self) -> Literal["active", "pending", "failed"]: diff --git a/zerver/tests/test_e2ee_push_notifications.py b/zerver/tests/test_e2ee_push_notifications.py new file mode 100644 index 0000000000..057d5f72dc --- /dev/null +++ b/zerver/tests/test_e2ee_push_notifications.py @@ -0,0 +1,460 @@ +from datetime import datetime, timezone +from unittest import mock + +import responses +from django.test import override_settings +from firebase_admin.exceptions import InternalError +from firebase_admin.messaging import UnregisteredError + +from analytics.models import RealmCount +from zerver.lib.push_notifications import handle_push_notification +from zerver.lib.test_classes import E2EEPushNotificationTestCase +from zerver.lib.test_helpers import activate_push_notification_service +from zerver.models import PushDevice +from zerver.models.scheduled_jobs import NotificationTriggers +from zilencer.models import RemoteRealm, RemoteRealmCount + + +@activate_push_notification_service() +@mock.patch("zerver.lib.push_notifications.send_push_notifications_legacy") +class SendPushNotificationTest(E2EEPushNotificationTestCase): + def test_success_cloud(self, unused_mock: mock.MagicMock) -> None: + hamlet = self.example_user("hamlet") + aaron = self.example_user("aaron") + + registered_device_apple, registered_device_android = ( + self.register_push_devices_for_notification() + ) + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + self.assertEqual(RealmCount.objects.count(), 0) + + with ( + self.mock_fcm() as mock_fcm_messaging, + self.mock_apns() as send_notification, + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.lib.push_notifications", level="INFO") as zilencer_logger, + ): + mock_fcm_messaging.send_each.return_value = self.make_fcm_success_response() + send_notification.return_value.is_successful = True + + handle_push_notification(hamlet.id, missed_message) + + mock_fcm_messaging.send_each.assert_called_once() + send_notification.assert_called_once() + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"APNs: Success sending to (push_account_id={registered_device_apple.push_account_id}, device={registered_device_apple.token})", + zerver_logger.output[1], + ) + self.assertEqual( + "INFO:zilencer.lib.push_notifications:" + f"FCM: Sent message with ID: 0 to (push_account_id={registered_device_android.push_account_id}, device={registered_device_android.token})", + zilencer_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sent E2EE mobile push notifications for user {hamlet.id}: 1 via FCM, 1 via APNs", + zerver_logger.output[2], + ) + + realm_count_dict = ( + RealmCount.objects.filter(property="mobile_pushes_sent::day") + .values("subgroup", "value") + .last() + ) + self.assertEqual(realm_count_dict, dict(subgroup=None, value=2)) + + def test_no_registered_device(self, unused_mock: mock.MagicMock) -> None: + aaron = self.example_user("aaron") + hamlet = self.example_user("hamlet") + + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + with self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger: + handle_push_notification(hamlet.id, missed_message) + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Skipping E2EE push notifications for user {hamlet.id} because there are no registered devices", + zerver_logger.output[1], + ) + + def test_invalid_or_expired_token(self, unused_mock: mock.MagicMock) -> None: + aaron = self.example_user("aaron") + hamlet = self.example_user("hamlet") + + registered_device_apple, registered_device_android = ( + self.register_push_devices_for_notification() + ) + self.assertIsNone(registered_device_apple.expired_time) + self.assertIsNone(registered_device_android.expired_time) + self.assertEqual(PushDevice.objects.count(), 2) + + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + with ( + self.mock_fcm() as mock_fcm_messaging, + self.mock_apns() as send_notification, + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.lib.push_notifications", level="INFO") as zilencer_logger, + ): + mock_fcm_messaging.send_each.return_value = self.make_fcm_error_response( + UnregisteredError("Token expired") + ) + send_notification.return_value.is_successful = False + send_notification.return_value.description = "BadDeviceToken" + + handle_push_notification(hamlet.id, missed_message) + + mock_fcm_messaging.send_each.assert_called_once() + send_notification.assert_called_once() + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"APNs: Removing invalid/expired token {registered_device_apple.token} (BadDeviceToken)", + zerver_logger.output[1], + ) + self.assertEqual( + "INFO:zilencer.lib.push_notifications:" + f"FCM: Removing {registered_device_android.token} due to NOT_FOUND", + zilencer_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Deleting PushDevice rows with the following device IDs based on response from bouncer: [{registered_device_apple.device_id}, {registered_device_android.device_id}]", + zerver_logger.output[2], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sent E2EE mobile push notifications for user {hamlet.id}: 0 via FCM, 0 via APNs", + zerver_logger.output[3], + ) + + # Verify `expired_time` set for `RemotePushDevice` entries + # and corresponding `PushDevice` deleted on server. + registered_device_apple.refresh_from_db() + registered_device_android.refresh_from_db() + self.assertIsNotNone(registered_device_apple.expired_time) + self.assertIsNotNone(registered_device_android.expired_time) + self.assertEqual(PushDevice.objects.count(), 0) + + def test_fcm_apns_error(self, unused_mock: mock.MagicMock) -> None: + hamlet = self.example_user("hamlet") + aaron = self.example_user("aaron") + + unused, registered_device_android = self.register_push_devices_for_notification() + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + # `get_apns_context` returns `None` + FCM returns error other than UnregisteredError. + with ( + self.mock_fcm() as mock_fcm_messaging, + mock.patch("zilencer.lib.push_notifications.get_apns_context", return_value=None), + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.lib.push_notifications", level="DEBUG") as zilencer_logger, + ): + mock_fcm_messaging.send_each.return_value = self.make_fcm_error_response( + InternalError("fcm-error") + ) + + handle_push_notification(hamlet.id, missed_message) + + mock_fcm_messaging.send_each.assert_called_once() + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "DEBUG:zilencer.lib.push_notifications:" + "APNs: Dropping a notification because nothing configured. " + "Set ZULIP_SERVICES_URL (or APNS_CERT_FILE).", + zilencer_logger.output[0], + ) + self.assertIn( + "WARNING:zilencer.lib.push_notifications:" + f"FCM: Delivery failed for (push_account_id={registered_device_android.push_account_id}, device={registered_device_android.token})", + zilencer_logger.output[1], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sent E2EE mobile push notifications for user {hamlet.id}: 0 via FCM, 0 via APNs", + zerver_logger.output[1], + ) + + # `firebase_messaging.send_each` raises Error. + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + with ( + self.mock_fcm() as mock_fcm_messaging, + mock.patch( + "zilencer.lib.push_notifications.send_e2ee_push_notification_apple", return_value=1 + ), + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.lib.push_notifications", level="WARNING") as zilencer_logger, + ): + mock_fcm_messaging.send_each.side_effect = InternalError("server error") + handle_push_notification(hamlet.id, missed_message) + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertIn( + "WARNING:zilencer.lib.push_notifications:Error while pushing to FCM", + zilencer_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sent E2EE mobile push notifications for user {hamlet.id}: 0 via FCM, 1 via APNs", + zerver_logger.output[1], + ) + + @activate_push_notification_service() + @responses.activate + @override_settings(ZILENCER_ENABLED=False) + def test_success_self_hosted(self, unused_mock: mock.MagicMock) -> None: + self.add_mock_response() + + hamlet = self.example_user("hamlet") + aaron = self.example_user("aaron") + realm = hamlet.realm + + registered_device_apple, registered_device_android = ( + self.register_push_devices_for_notification(is_server_self_hosted=True) + ) + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + # Setup to verify whether these fields get updated correctly. + realm.push_notifications_enabled = False + realm.push_notifications_enabled_end_timestamp = datetime(2099, 4, 24, tzinfo=timezone.utc) + realm.save( + update_fields=["push_notifications_enabled", "push_notifications_enabled_end_timestamp"] + ) + + self.assertEqual(RealmCount.objects.count(), 0) + self.assertEqual(RemoteRealmCount.objects.count(), 0) + + with ( + self.mock_fcm() as mock_fcm_messaging, + self.mock_apns() as send_notification, + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + return_value=10, + ), + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.lib.push_notifications", level="INFO") as zilencer_logger, + ): + mock_fcm_messaging.send_each.return_value = self.make_fcm_success_response() + send_notification.return_value.is_successful = True + + handle_push_notification(hamlet.id, missed_message) + + mock_fcm_messaging.send_each.assert_called_once() + send_notification.assert_called_once() + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"APNs: Success sending to (push_account_id={registered_device_apple.push_account_id}, device={registered_device_apple.token})", + zerver_logger.output[1], + ) + self.assertEqual( + "INFO:zilencer.lib.push_notifications:" + f"FCM: Sent message with ID: 0 to (push_account_id={registered_device_android.push_account_id}, device={registered_device_android.token})", + zilencer_logger.output[0], + ) + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sent E2EE mobile push notifications for user {hamlet.id}: 1 via FCM, 1 via APNs", + zerver_logger.output[2], + ) + + realm_count_dict = ( + RealmCount.objects.filter(property="mobile_pushes_sent::day") + .values("subgroup", "value") + .last() + ) + self.assertEqual(realm_count_dict, dict(subgroup=None, value=2)) + + remote_realm_count_dict = ( + RemoteRealmCount.objects.filter(property="mobile_pushes_received::day") + .values("subgroup", "value") + .last() + ) + self.assertEqual(remote_realm_count_dict, dict(subgroup=None, value=2)) + + remote_realm_count_dict = ( + RemoteRealmCount.objects.filter(property="mobile_pushes_forwarded::day") + .values("subgroup", "value") + .last() + ) + self.assertEqual(remote_realm_count_dict, dict(subgroup=None, value=2)) + + realm.refresh_from_db() + self.assertTrue(realm.push_notifications_enabled) + self.assertIsNone(realm.push_notifications_enabled_end_timestamp) + + @activate_push_notification_service() + @responses.activate + @override_settings(ZILENCER_ENABLED=False) + def test_missing_remote_realm_error(self, unused_mock: mock.MagicMock) -> None: + self.add_mock_response() + + hamlet = self.example_user("hamlet") + aaron = self.example_user("aaron") + realm = hamlet.realm + + self.register_push_devices_for_notification(is_server_self_hosted=True) + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + # Setup to verify whether these fields get updated correctly. + realm.push_notifications_enabled = True + realm.push_notifications_enabled_end_timestamp = datetime(2099, 4, 24, tzinfo=timezone.utc) + realm.save( + update_fields=["push_notifications_enabled", "push_notifications_enabled_end_timestamp"] + ) + + # To replicate missing remote realm + RemoteRealm.objects.all().delete() + + with ( + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + self.assertLogs("zilencer.views", level="INFO") as zilencer_logger, + ): + handle_push_notification(hamlet.id, missed_message) + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "INFO:zilencer.views:" + f"/api/v1/remotes/push/e2ee/notify: Received request for unknown realm {realm.uuid}, server {self.server.id}", + zilencer_logger.output[0], + ) + self.assertEqual( + "WARNING:zerver.lib.push_notifications:" + "Bouncer refused to send E2EE push notification: Organization not registered", + zerver_logger.output[1], + ) + + realm.refresh_from_db() + self.assertFalse(realm.push_notifications_enabled) + self.assertIsNone(realm.push_notifications_enabled_end_timestamp) + + @activate_push_notification_service() + @responses.activate + @override_settings(ZILENCER_ENABLED=False) + def test_no_plan_error(self, unused_mock: mock.MagicMock) -> None: + self.add_mock_response() + + hamlet = self.example_user("hamlet") + aaron = self.example_user("aaron") + realm = hamlet.realm + + self.register_push_devices_for_notification(is_server_self_hosted=True) + message_id = self.send_personal_message( + from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True + ) + missed_message = { + "message_id": message_id, + "trigger": NotificationTriggers.DIRECT_MESSAGE, + } + + # Setup to verify whether these fields get updated correctly. + realm.push_notifications_enabled = True + realm.push_notifications_enabled_end_timestamp = datetime(2099, 4, 24, tzinfo=timezone.utc) + realm.save( + update_fields=["push_notifications_enabled", "push_notifications_enabled_end_timestamp"] + ) + + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + return_value=100, + ), + self.assertLogs("zerver.lib.push_notifications", level="INFO") as zerver_logger, + ): + handle_push_notification(hamlet.id, missed_message) + + self.assertEqual( + "INFO:zerver.lib.push_notifications:" + f"Sending push notifications to mobile clients for user {hamlet.id}", + zerver_logger.output[0], + ) + self.assertEqual( + "WARNING:zerver.lib.push_notifications:" + "Bouncer refused to send E2EE push notification: Your plan doesn't allow sending push notifications. " + "Reason provided by the server: Push notifications access with 10+ users requires signing up for a plan. https://zulip.com/plans/", + zerver_logger.output[1], + ) + + realm.refresh_from_db() + self.assertFalse(realm.push_notifications_enabled) + self.assertIsNone(realm.push_notifications_enabled_end_timestamp) diff --git a/zerver/tests/test_handle_push_notification.py b/zerver/tests/test_handle_push_notification.py index 509359abac..20e769b910 100644 --- a/zerver/tests/test_handle_push_notification.py +++ b/zerver/tests/test_handle_push_notification.py @@ -148,7 +148,8 @@ class HandlePushNotificationTest(PushNotificationTestCase): @activate_push_notification_service() @responses.activate - def test_end_to_end_failure_due_to_no_plan(self) -> None: + @mock.patch("zerver.lib.push_notifications.send_push_notifications") + def test_end_to_end_failure_due_to_no_plan(self, unused_mock: mock.MagicMock) -> None: self.add_mock_response() self.setup_apns_tokens() @@ -480,7 +481,8 @@ class HandlePushNotificationTest(PushNotificationTestCase): ], ) - def test_send_notifications_to_bouncer(self) -> None: + @mock.patch("zerver.lib.push_notifications.send_push_notifications") + def test_send_notifications_to_bouncer(self, unused_mock: mock.MagicMock) -> None: self.setup_apns_tokens() self.setup_fcm_tokens() diff --git a/zilencer/lib/push_notifications.py b/zilencer/lib/push_notifications.py new file mode 100644 index 0000000000..38a4dcaf47 --- /dev/null +++ b/zilencer/lib/push_notifications.py @@ -0,0 +1,208 @@ +import asyncio +import logging +from collections.abc import Iterable + +from aioapns import NotificationRequest +from django.utils.timezone import now as timezone_now +from firebase_admin import exceptions as firebase_exceptions +from firebase_admin import messaging as firebase_messaging +from firebase_admin.messaging import UnregisteredError as FCMUnregisteredError + +from zerver.lib.push_notifications import ( + SendNotificationResponseData, + fcm_app, + get_apns_context, + get_info_from_apns_result, +) +from zerver.models.realms import Realm +from zilencer.models import RemotePushDevice, RemoteRealm + +logger = logging.getLogger(__name__) + + +def send_e2ee_push_notification_apple( + apns_requests: list[NotificationRequest], + apns_remote_push_devices: list[RemotePushDevice], + delete_device_ids: list[int], +) -> int: + import aioapns + + successfully_sent_count = 0 + apns_context = get_apns_context() + + if apns_context is None: + logger.debug( + "APNs: Dropping a notification because nothing configured. " + "Set ZULIP_SERVICES_URL (or APNS_CERT_FILE)." + ) + return successfully_sent_count + + async def send_all_notifications() -> Iterable[ + tuple[RemotePushDevice, aioapns.common.NotificationResult | BaseException] + ]: + results = await asyncio.gather( + *(apns_context.apns.send_notification(request) for request in apns_requests), + return_exceptions=True, + ) + return zip(apns_remote_push_devices, results, strict=False) + + results = apns_context.loop.run_until_complete(send_all_notifications()) + + for remote_push_device, result in results: + log_context = f"to (push_account_id={remote_push_device.push_account_id}, device={remote_push_device.token})" + result_info = get_info_from_apns_result( + result, + remote_push_device, + log_context, + ) + + if result_info.successfully_sent: + successfully_sent_count += 1 + elif result_info.delete_device_id is not None: + remote_push_device.expired_time = timezone_now() + remote_push_device.save(update_fields=["expired_time"]) + delete_device_ids.append(result_info.delete_device_id) + + return successfully_sent_count + + +def send_e2ee_push_notification_android( + fcm_requests: list[firebase_messaging.Message], + fcm_remote_push_devices: list[RemotePushDevice], + delete_device_ids: list[int], +) -> int: + try: + batch_response = firebase_messaging.send_each(fcm_requests, app=fcm_app) + except firebase_exceptions.FirebaseError: + logger.warning("Error while pushing to FCM", exc_info=True) + return 0 + + successfully_sent_count = 0 + for idx, response in enumerate(batch_response.responses): + # We enumerate to have idx to track which token the response + # corresponds to. send_each() preserves the order of the messages, + # so this works. + + remote_push_device = fcm_remote_push_devices[idx] + token = remote_push_device.token + push_account_id = remote_push_device.push_account_id + if response.success: + successfully_sent_count += 1 + logger.info( + "FCM: Sent message with ID: %s to (push_account_id=%s, device=%s)", + response.message_id, + push_account_id, + token, + ) + else: + error = response.exception + if isinstance(error, FCMUnregisteredError): + remote_push_device.expired_time = timezone_now() + remote_push_device.save(update_fields=["expired_time"]) + delete_device_ids.append(remote_push_device.device_id) + + logger.info("FCM: Removing %s due to %s", token, error.code) + else: + logger.warning( + "FCM: Delivery failed for (push_account_id=%s, device=%s): %s:%s", + push_account_id, + token, + error.__class__, + error, + ) + + return successfully_sent_count + + +def send_e2ee_push_notifications( + device_id_to_encrypted_data: dict[str, str], + *, + realm: Realm | None = None, + remote_realm: RemoteRealm | None = None, +) -> SendNotificationResponseData: + assert (realm is None) ^ (remote_realm is None) + + import aioapns + + device_ids = [int(device_id_str) for device_id_str in device_id_to_encrypted_data] + remote_push_devices = RemotePushDevice.objects.filter( + device_id__in=device_ids, expired_time__isnull=True, realm=realm, remote_realm=remote_realm + ) + unexpired_remote_push_device_ids = { + remote_push_device.device_id for remote_push_device in remote_push_devices + } + + # Device IDs which should be deleted on server. + # Either the device ID is invalid or the token + # associated has been marked invalid/expired by APNs/FCM. + delete_device_ids = list( + filter(lambda device_id: device_id not in unexpired_remote_push_device_ids, device_ids) + ) + + apns_requests = [] + apns_remote_push_devices: list[RemotePushDevice] = [] + apns_base_message_payload = { + "aps": { + "mutable-content": 1, + "alert": { + "title": "New notification", + }, + # TODO: Should we remove `sound` and let the clients add it. + # Then we can rename it as `apns_required_message_payload`. + "sound": "default", + }, + } + + fcm_requests = [] + fcm_remote_push_devices: list[RemotePushDevice] = [] + + # TODO: "normal" if remove event. + priority = "high" + + for remote_push_device in remote_push_devices: + message_payload = { + "encrypted_data": device_id_to_encrypted_data[str(remote_push_device.device_id)], + "push_account_id": remote_push_device.push_account_id, + } + if remote_push_device.token_kind == RemotePushDevice.TokenKind.APNS: + apns_message_payload = { + **apns_base_message_payload, + **message_payload, + } + apns_requests.append( + aioapns.NotificationRequest( + apns_topic=remote_push_device.ios_app_id, + device_token=remote_push_device.token, + message=apns_message_payload, + time_to_live=24 * 3600, + # TODO: priority + ) + ) + apns_remote_push_devices.append(remote_push_device) + else: + fcm_requests.append( + firebase_messaging.Message( + data=message_payload, + token=remote_push_device.token, + android=firebase_messaging.AndroidConfig(priority=priority), + ) + ) + fcm_remote_push_devices.append(remote_push_device) + + apple_successfully_sent_count = 0 + if len(apns_requests) > 0: + apple_successfully_sent_count = send_e2ee_push_notification_apple( + apns_requests, apns_remote_push_devices, delete_device_ids + ) + + android_successfully_sent_count = 0 + if len(fcm_requests) > 0: + android_successfully_sent_count = send_e2ee_push_notification_android( + fcm_requests, fcm_remote_push_devices, delete_device_ids + ) + + return { + "apple_successfully_sent_count": apple_successfully_sent_count, + "android_successfully_sent_count": android_successfully_sent_count, + "delete_device_ids": delete_device_ids, + } diff --git a/zilencer/urls.py b/zilencer/urls.py index fcf8f3602b..34f26bff87 100644 --- a/zilencer/urls.py +++ b/zilencer/urls.py @@ -13,6 +13,7 @@ from zilencer.views import ( remote_server_check_analytics, remote_server_notify_push, remote_server_post_analytics, + remote_server_send_e2ee_push_notification, remote_server_send_test_notification, transfer_remote_server_registration, unregister_all_remote_push_devices, @@ -31,6 +32,7 @@ push_bouncer_patterns = [ remote_server_path("remotes/push/unregister", POST=unregister_remote_push_device), remote_server_path("remotes/push/unregister/all", POST=unregister_all_remote_push_devices), remote_server_path("remotes/push/notify", POST=remote_server_notify_push), + remote_server_path("remotes/push/e2ee/notify", POST=remote_server_send_e2ee_push_notification), remote_server_path("remotes/push/test_notification", POST=remote_server_send_test_notification), # Push signup doesn't use the REST API, since there's no auth. path("remotes/server/register", register_remote_server), diff --git a/zilencer/views.py b/zilencer/views.py index 2b36bbf0ec..cd6e1e539d 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -57,6 +57,7 @@ from zerver.lib.push_notifications import ( PUSH_REGISTRATION_LIVENESS_TIMEOUT, HostnameAlreadyInUseBouncerError, InvalidRemotePushDeviceTokenError, + RealmPushStatusDict, UserPushIdentityCompat, send_android_push_notification, send_apple_push_notification, @@ -91,6 +92,7 @@ from zilencer.auth import ( generate_registration_transfer_verification_secret, validate_registration_transfer_verification_secret, ) +from zilencer.lib.push_notifications import send_e2ee_push_notifications from zilencer.lib.remote_counts import MissingDataError from zilencer.models import ( RemoteInstallationCount, @@ -1800,3 +1802,64 @@ def remote_server_check_analytics(request: HttpRequest, server: RemoteZulipServe "last_realmauditlog_id": get_last_id_from_server(server, RemoteRealmAuditLog), } return json_success(request, data=result) + + +class SendE2EEPushNotificationPayload(BaseModel): + realm_uuid: str + device_id_to_encrypted_data: dict[str, str] + + +@typed_endpoint +def remote_server_send_e2ee_push_notification( + request: HttpRequest, + server: RemoteZulipServer, + *, + payload: JsonBodyPayload[SendE2EEPushNotificationPayload], +) -> HttpResponse: + from corporate.lib.stripe import get_push_status_for_remote_request + + remote_realm = get_remote_realm_helper(request, server, payload.realm_uuid) + if remote_realm is None: + raise MissingRemoteRealmError + else: + remote_realm.last_request_datetime = timezone_now() + remote_realm.save(update_fields=["last_request_datetime"]) + + push_status = get_push_status_for_remote_request(server, remote_realm) + log_data = RequestNotes.get_notes(request).log_data + assert log_data is not None + log_data["extra"] = f"[can_push={push_status.can_push}/{push_status.message}]" + if not push_status.can_push: + reason = push_status.message + raise PushNotificationsDisallowedError(reason=reason) + + device_id_to_encrypted_data = payload.device_id_to_encrypted_data + + do_increment_logging_stat( + remote_realm, + COUNT_STATS["mobile_pushes_received::day"], + None, + timezone_now(), + increment=len(device_id_to_encrypted_data), + ) + + response_data = send_e2ee_push_notifications( + device_id_to_encrypted_data, remote_realm=remote_realm + ) + + do_increment_logging_stat( + remote_realm, + COUNT_STATS["mobile_pushes_forwarded::day"], + None, + timezone_now(), + increment=response_data["apple_successfully_sent_count"] + + response_data["android_successfully_sent_count"], + ) + realm_push_status_dict: RealmPushStatusDict = { + "can_push": push_status.can_push, + "expected_end_timestamp": push_status.expected_end_timestamp, + } + + return json_success( + request, data={**response_data, "realm_push_status": realm_push_status_dict} + )