queue: Fix strict_optional errors.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2020-07-04 18:18:11 -07:00
committed by Tim Abbott
parent 6c9c12ee2d
commit 489d73f63a
3 changed files with 52 additions and 34 deletions

View File

@@ -80,9 +80,14 @@ class SimpleQueueClient:
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None:
self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}")
self.ensure_queue(queue, lambda: self.channel.basic_consume(queue,
consumer,
consumer_tag=self._generate_ctag(queue)))
self.ensure_queue(
queue,
lambda channel: channel.basic_consume(
queue,
consumer,
consumer_tag=self._generate_ctag(queue),
),
)
def _reconnect_consumer_callbacks(self) -> None:
for queue, consumers in self.consumers.items():
@@ -96,20 +101,21 @@ class SimpleQueueClient:
def ready(self) -> bool:
return self.channel is not None
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
'''Ensure that a given queue has been declared, and then call
the callback with no arguments.'''
if self.connection is None or not self.connection.is_open:
self._connect()
assert self.channel is not None
if queue_name not in self.queues:
self.channel.queue_declare(queue=queue_name, durable=True)
self.queues.add(queue_name)
callback()
callback(self.channel)
def publish(self, queue_name: str, body: bytes) -> None:
def do_publish() -> None:
self.channel.basic_publish(
def do_publish(channel: BlockingChannel) -> None:
channel.basic_publish(
exchange='',
routing_key=queue_name,
properties=pika.BasicProperties(delivery_mode=2),
@@ -143,9 +149,14 @@ class SimpleQueueClient:
raise e
self.consumers[queue_name].add(wrapped_consumer)
self.ensure_queue(queue_name,
lambda: self.channel.basic_consume(queue_name, wrapped_consumer,
consumer_tag=self._generate_ctag(queue_name)))
self.ensure_queue(
queue_name,
lambda channel: channel.basic_consume(
queue_name,
wrapped_consumer,
consumer_tag=self._generate_ctag(queue_name),
),
)
def register_json_consumer(self, queue_name: str,
callback: Callable[[Dict[str, Any]], None]) -> None:
@@ -160,14 +171,14 @@ class SimpleQueueClient:
"Returns all messages in the desired queue"
messages = []
def opened() -> None:
def opened(channel: BlockingChannel) -> None:
while True:
(meta, _, message) = self.channel.basic_get(queue_name)
(meta, _, message) = channel.basic_get(queue_name)
if message is None:
break
self.channel.basic_ack(meta.delivery_tag)
channel.basic_ack(meta.delivery_tag)
messages.append(message)
self.ensure_queue(queue_name, opened)
@@ -177,12 +188,15 @@ class SimpleQueueClient:
return list(map(ujson.loads, self.drain_queue(queue_name)))
def queue_size(self) -> int:
assert self.channel is not None
return len(self.channel._pending_events)
def start_consuming(self) -> None:
assert self.channel is not None
self.channel.start_consuming()
def stop_consuming(self) -> None:
assert self.channel is not None
self.channel.stop_consuming()
# Patch pika.adapters.tornado_connection.TornadoConnection so that a socket error doesn't
@@ -206,7 +220,7 @@ class TornadoQueueClient(SimpleQueueClient):
super().__init__(
# TornadoConnection can process heartbeats, so enable them.
rabbitmq_heartbeat=None)
self._on_open_cbs: List[Callable[[], None]] = []
self._on_open_cbs: List[Callable[[BlockingChannel], None]] = []
self._connection_failure_count = 0
def _connect(self) -> None:
@@ -275,25 +289,28 @@ class TornadoQueueClient(SimpleQueueClient):
def _on_channel_open(self, channel: BlockingChannel) -> None:
self.channel = channel
for callback in self._on_open_cbs:
callback()
callback(channel)
self._reconnect_consumer_callbacks()
self.log.info('TornadoQueueClient connected')
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
def finish(frame: Any) -> None:
assert self.channel is not None
self.queues.add(queue_name)
callback()
callback(self.channel)
if queue_name not in self.queues:
# If we're not connected yet, send this message
# once we have created the channel
if not self.ready():
self._on_open_cbs.append(lambda: self.ensure_queue(queue_name, callback))
self._on_open_cbs.append(lambda channel: self.ensure_queue(queue_name, callback))
return
assert self.channel is not None
self.channel.queue_declare(queue=queue_name, durable=True, callback=finish)
else:
callback()
assert self.channel is not None
callback(self.channel)
def register_consumer(self, queue_name: str, consumer: Consumer) -> None:
def wrapped_consumer(ch: BlockingChannel,
@@ -308,9 +325,14 @@ class TornadoQueueClient(SimpleQueueClient):
return
self.consumers[queue_name].add(wrapped_consumer)
self.ensure_queue(queue_name,
lambda: self.channel.basic_consume(queue_name, wrapped_consumer,
consumer_tag=self._generate_ctag(queue_name)))
self.ensure_queue(
queue_name,
lambda channel: channel.basic_consume(
queue_name,
wrapped_consumer,
consumer_tag=self._generate_ctag(queue_name),
),
)
queue_client: Optional[SimpleQueueClient] = None
def get_queue_client() -> SimpleQueueClient:
@@ -320,6 +342,8 @@ def get_queue_client() -> SimpleQueueClient:
queue_client = TornadoQueueClient()
elif settings.USING_RABBITMQ:
queue_client = SimpleQueueClient()
else:
raise RuntimeError("Cannot get a queue client without USING_RABBITMQ")
return queue_client
@@ -330,9 +354,11 @@ def get_queue_client() -> SimpleQueueClient:
# randomly close.
queue_lock = threading.RLock()
def queue_json_publish(queue_name: str,
event: Dict[str, Any],
processor: Callable[[Any], None]=None) -> None:
def queue_json_publish(
queue_name: str,
event: Dict[str, Any],
processor: Optional[Callable[[Any], None]] = None,
) -> None:
# most events are dicts, but zerver.middleware.write_log_line uses a str
with queue_lock:
if settings.USING_RABBITMQ: