mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	push_notification: Send a list of push requests.
Earlier, we were passing a map `device_id_to_encrypted_data` and http headers as separate fields to bouncer. The downside of that approach is it restricts the bouncer to process only one type of notice i.e. either notification for a new message or removal of sent notification, because it used to receive a fixed priority and push_type for all the entries in the map. Also, using map restricts the bouncer to receive only one request per device_id. Server can't send multiple notices to a device in a single call to bouncer. Currently, the server isn't modelled in a way to make a single call to the bouncer with: * Both send-notification & remove-notification request data. * Multiple send-notification request data to the same device. This commit replaces the old protocol of sending data with a list of objects where each object has the required data for bouncer to send it to FCM or APNs. This makes things a lot flexible and opens possibility for server to batch requests in a different way if we'd like to.
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							3d3f4d5e62
						
					
				
				
					commit
					6ab6df96c8
				
			@@ -5,7 +5,7 @@ import copy
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
from collections.abc import Iterable, Mapping, Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from email.headerregistry import Address
 | 
			
		||||
from functools import cache
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Final, Literal, Optional, TypeAlias, Union
 | 
			
		||||
@@ -1366,6 +1366,51 @@ class SendNotificationResponseData(TypedDict):
 | 
			
		||||
    realm_push_status: NotRequired[RealmPushStatusDict]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FCMPriority: TypeAlias = Literal["high", "normal"]
 | 
			
		||||
