Annotate zerver/lib/event_queue.py.

This commit is contained in:
Conrad Dean
2016-07-03 19:28:59 +05:30
committed by Eklavya Sharma
parent 9812e676f0
commit bbf7a9c801

View File

@@ -1,5 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import Any, Dict, Iterable, Union from typing import cast, AbstractSet, Any, Optional, Iterable, Sequence, Mapping, MutableMapping, Callable, Union
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.conf import settings from django.conf import settings
@@ -19,6 +19,7 @@ import tornado
import tornado.autoreload import tornado.autoreload
import random import random
import traceback import traceback
from zerver.models import UserProfile, Client
from zerver.decorator import RespondAsynchronously from zerver.decorator import RespondAsynchronously
from zerver.lib.cache import cache_get_many, message_cache_key, \ from zerver.lib.cache import cache_get_many, message_cache_key, \
user_profile_by_id_cache_key, cache_save_user_profile, cache_with_key user_profile_by_id_cache_key, cache_save_user_profile, cache_with_key
@@ -32,6 +33,7 @@ from zerver.lib.request import JsonableError
from zerver.lib.timestamp import timestamp_to_datetime from zerver.lib.timestamp import timestamp_to_datetime
import copy import copy
import six import six
from six import text_type
# The idle timeout used to be a week, but we found that in that # The idle timeout used to be a week, but we found that in that
# situation, queues from dead browser sessions would grow quite large # situation, queues from dead browser sessions would grow quite large
@@ -54,6 +56,7 @@ class ClientDescriptor(object):
def __init__(self, user_profile_id, user_profile_email, realm_id, event_queue, def __init__(self, user_profile_id, user_profile_email, realm_id, event_queue,
event_types, client_type_name, apply_markdown=True, event_types, client_type_name, apply_markdown=True,
all_public_streams=False, lifespan_secs=0, narrow=[]): all_public_streams=False, lifespan_secs=0, narrow=[]):
# type: (int, text_type, int, EventQueue, Optional[Sequence[str]], text_type, bool, bool, int, Iterable[Sequence[text_type]]) -> None
# These objects are serialized on shutdown and restored on restart. # These objects are serialized on shutdown and restored on restart.
# If fields are added or semantics are changed, temporary code must be # If fields are added or semantics are changed, temporary code must be
# added to load_event_queues() to update the restored objects. # added to load_event_queues() to update the restored objects.
@@ -61,8 +64,8 @@ class ClientDescriptor(object):
self.user_profile_id = user_profile_id self.user_profile_id = user_profile_id
self.user_profile_email = user_profile_email self.user_profile_email = user_profile_email
self.realm_id = realm_id self.realm_id = realm_id
self.current_handler_id = None # type: int self.current_handler_id = None # type: Optional[int]
self.current_client_name = None # type: str self.current_client_name = None # type: Optional[text_type]
self.event_queue = event_queue self.event_queue = event_queue
self.queue_timeout = lifespan_secs self.queue_timeout = lifespan_secs
self.event_types = event_types self.event_types = event_types
@@ -78,6 +81,7 @@ class ClientDescriptor(object):
self.queue_timeout = max(IDLE_EVENT_QUEUE_TIMEOUT_SECS, min(self.queue_timeout, MAX_QUEUE_TIMEOUT_SECS)) self.queue_timeout = max(IDLE_EVENT_QUEUE_TIMEOUT_SECS, min(self.queue_timeout, MAX_QUEUE_TIMEOUT_SECS))
def to_dict(self): def to_dict(self):
# type: () -> Dict[str, Any]
# If you add a new key to this dict, make sure you add appropriate # If you add a new key to this dict, make sure you add appropriate
# migration code in from_dict or load_event_queues to account for # migration code in from_dict or load_event_queues to account for
# loading event queues that lack that key. # loading event queues that lack that key.
@@ -94,10 +98,12 @@ class ClientDescriptor(object):
client_type_name=self.client_type_name) client_type_name=self.client_type_name)
def __repr__(self): def __repr__(self):
# type: () -> str
return "ClientDescriptor<%s>" % (self.event_queue.id,) return "ClientDescriptor<%s>" % (self.event_queue.id,)
@classmethod @classmethod
def from_dict(cls, d): def from_dict(cls, d):
# type: (MutableMapping[str, Any]) -> ClientDescriptor
if 'user_profile_email' not in d: if 'user_profile_email' not in d:
# Temporary migration for the addition of the new user_profile_email field # Temporary migration for the addition of the new user_profile_email field
from zerver.models import get_user_profile_by_id from zerver.models import get_user_profile_by_id
@@ -113,10 +119,12 @@ class ClientDescriptor(object):
return ret return ret
def prepare_for_pickling(self): def prepare_for_pickling(self):
# type: () -> None
self.current_handler_id = None self.current_handler_id = None
self._timeout_handle = None self._timeout_handle = None
def add_event(self, event): def add_event(self, event):
# type: (Dict[str, Any]) -> None
if self.current_handler_id is not None: if self.current_handler_id is not None:
handler = get_handler_by_id(self.current_handler_id) handler = get_handler_by_id(self.current_handler_id)
async_request_restart(handler._request) async_request_restart(handler._request)
@@ -125,6 +133,7 @@ class ClientDescriptor(object):
self.finish_current_handler() self.finish_current_handler()
def finish_current_handler(self, need_timeout=False): def finish_current_handler(self, need_timeout=False):
# type: (bool) -> bool
if self.current_handler_id is not None: if self.current_handler_id is not None:
err_msg = "Got error finishing handler for queue %s" % (self.event_queue.id,) err_msg = "Got error finishing handler for queue %s" % (self.event_queue.id,)
try: try:
@@ -138,6 +147,7 @@ class ClientDescriptor(object):
return False return False
def accepts_event(self, event): def accepts_event(self, event):
# type: (Mapping[str, Any]) -> bool
if self.event_types is not None and event["type"] not in self.event_types: if self.event_types is not None and event["type"] not in self.event_types:
return False return False
if event["type"] == "message": if event["type"] == "message":
@@ -146,9 +156,11 @@ class ClientDescriptor(object):
# TODO: Refactor so we don't need this function # TODO: Refactor so we don't need this function
def accepts_messages(self): def accepts_messages(self):
# type: () -> bool
return self.event_types is None or "message" in self.event_types return self.event_types is None or "message" in self.event_types
def idle(self, now): def idle(self, now):
# type: (float) -> bool
if not hasattr(self, 'queue_timeout'): if not hasattr(self, 'queue_timeout'):
self.queue_timeout = IDLE_EVENT_QUEUE_TIMEOUT_SECS self.queue_timeout = IDLE_EVENT_QUEUE_TIMEOUT_SECS
@@ -156,11 +168,13 @@ class ClientDescriptor(object):
and now - self.last_connection_time >= self.queue_timeout) and now - self.last_connection_time >= self.queue_timeout)
def connect_handler(self, handler_id, client_name): def connect_handler(self, handler_id, client_name):
# type: (int, text_type) -> None
self.current_handler_id = handler_id self.current_handler_id = handler_id
self.current_client_name = client_name self.current_client_name = client_name
set_descriptor_by_handler_id(handler_id, self) set_descriptor_by_handler_id(handler_id, self)
self.last_connection_time = time.time() self.last_connection_time = time.time()
def timeout_callback(): def timeout_callback():
# type: () -> None
self._timeout_handle = None self._timeout_handle = None
# All clients get heartbeat events # All clients get heartbeat events
self.add_event(dict(type='heartbeat')) self.add_event(dict(type='heartbeat'))
@@ -170,6 +184,7 @@ class ClientDescriptor(object):
self._timeout_handle = ioloop.add_timeout(heartbeat_time, timeout_callback) self._timeout_handle = ioloop.add_timeout(heartbeat_time, timeout_callback)
def disconnect_handler(self, client_closed=False, need_timeout=True): def disconnect_handler(self, client_closed=False, need_timeout=True):
# type: (bool, bool) -> None
if self.current_handler_id: if self.current_handler_id:
clear_descriptor_by_handler_id(self.current_handler_id, None) clear_descriptor_by_handler_id(self.current_handler_id, None)
clear_handler_by_id(self.current_handler_id) clear_handler_by_id(self.current_handler_id)
@@ -185,6 +200,7 @@ class ClientDescriptor(object):
self._timeout_handle = None self._timeout_handle = None
def cleanup(self): def cleanup(self):
# type: () -> None
# Before we can GC the event queue, we need to disconnect the # Before we can GC the event queue, we need to disconnect the
# handler and notify the client (or connection server) so that # handler and notify the client (or connection server) so that
# they can cleanup their own state related to the GC'd event # they can cleanup their own state related to the GC'd event
@@ -198,15 +214,19 @@ class ClientDescriptor(object):
descriptors_by_handler_id = {} # type: Dict[int, ClientDescriptor] descriptors_by_handler_id = {} # type: Dict[int, ClientDescriptor]
def get_descriptor_by_handler_id(handler_id): def get_descriptor_by_handler_id(handler_id):
# type: (int) -> ClientDescriptor
return descriptors_by_handler_id.get(handler_id) return descriptors_by_handler_id.get(handler_id)
def set_descriptor_by_handler_id(handler_id, client_descriptor): def set_descriptor_by_handler_id(handler_id, client_descriptor):
# type: (int, ClientDescriptor) -> None
descriptors_by_handler_id[handler_id] = client_descriptor descriptors_by_handler_id[handler_id] = client_descriptor
def clear_descriptor_by_handler_id(handler_id, client_descriptor): def clear_descriptor_by_handler_id(handler_id, client_descriptor):
# type: (int, Optional[ClientDescriptor]) -> None
del descriptors_by_handler_id[handler_id] del descriptors_by_handler_id[handler_id]
def compute_full_event_type(event): def compute_full_event_type(event):
# type: (Mapping[str, Any]) -> str
if event["type"] == "update_message_flags": if event["type"] == "update_message_flags":
if event["all"]: if event["all"]:
# Put the "all" case in its own category # Put the "all" case in its own category
@@ -219,10 +239,11 @@ class EventQueue(object):
# type: (str) -> None # type: (str) -> None
self.queue = deque() # type: deque[Dict[str, Any]] self.queue = deque() # type: deque[Dict[str, Any]]
self.next_event_id = 0 # type: int self.next_event_id = 0 # type: int
self.id = id self.id = id # type: str
self.virtual_events = {} # type: Dict[str, Dict[str, str]] self.virtual_events = {} # type: Dict[str, Dict[str, Any]]
def to_dict(self): def to_dict(self):
# type: () -> Dict[str, Any]
# If you add a new key to this dict, make sure you add appropriate # If you add a new key to this dict, make sure you add appropriate
# migration code in from_dict or load_event_queues to account for # migration code in from_dict or load_event_queues to account for
# loading event queues that lack that key. # loading event queues that lack that key.
@@ -233,6 +254,7 @@ class EventQueue(object):
@classmethod @classmethod
def from_dict(cls, d): def from_dict(cls, d):
# type: (Dict[str, Any]) -> EventQueue
ret = cls(d['id']) ret = cls(d['id'])
ret.next_event_id = d['next_event_id'] ret.next_event_id = d['next_event_id']
ret.queue = deque(d['queue']) ret.queue = deque(d['queue'])
@@ -240,6 +262,7 @@ class EventQueue(object):
return ret return ret
def push(self, event): def push(self, event):
# type: (Dict[str, Any]) -> None
event['id'] = self.next_event_id event['id'] = self.next_event_id
self.next_event_id += 1 self.next_event_id += 1
full_event_type = compute_full_event_type(event) full_event_type = compute_full_event_type(event)
@@ -266,19 +289,23 @@ class EventQueue(object):
# current usage since virtual events should always be resolved to # current usage since virtual events should always be resolved to
# a real event before being given to users. # a real event before being given to users.
def pop(self): def pop(self):
# type: () -> Dict[str, Any]
return self.queue.popleft() return self.queue.popleft()
def empty(self): def empty(self):
# type: () -> bool
return len(self.queue) == 0 and len(self.virtual_events) == 0 return len(self.queue) == 0 and len(self.virtual_events) == 0
# See the comment on pop; that applies here as well # See the comment on pop; that applies here as well
def prune(self, through_id): def prune(self, through_id):
# type: (int) -> None
while len(self.queue) != 0 and self.queue[0]['id'] <= through_id: while len(self.queue) != 0 and self.queue[0]['id'] <= through_id:
self.pop() self.pop()
def contents(self): def contents(self):
contents = [] # type: () -> List[Dict[str, Any]]
virtual_id_map = {} contents = [] # type: List[Dict[str, Any]]
virtual_id_map = {} # type: Dict[str, Dict[str, Any]]
for event_type in self.virtual_events: for event_type in self.virtual_events:
virtual_id_map[self.virtual_events[event_type]["id"]] = self.virtual_events[event_type] virtual_id_map[self.virtual_events[event_type]["id"]] = self.virtual_events[event_type]
virtual_ids = sorted(list(virtual_id_map.keys())) virtual_ids = sorted(list(virtual_id_map.keys()))
@@ -311,28 +338,34 @@ realm_clients_all_streams = {} # type: Dict[int, List[ClientDescriptor]]
# last_for_client that is true if this is the last queue pertaining # last_for_client that is true if this is the last queue pertaining
# to this user_profile_id # to this user_profile_id
# that is about to be deleted # that is about to be deleted
gc_hooks = [] gc_hooks = [] # type: List[Callable[[int, ClientDescriptor, bool], None]]
next_queue_id = 0 next_queue_id = 0
def add_client_gc_hook(hook): def add_client_gc_hook(hook):
# type: (Callable[[int, ClientDescriptor, bool], None]) -> None
gc_hooks.append(hook) gc_hooks.append(hook)
def get_client_descriptor(queue_id): def get_client_descriptor(queue_id):
# type: (str) -> ClientDescriptor
return clients.get(queue_id) return clients.get(queue_id)
def get_client_descriptors_for_user(user_profile_id): def get_client_descriptors_for_user(user_profile_id):
# type: (int) -> List[ClientDescriptor]
return user_clients.get(user_profile_id, []) return user_clients.get(user_profile_id, [])
def get_client_descriptors_for_realm_all_streams(realm_id): def get_client_descriptors_for_realm_all_streams(realm_id):
# type: (int) -> List[ClientDescriptor]
return realm_clients_all_streams.get(realm_id, []) return realm_clients_all_streams.get(realm_id, [])
def add_to_client_dicts(client): def add_to_client_dicts(client):
# type: (ClientDescriptor) -> None
user_clients.setdefault(client.user_profile_id, []).append(client) user_clients.setdefault(client.user_profile_id, []).append(client)
if client.all_public_streams or client.narrow != []: if client.all_public_streams or client.narrow != []:
realm_clients_all_streams.setdefault(client.realm_id, []).append(client) realm_clients_all_streams.setdefault(client.realm_id, []).append(client)
def allocate_client_descriptor(new_queue_data): def allocate_client_descriptor(new_queue_data):
# type: (MutableMapping[str, Any]) -> ClientDescriptor
global next_queue_id global next_queue_id
queue_id = str(settings.SERVER_GENERATION) + ':' + str(next_queue_id) queue_id = str(settings.SERVER_GENERATION) + ':' + str(next_queue_id)
next_queue_id += 1 next_queue_id += 1
@@ -343,7 +376,9 @@ def allocate_client_descriptor(new_queue_data):
return client return client
def do_gc_event_queues(to_remove, affected_users, affected_realms): def do_gc_event_queues(to_remove, affected_users, affected_realms):
# type: (AbstractSet[str], AbstractSet[int], AbstractSet[int]) -> None
def filter_client_dict(client_dict, key): def filter_client_dict(client_dict, key):
# type: (MutableMapping[int, List[ClientDescriptor]], int) -> None
if key not in client_dict: if key not in client_dict:
return return
@@ -365,10 +400,11 @@ def do_gc_event_queues(to_remove, affected_users, affected_realms):
del clients[id] del clients[id]
def gc_event_queues(): def gc_event_queues():
# type: () -> None
start = time.time() start = time.time()
to_remove = set() to_remove = set() # type: Set[str]
affected_users = set() affected_users = set() # type: Set[int]
affected_realms = set() affected_realms = set() # type: Set[int]
for (id, client) in six.iteritems(clients): for (id, client) in six.iteritems(clients):
if client.idle(start): if client.idle(start):
to_remove.add(id) to_remove.add(id)
@@ -388,6 +424,7 @@ def gc_event_queues():
statsd.gauge('tornado.active_users', len(user_clients)) statsd.gauge('tornado.active_users', len(user_clients))
def dump_event_queues(): def dump_event_queues():
# type: () -> None
start = time.time() start = time.time()
with open(settings.JSON_PERSISTENT_QUEUE_FILENAME, "w") as stored_queues: with open(settings.JSON_PERSISTENT_QUEUE_FILENAME, "w") as stored_queues:
@@ -398,6 +435,7 @@ def dump_event_queues():
% (len(clients), time.time() - start)) % (len(clients), time.time() - start))
def load_event_queues(): def load_event_queues():
# type: () -> None
global clients global clients
start = time.time() start = time.time()
@@ -424,7 +462,8 @@ def load_event_queues():
% (len(clients), time.time() - start)) % (len(clients), time.time() - start))
def send_restart_events(immediate=False): def send_restart_events(immediate=False):
event = dict(type='restart', server_generation=settings.SERVER_GENERATION) # type: (bool) -> None
event = dict(type='restart', server_generation=settings.SERVER_GENERATION) # type: Dict[str, Any]
if immediate: if immediate:
event['immediate'] = True event['immediate'] = True
for client in six.itervalues(clients): for client in six.itervalues(clients):
@@ -432,6 +471,7 @@ def send_restart_events(immediate=False):
client.add_event(event.copy()) client.add_event(event.copy())
def setup_event_queue(): def setup_event_queue():
# type: () -> None
if not settings.TEST_SUITE: if not settings.TEST_SUITE:
load_event_queues() load_event_queues()
atexit.register(dump_event_queues) atexit.register(dump_event_queues)
@@ -453,14 +493,15 @@ def setup_event_queue():
send_restart_events(immediate=settings.DEVELOPMENT) send_restart_events(immediate=settings.DEVELOPMENT)
def fetch_events(query): def fetch_events(query):
queue_id = query["queue_id"] # type: (Mapping[str, Any]) -> Dict[str, Any]
dont_block = query["dont_block"] queue_id = query["queue_id"] # type: str
last_event_id = query["last_event_id"] dont_block = query["dont_block"] # type: bool
user_profile_id = query["user_profile_id"] last_event_id = query["last_event_id"] # type: int
new_queue_data = query.get("new_queue_data") user_profile_id = query["user_profile_id"] # type: int
user_profile_email = query["user_profile_email"] new_queue_data = query.get("new_queue_data") # type: Optional[MutableMapping[str, Any]]
client_type_name = query["client_type_name"] user_profile_email = query["user_profile_email"] # type: text_type
handler_id = query["handler_id"] client_type_name = query["client_type_name"] # type: text_type
handler_id = query["handler_id"] # type: int
try: try:
was_connected = False was_connected = False
@@ -485,7 +526,7 @@ def fetch_events(query):
if not client.event_queue.empty() or dont_block: if not client.event_queue.empty() or dont_block:
response = dict(events=client.event_queue.contents(), response = dict(events=client.event_queue.contents(),
handler_id=handler_id) handler_id=handler_id) # type: Dict[str, Any]
if orig_queue_id is None: if orig_queue_id is None:
response['queue_id'] = queue_id response['queue_id'] = queue_id
extra_log_data = "[%s/%s]" % (queue_id, len(response["events"])) extra_log_data = "[%s/%s]" % (queue_id, len(response["events"]))
@@ -513,14 +554,16 @@ def fetch_events(query):
# from a property to a function # from a property to a function
requests_json_is_function = callable(requests.Response.json) requests_json_is_function = callable(requests.Response.json)
def extract_json_response(resp): def extract_json_response(resp):
# type: (requests.Response) -> Dict[str, Any]
if requests_json_is_function: if requests_json_is_function:
return resp.json() return resp.json()
else: else:
return resp.json return resp.json # type: ignore # mypy trusts the stub, not the runtime type checking of this fn
def request_event_queue(user_profile, user_client, apply_markdown, def request_event_queue(user_profile, user_client, apply_markdown,
queue_lifespan_secs, event_types=None, all_public_streams=False, queue_lifespan_secs, event_types=None, all_public_streams=False,
narrow=[]): narrow=[]):
# type: (UserProfile, Client, bool, int, Optional[Iterable[str]], bool, Iterable[Sequence[text_type]]) -> Optional[str]
if settings.TORNADO_SERVER: if settings.TORNADO_SERVER:
req = {'dont_block' : 'true', req = {'dont_block' : 'true',
'apply_markdown': ujson.dumps(apply_markdown), 'apply_markdown': ujson.dumps(apply_markdown),
@@ -543,6 +586,7 @@ def request_event_queue(user_profile, user_client, apply_markdown,
return None return None
def get_user_events(user_profile, queue_id, last_event_id): def get_user_events(user_profile, queue_id, last_event_id):
# type: (UserProfile, str, int) -> List[Dict]
if settings.TORNADO_SERVER: if settings.TORNADO_SERVER:
resp = requests.get(settings.TORNADO_SERVER + '/api/v1/events', resp = requests.get(settings.TORNADO_SERVER + '/api/v1/events',
auth=requests.auth.HTTPBasicAuth(user_profile.email, auth=requests.auth.HTTPBasicAuth(user_profile.email,
@@ -560,19 +604,20 @@ def get_user_events(user_profile, queue_id, last_event_id):
# Send email notifications to idle users # Send email notifications to idle users
# after they are idle for 1 hour # after they are idle for 1 hour
NOTIFY_AFTER_IDLE_HOURS = 1 NOTIFY_AFTER_IDLE_HOURS = 1
def build_offline_notification(user_profile_id, message_id): def build_offline_notification(user_profile_id, message_id):
# type: (int, int) -> Dict[str, Any]
return {"user_profile_id": user_profile_id, return {"user_profile_id": user_profile_id,
"message_id": message_id, "message_id": message_id,
"timestamp": time.time()} "timestamp": time.time()}
def missedmessage_hook(user_profile_id, queue, last_for_client): def missedmessage_hook(user_profile_id, queue, last_for_client):
# type: (int, ClientDescriptor, bool) -> None
# Only process missedmessage hook when the last queue for a # Only process missedmessage hook when the last queue for a
# client has been garbage collected # client has been garbage collected
if not last_for_client: if not last_for_client:
return return
message_ids_to_notify = [] message_ids_to_notify = [] # type: List[Dict[str, Any]]
for event in queue.event_queue.contents(): for event in queue.event_queue.contents():
if not event['type'] == 'message' or not event['flags']: if not event['type'] == 'message' or not event['flags']:
continue continue
@@ -595,6 +640,7 @@ def missedmessage_hook(user_profile_id, queue, last_for_client):
queue_json_publish("missedmessage_emails", notice, lambda notice: None) queue_json_publish("missedmessage_emails", notice, lambda notice: None)
def receiver_is_idle(user_profile_id, realm_presences): def receiver_is_idle(user_profile_id, realm_presences):
# type: (int, Optional[Dict[int, Dict[text_type, Dict[str, Any]]]]) -> bool
# If a user has no message-receiving event queues, they've got no open zulip # If a user has no message-receiving event queues, they've got no open zulip
# session so we notify them # session so we notify them
all_client_descriptors = get_client_descriptors_for_user(user_profile_id) all_client_descriptors = get_client_descriptors_for_user(user_profile_id)
@@ -629,20 +675,21 @@ def receiver_is_idle(user_profile_id, realm_presences):
return off_zulip or idle return off_zulip or idle
def process_message_event(event_template, users): def process_message_event(event_template, users):
realm_presences = {int(k): v for k, v in event_template['presences'].items()} # type: (Mapping[str, Any], Iterable[Mapping[str, Any]]) -> None
sender_queue_id = event_template.get('sender_queue_id', None) realm_presences = {int(k): v for k, v in event_template['presences'].items()} # type: Dict[int, Dict[text_type, Dict[str, Any]]]
message_dict_markdown = event_template['message_dict_markdown'] sender_queue_id = event_template.get('sender_queue_id', None) # type: Optional[str]
message_dict_no_markdown = event_template['message_dict_no_markdown'] message_dict_markdown = event_template['message_dict_markdown'] # type: Dict[str, Any]
sender_id = message_dict_markdown['sender_id'] message_dict_no_markdown = event_template['message_dict_no_markdown'] # type: Dict[str, Any]
message_id = message_dict_markdown['id'] sender_id = message_dict_markdown['sender_id'] # type: int
message_type = message_dict_markdown['type'] message_id = message_dict_markdown['id'] # type: int
sending_client = message_dict_markdown['client'] message_type = message_dict_markdown['type'] # type: str
sending_client = message_dict_markdown['client'] # type: text_type
# To remove duplicate clients: Maps queue ID to {'client': Client, 'flags': flags} # To remove duplicate clients: Maps queue ID to {'client': Client, 'flags': flags}
send_to_clients = dict() send_to_clients = {} # type: Dict[str, Dict[str, Any]]
# Extra user-specific data to include # Extra user-specific data to include
extra_user_data = {} extra_user_data = {} # type: Dict[int, Any]
if 'stream_name' in event_template and not event_template.get("invite_only"): if 'stream_name' in event_template and not event_template.get("invite_only"):
for client in get_client_descriptors_for_realm_all_streams(event_template['realm_id']): for client in get_client_descriptors_for_realm_all_streams(event_template['realm_id']):
@@ -651,8 +698,8 @@ def process_message_event(event_template, users):
send_to_clients[client.event_queue.id]['is_sender'] = True send_to_clients[client.event_queue.id]['is_sender'] = True
for user_data in users: for user_data in users:
user_profile_id = user_data['id'] user_profile_id = user_data['id'] # type: int
flags = user_data.get('flags', []) flags = user_data.get('flags', []) # type: Iterable[str]
for client in get_client_descriptors_for_user(user_profile_id): for client in get_client_descriptors_for_user(user_profile_id):
send_to_clients[client.event_queue.id] = {'client': client, 'flags': flags} send_to_clients[client.event_queue.id] = {'client': client, 'flags': flags}
@@ -668,7 +715,7 @@ def process_message_event(event_template, users):
if (received_pm or mentioned) and (idle or always_push_notify): if (received_pm or mentioned) and (idle or always_push_notify):
notice = build_offline_notification(user_profile_id, message_id) notice = build_offline_notification(user_profile_id, message_id)
queue_json_publish("missedmessage_mobile_notifications", notice, lambda notice: None) queue_json_publish("missedmessage_mobile_notifications", notice, lambda notice: None)
notified = dict(push_notified=True) notified = dict(push_notified=True) # type: Dict[str, bool]
# Don't send missed message emails if always_push_notify is True # Don't send missed message emails if always_push_notify is True
if idle: if idle:
# We require RabbitMQ to do this, as we can't call the email handler # We require RabbitMQ to do this, as we can't call the email handler
@@ -681,8 +728,8 @@ def process_message_event(event_template, users):
for client_data in six.itervalues(send_to_clients): for client_data in six.itervalues(send_to_clients):
client = client_data['client'] client = client_data['client']
flags = client_data['flags'] flags = client_data['flags']
is_sender = client_data.get('is_sender', False) is_sender = client_data.get('is_sender', False) # type: bool
extra_data = extra_user_data.get(client.user_profile_id, None) extra_data = extra_user_data.get(client.user_profile_id, None) # type: Optional[Mapping[str, bool]]
if not client.accepts_messages(): if not client.accepts_messages():
# The actual check is the accepts_event() check below; # The actual check is the accepts_event() check below;
@@ -700,7 +747,7 @@ def process_message_event(event_template, users):
message_dict = message_dict.copy() message_dict = message_dict.copy()
message_dict["invite_only_stream"] = True message_dict["invite_only_stream"] = True
user_event = dict(type='message', message=message_dict, flags=flags) user_event = dict(type='message', message=message_dict, flags=flags) # type: Dict[str, Any]
if extra_data is not None: if extra_data is not None:
user_event.update(extra_data) user_event.update(extra_data)
@@ -719,15 +766,17 @@ def process_message_event(event_template, users):
client.add_event(user_event) client.add_event(user_event)
def process_event(event, users): def process_event(event, users):
# type: (Mapping[str, Any], Iterable[int]) -> None
for user_profile_id in users: for user_profile_id in users:
for client in get_client_descriptors_for_user(user_profile_id): for client in get_client_descriptors_for_user(user_profile_id):
if client.accepts_event(event): if client.accepts_event(event):
client.add_event(event.copy()) client.add_event(dict(event))
def process_userdata_event(event_template, users): def process_userdata_event(event_template, users):
# type: (Mapping[str, Any], Iterable[Mapping[str, Any]]) -> None
for user_data in users: for user_data in users:
user_profile_id = user_data['id'] user_profile_id = user_data['id']
user_event = event_template.copy() # shallow, but deep enough for our needs user_event = dict(event_template) # shallow copy, but deep enough for our needs
for key in user_data.keys(): for key in user_data.keys():
if key != "id": if key != "id":
user_event[key] = user_data[key] user_event[key] = user_data[key]
@@ -737,14 +786,15 @@ def process_userdata_event(event_template, users):
client.add_event(user_event) client.add_event(user_event)
def process_notification(notice): def process_notification(notice):
event = notice['event'] # type: (Mapping[str, Any]) -> None
users = notice['users'] event = notice['event'] # type: Mapping[str, Any]
users = notice['users'] # type: Union[Iterable[int], Iterable[Mapping[str, Any]]]
if event['type'] in ["update_message"]: if event['type'] in ["update_message"]:
process_userdata_event(event, users) process_userdata_event(event, cast(Iterable[Mapping[str, Any]], users))
elif event['type'] == "message": elif event['type'] == "message":
process_message_event(event, users) process_message_event(event, cast(Iterable[Mapping[str, Any]], users))
else: else:
process_event(event, users) process_event(event, cast(Iterable[int], users))
# Runs in the Django process to send a notification to Tornado. # Runs in the Django process to send a notification to Tornado.
# #
@@ -752,6 +802,7 @@ def process_notification(notice):
# different types and for compatibility with non-HTTP transports. # different types and for compatibility with non-HTTP transports.
def send_notification_http(data): def send_notification_http(data):
# type: (Mapping[str, Any]) -> None
if settings.TORNADO_SERVER and not settings.RUNNING_INSIDE_TORNADO: if settings.TORNADO_SERVER and not settings.RUNNING_INSIDE_TORNADO:
requests.post(settings.TORNADO_SERVER + '/notify_tornado', data=dict( requests.post(settings.TORNADO_SERVER + '/notify_tornado', data=dict(
data = ujson.dumps(data), data = ujson.dumps(data),
@@ -760,11 +811,11 @@ def send_notification_http(data):
process_notification(data) process_notification(data)
def send_notification(data): def send_notification(data):
# type: (Dict[str, Any]) -> None # type: (Mapping[str, Any]) -> None
queue_json_publish("notify_tornado", data, send_notification_http) queue_json_publish("notify_tornado", data, send_notification_http)
def send_event(event, users): def send_event(event, users):
# type: (Dict[str, Any], Union[Iterable[int], Iterable[Dict[str, Any]]]) -> None # type: (Mapping[str, Any], Union[Iterable[int], Iterable[Mapping[str, Any]]]) -> None
"""`users` is a list of user IDs, or in the case of `message` type """`users` is a list of user IDs, or in the case of `message` type
events, a list of dicts describing the users and metadata about events, a list of dicts describing the users and metadata about
the user/message pair.""" the user/message pair."""