mirror of
https://github.com/zulip/zulip.git
synced 2025-11-10 17:07:07 +00:00
queue: Fix strict_optional errors.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
committed by
Tim Abbott
parent
6c9c12ee2d
commit
489d73f63a
@@ -80,9 +80,14 @@ class SimpleQueueClient:
|
||||
|
||||
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None:
|
||||
self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}")
|
||||
self.ensure_queue(queue, lambda: self.channel.basic_consume(queue,
|
||||
consumer,
|
||||
consumer_tag=self._generate_ctag(queue)))
|
||||
self.ensure_queue(
|
||||
queue,
|
||||
lambda channel: channel.basic_consume(
|
||||
queue,
|
||||
consumer,
|
||||
consumer_tag=self._generate_ctag(queue),
|
||||
),
|
||||
)
|
||||
|
||||
def _reconnect_consumer_callbacks(self) -> None:
|
||||
for queue, consumers in self.consumers.items():
|
||||
@@ -96,20 +101,21 @@ class SimpleQueueClient:
|
||||
def ready(self) -> bool:
|
||||
return self.channel is not None
|
||||
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
'''Ensure that a given queue has been declared, and then call
|
||||
the callback with no arguments.'''
|
||||
if self.connection is None or not self.connection.is_open:
|
||||
self._connect()
|
||||
|
||||
assert self.channel is not None
|
||||
if queue_name not in self.queues:
|
||||
self.channel.queue_declare(queue=queue_name, durable=True)
|
||||
self.queues.add(queue_name)
|
||||
callback()
|
||||
callback(self.channel)
|
||||
|
||||
def publish(self, queue_name: str, body: bytes) -> None:
|
||||
def do_publish() -> None:
|
||||
self.channel.basic_publish(
|
||||
def do_publish(channel: BlockingChannel) -> None:
|
||||
channel.basic_publish(
|
||||
exchange='',
|
||||
routing_key=queue_name,
|
||||
properties=pika.BasicProperties(delivery_mode=2),
|
||||
@@ -143,9 +149,14 @@ class SimpleQueueClient:
|
||||
raise e
|
||||
|
||||
self.consumers[queue_name].add(wrapped_consumer)
|
||||
self.ensure_queue(queue_name,
|
||||
lambda: self.channel.basic_consume(queue_name, wrapped_consumer,
|
||||
consumer_tag=self._generate_ctag(queue_name)))
|
||||
self.ensure_queue(
|
||||
queue_name,
|
||||
lambda channel: channel.basic_consume(
|
||||
queue_name,
|
||||
wrapped_consumer,
|
||||
consumer_tag=self._generate_ctag(queue_name),
|
||||
),
|
||||
)
|
||||
|
||||
def register_json_consumer(self, queue_name: str,
|
||||
callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
@@ -160,14 +171,14 @@ class SimpleQueueClient:
|
||||
"Returns all messages in the desired queue"
|
||||
messages = []
|
||||
|
||||
def opened() -> None:
|
||||
def opened(channel: BlockingChannel) -> None:
|
||||
while True:
|
||||
(meta, _, message) = self.channel.basic_get(queue_name)
|
||||
(meta, _, message) = channel.basic_get(queue_name)
|
||||
|
||||
if message is None:
|
||||
break
|
||||
|
||||
self.channel.basic_ack(meta.delivery_tag)
|
||||
channel.basic_ack(meta.delivery_tag)
|
||||
messages.append(message)
|
||||
|
||||
self.ensure_queue(queue_name, opened)
|
||||
@@ -177,12 +188,15 @@ class SimpleQueueClient:
|
||||
return list(map(ujson.loads, self.drain_queue(queue_name)))
|
||||
|
||||
def queue_size(self) -> int:
|
||||
assert self.channel is not None
|
||||
return len(self.channel._pending_events)
|
||||
|
||||
def start_consuming(self) -> None:
|
||||
assert self.channel is not None
|
||||
self.channel.start_consuming()
|
||||
|
||||
def stop_consuming(self) -> None:
|
||||
assert self.channel is not None
|
||||
self.channel.stop_consuming()
|
||||
|
||||
# Patch pika.adapters.tornado_connection.TornadoConnection so that a socket error doesn't
|
||||
@@ -206,7 +220,7 @@ class TornadoQueueClient(SimpleQueueClient):
|
||||
super().__init__(
|
||||
# TornadoConnection can process heartbeats, so enable them.
|
||||
rabbitmq_heartbeat=None)
|
||||
self._on_open_cbs: List[Callable[[], None]] = []
|
||||
self._on_open_cbs: List[Callable[[BlockingChannel], None]] = []
|
||||
self._connection_failure_count = 0
|
||||
|
||||
def _connect(self) -> None:
|
||||
@@ -275,25 +289,28 @@ class TornadoQueueClient(SimpleQueueClient):
|
||||
def _on_channel_open(self, channel: BlockingChannel) -> None:
|
||||
self.channel = channel
|
||||
for callback in self._on_open_cbs:
|
||||
callback()
|
||||
callback(channel)
|
||||
self._reconnect_consumer_callbacks()
|
||||
self.log.info('TornadoQueueClient connected')
|
||||
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
def finish(frame: Any) -> None:
|
||||
assert self.channel is not None
|
||||
self.queues.add(queue_name)
|
||||
callback()
|
||||
callback(self.channel)
|
||||
|
||||
if queue_name not in self.queues:
|
||||
# If we're not connected yet, send this message
|
||||
# once we have created the channel
|
||||
if not self.ready():
|
||||
self._on_open_cbs.append(lambda: self.ensure_queue(queue_name, callback))
|
||||
self._on_open_cbs.append(lambda channel: self.ensure_queue(queue_name, callback))
|
||||
return
|
||||
|
||||
assert self.channel is not None
|
||||
self.channel.queue_declare(queue=queue_name, durable=True, callback=finish)
|
||||
else:
|
||||
callback()
|
||||
assert self.channel is not None
|
||||
callback(self.channel)
|
||||
|
||||
def register_consumer(self, queue_name: str, consumer: Consumer) -> None:
|
||||
def wrapped_consumer(ch: BlockingChannel,
|
||||
@@ -308,9 +325,14 @@ class TornadoQueueClient(SimpleQueueClient):
|
||||
return
|
||||
|
||||
self.consumers[queue_name].add(wrapped_consumer)
|
||||
self.ensure_queue(queue_name,
|
||||
lambda: self.channel.basic_consume(queue_name, wrapped_consumer,
|
||||
consumer_tag=self._generate_ctag(queue_name)))
|
||||
self.ensure_queue(
|
||||
queue_name,
|
||||
lambda channel: channel.basic_consume(
|
||||
queue_name,
|
||||
wrapped_consumer,
|
||||
consumer_tag=self._generate_ctag(queue_name),
|
||||
),
|
||||
)
|
||||
|
||||
queue_client: Optional[SimpleQueueClient] = None
|
||||
def get_queue_client() -> SimpleQueueClient:
|
||||
@@ -320,6 +342,8 @@ def get_queue_client() -> SimpleQueueClient:
|
||||
queue_client = TornadoQueueClient()
|
||||
elif settings.USING_RABBITMQ:
|
||||
queue_client = SimpleQueueClient()
|
||||
else:
|
||||
raise RuntimeError("Cannot get a queue client without USING_RABBITMQ")
|
||||
|
||||
return queue_client
|
||||
|
||||
@@ -330,9 +354,11 @@ def get_queue_client() -> SimpleQueueClient:
|
||||
# randomly close.
|
||||
queue_lock = threading.RLock()
|
||||
|
||||
def queue_json_publish(queue_name: str,
|
||||
event: Dict[str, Any],
|
||||
processor: Callable[[Any], None]=None) -> None:
|
||||
def queue_json_publish(
|
||||
queue_name: str,
|
||||
event: Dict[str, Any],
|
||||
processor: Optional[Callable[[Any], None]] = None,
|
||||
) -> None:
|
||||
# most events are dicts, but zerver.middleware.write_log_line uses a str
|
||||
with queue_lock:
|
||||
if settings.USING_RABBITMQ:
|
||||
|
||||
Reference in New Issue
Block a user