mirror of
https://github.com/zulip/zulip.git
synced 2025-11-19 05:58:25 +00:00
parallel: Factor out multiple callsites that use ProcessPoolExecutor.
This commit is contained in:
committed by
Tim Abbott
parent
454905f988
commit
131580f23c
@@ -5,7 +5,6 @@ import shutil
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterable, Iterator, Mapping
|
from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||||
from collections.abc import Set as AbstractSet
|
from collections.abc import Set as AbstractSet
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from typing import Any, Protocol, TypeAlias, TypeVar
|
from typing import Any, Protocol, TypeAlias, TypeVar
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
@@ -19,6 +18,7 @@ from zerver.data_import.sequencer import NEXT_ID
|
|||||||
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
|
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
|
||||||
from zerver.lib.message import normalize_body_for_import
|
from zerver.lib.message import normalize_body_for_import
|
||||||
from zerver.lib.mime_types import INLINE_MIME_TYPES, guess_extension
|
from zerver.lib.mime_types import INLINE_MIME_TYPES, guess_extension
|
||||||
|
from zerver.lib.parallel import run_parallel
|
||||||
from zerver.lib.partial import partial
|
from zerver.lib.partial import partial
|
||||||
from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS as STREAM_COLORS
|
from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS as STREAM_COLORS
|
||||||
from zerver.lib.thumbnail import THUMBNAIL_ACCEPT_IMAGE_TYPES, BadImageError
|
from zerver.lib.thumbnail import THUMBNAIL_ACCEPT_IMAGE_TYPES, BadImageError
|
||||||
@@ -634,38 +634,18 @@ def process_avatars(
|
|||||||
avatar_original_list.append(avatar_original)
|
avatar_original_list.append(avatar_original)
|
||||||
|
|
||||||
# Run downloads in parallel
|
# Run downloads in parallel
|
||||||
run_parallel_wrapper(
|
run_parallel(
|
||||||
partial(get_avatar, avatar_dir, size_url_suffix), avatar_upload_list, threads=threads
|
partial(get_avatar, avatar_dir, size_url_suffix),
|
||||||
|
avatar_upload_list,
|
||||||
|
processes=threads,
|
||||||
|
catch=True,
|
||||||
|
report=lambda count: logging.info("Finished %s items", count),
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("######### GETTING AVATARS FINISHED #########\n")
|
logging.info("######### GETTING AVATARS FINISHED #########\n")
|
||||||
return avatar_list + avatar_original_list
|
return avatar_list + avatar_original_list
|
||||||
|
|
||||||
|
|
||||||
ListJobData = TypeVar("ListJobData")
|
|
||||||
|
|
||||||
|
|
||||||
def wrapping_function(f: Callable[[ListJobData], None], item: ListJobData) -> None:
|
|
||||||
try:
|
|
||||||
f(item)
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Error processing item: %s", item, stack_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def run_parallel_wrapper(
|
|
||||||
f: Callable[[ListJobData], None], full_items: list[ListJobData], threads: int = 6
|
|
||||||
) -> None:
|
|
||||||
logging.info("Distributing %s items across %s threads", len(full_items), threads)
|
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers=threads) as executor:
|
|
||||||
for count, future in enumerate(
|
|
||||||
as_completed(executor.submit(wrapping_function, f, item) for item in full_items), 1
|
|
||||||
):
|
|
||||||
future.result()
|
|
||||||
if count % 1000 == 0:
|
|
||||||
logging.info("Finished %s items", count)
|
|
||||||
|
|
||||||
|
|
||||||
def get_uploads(upload_dir: str, upload: list[str]) -> None:
|
def get_uploads(upload_dir: str, upload: list[str]) -> None:
|
||||||
upload_url = upload[0]
|
upload_url = upload[0]
|
||||||
upload_path = upload[1]
|
upload_path = upload[1]
|
||||||
@@ -697,7 +677,13 @@ def process_uploads(
|
|||||||
upload["path"] = upload_s3_path
|
upload["path"] = upload_s3_path
|
||||||
|
|
||||||
# Run downloads in parallel
|
# Run downloads in parallel
|
||||||
run_parallel_wrapper(partial(get_uploads, upload_dir), upload_url_list, threads=threads)
|
run_parallel(
|
||||||
|
partial(get_uploads, upload_dir),
|
||||||
|
upload_url_list,
|
||||||
|
processes=threads,
|
||||||
|
catch=True,
|
||||||
|
report=lambda count: logging.info("Finished %s items", count),
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("######### GETTING ATTACHMENTS FINISHED #########\n")
|
logging.info("######### GETTING ATTACHMENTS FINISHED #########\n")
|
||||||
return upload_list
|
return upload_list
|
||||||
|
|||||||
@@ -2,17 +2,14 @@ import collections
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from difflib import unified_diff
|
from difflib import unified_diff
|
||||||
from typing import Any, TypeAlias
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
import bmemcached
|
|
||||||
import orjson
|
import orjson
|
||||||
import pyvips
|
import pyvips
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.cache import cache
|
|
||||||
from django.core.management.base import CommandError
|
from django.core.management.base import CommandError
|
||||||
from django.core.validators import validate_email
|
from django.core.validators import validate_email
|
||||||
from django.db import connection, transaction
|
from django.db import connection, transaction
|
||||||
@@ -44,6 +41,7 @@ from zerver.lib.onboarding import (
|
|||||||
send_initial_direct_messages_to_user,
|
send_initial_direct_messages_to_user,
|
||||||
send_initial_realm_messages,
|
send_initial_realm_messages,
|
||||||
)
|
)
|
||||||
|
from zerver.lib.parallel import run_parallel
|
||||||
from zerver.lib.partial import partial
|
from zerver.lib.partial import partial
|
||||||
from zerver.lib.push_notifications import sends_notifications_directly
|
from zerver.lib.push_notifications import sends_notifications_directly
|
||||||
from zerver.lib.remote_server import maybe_enqueue_audit_log_upload
|
from zerver.lib.remote_server import maybe_enqueue_audit_log_upload
|
||||||
@@ -1143,19 +1141,12 @@ def import_uploads(
|
|||||||
# TODO: This implementation is hacky, both in that it
|
# TODO: This implementation is hacky, both in that it
|
||||||
# does get_user_profile_by_id for each user, and in that it
|
# does get_user_profile_by_id for each user, and in that it
|
||||||
# might be better to require the export to just have these.
|
# might be better to require the export to just have these.
|
||||||
if processes == 1:
|
run_parallel(
|
||||||
for record in records:
|
process_func,
|
||||||
process_func(record)
|
records,
|
||||||
else:
|
processes if s3_uploads else 1,
|
||||||
connection.close()
|
report=lambda count: logging.info("Processed %s/%s avatars", count, len(records)),
|
||||||
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
|
)
|
||||||
assert isinstance(_cache, bmemcached.Client)
|
|
||||||
_cache.disconnect_all()
|
|
||||||
with ProcessPoolExecutor(max_workers=processes) as executor:
|
|
||||||
for future in as_completed(
|
|
||||||
executor.submit(process_func, record) for record in records
|
|
||||||
):
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def disable_restricted_authentication_methods(data: ImportedTableData) -> None:
|
def disable_restricted_authentication_methods(data: ImportedTableData) -> None:
|
||||||
|
|||||||
77
zerver/lib/parallel.py
Normal file
77
zerver/lib/parallel.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from multiprocessing import current_process
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import bmemcached
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.db import connection
|
||||||
|
|
||||||
|
from zerver.lib.partial import partial
|
||||||
|
from zerver.lib.queue import get_queue_client
|
||||||
|
|
||||||
|
ParallelRecordType = TypeVar("ParallelRecordType")
|
||||||
|
|
||||||
|
|
||||||
|
def _disconnect() -> None:
|
||||||
|
# Close our database, cache, and RabbitMQ connections, so our
|
||||||
|
# forked children do not share them. Django will transparently
|
||||||
|
# re-open them as needed.
|
||||||
|
connection.close()
|
||||||
|
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
|
||||||
|
if isinstance(_cache, bmemcached.Client): # nocoverage
|
||||||
|
# In tests, this is an OrderedDict
|
||||||
|
_cache.disconnect_all()
|
||||||
|
|
||||||
|
if settings.USING_RABBITMQ: # nocoverage
|
||||||
|
rabbitmq_client = get_queue_client()
|
||||||
|
if rabbitmq_client.connection and rabbitmq_client.connection.is_open:
|
||||||
|
rabbitmq_client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def func_with_catch(func: Callable[[ParallelRecordType], None], item: ParallelRecordType) -> None:
|
||||||
|
try:
|
||||||
|
return func(item)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Error processing item: %s", item)
|
||||||
|
|
||||||
|
|
||||||
|
def run_parallel(
|
||||||
|
func: Callable[[ParallelRecordType], None],
|
||||||
|
records: Iterable[ParallelRecordType],
|
||||||
|
processes: int,
|
||||||
|
*,
|
||||||
|
initializer: Callable[..., None] | None = None,
|
||||||
|
initargs: tuple[Any, ...] = tuple(),
|
||||||
|
catch: bool = False,
|
||||||
|
report_every: int = 1000,
|
||||||
|
report: Callable[[int], None] | None = None,
|
||||||
|
) -> None: # nocoverage
|
||||||
|
assert processes > 0
|
||||||
|
if settings.TEST_SUITE and current_process().daemon:
|
||||||
|
assert processes == 1, "Only one process possible under parallel tests"
|
||||||
|
|
||||||
|
wrapped_func = partial(func_with_catch, func) if catch else func
|
||||||
|
|
||||||
|
if processes == 1:
|
||||||
|
if initializer is not None:
|
||||||
|
initializer(*initargs)
|
||||||
|
for count, record in enumerate(records, 1):
|
||||||
|
wrapped_func(record)
|
||||||
|
if report is not None and count % report_every == 0:
|
||||||
|
report(count)
|
||||||
|
return
|
||||||
|
|
||||||
|
_disconnect()
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=processes, initializer=initializer, initargs=initargs
|
||||||
|
) as executor:
|
||||||
|
for count, future in enumerate(
|
||||||
|
as_completed(executor.submit(wrapped_func, record) for record in records), 1
|
||||||
|
):
|
||||||
|
future.result()
|
||||||
|
if report is not None and count % report_every == 0:
|
||||||
|
report(count)
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import bmemcached
|
|
||||||
import magic
|
import magic
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.cache import cache
|
|
||||||
from django.db import connection
|
|
||||||
|
|
||||||
from zerver.lib.avatar_hash import user_avatar_path
|
from zerver.lib.avatar_hash import user_avatar_path
|
||||||
from zerver.lib.mime_types import guess_type
|
from zerver.lib.mime_types import guess_type
|
||||||
|
from zerver.lib.parallel import run_parallel
|
||||||
from zerver.lib.thumbnail import BadImageError
|
from zerver.lib.thumbnail import BadImageError
|
||||||
from zerver.lib.upload import upload_emoji_image, write_avatar_images
|
from zerver.lib.upload import upload_emoji_image, write_avatar_images
|
||||||
from zerver.lib.upload.s3 import S3UploadBackend, upload_content_to_s3
|
from zerver.lib.upload.s3 import S3UploadBackend, upload_content_to_s3
|
||||||
@@ -52,20 +49,7 @@ def _transfer_avatar_to_s3(user: UserProfile) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def transfer_avatars_to_s3(processes: int) -> None:
|
def transfer_avatars_to_s3(processes: int) -> None:
|
||||||
users = list(UserProfile.objects.all())
|
run_parallel(_transfer_avatar_to_s3, UserProfile.objects.all(), processes)
|
||||||
if processes == 1:
|
|
||||||
for user in users:
|
|
||||||
_transfer_avatar_to_s3(user)
|
|
||||||
else: # nocoverage
|
|
||||||
connection.close()
|
|
||||||
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
|
|
||||||
assert isinstance(_cache, bmemcached.Client)
|
|
||||||
_cache.disconnect_all()
|
|
||||||
with ProcessPoolExecutor(max_workers=processes) as executor:
|
|
||||||
for future in as_completed(
|
|
||||||
executor.submit(_transfer_avatar_to_s3, user) for user in users
|
|
||||||
):
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def _transfer_message_files_to_s3(attachment: Attachment) -> None:
|
def _transfer_message_files_to_s3(attachment: Attachment) -> None:
|
||||||
@@ -118,21 +102,7 @@ def _transfer_message_files_to_s3(attachment: Attachment) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def transfer_message_files_to_s3(processes: int) -> None:
|
def transfer_message_files_to_s3(processes: int) -> None:
|
||||||
attachments = list(Attachment.objects.all())
|
run_parallel(_transfer_message_files_to_s3, Attachment.objects.all(), processes)
|
||||||
if processes == 1:
|
|
||||||
for attachment in attachments:
|
|
||||||
_transfer_message_files_to_s3(attachment)
|
|
||||||
else: # nocoverage
|
|
||||||
connection.close()
|
|
||||||
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
|
|
||||||
assert isinstance(_cache, bmemcached.Client)
|
|
||||||
_cache.disconnect_all()
|
|
||||||
with ProcessPoolExecutor(max_workers=processes) as executor:
|
|
||||||
for future in as_completed(
|
|
||||||
executor.submit(_transfer_message_files_to_s3, attachment)
|
|
||||||
for attachment in attachments
|
|
||||||
):
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None:
|
def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None:
|
||||||
@@ -164,17 +134,4 @@ def _transfer_emoji_to_s3(realm_emoji: RealmEmoji) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def transfer_emoji_to_s3(processes: int) -> None:
|
def transfer_emoji_to_s3(processes: int) -> None:
|
||||||
realm_emojis = list(RealmEmoji.objects.filter())
|
run_parallel(_transfer_emoji_to_s3, RealmEmoji.objects.filter(), processes)
|
||||||
if processes == 1:
|
|
||||||
for realm_emoji in realm_emojis:
|
|
||||||
_transfer_emoji_to_s3(realm_emoji)
|
|
||||||
else: # nocoverage
|
|
||||||
connection.close()
|
|
||||||
_cache = cache._cache # type: ignore[attr-defined] # not in stubs
|
|
||||||
assert isinstance(_cache, bmemcached.Client)
|
|
||||||
_cache.disconnect_all()
|
|
||||||
with ProcessPoolExecutor(max_workers=processes) as executor:
|
|
||||||
for future in as_completed(
|
|
||||||
executor.submit(_transfer_emoji_to_s3, realm_emoji) for realm_emoji in realm_emojis
|
|
||||||
):
|
|
||||||
future.result()
|
|
||||||
|
|||||||
273
zerver/tests/test_parallel.py
Normal file
273
zerver/tests/test_parallel.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from multiprocessing import current_process
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
|
||||||
|
from zerver.lib.parallel import _disconnect, run_parallel
|
||||||
|
from zerver.lib.partial import partial
|
||||||
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
from zerver.models import Realm
|
||||||
|
|
||||||
|
|
||||||
|
class RunNotParallelTest(ZulipTestCase):
|
||||||
|
def test_disconnect(self) -> None:
|
||||||
|
self.assertTrue(connection.is_usable())
|
||||||
|
self.assertEqual(Realm.objects.count(), 4)
|
||||||
|
_disconnect()
|
||||||
|
self.assertFalse(connection.is_usable())
|
||||||
|
|
||||||
|
def test_not_parallel(self) -> None:
|
||||||
|
# Nothing here is parallel, or forks at all
|
||||||
|
events = []
|
||||||
|
|
||||||
|
run_parallel(
|
||||||
|
lambda item: events.append(f"Item: {item}"),
|
||||||
|
range(100, 110),
|
||||||
|
processes=1,
|
||||||
|
initializer=lambda a, b: events.append(f"Init: {a}, {b}"),
|
||||||
|
initargs=("alpha", "bravo"),
|
||||||
|
report_every=3,
|
||||||
|
report=lambda n: events.append(f"Completed {n}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
events,
|
||||||
|
[
|
||||||
|
"Init: alpha, bravo",
|
||||||
|
"Item: 100",
|
||||||
|
"Item: 101",
|
||||||
|
"Item: 102",
|
||||||
|
"Completed 3",
|
||||||
|
"Item: 103",
|
||||||
|
"Item: 104",
|
||||||
|
"Item: 105",
|
||||||
|
"Completed 6",
|
||||||
|
"Item: 106",
|
||||||
|
"Item: 107",
|
||||||
|
"Item: 108",
|
||||||
|
"Completed 9",
|
||||||
|
"Item: 109",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_not_parallel_throw(self) -> None:
|
||||||
|
events = []
|
||||||
|
|
||||||
|
def do_work(item: int) -> None:
|
||||||
|
if item == 103:
|
||||||
|
raise Exception("I don't like threes")
|
||||||
|
events.append(f"Item: {item}")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(Exception, "I don't like threes"):
|
||||||
|
run_parallel(
|
||||||
|
do_work,
|
||||||
|
range(100, 110),
|
||||||
|
processes=1,
|
||||||
|
report_every=5,
|
||||||
|
report=lambda n: events.append(f"Completed {n}"),
|
||||||
|
catch=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
events,
|
||||||
|
[
|
||||||
|
"Item: 100",
|
||||||
|
"Item: 101",
|
||||||
|
"Item: 102",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_not_parallel_catch(self) -> None:
|
||||||
|
events = []
|
||||||
|
|
||||||
|
def do_work(item: int) -> None:
|
||||||
|
if item == 103:
|
||||||
|
raise Exception("I don't like threes")
|
||||||
|
events.append(f"Item: {item}")
|
||||||
|
|
||||||
|
with self.assertLogs(level="ERROR") as error_logs:
|
||||||
|
run_parallel(
|
||||||
|
do_work,
|
||||||
|
range(100, 105),
|
||||||
|
processes=1,
|
||||||
|
report_every=5,
|
||||||
|
report=lambda n: events.append(f"Completed {n}"),
|
||||||
|
catch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_length(error_logs.output, 1)
|
||||||
|
self.assertTrue(
|
||||||
|
error_logs.output[0].startswith("ERROR:root:Error processing item: 103\nTraceback")
|
||||||
|
)
|
||||||
|
self.assertIn("I don't like threes", error_logs.output[0])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
events,
|
||||||
|
[
|
||||||
|
"Item: 100",
|
||||||
|
"Item: 101",
|
||||||
|
"Item: 102",
|
||||||
|
"Item: 104",
|
||||||
|
# We "completed" the one which raised an exception,
|
||||||
|
# despite it not having output
|
||||||
|
"Completed 5",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_number(
|
||||||
|
output_dir: str, total_processes: int, fail: set[int], item: int
|
||||||
|
) -> None: # nocoverage
|
||||||
|
if item in fail:
|
||||||
|
raise Exception("Whoops")
|
||||||
|
|
||||||
|
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
|
||||||
|
fh.write(f"{item}\n")
|
||||||
|
# We wait to exit until we see total_processes unique files in the
|
||||||
|
# output directory, so we ensure that every PID got a chance to
|
||||||
|
# run.
|
||||||
|
slept = 0
|
||||||
|
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
|
||||||
|
time.sleep(1)
|
||||||
|
slept += 1
|
||||||
|
|
||||||
|
|
||||||
|
def db_query(output_dir: str, total_processes: int, item: int) -> None: # nocoverage
|
||||||
|
connection.connect()
|
||||||
|
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
|
||||||
|
fh.write(f"{Realm.objects.count()}\n")
|
||||||
|
slept = 0
|
||||||
|
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
|
||||||
|
time.sleep(1)
|
||||||
|
slept += 1
|
||||||
|
|
||||||
|
|
||||||
|
class RunParallelTest(ZulipTestCase):
|
||||||
|
def skip_in_parallel_harness(self) -> None:
|
||||||
|
if current_process().daemon:
|
||||||
|
self.skipTest("Testing of parallel pool is skipped under the parallel test harness")
|
||||||
|
|
||||||
|
def test_parallel(self) -> None: # nocoverage
|
||||||
|
self.skip_in_parallel_harness()
|
||||||
|
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
report_lines = []
|
||||||
|
try:
|
||||||
|
run_parallel(
|
||||||
|
partial(write_number, output_dir, 4, set()),
|
||||||
|
range(100, 110),
|
||||||
|
processes=4,
|
||||||
|
report_every=3,
|
||||||
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
files = glob.glob(f"{output_dir}/*.output")
|
||||||
|
self.assert_length(files, 4)
|
||||||
|
all_lines: Counter[str] = Counter()
|
||||||
|
for output_path in files:
|
||||||
|
with open(output_path) as output_file:
|
||||||
|
file_lines = output_file.readlines()
|
||||||
|
self.assertGreater(len(file_lines), 0)
|
||||||
|
self.assertLessEqual(len(file_lines), 10 - (4 - 1))
|
||||||
|
self.assertEqual(sorted(file_lines), file_lines)
|
||||||
|
all_lines.update(file_lines)
|
||||||
|
|
||||||
|
self.assertEqual(all_lines.total(), 10)
|
||||||
|
self.assertEqual(sorted(all_lines.keys()), [f"{n}\n" for n in range(100, 110)])
|
||||||
|
|
||||||
|
self.assertEqual(report_lines, ["Completed 3", "Completed 6", "Completed 9"])
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
|
||||||
|
def test_parallel_throw(self) -> None: # nocoverage
|
||||||
|
self.skip_in_parallel_harness()
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
report_lines = []
|
||||||
|
try:
|
||||||
|
with self.assertRaisesMessage(Exception, "Whoops"):
|
||||||
|
run_parallel(
|
||||||
|
partial(write_number, output_dir, 4, {103}),
|
||||||
|
range(100, 105),
|
||||||
|
processes=2,
|
||||||
|
report_every=5,
|
||||||
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
||||||
|
)
|
||||||
|
output_files = glob.glob(f"{output_dir}/*.output")
|
||||||
|
self.assert_length(output_files, 2)
|
||||||
|
all_lines: set[int] = set()
|
||||||
|
for output_path in output_files:
|
||||||
|
with open(output_path) as output_file:
|
||||||
|
all_lines.update(int(line) for line in output_file)
|
||||||
|
self.assertIn(100, all_lines)
|
||||||
|
self.assertIn(101, all_lines)
|
||||||
|
self.assertNotIn(103, all_lines)
|
||||||
|
self.assertEqual(report_lines, [])
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
|
||||||
|
def test_parallel_catch(self) -> None: # nocoverage
|
||||||
|
self.skip_in_parallel_harness()
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
report_lines = []
|
||||||
|
|
||||||
|
def set_file_logger(output_dir: str) -> None:
|
||||||
|
# In each worker process, we set up the logger to write to
|
||||||
|
# a (pid).error file.
|
||||||
|
logging.basicConfig(
|
||||||
|
filename=f"{output_dir}/{os.getpid()}.error",
|
||||||
|
level=logging.INFO,
|
||||||
|
filemode="w",
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_parallel(
|
||||||
|
partial(write_number, output_dir, 4, {103}),
|
||||||
|
range(100, 105),
|
||||||
|
processes=2,
|
||||||
|
report_every=5,
|
||||||
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
||||||
|
catch=True,
|
||||||
|
initializer=set_file_logger,
|
||||||
|
initargs=(output_dir,),
|
||||||
|
)
|
||||||
|
output_files = glob.glob(f"{output_dir}/*.output")
|
||||||
|
self.assert_length(output_files, 2)
|
||||||
|
all_lines: set[int] = set()
|
||||||
|
for output_path in output_files:
|
||||||
|
with open(output_path) as output_file:
|
||||||
|
all_lines.update(int(line) for line in output_file)
|
||||||
|
self.assertEqual(sorted(all_lines), [100, 101, 102, 104])
|
||||||
|
self.assertEqual(report_lines, ["Completed 5"])
|
||||||
|
|
||||||
|
error_files = glob.glob(f"{output_dir}/*.error")
|
||||||
|
error_lines = []
|
||||||
|
self.assert_length(error_files, 2)
|
||||||
|
for error_path in error_files:
|
||||||
|
with open(error_path) as error_file:
|
||||||
|
error_lines.extend(error_file.readlines())
|
||||||
|
self.assertEqual(error_lines[0], "ERROR:root:Error processing item: 103\n")
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
|
||||||
|
def test_parallel_reconnect(self) -> None: # nocoverage
|
||||||
|
self.skip_in_parallel_harness()
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
run_parallel(
|
||||||
|
partial(db_query, output_dir, 2),
|
||||||
|
range(100, 105),
|
||||||
|
processes=2,
|
||||||
|
)
|
||||||
|
output_files = glob.glob(f"{output_dir}/*.output")
|
||||||
|
self.assert_length(output_files, 2)
|
||||||
|
all_lines: set[int] = set()
|
||||||
|
for output_path in output_files:
|
||||||
|
with open(output_path) as output_file:
|
||||||
|
all_lines.update(int(line) for line in output_file)
|
||||||
|
self.assertEqual(all_lines, {4})
|
||||||
@@ -1929,7 +1929,7 @@ by Pieter
|
|||||||
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
|
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
|
||||||
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
|
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
|
||||||
# uses it to generate email addresses for users without an email specified.
|
# uses it to generate email addresses for users without an email specified.
|
||||||
do_convert_zipfile(test_slack_zip_file, output_dir, token)
|
do_convert_zipfile(test_slack_zip_file, output_dir, token, threads=1)
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(output_dir))
|
self.assertTrue(os.path.exists(output_dir))
|
||||||
self.assertTrue(os.path.exists(output_dir + "/realm.json"))
|
self.assertTrue(os.path.exists(output_dir + "/realm.json"))
|
||||||
@@ -2138,7 +2138,7 @@ by Pieter
|
|||||||
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
|
with self.assertLogs(level="INFO"), self.settings(EXTERNAL_HOST="zulip.example.com"):
|
||||||
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
|
# We need to mock EXTERNAL_HOST to be a valid domain because Slack's importer
|
||||||
# uses it to generate email addresses for users without an email specified.
|
# uses it to generate email addresses for users without an email specified.
|
||||||
do_convert_zipfile(test_slack_zip_file, output_dir, token)
|
do_convert_zipfile(test_slack_zip_file, output_dir, token, threads=1)
|
||||||
|
|
||||||
@mock.patch("zerver.data_import.slack.check_slack_token_access")
|
@mock.patch("zerver.data_import.slack.check_slack_token_access")
|
||||||
@responses.activate
|
@responses.activate
|
||||||
|
|||||||
Reference in New Issue
Block a user