From b3b9d2c3ccde1a4bca38d3de266246a64dd3219c Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Fri, 19 Jan 2024 21:33:37 +0000 Subject: [PATCH] export: Use run_parallel for rewriting UserMessage data. --- zerver/lib/export.py | 82 ++++++++++--------- .../commands/export_usermessage_batch.py | 61 -------------- zerver/tests/test_import_export.py | 10 +-- 3 files changed, 46 insertions(+), 107 deletions(-) delete mode 100644 zerver/management/commands/export_usermessage_batch.py diff --git a/zerver/lib/export.py b/zerver/lib/export.py index b7e7954ace..178de7b087 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -16,6 +16,8 @@ import shutil import subprocess import tempfile from collections.abc import Callable, Iterable, Iterator, Mapping +from contextvars import ContextVar +from dataclasses import dataclass from datetime import datetime from email.headerregistry import Address from functools import cache @@ -38,6 +40,7 @@ from analytics.models import RealmCount, StreamCount, UserCount from version import ZULIP_VERSION from zerver.lib.avatar_hash import user_avatar_base_path_from_ids from zerver.lib.migration_status import MigrationStatusJson, parse_migration_status +from zerver.lib.parallel import run_parallel from zerver.lib.pysa import mark_sanitized from zerver.lib.stream_color import STREAM_ASSIGNMENT_COLORS from zerver.lib.timestamp import datetime_to_timestamp @@ -1713,18 +1716,18 @@ def fetch_usermessages( def export_usermessages_batch( input_path: Path, - output_path: Path, - export_full_with_consent: bool, - consented_user_ids: set[int] | None = None, ) -> None: """As part of the system for doing parallel exports, this runs on one batch of Message objects and adds the corresponding UserMessage - objects. (This is called by the export_usermessage_batch - management command). + objects. See write_message_partial_for_query for more context.""" - assert input_path.endswith((".partial", ".locked")) - assert output_path.endswith(".json") + context = usermessage_context.get() + export_full_with_consent = context.export_full_with_consent + consented_user_ids = context.consented_user_ids + + assert input_path.endswith(".partial") + output_path = input_path.replace(".json.partial", ".json") with open(input_path, "rb") as input_file: input_data: MessagePartial = orjson.loads(input_file.read()) @@ -2462,11 +2465,7 @@ def do_export_realm( # indicates a bug. assert export_type == RealmExport.EXPORT_FULL_WITH_CONSENT - # We need at least one process running to export - # UserMessage rows. The management command should - # enforce this for us. - if not settings.TEST_SUITE: - assert processes >= 1 + assert processes >= 1 realm_config = get_realm_config() @@ -2584,6 +2583,27 @@ def export_attachment_table( return attachments +@dataclass +class UserMessageProcessState: + export_full_with_consent: bool + consented_user_ids: set[int] | None + + +# We are not using the ContextVar for its thread-safety, here -- since +# we use processes, not threads, for parallelism. All we need is a +# global box which is serializable by pickle's dependency analysis, +# which we can set and get out of in the other process. We want it +# primarily for things which are large and don't change (the list of +# user-ids with consent) which we don't want to pass on every call. +usermessage_context: ContextVar[UserMessageProcessState] = ContextVar("usermessage_context") + + +def usermessage_process_initializer( + export_full_with_consent: bool, consented_user_ids: set[int] | None +) -> None: + usermessage_context.set(UserMessageProcessState(export_full_with_consent, consented_user_ids)) + + def launch_user_message_subprocesses( processes: int, output_dir: Path, @@ -2591,32 +2611,20 @@ def launch_user_message_subprocesses( exportable_user_ids: set[int] | None, ) -> None: logging.info("Launching %d PARALLEL subprocesses to export UserMessage rows", processes) - pids = {} - if export_full_with_consent: - assert exportable_user_ids is not None - consented_user_ids_filepath = os.path.join(output_dir, "consented_user_ids.json") - with open(consented_user_ids_filepath, "wb") as f: - f.write(orjson.dumps(list(exportable_user_ids))) - logging.info("Created consented_user_ids.json file.") - - for shard_id in range(processes): - arguments = [ - os.path.join(settings.DEPLOY_ROOT, "manage.py"), - "export_usermessage_batch", - f"--path={output_dir}", - f"--process={shard_id}", - ] - if export_full_with_consent: - arguments.append("--export-full-with-consent") - - process = subprocess.Popen(arguments) - pids[process.pid] = shard_id - - while pids: - pid, status = os.wait() - shard = pids.pop(pid) - print(f"Shard {shard} finished, status {status}") + files = glob.glob(os.path.join(output_dir, "messages-*.json.partial")) + run_parallel( + export_usermessages_batch, + files, + processes, + initializer=usermessage_process_initializer, + initargs=( + export_full_with_consent, + exportable_user_ids, + ), + report_every=10, + report=lambda count: logging.info("Successfully processed %s message files", count), + ) def do_export_user(user_profile: UserProfile, output_dir: Path) -> None: diff --git a/zerver/management/commands/export_usermessage_batch.py b/zerver/management/commands/export_usermessage_batch.py deleted file mode 100644 index a9dcbd8d59..0000000000 --- a/zerver/management/commands/export_usermessage_batch.py +++ /dev/null @@ -1,61 +0,0 @@ -import glob -import logging -import os -from argparse import ArgumentParser -from typing import Any - -import orjson -from typing_extensions import override - -from zerver.lib.export import export_usermessages_batch -from zerver.lib.management import ZulipBaseCommand - - -class Command(ZulipBaseCommand): - help = """UserMessage fetching helper for export.py""" - - @override - def add_arguments(self, parser: ArgumentParser) -> None: - parser.add_argument("--path", help="Path to find messages.json archives") - parser.add_argument("--process", help="Process identifier (used only for debug output)") - parser.add_argument( - "--export-full-with-consent", - action="store_true", - help="Whether to export private data of users who consented", - ) - - @override - def handle(self, *args: Any, **options: Any) -> None: - logging.info("Starting UserMessage batch process %s", options["process"]) - path = options["path"] - files = set(glob.glob(os.path.join(path, "messages-*.json.partial"))) - - export_full_with_consent = options["export_full_with_consent"] - consented_user_ids = None - if export_full_with_consent: - consented_user_ids_path = os.path.join(path, "consented_user_ids.json") - assert os.path.exists(consented_user_ids_path) - - with open(consented_user_ids_path, "rb") as f: - consented_user_ids = set(orjson.loads(f.read())) - - for partial_path in files: - locked_path = partial_path.replace(".json.partial", ".json.locked") - output_path = partial_path.replace(".json.partial", ".json") - try: - os.rename(partial_path, locked_path) - except FileNotFoundError: - # Already claimed by another process - continue - logging.info("Process %s processing %s", options["process"], output_path) - try: - export_usermessages_batch( - locked_path, - output_path, - export_full_with_consent, - consented_user_ids=consented_user_ids, - ) - except BaseException: - # Put the item back in the free pool when we fail - os.rename(locked_path, partial_path) - raise diff --git a/zerver/tests/test_import_export.py b/zerver/tests/test_import_export.py index 1dcb1fbcc6..549c337907 100644 --- a/zerver/tests/test_import_export.py +++ b/zerver/tests/test_import_export.py @@ -53,7 +53,6 @@ from zerver.lib.export import ( Record, do_export_realm, do_export_user, - export_usermessages_batch, get_consented_user_ids, ) from zerver.lib.import_realm import do_import_realm, get_incoming_message_ids @@ -416,7 +415,7 @@ class RealmImportExportTest(ExportFile): do_export_realm( realm=realm, output_dir=output_dir, - processes=0, + processes=1, export_type=export_type, exportable_user_ids=exportable_user_ids, ) @@ -427,13 +426,6 @@ class RealmImportExportTest(ExportFile): realm.uuid = uuid.uuid4() realm.save() - export_usermessages_batch( - input_path=os.path.join(output_dir, "messages-000001.json.partial"), - output_path=os.path.join(output_dir, "messages-000001.json"), - export_full_with_consent=export_type == RealmExport.EXPORT_FULL_WITH_CONSENT, - consented_user_ids=exportable_user_ids, - ) - def export_realm_and_create_auditlog( self, original_realm: Realm,