Files
zulip/zerver/lib/parallel.py

78 lines
2.5 KiB
Python

import logging
from collections.abc import Callable, Iterable
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import current_process
from typing import Any, TypeVar
import bmemcached
from django.conf import settings
from django.core.cache import cache
from django.db import connection
from zerver.lib.partial import partial
from zerver.lib.queue import get_queue_client
ParallelRecordType = TypeVar("ParallelRecordType")
def _disconnect() -> None:
# Close our database, cache, and RabbitMQ connections, so our
# forked children do not share them. Django will transparently
# re-open them as needed.
connection.close()
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
if isinstance(_cache, bmemcached.Client): # nocoverage
# In tests, this is an OrderedDict
_cache.disconnect_all()
if settings.USING_RABBITMQ: # nocoverage
rabbitmq_client = get_queue_client()
if rabbitmq_client.connection and rabbitmq_client.connection.is_open:
rabbitmq_client.close()
def func_with_catch(func: Callable[[ParallelRecordType], None], item: ParallelRecordType) -> None:
try:
return func(item)
except Exception:
logging.exception("Error processing item: %s", item)
def run_parallel(
func: Callable[[ParallelRecordType], None],
records: Iterable[ParallelRecordType],
processes: int,
*,
initializer: Callable[..., None] | None = None,
initargs: tuple[Any, ...] = tuple(),
catch: bool = False,
report_every: int = 1000,
report: Callable[[int], None] | None = None,
) -> None: # nocoverage
assert processes > 0
if settings.TEST_SUITE and current_process().daemon:
assert processes == 1, "Only one process possible under parallel tests"
wrapped_func = partial(func_with_catch, func) if catch else func
if processes == 1:
if initializer is not None:
initializer(*initargs)
for count, record in enumerate(records, 1):
wrapped_func(record)
if report is not None and count % report_every == 0:
report(count)
return
_disconnect()
with ProcessPoolExecutor(
max_workers=processes, initializer=initializer, initargs=initargs
) as executor:
for count, future in enumerate(
as_completed(executor.submit(wrapped_func, record) for record in records), 1
):
future.result()
if report is not None and count % report_every == 0:
report(count)