parallel: Factor out multiple callsites that use ProcessPoolExecutor.

This commit is contained in:
Alex Vandiver
2024-01-18 15:26:41 +00:00
committed by Tim Abbott
parent 454905f988
commit 131580f23c
6 changed files with 377 additions and 93 deletions

View File

@@ -5,7 +5,6 @@ import shutil
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Mapping from collections.abc import Callable, Iterable, Iterator, Mapping
from collections.abc import Set as AbstractSet from collections.abc import Set as AbstractSet
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any, Protocol, TypeAlias, TypeVar from typing import Any, Protocol, TypeAlias, TypeVar
import orjson import orjson
@@ -19,6 +18,7 @@ from zerver.data_import.sequencer import NEXT_ID
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
from zerver.lib.message import normalize_body_for_import from zerver.lib.message import normalize_body_for_import
from zerver.lib.mime_types import INLINE_MIME_TYPES, guess_extension from zerver.lib.mime_types import INLINE_MIME_TYPES, guess_extension
from zerver.lib.parallel import run_parallel
from zerver.lib.partial import partial from zerver.lib.partial import partial
from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS as STREAM_COLORS from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS as STREAM_COLORS
from zerver.lib.thumbnail import THUMBNAIL_ACCEPT_IMAGE_TYPES, BadImageError from zerver.lib.thumbnail import THUMBNAIL_ACCEPT_IMAGE_TYPES, BadImageError
@@ -634,38 +634,18 @@ def process_avatars(
avatar_original_list.append(avatar_original) avatar_original_list.append(avatar_original)
# Run downloads in parallel # Run downloads in parallel
run_parallel_wrapper( run_parallel(
partial(get_avatar, avatar_dir, size_url_suffix), avatar_upload_list, threads=threads partial(get_avatar, avatar_dir, size_url_suffix),
avatar_upload_list,
processes=threads,
catch=True,
report=lambda count: logging.info("Finished %s items", count),
) )
logging.info("######### GETTING AVATARS FINISHED #########\n") logging.info("######### GETTING AVATARS FINISHED #########\n")
return avatar_list + avatar_original_list return avatar_list + avatar_original_list
ListJobData = TypeVar("ListJobData")
def wrapping_function(f: Callable[[ListJobData], None], item: ListJobData) -> None:
try:
f(item)
except Exception:
logging.exception("Error processing item: %s", item, stack_info=True)
def run_parallel_wrapper(
f: Callable[[ListJobData], None], full_items: list[ListJobData], threads: int = 6
) -> None:
logging.info("Distributing %s items across %s threads", len(full_items), threads)
with ProcessPoolExecutor(max_workers=threads) as executor:
for count, future in enumerate(
as_completed(executor.submit(wrapping_function, f, item) for item in full_items), 1
):
future.result()
if count % 1000 == 0:
logging.info("Finished %s items", count)
def get_uploads(upload_dir: str, upload: list[str]) -> None: def get_uploads(upload_dir: str, upload: list[str]) -> None:
upload_url = upload[0] upload_url = upload[0]
upload_path = upload[1] upload_path = upload[1]
@@ -697,7 +677,13 @@ def process_uploads(
upload["path"] = upload_s3_path upload["path"] = upload_s3_path
# Run downloads in parallel # Run downloads in parallel
run_parallel_wrapper(partial(get_uploads, upload_dir), upload_url_list, threads=threads) run_parallel(
partial(get_uploads, upload_dir),
upload_url_list,
processes=threads,
catch=True,
report=lambda count: logging.info("Finished %s items", count),
)
logging.info("######### GETTING ATTACHMENTS FINISHED #########\n") logging.info("######### GETTING ATTACHMENTS FINISHED #########\n")
return upload_list return upload_list

View File

@@ -2,17 +2,14 @@ import collections
import logging import logging
import os import os
import shutil import shutil
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from difflib import unified_diff from difflib import unified_diff
from typing import Any, TypeAlias from typing import Any, TypeAlias
import bmemcached
import orjson import orjson
import pyvips import pyvips
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from django.core.management.base import CommandError from django.core.management.base import CommandError
from django.core.validators import validate_email from django.core.validators import validate_email
from django.db import connection, transaction from django.db import connection, transaction
@@ -44,6 +41,7 @@ from zerver.lib.onboarding import (
send_initial_direct_messages_to_user, send_initial_direct_messages_to_user,
send_initial_realm_messages, send_initial_realm_messages,
) )
from zerver.lib.parallel import run_parallel
from zerver.lib.partial import partial from zerver.lib.partial import partial
from zerver.lib.push_notifications import sends_notifications_directly from zerver.lib.push_notifications import sends_notifications_directly
from zerver.lib.remote_server import maybe_enqueue_audit_log_upload from zerver.lib.remote_server import maybe_enqueue_audit_log_upload
@@ -1143,19 +1141,12 @@ def import_uploads(
# TODO: This implementation is hacky, both in that it # TODO: This implementation is hacky, both in that it
# does get_user_profile_by_id for each user, and in that it # does get_user_profile_by_id for each user, and in that it
# might be better to require the export to just have these. # might be better to require the export to just have these.
if processes == 1: run_parallel(
for record in records: process_func,
process_func(record) records,
else: processes if s3_uploads else 1,
connection.close() report=lambda count: logging.info("Processed %s/%s avatars", count, len(records)),
_cache = cache._cache # type: ignore[attr-defined] # not in stubs )
assert isinstance(_cache, bmemcached.Client)
_cache.disconnect_all()
with ProcessPoolExecutor(max_workers=processes) as executor:
for future in as_completed(
executor.submit(process_func, record) for record in records
):
future.result()
def disable_restricted_authentication_methods(data: ImportedTableData) -> None: def disable_restricted_authentication_methods(data: ImportedTableData) -> None:

77
zerver/lib/parallel.py Normal file
View File

@@ -0,0 +1,77 @@
import logging
from collections.abc import Callable, Iterable
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import current_process
from typing import Any, TypeVar
import bmemcached
from django.conf import settings
from django.core.cache import cache
from django.db import connection
from zerver.lib.partial import partial
from zerver.lib.queue import get_queue_client
ParallelRecordType = TypeVar("ParallelRecordType")
def _disconnect() -> None:
# Close our database, cache, and RabbitMQ connections, so our
# forked children do not share them. Django will transparently
# re-open them as needed.
connection.close()
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
if isinstance(_cache, bmemcached.Client): # nocoverage
# In tests, this is an OrderedDict
_cache.disconnect_all()
if settings.USING_RABBITMQ: # nocoverage
rabbitmq_client = get_queue_client()
if rabbitmq_client.connection and rabbitmq_client.connection.is_open:
rabbitmq_client.close()
def func_with_catch(func: Callable[[ParallelRecordType], None], item: ParallelRecordType) -> None:
try:
return func(item)
except Exception:
logging.exception("Error processing item: %s", item)
def run_parallel(
func: Callable[[ParallelRecordType], None],
records: Iterable[ParallelRecordType],
processes: int,
*,
initializer: Callable[..., None] | None = None,
initargs: tuple[Any, ...] = tuple(),
catch: bool = False,
report_every: int = 1000,
report: Callable[[int], None] | None = None,
) -> None: # nocoverage
assert processes > 0
if settings.TEST_SUITE and current_process().daemon:
assert processes == 1, "Only one process possible under parallel tests"
wrapped_func = partial(func_with_catch, func) if catch else func
if processes == 1:
if initializer is not None:
initializer(*initargs)
for count, record in enumerate(records, 1):
wrapped_func(record)
if report is not None and count % report_every == 0:
report(count)
return
_disconnect()
with ProcessPoolExecutor(
max_workers=processes, initializer=initializer, initargs=initargs
) as executor:
for count, future in enumerate(
as_completed(executor.submit(wrapped_func, record) for record in records), 1
):
future.result()
if report is not None and count % report_every == 0:
report(count)

View File

@@ -1,16 +1,13 @@
import logging import logging
import os import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from glob import glob from glob import glob
import bmemcached
import magic import magic
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from django.db import connection
from zerver.lib.avatar_hash import user_avatar_path from zerver.lib.avatar_hash import user_avatar_path
from zerver.lib.mime_types import guess_type from zerver.lib.mime_types import guess_type
from zerver.lib.parallel import run_parallel
from zerver.lib.thumbnail import BadImageError from zerver.lib.thumbnail import BadImageError
from zerver.lib.upload import upload_emoji_image, write_avatar_images from zerver.lib.upload import upload_emoji_image, write_avatar_images
from zerver.lib.upload.s3 import S3UploadBackend, upload_content_to_s3 from zerver.lib.upload.s3 import S3UploadBackend, upload_content_to_s3
@@ -52,20 +49,7 @@ def _transfer_avatar_to_s3(user: UserProfile) -> None:
def transfer_avatars_to_s3(processes: int) -> None: def transfer_avatars_to_s3(processes: int) -> None:
users = list(UserProfile.objects.all()) run_parallel(_transfer_avatar_to_s3, UserProfile.objects.all(), processes)
if processes == 1:
for user in users:
_transfer_avatar_to_s3(user)
else: # nocoverage
connection.close()
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
assert isinstance(_cache, bmemcached.Client)
_cache.disconnect_all()
with ProcessPoolExecutor(max_workers=processes) as executor:
for future in as_completed(
executor.submit(_transfer_avatar_to_s3, user) for user in users
):
future.result()
def _transfer_message_files_to_s3(attachment: Attachment) -> None: def _transfer_message_files_to_s3(attachment: Attachment) -> None:
@@ -118,21 +102,7 @@ def _transfer_message_files_to_s3(attachment: Attachment) -> None:
def transfer_message_files_to_s3(processes: int) -> None: def transfer_message_files_to_s3(processes: int) -> None:
attachments = list(Attachment.objects.all()) run_parallel(_transfer_message_files_to_s3, Attachment.objects.all(), processes)
if processes == 1:
for attachment in attachments:
_transfer_message_files_to_s3(attachment)
else: # nocoverage
connection.close()
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
assert isinstance(_cache, bmemcached.Client)
_cache.disconnect_all()
with ProcessPoolExecutor(max_workers=processes) as executor:
for future in as_completed(
executor.submit(_transfer_message_files_to_s3, attachment)
for attachment in attachments
):
future.result()
def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None: def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None:
@@ -164,17 +134,4 @@ def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None:
def transfer_emoji_to_s3(processes: int) -> None: def transfer_emoji_to_s3(processes: int) -> None:
realm_emojis = list(RealmEmoji.objects.filter()) run_parallel(_transfer_emoji_to_s3, RealmEmoji.objects.filter(), processes)
if processes == 1:
for realm_emoji in realm_emojis:
_transfer_emoji_to_s3(realm_emoji)
else: # nocoverage
connection.close()
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
assert isinstance(_cache, bmemcached.Client)
_cache.disconnect_all()
with ProcessPoolExecutor(max_workers=processes) as executor:
for future in as_completed(
executor.submit(_transfer_emoji_to_s3, realm_emoji) for realm_emoji in realm_emojis
):
future.result()

View File

@@ -0,0 +1,273 @@
import glob
import logging
import os
import shutil
import tempfile
import time
from collections import Counter
from multiprocessing import current_process
from django.db import connection
from zerver.lib.parallel import _disconnect, run_parallel
from zerver.lib.partial import partial
from zerver.lib.test_classes import ZulipTestCase
from zerver.models import Realm
class RunNotParallelTest(ZulipTestCase):
def test_disconnect(self) -> None:
self.assertTrue(connection.is_usable())
self.assertEqual(Realm.objects.count(), 4)
_disconnect()
self.assertFalse(connection.is_usable())
def test_not_parallel(self) -> None:
# Nothing here is parallel, or forks at all
events = []
run_parallel(
lambda item: events.append(f"Item: {item}"),
range(100, 110),
processes=1,
initializer=lambda a, b: events.append(f"Init: {a}, {b}"),
initargs=("alpha", "bravo"),
report_every=3,
report=lambda n: events.append(f"Completed {n}"),
)
self.assertEqual(
events,
[
"Init: alpha, bravo",
"Item: 100",
"Item: 101",
"Item: 102",
"Completed 3",
"Item: 103",
"Item: 104",
"Item: 105",
"Completed 6",
"Item: 106",
"Item: 107",
"Item: 108",
"Completed 9",
"Item: 109",
],
)
def test_not_parallel_throw(self) -> None:
events = []
def do_work(item: int) -> None:
if item == 103:
raise Exception("I don't like threes")
events.append(f"Item: {item}")
with self.assertRaisesRegex(Exception, "I don't like threes"):
run_parallel(
do_work,
range(100, 110),
processes=1,
report_every=5,
report=lambda n: events.append(f"Completed {n}"),
catch=False,
)
self.assertEqual(
events,
[
"Item: 100",
"Item: 101",
"Item: 102",
],
)
def test_not_parallel_catch(self) -> None:
events = []
def do_work(item: int) -> None:
if item == 103:
raise Exception("I don't like threes")
events.append(f"Item: {item}")
with self.assertLogs(level="ERROR") as error_logs:
run_parallel(
do_work,
range(100, 105),
processes=1,
report_every=5,
report=lambda n: events.append(f"Completed {n}"),
catch=True,
)
self.assert_length(error_logs.output, 1)
self.assertTrue(
error_logs.output[0].startswith("ERROR:root:Error processing item: 103\nTraceback")
)
self.assertIn("I don't like threes", error_logs.output[0])
self.assertEqual(
events,
[
"Item: 100",
"Item: 101",
"Item: 102",
"Item: 104",
# We "completed" the one which raised an exception,
# despite it not having output
"Completed 5",
],
)
def write_number(
output_dir: str, total_processes: int, fail: set[int], item: int
) -> None: # nocoverage
if item in fail:
raise Exception("Whoops")
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
fh.write(f"{item}\n")
# We wait to exit until we see total_processes unique files in the
# output directory, so we ensure that every PID got a chance to
# run.
slept = 0
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
time.sleep(1)
slept += 1
def db_query(output_dir: str, total_processes: int, item: int) -> None: # nocoverage
connection.connect()
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
fh.write(f"{Realm.objects.count()}\n")
slept = 0
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
time.sleep(1)
slept += 1
class RunParallelTest(ZulipTestCase):
def skip_in_parallel_harness(self) -> None:
if current_process().daemon:
self.skipTest("Testing of parallel pool is skipped under the parallel test harness")
def test_parallel(self) -> None: # nocoverage
self.skip_in_parallel_harness()
output_dir = tempfile.mkdtemp()
report_lines = []
try:
run_parallel(
partial(write_number, output_dir, 4, set()),
range(100, 110),
processes=4,
report_every=3,
report=lambda n: report_lines.append(f"Completed {n}"),
)
files = glob.glob(f"{output_dir}/*.output")
self.assert_length(files, 4)
all_lines: Counter[str] = Counter()
for output_path in files:
with open(output_path) as output_file:
file_lines = output_file.readlines()
self.assertGreater(len(file_lines), 0)
self.assertLessEqual(len(file_lines), 10 - (4 - 1))
self.assertEqual(sorted(file_lines), file_lines)
all_lines.update(file_lines)
self.assertEqual(all_lines.total(), 10)
self.assertEqual(sorted(all_lines.keys()), [f"{n}\n" for n in range(100, 110)])
self.assertEqual(report_lines, ["Completed 3", "Completed 6", "Completed 9"])
finally:
shutil.rmtree(output_dir)
def test_parallel_throw(self) -> None: # nocoverage
self.skip_in_parallel_harness()
output_dir = tempfile.mkdtemp()
report_lines = []
try:
with self.assertRaisesMessage(Exception, "Whoops"):
run_parallel(
partial(write_number, output_dir, 4, {103}),
range(100, 105),
processes=2,
report_every=5,
report=lambda n: report_lines.append(f"Completed {n}"),
)
output_files = glob.glob(f"{output_dir}/*.output")
self.assert_length(output_files, 2)
all_lines: set[int] = set()
for output_path in output_files:
with open(output_path) as output_file:
all_lines.update(int(line) for line in output_file)
self.assertIn(100, all_lines)
self.assertIn(101, all_lines)
self.assertNotIn(103, all_lines)
self.assertEqual(report_lines, [])
finally:
shutil.rmtree(output_dir)
def test_parallel_catch(self) -> None: # nocoverage
self.skip_in_parallel_harness()
output_dir = tempfile.mkdtemp()
report_lines = []
def set_file_logger(output_dir: str) -> None:
# In each worker process, we set up the logger to write to
# a (pid).error file.
logging.basicConfig(
filename=f"{output_dir}/{os.getpid()}.error",
level=logging.INFO,
filemode="w",
force=True,
)
try:
run_parallel(
partial(write_number, output_dir, 4, {103}),
range(100, 105),
processes=2,
report_every=5,
report=lambda n: report_lines.append(f"Completed {n}"),
catch=True,
initializer=set_file_logger,
initargs=(output_dir,),
)
output_files = glob.glob(f"{output_dir}/*.output")
self.assert_length(output_files, 2)
all_lines: set[int] = set()
for output_path in output_files:
with open(output_path) as output_file:
all_lines.update(int(line) for line in output_file)
self.assertEqual(sorted(all_lines), [100, 101, 102, 104])
self.assertEqual(report_lines, ["Completed 5"])
error_files = glob.glob(f"{output_dir}/*.error")
error_lines = []
self.assert_length(error_files, 2)
for error_path in error_files:
with open(error_path) as error_file:
error_lines.extend(error_file.readlines())
self.assertEqual(error_lines[0], "ERROR:root:Error processing item: 103\n")
finally:
shutil.rmtree(output_dir)
def test_parallel_reconnect(self) -> None: # nocoverage
self.skip_in_parallel_harness()
output_dir = tempfile.mkdtemp()
run_parallel(
partial(db_query, output_dir, 2),
range(100, 105),
processes=2,
)
output_files = glob.glob(f"{output_dir}/*.output")
self.assert_length(output_files, 2)
all_lines: set[int] = set()
for output_path in output_files:
with open(output_path) as output_file:
all_lines.update(int(line) for line in output_file)
self.assertEqual(all_lines, {4})

View File

@@ -1929,7 +1929,7 @@ by Pieter
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"): with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer # We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
# uses it to generate email addresses for users without an email specified. # uses it to generate email addresses for users without an email specified.
do_convert_zipfile(test_slack_zip_file, output_dir, token) do_convert_zipfile(test_slack_zip_file, output_dir, token, threads=1)
self.assertTrue(os.path.exists(output_dir)) self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(output_dir + "/realm.json")) self.assertTrue(os.path.exists(output_dir + "/realm.json"))
@@ -2138,7 +2138,7 @@ by Pieter
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"): with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer # We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
# uses it to generate email addresses for users without an email specified. # uses it to generate email addresses for users without an email specified.
do_convert_zipfile(test_slack_zip_file, output_dir, token) do_convert_zipfile(test_slack_zip_file, output_dir, token, threads=1)
@mock.patch("zerver.data_import.slack.check_slack_token_access") @mock.patch("zerver.data_import.slack.check_slack_token_access")
@responses.activate @responses.activate