Annotate zerver/lib/queue.py.

This commit is contained in:
Conrad Dean
2016-07-03 19:28:27 +05:30
committed by Eklavya Sharma
parent 192663edcf
commit 9812e676f0

View File

@@ -2,6 +2,8 @@ from __future__ import absolute_import
from django.conf import settings from django.conf import settings
import pika import pika
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic
import logging import logging
import ujson import ujson
import random import random
@@ -11,7 +13,9 @@ import atexit
from collections import defaultdict from collections import defaultdict
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
from typing import Any, Callable from typing import Any, Callable, Dict, Mapping, Optional, Set, Union
Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, str], None]
# This simple queuing library doesn't expose much of the power of # This simple queuing library doesn't expose much of the power of
# rabbitmq/pika's queuing system; its purpose is to just provide an # rabbitmq/pika's queuing system; its purpose is to just provide an
@@ -19,27 +23,31 @@ from typing import Any, Callable
# out from bots without having to import pika code all over our codebase. # out from bots without having to import pika code all over our codebase.
class SimpleQueueClient(object): class SimpleQueueClient(object):
def __init__(self): def __init__(self):
# type: () -> None
self.log = logging.getLogger('zulip.queue') self.log = logging.getLogger('zulip.queue')
self.queues = set() # type: Set[str] self.queues = set() # type: Set[str]
self.channel = None # type: Any self.channel = None # type: Optional[BlockingChannel]
self.consumers = defaultdict(set) # type: Dict[str, Set[Any]] self.consumers = defaultdict(set) # type: Dict[str, Set[Consumer]]
# Disable RabbitMQ heartbeats since BlockingConnection can't process them # Disable RabbitMQ heartbeats since BlockingConnection can't process them
self.rabbitmq_heartbeat = 0 self.rabbitmq_heartbeat = 0
self._connect() self._connect()
def _connect(self): def _connect(self):
# type: () -> None
start = time.time() start = time.time()
self.connection = pika.BlockingConnection(self._get_parameters()) self.connection = pika.BlockingConnection(self._get_parameters())
self.channel = self.connection.channel() self.channel = self.connection.channel()
self.log.info('SimpleQueueClient connected (connecting took %.3fs)' % (time.time() - start,)) self.log.info('SimpleQueueClient connected (connecting took %.3fs)' % (time.time() - start,))
def _reconnect(self): def _reconnect(self):
# type: () -> None
self.connection = None self.connection = None
self.channel = None self.channel = None
self.queues = set() self.queues = set()
self._connect() self._connect()
def _get_parameters(self): def _get_parameters(self):
# type: () -> pika.ConnectionParameters
# We explicitly disable the RabbitMQ heartbeat feature, since # We explicitly disable the RabbitMQ heartbeat feature, since
# it doesn't make sense with BlockingConnection # it doesn't make sense with BlockingConnection
credentials = pika.PlainCredentials(settings.RABBITMQ_USERNAME, credentials = pika.PlainCredentials(settings.RABBITMQ_USERNAME,
@@ -49,27 +57,33 @@ class SimpleQueueClient(object):
credentials=credentials) credentials=credentials)
def _generate_ctag(self, queue_name): def _generate_ctag(self, queue_name):
# type: (str) -> str
return "%s_%s" % (queue_name, str(random.getrandbits(16))) return "%s_%s" % (queue_name, str(random.getrandbits(16)))
def _reconnect_consumer_callback(self, queue, consumer): def _reconnect_consumer_callback(self, queue, consumer):
# type: (str, Consumer) -> None
self.log.info("Queue reconnecting saved consumer %s to queue %s" % (consumer, queue)) self.log.info("Queue reconnecting saved consumer %s to queue %s" % (consumer, queue))
self.ensure_queue(queue, lambda: self.channel.basic_consume(consumer, self.ensure_queue(queue, lambda: self.channel.basic_consume(consumer,
queue=queue, queue=queue,
consumer_tag=self._generate_ctag(queue))) consumer_tag=self._generate_ctag(queue)))
def _reconnect_consumer_callbacks(self): def _reconnect_consumer_callbacks(self):
# type: () -> None
for queue, consumers in self.consumers.items(): for queue, consumers in self.consumers.items():
for consumer in consumers: for consumer in consumers:
self._reconnect_consumer_callback(queue, consumer) self._reconnect_consumer_callback(queue, consumer)
def close(self): def close(self):
# type: () -> None
if self.connection: if self.connection:
self.connection.close() self.connection.close()
def ready(self): def ready(self):
# type: () -> bool
return self.channel is not None return self.channel is not None
def ensure_queue(self, queue_name, callback): def ensure_queue(self, queue_name, callback):
# type: (str, Callable[[], None]) -> None
'''Ensure that a given queue has been declared, and then call '''Ensure that a given queue has been declared, and then call
the callback with no arguments.''' the callback with no arguments.'''
if not self.connection.is_open: if not self.connection.is_open:
@@ -81,7 +95,9 @@ class SimpleQueueClient(object):
callback() callback()
def publish(self, queue_name, body): def publish(self, queue_name, body):
# type: (str, str) -> None
def do_publish(): def do_publish():
# type: () -> None
self.channel.basic_publish( self.channel.basic_publish(
exchange='', exchange='',
routing_key=queue_name, routing_key=queue_name,
@@ -93,6 +109,8 @@ class SimpleQueueClient(object):
self.ensure_queue(queue_name, do_publish) self.ensure_queue(queue_name, do_publish)
def json_publish(self, queue_name, body): def json_publish(self, queue_name, body):
# type: (str, Union[Mapping[str, Any], str]) -> None
# Union because of zerver.middleware.write_log_line uses a str
try: try:
self.publish(queue_name, ujson.dumps(body)) self.publish(queue_name, ujson.dumps(body))
except (AttributeError, pika.exceptions.AMQPConnectionError): except (AttributeError, pika.exceptions.AMQPConnectionError):
@@ -102,7 +120,9 @@ class SimpleQueueClient(object):
self.publish(queue_name, ujson.dumps(body)) self.publish(queue_name, ujson.dumps(body))
def register_consumer(self, queue_name, consumer): def register_consumer(self, queue_name, consumer):
# type: (str, Consumer) -> None
def wrapped_consumer(ch, method, properties, body): def wrapped_consumer(ch, method, properties, body):
# type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
try: try:
consumer(ch, method, properties, body) consumer(ch, method, properties, body)
ch.basic_ack(delivery_tag=method.delivery_tag) ch.basic_ack(delivery_tag=method.delivery_tag)
@@ -116,14 +136,18 @@ class SimpleQueueClient(object):
consumer_tag=self._generate_ctag(queue_name))) consumer_tag=self._generate_ctag(queue_name)))
def register_json_consumer(self, queue_name, callback): def register_json_consumer(self, queue_name, callback):
# type: (str, Callable[[Mapping[str, Any]], None]) -> None
def wrapped_callback(ch, method, properties, body): def wrapped_callback(ch, method, properties, body):
return callback(ujson.loads(body)) # type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
return self.register_consumer(queue_name, wrapped_callback) callback(ujson.loads(body))
self.register_consumer(queue_name, wrapped_callback)
def drain_queue(self, queue_name, json=False): def drain_queue(self, queue_name, json=False):
# type: (str, bool) -> List[Dict[str, Any]]
"Returns all messages in the desired queue" "Returns all messages in the desired queue"
messages = [] messages = []
def opened(): def opened():
# type: () -> None
while True: while True:
(meta, _, message) = self.channel.basic_get(queue_name) (meta, _, message) = self.channel.basic_get(queue_name)
@@ -139,9 +163,11 @@ class SimpleQueueClient(object):
return messages return messages
def start_consuming(self): def start_consuming(self):
# type: () -> None
self.channel.start_consuming() self.channel.start_consuming()
def stop_consuming(self): def stop_consuming(self):
# type: () -> None
self.channel.stop_consuming() self.channel.stop_consuming()
# Patch pika.adapters.TornadoConnection so that a socket error doesn't # Patch pika.adapters.TornadoConnection so that a socket error doesn't
@@ -149,6 +175,7 @@ class SimpleQueueClient(object):
# queue. Instead, just re-connect as usual # queue. Instead, just re-connect as usual
class ExceptionFreeTornadoConnection(pika.adapters.TornadoConnection): class ExceptionFreeTornadoConnection(pika.adapters.TornadoConnection):
def _adapter_disconnect(self): def _adapter_disconnect(self):
# type: () -> None
try: try:
super(ExceptionFreeTornadoConnection, self)._adapter_disconnect() super(ExceptionFreeTornadoConnection, self)._adapter_disconnect()
except (pika.exceptions.ProbableAuthenticationError, except (pika.exceptions.ProbableAuthenticationError,
@@ -162,14 +189,16 @@ class TornadoQueueClient(SimpleQueueClient):
# Based on: # Based on:
# https://pika.readthedocs.io/en/0.9.8/examples/asynchronous_consumer_example.html # https://pika.readthedocs.io/en/0.9.8/examples/asynchronous_consumer_example.html
def __init__(self): def __init__(self):
# type: () -> None
super(TornadoQueueClient, self).__init__() super(TornadoQueueClient, self).__init__()
# Enable rabbitmq heartbeat since TornadoConection can process them # Enable rabbitmq heartbeat since TornadoConection can process them
self.rabbitmq_heartbeat = None self.rabbitmq_heartbeat = None
self._on_open_cbs = [] # type: List[Callable[[], None]] self._on_open_cbs = [] # type: List[Callable[[], None]]
def _connect(self, on_open_cb = None): def _connect(self, on_open_cb = None):
# type: (Optional[Callable[[], None]]) -> None
self.log.info("Beginning TornadoQueueClient connection") self.log.info("Beginning TornadoQueueClient connection")
if on_open_cb: if on_open_cb is not None:
self._on_open_cbs.append(on_open_cb) self._on_open_cbs.append(on_open_cb)
self.connection = ExceptionFreeTornadoConnection( self.connection = ExceptionFreeTornadoConnection(
self._get_parameters(), self._get_parameters(),
@@ -178,16 +207,19 @@ class TornadoQueueClient(SimpleQueueClient):
self.connection.add_on_close_callback(self._on_connection_closed) self.connection.add_on_close_callback(self._on_connection_closed)
def _reconnect(self): def _reconnect(self):
# type: () -> None
self.connection = None self.connection = None
self.channel = None self.channel = None
self.queues = set() self.queues = set()
self._connect() self._connect()
def _on_open(self, connection): def _on_open(self, connection):
# type: (pika.Connection) -> None
self.connection.channel( self.connection.channel(
on_open_callback = self._on_channel_open) on_open_callback = self._on_channel_open)
def _on_channel_open(self, channel): def _on_channel_open(self, channel):
# type: (BlockingChannel) -> None
self.channel = channel self.channel = channel
for callback in self._on_open_cbs: for callback in self._on_open_cbs:
callback() callback()
@@ -195,12 +227,14 @@ class TornadoQueueClient(SimpleQueueClient):
self.log.info('TornadoQueueClient connected') self.log.info('TornadoQueueClient connected')
def _on_connection_closed(self, connection, reply_code, reply_text): def _on_connection_closed(self, connection, reply_code, reply_text):
# type: (pika.Connection, int, str) -> None
self.log.warning("TornadoQueueClient lost connection to RabbitMQ, reconnecting...") self.log.warning("TornadoQueueClient lost connection to RabbitMQ, reconnecting...")
from tornado import ioloop from tornado import ioloop
# Try to reconnect in two seconds # Try to reconnect in two seconds
retry_seconds = 2 retry_seconds = 2
def on_timeout(): def on_timeout():
# type: () -> None
try: try:
self._reconnect() self._reconnect()
except pika.exceptions.AMQPConnectionError: except pika.exceptions.AMQPConnectionError:
@@ -210,7 +244,9 @@ class TornadoQueueClient(SimpleQueueClient):
ioloop.IOLoop.instance().add_timeout(time.time() + retry_seconds, on_timeout) ioloop.IOLoop.instance().add_timeout(time.time() + retry_seconds, on_timeout)
def ensure_queue(self, queue_name, callback): def ensure_queue(self, queue_name, callback):
# type: (str, Callable[[], None]) -> None
def finish(frame): def finish(frame):
# type: (Any) -> None
self.queues.add(queue_name) self.queues.add(queue_name)
callback() callback()
@@ -226,7 +262,9 @@ class TornadoQueueClient(SimpleQueueClient):
callback() callback()
def register_consumer(self, queue_name, consumer): def register_consumer(self, queue_name, consumer):
# type: (str, Consumer) -> None
def wrapped_consumer(ch, method, properties, body): def wrapped_consumer(ch, method, properties, body):
# type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
consumer(ch, method, properties, body) consumer(ch, method, properties, body)
ch.basic_ack(delivery_tag=method.delivery_tag) ch.basic_ack(delivery_tag=method.delivery_tag)
@@ -239,8 +277,9 @@ class TornadoQueueClient(SimpleQueueClient):
lambda: self.channel.basic_consume(wrapped_consumer, queue=queue_name, lambda: self.channel.basic_consume(wrapped_consumer, queue=queue_name,
consumer_tag=self._generate_ctag(queue_name))) consumer_tag=self._generate_ctag(queue_name)))
queue_client = None # type: SimpleQueueClient queue_client = None # type: Optional[SimpleQueueClient]
def get_queue_client(): def get_queue_client():
# type: () -> SimpleQueueClient
global queue_client global queue_client
if queue_client is None: if queue_client is None:
if settings.RUNNING_INSIDE_TORNADO and settings.USING_RABBITMQ: if settings.RUNNING_INSIDE_TORNADO and settings.USING_RABBITMQ:
@@ -251,6 +290,7 @@ def get_queue_client():
return queue_client return queue_client
def setup_tornado_rabbitmq(): def setup_tornado_rabbitmq():
# type: () -> None
# When tornado is shut down, disconnect cleanly from rabbitmq # When tornado is shut down, disconnect cleanly from rabbitmq
if settings.USING_RABBITMQ: if settings.USING_RABBITMQ:
atexit.register(lambda: queue_client.close()) atexit.register(lambda: queue_client.close())
@@ -263,9 +303,10 @@ def setup_tornado_rabbitmq():
queue_lock = threading.RLock() queue_lock = threading.RLock()
def queue_json_publish(queue_name, event, processor): def queue_json_publish(queue_name, event, processor):
# type: (str, Union[Mapping[str, Any], str], Callable[[Any], None]) -> None
# most events are dicts, but zerver.middleware.write_log_line uses a str
with queue_lock: with queue_lock:
if settings.USING_RABBITMQ: if settings.USING_RABBITMQ:
get_queue_client().json_publish(queue_name, event) get_queue_client().json_publish(queue_name, event)
else: else:
processor(event) processor(event)