APNsPriority: TypeAlias = Literal[10, 5, 1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PushRequestBasePayload:
 | 
			
		||||
    push_account_id: int
 | 
			
		||||
    encrypted_data: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FCMPushRequest:
 | 
			
		||||
    device_id: int
 | 
			
		||||
    fcm_priority: FCMPriority
 | 
			
		||||
    payload: PushRequestBasePayload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class APNsHTTPHeaders:
 | 
			
		||||
    apns_priority: APNsPriority
 | 
			
		||||
    apns_push_type: PushType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class APNsPayload(PushRequestBasePayload):
 | 
			
		||||
    aps: dict[str, int | dict[str, str]] = field(
 | 
			
		||||
        default_factory=lambda: {"mutable-content": 1, "alert": {"title": "New notification"}}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class APNsPushRequest:
 | 
			
		||||
    device_id: int
 | 
			
		||||
    http_headers: APNsHTTPHeaders
 | 
			
		||||
    payload: APNsPayload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_encrypted_data(payload_data_to_encrypt: dict[str, Any], public_key_str: str) -> str:
 | 
			
		||||
    public_key = PublicKey(public_key_str.encode("utf-8"), Base64Encoder)
 | 
			
		||||
    sealed_box = SealedBox(public_key)
 | 
			
		||||
    encrypted_data_bytes = sealed_box.encrypt(orjson.dumps(payload_data_to_encrypt), Base64Encoder)
 | 
			
		||||
    encrypted_data = encrypted_data_bytes.decode("utf-8")
 | 
			
		||||
    return encrypted_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_push_notifications(
 | 
			
		||||
    user_profile: UserProfile,
 | 
			
		||||
    apns_payload_data_to_encrypt: dict[str, Any],
 | 
			
		||||
@@ -1382,51 +1427,64 @@ def send_push_notifications(
 | 
			
		||||
        )
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Prepare payload with encrypted data to send.
 | 
			
		||||
    device_id_to_encrypted_data: dict[str, str] = {}
 | 
			
		||||
    for push_device in push_devices:
 | 
			
		||||
        public_key_str: str = push_device.push_public_key
 | 
			
		||||
        public_key = PublicKey(public_key_str.encode("utf-8"), Base64Encoder)
 | 
			
		||||
        sealed_box = SealedBox(public_key)
 | 
			
		||||
 | 
			
		||||
        if push_device.token_kind == PushDevice.TokenKind.APNS:
 | 
			
		||||
            encrypted_data_bytes = sealed_box.encrypt(
 | 
			
		||||
                orjson.dumps(apns_payload_data_to_encrypt), Base64Encoder
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            encrypted_data_bytes = sealed_box.encrypt(
 | 
			
		||||
                orjson.dumps(fcm_payload_data_to_encrypt), Base64Encoder
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        encrypted_data = encrypted_data_bytes.decode("utf-8")
 | 
			
		||||
        assert push_device.bouncer_device_id is not None  # for mypy
 | 
			
		||||
        device_id_to_encrypted_data[str(push_device.bouncer_device_id)] = encrypted_data
 | 
			
		||||
 | 
			
		||||
    # Note: The "Final" qualifier serves as a shorthand
 | 
			
		||||
    # for declaring that a variable is effectively Literal.
 | 
			
		||||
    fcm_priority: Final = "normal" if is_removal else "high"
 | 
			
		||||
    apns_priority: Final = 5 if is_removal else 10
 | 
			
		||||
    apns_push_type = PushType.BACKGROUND if is_removal else PushType.ALERT
 | 
			
		||||
 | 
			
		||||
    # Prepare payload to send.
 | 
			
		||||
    push_requests: list[FCMPushRequest | APNsPushRequest] = []
 | 
			
		||||
    for push_device in push_devices:
 | 
			
		||||
        assert push_device.bouncer_device_id is not None  # for mypy
 | 
			
		||||
        if push_device.token_kind == PushDevice.TokenKind.APNS:
 | 
			
		||||
            apns_http_headers = APNsHTTPHeaders(
 | 
			
		||||
                apns_priority=apns_priority,
 | 
			
		||||
                apns_push_type=apns_push_type,
 | 
			
		||||
            )
 | 
			
		||||
            encrypted_data = get_encrypted_data(
 | 
			
		||||
                apns_payload_data_to_encrypt,
 | 
			
		||||
                push_device.push_public_key,
 | 
			
		||||
            )
 | 
			
		||||
            apns_payload = APNsPayload(
 | 
			
		||||
                push_account_id=push_device.push_account_id,
 | 
			
		||||
                encrypted_data=encrypted_data,
 | 
			
		||||
            )
 | 
			
		||||
            apns_push_request = APNsPushRequest(
 | 
			
		||||
                device_id=push_device.bouncer_device_id,
 | 
			
		||||
                http_headers=apns_http_headers,
 | 
			
		||||
                payload=apns_payload,
 | 
			
		||||
            )
 | 
			
		||||
            push_requests.append(apns_push_request)
 | 
			
		||||
        else:
 | 
			
		||||
            encrypted_data = get_encrypted_data(
 | 
			
		||||
                fcm_payload_data_to_encrypt,
 | 
			
		||||
                push_device.push_public_key,
 | 
			
		||||
            )
 | 
			
		||||
            fcm_payload = PushRequestBasePayload(
 | 
			
		||||
                push_account_id=push_device.push_account_id,
 | 
			
		||||
                encrypted_data=encrypted_data,
 | 
			
		||||
            )
 | 
			
		||||
            fcm_push_request = FCMPushRequest(
 | 
			
		||||
                device_id=push_device.bouncer_device_id,
 | 
			
		||||
                fcm_priority=fcm_priority,
 | 
			
		||||
                payload=fcm_payload,
 | 
			
		||||
            )
 | 
			
		||||
            push_requests.append(fcm_push_request)
 | 
			
		||||
 | 
			
		||||
    # Send push notification
 | 
			
		||||
    try:
 | 
			
		||||
        if settings.ZILENCER_ENABLED:
 | 
			
		||||
            from zilencer.lib.push_notifications import send_e2ee_push_notifications
 | 
			
		||||
 | 
			
		||||
            response_data: SendNotificationResponseData = send_e2ee_push_notifications(
 | 
			
		||||
                device_id_to_encrypted_data,
 | 
			
		||||
                fcm_priority=fcm_priority,
 | 
			
		||||
                apns_priority=apns_priority,
 | 
			
		||||
                apns_push_type=apns_push_type,
 | 
			
		||||
                push_requests,
 | 
			
		||||
                realm=user_profile.realm,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            post_data = {
 | 
			
		||||
                "realm_uuid": str(user_profile.realm.uuid),
 | 
			
		||||
                "device_id_to_encrypted_data": device_id_to_encrypted_data,
 | 
			
		||||
                "fcm_priority": fcm_priority,
 | 
			
		||||
                "apns_priority": apns_priority,
 | 
			
		||||
                "apns_push_type": apns_push_type,
 | 
			
		||||
                "push_requests": [asdict(push_request) for push_request in push_requests],
 | 
			
		||||
            }
 | 
			
		||||
            result = send_json_to_push_bouncer("POST", "push/e2ee/notify", post_data)
 | 
			
		||||
            assert isinstance(result["android_successfully_sent_count"], int)  # for mypy
 | 
			
		||||
 
 | 
			
		||||
@@ -258,6 +258,45 @@ class SendPushNotificationTest(E2EEPushNotificationTestCase):
 | 
			
		||||
                zerver_logger.output[1],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_early_return_if_expired_time_set(self, unused_mock: mock.MagicMock) -> None:
 | 
			
		||||
        aaron = self.example_user("aaron")
 | 
			
		||||
        hamlet = self.example_user("hamlet")
 | 
			
		||||
 | 
			
		||||
        registered_device_apple, registered_device_android = (
 | 
			
		||||
            self.register_push_devices_for_notification()
 | 
			
		||||
        )
 | 
			
		||||
        registered_device_apple.expired_time = datetime(2099, 4, 24, tzinfo=timezone.utc)
 | 
			
		||||
        registered_device_android.expired_time = datetime(2099, 4, 24, tzinfo=timezone.utc)
 | 
			
		||||
        registered_device_apple.save(update_fields=["expired_time"])
 | 
			
		||||
        registered_device_android.save(update_fields=["expired_time"])
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(PushDevice.objects.count(), 2)
 | 
			
		||||
 | 
			
		||||
        message_id = self.send_personal_message(
 | 
			
		||||
            from_user=aaron, to_user=hamlet, skip_capture_on_commit_callbacks=True
 | 
			
		||||
        )
 | 
			
		||||
        missed_message = {
 | 
			
		||||
            "message_id": message_id,
 | 
			
		||||
            "trigger": NotificationTriggers.DIRECT_MESSAGE,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        # Since 'expired_time' is set for concerned 'RemotePushDevice' rows,
 | 
			
		||||
        # the bouncer will not attempt to send notification and instead returns
 | 
			
		||||
        # a list of device IDs which server should erase on their own end.
 | 
			
		||||
        with (
 | 
			
		||||
            mock.patch(
 | 
			
		||||
                "zilencer.lib.push_notifications.send_e2ee_push_notification_apple"
 | 
			
		||||
            ) as send_apple,
 | 
			
		||||
            mock.patch(
 | 
			
		||||
                "zilencer.lib.push_notifications.send_e2ee_push_notification_android"
 | 
			
		||||
            ) as send_android,
 | 
			
		||||
        ):
 | 
			
		||||
            handle_push_notification(hamlet.id, missed_message)
 | 
			
		||||
 | 
			
		||||
            send_apple.assert_not_called()
 | 
			
		||||
            send_android.assert_not_called()
 | 
			
		||||
            self.assertEqual(PushDevice.objects.count(), 0)
 | 
			
		||||
 | 
			
		||||
    @responses.activate
 | 
			
		||||
    @override_settings(ZILENCER_ENABLED=False)
 | 
			
		||||
    def test_success_self_hosted(self, unused_mock: mock.MagicMock) -> None:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,17 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
from typing import Literal, TypeAlias
 | 
			
		||||
from dataclasses import asdict
 | 
			
		||||
 | 
			
		||||
from aioapns import NotificationRequest, PushType
 | 
			
		||||
from aioapns import NotificationRequest
 | 
			
		||||
from django.utils.timezone import now as timezone_now
 | 
			
		||||
from firebase_admin import exceptions as firebase_exceptions
 | 
			
		||||
from firebase_admin import messaging as firebase_messaging
 | 
			
		||||
from firebase_admin.messaging import UnregisteredError as FCMUnregisteredError
 | 
			
		||||
 | 
			
		||||
from zerver.lib.push_notifications import (
 | 
			
		||||
    APNsPushRequest,
 | 
			
		||||
    FCMPushRequest,
 | 
			
		||||
    SendNotificationResponseData,
 | 
			
		||||
    fcm_app,
 | 
			
		||||
    get_apns_context,
 | 
			
		||||
@@ -20,9 +22,6 @@ from zilencer.models import RemotePushDevice, RemoteRealm
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
FCMPriority: TypeAlias = Literal["high", "normal"]
 | 
			
		||||
APNsPriority: TypeAlias = Literal[10, 5, 1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_e2ee_push_notification_apple(
 | 
			
		||||
    apns_requests: list[NotificationRequest],
 | 
			
		||||
@@ -119,11 +118,8 @@ def send_e2ee_push_notification_android(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_e2ee_push_notifications(
 | 
			
		||||
    device_id_to_encrypted_data: dict[str, str],
 | 
			
		||||
    push_requests: list[APNsPushRequest | FCMPushRequest],
 | 
			
		||||
    *,
 | 
			
		||||
    fcm_priority: FCMPriority,
 | 
			
		||||
    apns_priority: APNsPriority,
 | 
			
		||||
    apns_push_type: PushType,
 | 
			
		||||
    realm: Realm | None = None,
 | 
			
		||||
    remote_realm: RemoteRealm | None = None,
 | 
			
		||||
) -> SendNotificationResponseData:
 | 
			
		||||
@@ -131,13 +127,15 @@ def send_e2ee_push_notifications(
 | 
			
		||||
 | 
			
		||||
    import aioapns
 | 
			
		||||
 | 
			
		||||
    device_ids = [int(device_id_str) for device_id_str in device_id_to_encrypted_data]
 | 
			
		||||
    device_ids = {push_request.device_id for push_request in push_requests}
 | 
			
		||||
    remote_push_devices = RemotePushDevice.objects.filter(
 | 
			
		||||
        device_id__in=device_ids, expired_time__isnull=True, realm=realm, remote_realm=remote_realm
 | 
			
		||||
    )
 | 
			
		||||
    unexpired_remote_push_device_ids = {
 | 
			
		||||
        remote_push_device.device_id for remote_push_device in remote_push_devices
 | 
			
		||||
    device_id_to_remote_push_device = {
 | 
			
		||||
        remote_push_device.device_id: remote_push_device
 | 
			
		||||
        for remote_push_device in remote_push_devices
 | 
			
		||||
    }
 | 
			
		||||
    unexpired_remote_push_device_ids = set(device_id_to_remote_push_device.keys())
 | 
			
		||||
 | 
			
		||||
    # Device IDs which should be deleted on server.
 | 
			
		||||
    # Either the device ID is invalid or the token
 | 
			
		||||
@@ -148,44 +146,35 @@ def send_e2ee_push_notifications(
 | 
			
		||||
 | 
			
		||||
    apns_requests = []
 | 
			
		||||
    apns_remote_push_devices: list[RemotePushDevice] = []
 | 
			
		||||
    apns_base_message_payload = {
 | 
			
		||||
        "aps": {
 | 
			
		||||
            "mutable-content": 1,
 | 
			
		||||
            "alert": {
 | 
			
		||||
                "title": "New notification",
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fcm_requests = []
 | 
			
		||||
    fcm_remote_push_devices: list[RemotePushDevice] = []
 | 
			
		||||
 | 
			
		||||
    for remote_push_device in remote_push_devices:
 | 
			
		||||
        message_payload = {
 | 
			
		||||
            "encrypted_data": device_id_to_encrypted_data[str(remote_push_device.device_id)],
 | 
			
		||||
            "push_account_id": remote_push_device.push_account_id,
 | 
			
		||||
        }
 | 
			
		||||
    for push_request in push_requests:
 | 
			
		||||
        device_id = push_request.device_id
 | 
			
		||||
        if device_id not in unexpired_remote_push_device_ids:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        remote_push_device = device_id_to_remote_push_device[device_id]
 | 
			
		||||
        if remote_push_device.token_kind == RemotePushDevice.TokenKind.APNS:
 | 
			
		||||
            apns_message_payload = {
 | 
			
		||||
                **apns_base_message_payload,
 | 
			
		||||
                **message_payload,
 | 
			
		||||
            }
 | 
			
		||||
            assert isinstance(push_request, APNsPushRequest)
 | 
			
		||||
            apns_requests.append(
 | 
			
		||||
                aioapns.NotificationRequest(
 | 
			
		||||
                    apns_topic=remote_push_device.ios_app_id,
 | 
			
		||||
                    device_token=remote_push_device.token,
 | 
			
		||||
                    message=apns_message_payload,
 | 
			
		||||
                    priority=apns_priority,
 | 
			
		||||
                    push_type=apns_push_type,
 | 
			
		||||
                    message=asdict(push_request.payload),
 | 
			
		||||
                    priority=push_request.http_headers.apns_priority,
 | 
			
		||||
                    push_type=push_request.http_headers.apns_push_type,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            apns_remote_push_devices.append(remote_push_device)
 | 
			
		||||
        else:
 | 
			
		||||
            assert isinstance(push_request, FCMPushRequest)
 | 
			
		||||
            fcm_requests.append(
 | 
			
		||||
                firebase_messaging.Message(
 | 
			
		||||
                    data=message_payload,
 | 
			
		||||
                    data=asdict(push_request.payload),
 | 
			
		||||
                    token=remote_push_device.token,
 | 
			
		||||
                    android=firebase_messaging.AndroidConfig(priority=fcm_priority),
 | 
			
		||||
                    android=firebase_messaging.AndroidConfig(priority=push_request.fcm_priority),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            fcm_remote_push_devices.append(remote_push_device)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ from uuid import UUID
 | 
			
		||||
 | 
			
		||||
import orjson
 | 
			
		||||
import requests.exceptions
 | 
			
		||||
from aioapns import PushType
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.core.validators import URLValidator, validate_email
 | 
			
		||||
@@ -56,6 +55,8 @@ from zerver.lib.exceptions import (
 | 
			
		||||
from zerver.lib.outgoing_http import OutgoingSession
 | 
			
		||||
from zerver.lib.push_notifications import (
 | 
			
		||||
    PUSH_REGISTRATION_LIVENESS_TIMEOUT,
 | 
			
		||||
    APNsPushRequest,
 | 
			
		||||
    FCMPushRequest,
 | 
			
		||||
    HostnameAlreadyInUseBouncerError,
 | 
			
		||||
    InvalidRemotePushDeviceTokenError,
 | 
			
		||||
    RealmPushStatusDict,
 | 
			
		||||
@@ -93,7 +94,7 @@ from zilencer.auth import (
 | 
			
		||||
    generate_registration_transfer_verification_secret,
 | 
			
		||||
    validate_registration_transfer_verification_secret,
 | 
			
		||||
)
 | 
			
		||||
from zilencer.lib.push_notifications import APNsPriority, FCMPriority, send_e2ee_push_notifications
 | 
			
		||||
from zilencer.lib.push_notifications import send_e2ee_push_notifications
 | 
			
		||||
from zilencer.lib.remote_counts import MissingDataError
 | 
			
		||||
from zilencer.models import (
 | 
			
		||||
    RemoteInstallationCount,
 | 
			
		||||
@@ -1807,10 +1808,7 @@ def remote_server_check_analytics(request: HttpRequest, server: RemoteZulipServe
 | 
			
		||||
 | 
			
		||||
class SendE2EEPushNotificationPayload(BaseModel):
 | 
			
		||||
    realm_uuid: str
 | 
			
		||||
    device_id_to_encrypted_data: dict[str, str]
 | 
			
		||||
    fcm_priority: FCMPriority
 | 
			
		||||
    apns_priority: APNsPriority
 | 
			
		||||
    apns_push_type: PushType
 | 
			
		||||
    push_requests: list[APNsPushRequest | FCMPushRequest]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@typed_endpoint
 | 
			
		||||
@@ -1837,21 +1835,18 @@ def remote_server_send_e2ee_push_notification(
 | 
			
		||||
        reason = push_status.message
 | 
			
		||||
        raise PushNotificationsDisallowedError(reason=reason)
 | 
			
		||||
 | 
			
		||||
    device_id_to_encrypted_data = payload.device_id_to_encrypted_data
 | 
			
		||||
    push_requests = payload.push_requests
 | 
			
		||||
 | 
			
		||||
    do_increment_logging_stat(
 | 
			
		||||
        remote_realm,
 | 
			
		||||
        COUNT_STATS["mobile_pushes_received::day"],
 | 
			
		||||
        None,
 | 
			
		||||
        timezone_now(),
 | 
			
		||||
        increment=len(device_id_to_encrypted_data),
 | 
			
		||||
        increment=len(push_requests),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    response_data = send_e2ee_push_notifications(
 | 
			
		||||
        device_id_to_encrypted_data,
 | 
			
		||||
        fcm_priority=payload.fcm_priority,
 | 
			
		||||
        apns_priority=payload.apns_priority,
 | 
			
		||||
        apns_push_type=payload.apns_push_type,
 | 
			
		||||
        push_requests,
 | 
			
		||||
        remote_realm=remote_realm,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user