diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 5b90814c2e..f384dfe156 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping from datetime import datetime from email.headerregistry import Address from functools import cache -from itertools import islice +from itertools import chain, islice from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, TypeVar, cast from urllib.parse import urlsplit @@ -27,7 +27,7 @@ import orjson from django.apps import apps from django.conf import settings from django.db import connection -from django.db.models import Exists, Model, OuterRef, Q +from django.db.models import Exists, Model, OuterRef, Q, QuerySet from django.forms.models import model_to_dict from django.utils.timezone import is_naive as timezone_is_naive from django.utils.timezone import now as timezone_now @@ -98,7 +98,7 @@ if TYPE_CHECKING: # Custom mypy types follow: Record: TypeAlias = dict[str, Any] TableName = str -TableData: TypeAlias = dict[TableName, list[Record]] +TableData: TypeAlias = dict[TableName, Iterator[Record] | list[Record]] Field = str Path = str Context: TypeAlias = dict[str, Any] @@ -108,11 +108,11 @@ SourceFilter: TypeAlias = Callable[[Record], bool] CustomFetch: TypeAlias = Callable[[TableData, Context], None] CustomReturnIds: TypeAlias = Callable[[TableData], set[int]] -CustomProcessResults: TypeAlias = Callable[[list[Record], Context], list[Record]] +CustomProcessResults: TypeAlias = Callable[[Iterable[Record], Context], Iterator[Record]] class MessagePartial(TypedDict): - zerver_message: list[Record] + zerver_message: Iterable[Record] zerver_userprofile_ids: list[int] realm_id: int @@ -516,12 +516,11 @@ def write_records_json_file(output_dir: str, records: Iterable[dict[str, Any]]) logging.info("Finished writing %s", output_file) -def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]: +def make_raw(query: Iterable[Any], exclude: list[Field] | None = None) -> Iterator[Record]: """ Takes a Django query and returns a JSONable list of dictionaries corresponding to the database rows. """ - rows = [] for instance in query: data = model_to_dict(instance, exclude=exclude) """ @@ -536,20 +535,19 @@ def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]: value = data[field.name] data[field.name] = [row.id for row in value] - rows.append(data) - - return rows + yield data -def floatify_datetime_fields(data: TableData, table: TableName) -> None: - for item in data[table]: - for field in DATE_FIELDS[table]: - dt = item[field] - if dt is None: - continue - assert isinstance(dt, datetime) - assert not timezone_is_naive(dt) - item[field] = dt.timestamp() +def floatify_datetime_fields(item: Record, table: TableName) -> Record: + updates = {} + for field in DATE_FIELDS[table]: + dt = item[field] + if dt is None: + continue + assert isinstance(dt, datetime) + assert not timezone_is_naive(dt) + updates[field] = dt.timestamp() + return {**item, **updates} class Config: @@ -587,6 +585,7 @@ class Config: exclude: list[Field] | None = None, limit_to_consenting_users: bool | None = None, collect_client_ids: bool = False, + use_iterator: bool = True, ) -> None: assert table or custom_tables self.table = table @@ -698,11 +697,17 @@ class Config: assert include_rows in ["user_profile_id__in", "user_id__in", "bot_profile_id__in"] assert normal_parent is not None and normal_parent.table == "zerver_userprofile" + if self.collect_client_ids or self.is_seeded: + self.use_iterator = False + else: + self.use_iterator = use_iterator + def return_ids(self, response: TableData) -> set[int]: if self.custom_return_ids is not None: return self.custom_return_ids(response) else: assert self.table is not None + assert not self.use_iterator, self.table return {row["id"] for row in response[self.table]} @@ -730,7 +735,8 @@ def export_from_config( for t in exported_tables: logging.info("Exporting via export_from_config: %s", t) - rows = None + rows: Iterable[Any] | None = None + query: QuerySet[Any] | None = None if config.is_seeded: rows = [seed_object] @@ -748,13 +754,13 @@ def export_from_config( # When we concat_and_destroy, we are working with # temporary "tables" that are lists of records that # should already be ready to export. - data: list[Record] = [] - for t in config.concat_and_destroy: - data += response[t] - del response[t] - logging.info("Deleted temporary %s", t) assert table is not None - response[table] = data + # We pop them off of the response and store them in a local + # which the iterable closes over + tables = {t: response.pop(t) for t in config.concat_and_destroy} + response[table] = chain.from_iterable(tables[t] for t in config.concat_and_destroy) + for t in config.concat_and_destroy: + logging.info("Deleted temporary %s", t) elif config.normal_parent: # In this mode, our current model is figuratively Article, @@ -818,7 +824,7 @@ def export_from_config( assert model is not None try: - query = model.objects.filter(**filter_params) + query = model.objects.filter(**filter_params).order_by("id") except Exception: print( f""" @@ -833,8 +839,6 @@ def export_from_config( ) raise - rows = list(query) - elif config.id_source: # In this mode, we are the figurative Blog, and we now # need to look at the current response to get all the @@ -843,6 +847,8 @@ def export_from_config( assert model is not None # This will be a tuple of the form ('zerver_article', 'blog'). (child_table, field) = config.id_source + assert config.virtual_parent is not None + assert not config.virtual_parent.use_iterator child_rows = response[child_table] if config.source_filter: child_rows = [r for r in child_rows if config.source_filter(r)] @@ -851,12 +857,17 @@ def export_from_config( if config.filter_args: filter_params.update(config.filter_args) query = model.objects.filter(**filter_params) - rows = list(query) + + if query is not None: + rows = query.iterator() if rows is not None: assert table is not None # Hint for mypy response[table] = make_raw(rows, exclude=config.exclude) if config.collect_client_ids and "collected_client_ids_set" in context: + # If we need to collect the client-ids, we can't just stream the results + response[table] = list(response[table]) + model = cast(type[Model], model) assert issubclass(model, Model) client_id_field_name = get_fk_field_name(model, Client) @@ -873,8 +884,10 @@ def export_from_config( # of the exported data for the tables - e.g. to strip out private data. response[t] = custom_process_results(response[t], context) if t in DATE_FIELDS: - floatify_datetime_fields(response, t) + response[t] = (floatify_datetime_fields(r, t) for r in response[t]) + if not config.use_iterator: + response[t] = list(response[t]) # Now walk our children. It's extremely important to respect # the order of children here. for child_config in config.children: @@ -998,6 +1011,7 @@ def get_realm_config() -> Config: table="zerver_userprofile", virtual_parent=realm_config, custom_fetch=custom_fetch_user_profile, + use_iterator=False, ) user_groups_config = Config( @@ -1006,6 +1020,7 @@ def get_realm_config() -> Config: normal_parent=realm_config, include_rows="realm_id__in", exclude=["direct_members", "direct_subgroups"], + use_iterator=False, ) Config( @@ -1078,6 +1093,7 @@ def get_realm_config() -> Config: # It is just "glue" data for internal data model consistency purposes # with no user-specific information. limit_to_consenting_users=False, + use_iterator=False, ) Config( @@ -1092,6 +1108,7 @@ def get_realm_config() -> Config: model=Stream, normal_parent=realm_config, include_rows="realm_id__in", + use_iterator=False, ) stream_recipient_config = Config( @@ -1100,6 +1117,7 @@ def get_realm_config() -> Config: normal_parent=stream_config, include_rows="type_id__in", filter_args={"type": Recipient.STREAM}, + use_iterator=False, ) Config( @@ -1131,6 +1149,7 @@ def get_realm_config() -> Config: "_stream_recipient", "_huddle_recipient", ], + use_iterator=False, ) Config( @@ -1157,11 +1176,12 @@ def get_realm_config() -> Config: def custom_process_subscription_in_realm_config( - subscriptions: list[Record], context: Context -) -> list[Record]: + subscriptions: Iterable[Record], context: Context +) -> Iterator[Record]: export_type = context["export_type"] if export_type == RealmExport.EXPORT_FULL_WITHOUT_CONSENT: - return subscriptions + yield from subscriptions + return exportable_user_ids_from_context = context["exportable_user_ids"] if export_type == RealmExport.EXPORT_FULL_WITH_CONSENT: @@ -1172,9 +1192,10 @@ def custom_process_subscription_in_realm_config( assert exportable_user_ids_from_context is None consented_user_ids = set() - def scrub_subscription_if_needed(subscription: Record) -> Record: + for subscription in subscriptions: if subscription["user_profile"] in consented_user_ids: - return subscription + yield subscription + continue # We create a replacement Subscription, setting only the essential fields, # while allowing all the other ones to fall back to the defaults # defined in the model. @@ -1190,10 +1211,7 @@ def custom_process_subscription_in_realm_config( color=random.choice(STREAM_ASSIGNMENT_COLORS), ) subscription_dict = model_to_dict(scrubbed_subscription) - return subscription_dict - - processed_rows = map(scrub_subscription_if_needed, subscriptions) - return list(processed_rows) + yield subscription_dict def add_user_profile_child_configs(user_profile_config: Config) -> None: @@ -1410,7 +1428,7 @@ def custom_fetch_user_profile(response: TableData, context: Context) -> None: def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) -> None: realm = context["realm"] - response["zerver_userprofile_crossrealm"] = [] + crossrealm_bots = [] bot_name_to_default_email = { "NOTIFICATION_BOT": "notification-bot@zulip.com", @@ -1432,13 +1450,14 @@ def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) bot_user_id = get_system_bot(bot_email, internal_realm.id).id recipient_id = Recipient.objects.get(type_id=bot_user_id, type=Recipient.PERSONAL).id - response["zerver_userprofile_crossrealm"].append( + crossrealm_bots.append( dict( email=bot_default_email, id=bot_user_id, recipient_id=recipient_id, - ) + ), ) + response["zerver_userprofile_crossrealm"] = crossrealm_bots def fetch_attachment_data( @@ -1448,10 +1467,21 @@ def fetch_attachment_data( Attachment.objects.filter( Q(messages__in=message_ids) | Q(scheduled_messages__in=scheduled_message_ids), realm_id=realm_id, - ).distinct() + ) + .distinct("path_id") + .order_by("path_id") ) - response["zerver_attachment"] = make_raw(attachments) - floatify_datetime_fields(response, "zerver_attachment") + + def postprocess_attachment(row: Record) -> Record: + row = floatify_datetime_fields(row, "zerver_attachment") + filtered_message_ids = set(row["messages"]).intersection(message_ids) + row["messages"] = sorted(filtered_message_ids) + + filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection( + scheduled_message_ids + ) + row["scheduled_messages"] = sorted(filtered_scheduled_message_ids) + return row """ We usually export most messages for the realm, but not @@ -1461,14 +1491,7 @@ def fetch_attachment_data( Same reasoning applies to scheduled_messages. """ - for row in response["zerver_attachment"]: - filtered_message_ids = set(row["messages"]).intersection(message_ids) - row["messages"] = sorted(filtered_message_ids) - - filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection( - scheduled_message_ids - ) - row["scheduled_messages"] = sorted(filtered_scheduled_message_ids) + response["zerver_attachment"] = (postprocess_attachment(r) for r in make_raw(attachments)) return attachments @@ -1480,18 +1503,17 @@ def custom_fetch_realm_audit_logs_for_user(response: TableData, context: Context """ user = context["user"] query = RealmAuditLog.objects.filter(Q(modified_user_id=user.id) | Q(acting_user_id=user.id)) - rows = make_raw(list(query)) - response["zerver_realmauditlog"] = rows + response["zerver_realmauditlog"] = make_raw(query.iterator()) def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None: query = Reaction.objects.filter(message_id__in=list(message_ids)) - response["zerver_reaction"] = make_raw(list(query)) + response["zerver_reaction"] = make_raw(query.iterator()) def fetch_client_data(response: TableData, client_ids: set[int]) -> None: query = Client.objects.filter(id__in=list(client_ids)) - response["zerver_client"] = make_raw(list(query)) + response["zerver_client"] = make_raw(query.iterator()) def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None: @@ -1509,7 +1531,9 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) -> consented_user_ids = set() user_profile_ids = { - r["id"] for r in response["zerver_userprofile"] + response["zerver_userprofile_mirrordummy"] + r["id"] + for r in list(response["zerver_userprofile"]) + + list(response["zerver_userprofile_mirrordummy"]) } recipient_filter = Q() @@ -1575,12 +1599,6 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) -> response["zerver_huddle"] = make_raw( DirectMessageGroup.objects.filter(id__in=direct_message_group_ids) ) - if export_type == RealmExport.EXPORT_PUBLIC and any( - response[t] for t in ["_huddle_recipient", "_huddle_subscription", "zerver_huddle"] - ): - raise AssertionError( - "Public export should not result in exporting any data in _huddle tables" - ) def custom_fetch_scheduled_messages(response: TableData, context: Context) -> None: @@ -1591,7 +1609,7 @@ def custom_fetch_scheduled_messages(response: TableData, context: Context) -> No exportable_scheduled_message_ids = context["exportable_scheduled_message_ids"] query = ScheduledMessage.objects.filter(realm=realm, id__in=exportable_scheduled_message_ids) - rows = make_raw(list(query)) + rows = make_raw(query) response["zerver_scheduledmessage"] = rows @@ -1654,14 +1672,15 @@ def custom_fetch_realm_audit_logs_for_realm(response: TableData, context: Contex def custom_fetch_onboarding_usermessage(response: TableData, context: Context) -> None: realm = context["realm"] - response["zerver_onboardingusermessage"] = [] + onboarding = [] onboarding_usermessage_query = OnboardingUserMessage.objects.filter(realm=realm) for onboarding_usermessage in onboarding_usermessage_query: onboarding_usermessage_obj = model_to_dict(onboarding_usermessage) onboarding_usermessage_obj["flags_mask"] = onboarding_usermessage.flags.mask del onboarding_usermessage_obj["flags"] - response["zerver_onboardingusermessage"].append(onboarding_usermessage_obj) + onboarding.append(onboarding_usermessage_obj) + response["zerver_onboardingusermessage"] = onboarding def fetch_usermessages( @@ -1710,7 +1729,8 @@ def export_usermessages_batch( with open(input_path, "rb") as input_file: input_data: MessagePartial = orjson.loads(input_file.read()) - message_ids = {item["id"] for item in input_data["zerver_message"]} + messages = list(input_data["zerver_message"]) + message_ids = {item["id"] for item in messages} user_profile_ids = set(input_data["zerver_userprofile_ids"]) realm = Realm.objects.get(id=input_data["realm_id"]) zerver_usermessage_data = fetch_usermessages( @@ -1723,7 +1743,7 @@ def export_usermessages_batch( ) output_data: TableData = dict( - zerver_message=input_data["zerver_message"], + zerver_message=messages, zerver_usermessage=zerver_usermessage_data, ) write_table_data(output_path, output_data) @@ -1764,9 +1784,9 @@ def export_partial_message_files( response["zerver_userprofile"], ) ids_of_our_possible_senders = get_ids( - response["zerver_userprofile"] - + response["zerver_userprofile_mirrordummy"] - + response["zerver_userprofile_crossrealm"] + list(response["zerver_userprofile"]) + + list(response["zerver_userprofile_mirrordummy"]) + + list(response["zerver_userprofile_crossrealm"]) ) consented_user_ids: set[int] = set() @@ -1913,7 +1933,9 @@ def write_message_partials( for message_id_chunk in message_id_chunks: # Uses index: zerver_message_pkey actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id") - message_chunk = make_raw(actual_query) + message_chunk = [ + floatify_datetime_fields(r, "zerver_message") for r in make_raw(actual_query.iterator()) + ] for row in message_chunk: collected_client_ids.add(row["sending_client"]) @@ -1923,16 +1945,11 @@ def write_message_partials( message_filename += ".partial" logging.info("Fetched messages for %s", message_filename) - # Clean up our messages. - table_data: TableData = {} - table_data["zerver_message"] = message_chunk - floatify_datetime_fields(table_data, "zerver_message") - # Build up our output for the .partial file, which needs # a list of user_profile_ids to search for (as well as # the realm id). output: MessagePartial = dict( - zerver_message=table_data["zerver_message"], + zerver_message=message_chunk, zerver_userprofile_ids=list(user_profile_ids), realm_id=realm.id, ) @@ -1945,7 +1962,7 @@ def write_message_partials( def export_uploads_and_avatars( realm: Realm, *, - attachments: list[Attachment] | None = None, + attachments: Iterable[Attachment] | None = None, user: UserProfile | None, output_dir: Path, ) -> None: @@ -1974,7 +1991,7 @@ def export_uploads_and_avatars( else: handle_system_bots = False users = [user] - attachments = list(Attachment.objects.filter(owner_id=user.id)) + attachments = list(Attachment.objects.filter(owner_id=user.id).order_by("path_id")) realm_emojis = list(RealmEmoji.objects.filter(author_id=user.id)) if settings.LOCAL_UPLOADS_DIR: @@ -2164,7 +2181,6 @@ def export_files_from_s3( processing_emoji = flavor == "emoji" bucket = get_bucket(bucket_name) - records = [] logging.info("Downloading %s files from %s", flavor, bucket_name) @@ -2175,85 +2191,88 @@ def export_files_from_s3( email_gateway_bot = get_system_bot(settings.EMAIL_GATEWAY_BOT, internal_realm.id) user_ids.add(email_gateway_bot.id) - count = 0 - for bkey in bucket.objects.filter(Prefix=object_prefix): - if valid_hashes is not None and bkey.Object().key not in valid_hashes: - continue + def iterate_attachments() -> Iterator[Record]: + count = 0 + for bkey in bucket.objects.filter(Prefix=object_prefix): + # This is promised to be iterated in sorted filename order. - key = bucket.Object(bkey.key) - - """ - For very old realms we may not have proper metadata. If you really need - an export to bypass these checks, flip the following flag. - """ - checking_metadata = True - if checking_metadata: - if "realm_id" not in key.metadata: - raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}") - - if "user_profile_id" not in key.metadata: - raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}") - - if int(key.metadata["user_profile_id"]) not in user_ids: + if valid_hashes is not None and bkey.Object().key not in valid_hashes: continue - # This can happen if an email address has moved realms - if key.metadata["realm_id"] != str(realm.id): - if email_gateway_bot is None or key.metadata["user_profile_id"] != str( - email_gateway_bot.id - ): - raise AssertionError( - f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}" - ) - # Email gateway bot sends messages, potentially including attachments, cross-realm. - print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}") + key = bucket.Object(bkey.key) - record = _get_exported_s3_record(bucket_name, key, processing_emoji) + """ + For very old realms we may not have proper metadata. If you really need + an export to bypass these checks, flip the following flag. + """ + checking_metadata = True + if checking_metadata: + if "realm_id" not in key.metadata: + raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}") - record["path"] = key.key - _save_s3_object_to_file(key, output_dir, processing_uploads) + if "user_profile_id" not in key.metadata: + raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}") - records.append(record) - count += 1 + if int(key.metadata["user_profile_id"]) not in user_ids: + continue - if count % 100 == 0: - logging.info("Finished %s", count) + # This can happen if an email address has moved realms + if key.metadata["realm_id"] != str(realm.id): + if email_gateway_bot is None or key.metadata["user_profile_id"] != str( + email_gateway_bot.id + ): + raise AssertionError( + f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}" + ) + # Email gateway bot sends messages, potentially including attachments, cross-realm. + print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}") - write_records_json_file(output_dir, records) + record = _get_exported_s3_record(bucket_name, key, processing_emoji) + + record["path"] = key.key + _save_s3_object_to_file(key, output_dir, processing_uploads) + + yield record + count += 1 + + if count % 100 == 0: + logging.info("Finished %s", count) + + write_records_json_file(output_dir, iterate_attachments()) def export_uploads_from_local( - realm: Realm, local_dir: Path, output_dir: Path, attachments: list[Attachment] + realm: Realm, local_dir: Path, output_dir: Path, attachments: Iterable[Attachment] ) -> None: - records = [] - for count, attachment in enumerate(attachments, 1): - # Use 'mark_sanitized' to work around false positive caused by Pysa - # thinking that 'realm' (and thus 'attachment' and 'attachment.path_id') - # are user controlled - path_id = mark_sanitized(attachment.path_id) + def iterate_attachments() -> Iterator[Record]: + for count, attachment in enumerate(attachments, 1): + # Use 'mark_sanitized' to work around false positive caused by Pysa + # thinking that 'realm' (and thus 'attachment' and 'attachment.path_id') + # are user controlled + path_id = mark_sanitized(attachment.path_id) - local_path = os.path.join(local_dir, path_id) - output_path = os.path.join(output_dir, path_id) + local_path = os.path.join(local_dir, path_id) + output_path = os.path.join(output_dir, path_id) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - shutil.copy2(local_path, output_path) - stat = os.stat(local_path) - record = dict( - realm_id=attachment.realm_id, - user_profile_id=attachment.owner.id, - user_profile_email=attachment.owner.email, - s3_path=path_id, - path=path_id, - size=stat.st_size, - last_modified=stat.st_mtime, - content_type=None, - ) - records.append(record) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copy2(local_path, output_path) + stat = os.stat(local_path) + record = dict( + realm_id=attachment.realm_id, + user_profile_id=attachment.owner.id, + user_profile_email=attachment.owner.email, + s3_path=path_id, + path=path_id, + size=stat.st_size, + last_modified=stat.st_mtime, + content_type=None, + ) + yield record - if count % 100 == 0: - logging.info("Finished %s", count) + if count % 100 == 0: + logging.info("Finished %s", count) - write_records_json_file(output_dir, records) + write_records_json_file(output_dir, iterate_attachments()) def export_avatars_from_local( @@ -2337,52 +2356,44 @@ def get_emoji_path(realm_emoji: RealmEmoji) -> str: def export_emoji_from_local( - realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: list[RealmEmoji] + realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: Iterable[RealmEmoji] ) -> None: - records = [] + def emoji_path_tuples() -> Iterator[tuple[RealmEmoji, str]]: + for realm_emoji in realm_emojis: + realm_emoji_path = mark_sanitized(get_emoji_path(realm_emoji)) - realm_emoji_helper_tuples: list[tuple[RealmEmoji, str]] = [] - for realm_emoji in realm_emojis: - realm_emoji_path = get_emoji_path(realm_emoji) + yield (realm_emoji, realm_emoji_path) + yield (realm_emoji, realm_emoji_path + ".original") - # Use 'mark_sanitized' to work around false positive caused by Pysa - # thinking that 'realm' (and thus 'attachment' and 'attachment.path_id') - # are user controlled - realm_emoji_path = mark_sanitized(realm_emoji_path) + def iterate_emoji( + realm_emoji_helper_tuples: Iterator[tuple[RealmEmoji, str]], + ) -> Iterator[Record]: + for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1): + realm_emoji_object, emoji_path = realm_emoji_helper_tuple - realm_emoji_path_original = realm_emoji_path + ".original" + local_path = os.path.join(local_dir, emoji_path) + output_path = os.path.join(output_dir, emoji_path) - realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path)) - realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path_original)) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copy2(local_path, output_path) + # Realm emoji author is optional. + author = realm_emoji_object.author + author_id = author.id if author else None + record = dict( + realm_id=realm.id, + author=author_id, + path=emoji_path, + s3_path=emoji_path, + file_name=realm_emoji_object.file_name, + name=realm_emoji_object.name, + deactivated=realm_emoji_object.deactivated, + ) + yield record - for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1): - realm_emoji_object, emoji_path = realm_emoji_helper_tuple + if count % 100 == 0: + logging.info("Finished %s", count) - local_path = os.path.join(local_dir, emoji_path) - output_path = os.path.join(output_dir, emoji_path) - - os.makedirs(os.path.dirname(output_path), exist_ok=True) - shutil.copy2(local_path, output_path) - # Realm emoji author is optional. - author = realm_emoji_object.author - author_id = None - if author: - author_id = author.id - record = dict( - realm_id=realm.id, - author=author_id, - path=emoji_path, - s3_path=emoji_path, - file_name=realm_emoji_object.file_name, - name=realm_emoji_object.name, - deactivated=realm_emoji_object.deactivated, - ) - records.append(record) - - if count % 100 == 0: - logging.info("Finished %s", count) - - write_records_json_file(output_dir, records) + write_records_json_file(output_dir, iterate_emoji(emoji_path_tuples())) def do_write_stats_file_for_realm_export(output_dir: Path) -> dict[str, int | dict[str, int]]: @@ -2502,17 +2513,13 @@ def do_export_realm( ) logging.info("%d messages were exported", len(message_ids)) - # zerver_reaction - zerver_reaction: TableData = {} - fetch_reaction_data(response=zerver_reaction, message_ids=message_ids) - response.update(zerver_reaction) + fetch_reaction_data(response=response, message_ids=message_ids) - zerver_client: TableData = {} - fetch_client_data(response=zerver_client, client_ids=collected_client_ids) - response.update(zerver_client) + fetch_client_data(response=response, client_ids=collected_client_ids) # Override the "deactivated" flag on the realm if export_as_active is not None: + assert isinstance(response["zerver_realm"], list) response["zerver_realm"][0]["deactivated"] = not export_as_active response["import_source"] = "zulip" # type: ignore[assignment] # this is an extra info field, not TableData @@ -2619,6 +2626,8 @@ def do_export_user(user_profile: UserProfile, output_dir: Path) -> None: export_file = os.path.join(output_dir, "user.json") write_table_data(output_file=export_file, data=response) + # Double-check that this is a list, and not an iterator, so we can run over it again + assert isinstance(response["zerver_reaction"], list) reaction_message_ids: set[int] = {row["message"] for row in response["zerver_reaction"]} logging.info("Exporting messages") @@ -2661,6 +2670,7 @@ def get_single_user_config() -> Config: # Exports with consent are not relevant in the context of exporting # a single user. limit_to_consenting_users=False, + use_iterator=False, ) # zerver_recipient @@ -2669,6 +2679,7 @@ def get_single_user_config() -> Config: model=Recipient, virtual_parent=subscription_config, id_source=("zerver_subscription", "recipient"), + use_iterator=False, ) # zerver_stream @@ -2713,6 +2724,7 @@ def get_single_user_config() -> Config: normal_parent=user_profile_config, include_rows="user_profile_id__in", limit_to_consenting_users=False, + use_iterator=False, ) add_user_profile_child_configs(user_profile_config) @@ -2823,25 +2835,23 @@ def export_messages_single_user( .order_by("message_id") ) - user_message_chunk = list(fat_query) - - message_chunk = [] - for user_message in user_message_chunk: + def process_row(user_message: UserMessage) -> Record: item = model_to_dict(user_message.message) item["flags"] = user_message.flags_list() item["flags_mask"] = user_message.flags.mask # Add a few nice, human-readable details item["sending_client_name"] = user_message.message.sending_client.name item["recipient_name"] = get_recipient(user_message.message.recipient_id) - message_chunk.append(item) + return floatify_datetime_fields(item, "zerver_message") message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json") + write_table_data( + message_filename, + { + "zerver_message": (process_row(um) for um in fat_query.iterator()), + }, + ) logging.info("Fetched messages for %s", message_filename) - - output = {"zerver_message": message_chunk} - floatify_datetime_fields(output, "zerver_message") - - write_table_data(message_filename, output) dump_file_id += 1 diff --git a/zerver/lib/import_realm.py b/zerver/lib/import_realm.py index 7195754028..f00fe7ab2c 100644 --- a/zerver/lib/import_realm.py +++ b/zerver/lib/import_realm.py @@ -5,7 +5,7 @@ import shutil from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from difflib import unified_diff -from typing import Any +from typing import Any, TypeAlias import bmemcached import orjson @@ -33,7 +33,7 @@ from zerver.actions.realm_settings import ( from zerver.actions.user_settings import do_change_avatar_fields from zerver.lib.avatar_hash import user_avatar_base_path_from_ids from zerver.lib.bulk_create import bulk_set_users_or_streams_recipient_fields -from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableData, TableName +from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableName from zerver.lib.markdown import markdown_convert from zerver.lib.markdown import version as markdown_version from zerver.lib.message import get_last_message_id @@ -118,6 +118,8 @@ from zerver.models.recipients import get_direct_message_group_hash from zerver.models.users import get_system_bot, get_user_profile_by_id from zproject.backends import AUTH_BACKEND_NAME_MAP +ImportedTableData: TypeAlias = dict[str, list[Record]] + realm_tables = [ ("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"), ("zerver_defaultstream", DefaultStream, "defaultstream"), @@ -208,7 +210,7 @@ message_id_to_attachments: dict[str, dict[int, list[str]]] = { } -def map_messages_to_attachments(data: TableData) -> None: +def map_messages_to_attachments(data: ImportedTableData) -> None: for attachment in data["zerver_attachment"]: for message_id in attachment["messages"]: message_id_to_attachments["zerver_message"][message_id].append(attachment["path_id"]) @@ -231,14 +233,14 @@ def update_id_map(table: TableName, old_id: int, new_id: int) -> None: ID_MAP[table][old_id] = new_id -def fix_datetime_fields(data: TableData, table: TableName) -> None: +def fix_datetime_fields(data: ImportedTableData, table: TableName) -> None: for item in data[table]: for field_name in DATE_FIELDS[table]: if item[field_name] is not None: item[field_name] = datetime.fromtimestamp(item[field_name], tz=timezone.utc) -def fix_upload_links(data: TableData, message_table: TableName) -> None: +def fix_upload_links(data: ImportedTableData, message_table: TableName) -> None: """ Because the URLs for uploaded files encode the realm ID of the organization being imported (which is only determined at import @@ -260,7 +262,7 @@ def fix_upload_links(data: TableData, message_table: TableName) -> None: def fix_stream_permission_group_settings( - data: TableData, system_groups_name_dict: dict[str, NamedUserGroup] + data: ImportedTableData, system_groups_name_dict: dict[str, NamedUserGroup] ) -> None: table = get_db_table(Stream) for stream in data[table]: @@ -289,7 +291,7 @@ def fix_stream_permission_group_settings( ) -def create_subscription_events(data: TableData, realm_id: int) -> None: +def create_subscription_events(data: ImportedTableData, realm_id: int) -> None: """ When the export data doesn't contain the table `zerver_realmauditlog`, this function creates RealmAuditLog objects for `subscription_created` @@ -332,7 +334,7 @@ def create_subscription_events(data: TableData, realm_id: int) -> None: RealmAuditLog.objects.bulk_create(all_subscription_logs) -def fix_service_tokens(data: TableData, table: TableName) -> None: +def fix_service_tokens(data: ImportedTableData, table: TableName) -> None: """ The tokens in the services are created by 'generate_api_key'. As the tokens are unique, they should be re-created for the imports. @@ -341,7 +343,7 @@ def fix_service_tokens(data: TableData, table: TableName) -> None: item["token"] = generate_api_key() -def process_direct_message_group_hash(data: TableData, table: TableName) -> None: +def process_direct_message_group_hash(data: ImportedTableData, table: TableName) -> None: """ Build new direct message group hashes with the updated ids of the users """ @@ -350,7 +352,7 @@ def process_direct_message_group_hash(data: TableData, table: TableName) -> None direct_message_group["huddle_hash"] = get_direct_message_group_hash(user_id_list) -def get_direct_message_groups_from_subscription(data: TableData, table: TableName) -> None: +def get_direct_message_groups_from_subscription(data: ImportedTableData, table: TableName) -> None: """ Extract the IDs of the user_profiles involved in a direct message group from the subscription object @@ -369,7 +371,7 @@ def get_direct_message_groups_from_subscription(data: TableData, table: TableNam ) -def fix_customprofilefield(data: TableData) -> None: +def fix_customprofilefield(data: ImportedTableData) -> None: """ In CustomProfileField with 'field_type' like 'USER', the IDs need to be re-mapped. @@ -530,7 +532,7 @@ def fix_message_edit_history( message["edit_history"] = orjson.dumps(edit_history).decode() -def current_table_ids(data: TableData, table: TableName) -> list[int]: +def current_table_ids(data: ImportedTableData, table: TableName) -> list[int]: """ Returns the ids present in the current table """ @@ -560,7 +562,7 @@ def allocate_ids(model_class: Any, count: int) -> list[int]: return [item[0] for item in query] -def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None: +def convert_to_id_fields(data: ImportedTableData, table: TableName, field_name: Field) -> None: """ When Django gives us dict objects via model_to_dict, the foreign key fields are `foo`, but we want `foo_id` for the bulk insert. @@ -574,7 +576,7 @@ def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) - def re_map_foreign_keys( - data: TableData, + data: ImportedTableData, table: TableName, field_name: Field, related_table: TableName, @@ -585,7 +587,7 @@ def re_map_foreign_keys( """ This is a wrapper function for all the realm data tables and only avatar and attachment records need to be passed through the internal function - because of the difference in data format (TableData corresponding to realm data tables + because of the difference in data format (ImportedTableData corresponding to realm data tables and List[Record] corresponding to the avatar and attachment records) """ @@ -653,7 +655,7 @@ def re_map_foreign_keys_internal( item[field_name] = new_id -def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None: +def re_map_realm_emoji_codes(data: ImportedTableData, *, table_name: str) -> None: """ Some tables, including Reaction and UserStatus, contain a form of foreign key reference to the RealmEmoji table in the form of @@ -683,7 +685,7 @@ def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None: def re_map_foreign_keys_many_to_many( - data: TableData, + data: ImportedTableData, table: TableName, field_name: Field, related_table: TableName, @@ -733,13 +735,13 @@ def re_map_foreign_keys_many_to_many_internal( return new_id_list -def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None: +def fix_bitfield_keys(data: ImportedTableData, table: TableName, field_name: Field) -> None: for item in data[table]: item[field_name] = item[field_name + "_mask"] del item[field_name + "_mask"] -def remove_denormalized_recipient_column_from_data(data: TableData) -> None: +def remove_denormalized_recipient_column_from_data(data: ImportedTableData) -> None: """ The recipient column shouldn't be imported, we'll set the correct values when Recipient table gets imported. @@ -762,7 +764,7 @@ def get_db_table(model_class: Any) -> str: return model_class._meta.db_table -def update_model_ids(model: Any, data: TableData, related_table: TableName) -> None: +def update_model_ids(model: Any, data: ImportedTableData, related_table: TableName) -> None: table = get_db_table(model) # Important: remapping usermessage rows is @@ -777,7 +779,7 @@ def update_model_ids(model: Any, data: TableData, related_table: TableName) -> N re_map_foreign_keys(data, table, "id", related_table=related_table, id_field=True) -def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None: +def bulk_import_user_message_data(data: ImportedTableData, dump_file_id: int) -> None: model = UserMessage table = "zerver_usermessage" lst = data[table] @@ -810,7 +812,7 @@ def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None: logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id) -def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = None) -> None: +def bulk_import_model(data: ImportedTableData, model: Any, dump_file_id: str | None = None) -> None: table = get_db_table(model) # TODO, deprecate dump_file_id model.objects.bulk_create(model(**item) for item in data[table]) @@ -820,7 +822,7 @@ def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = No logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id) -def bulk_import_named_user_groups(data: TableData) -> None: +def bulk_import_named_user_groups(data: ImportedTableData) -> None: vals = [ ( group["usergroup_ptr_id"], @@ -854,7 +856,7 @@ def bulk_import_named_user_groups(data: TableData) -> None: # correctly import multiple realms into the same server, we need to # check if a Client object already exists, and so we need to support # remap all Client IDs to the values in the new DB. -def bulk_import_client(data: TableData, model: Any, table: TableName) -> None: +def bulk_import_client(data: ImportedTableData, model: Any, table: TableName) -> None: for item in data[table]: try: client = Client.objects.get(name=item["name"]) @@ -880,7 +882,7 @@ def set_subscriber_count_for_channels(realm: Realm) -> None: def fix_subscriptions_is_user_active_column( - data: TableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int] + data: ImportedTableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int] ) -> None: table = get_db_table(Subscription) user_id_to_active_status = {user.id: user.is_active for user in user_profiles} @@ -1156,7 +1158,7 @@ def import_uploads( future.result() -def disable_restricted_authentication_methods(data: TableData) -> None: +def disable_restricted_authentication_methods(data: ImportedTableData) -> None: """ Should run only with settings.BILLING_ENABLED. Ensures that we only enable authentication methods that are available without needing a plan. @@ -2105,7 +2107,7 @@ def import_message_data(realm: Realm, sender_map: dict[int, Record], import_dir: dump_file_id += 1 -def import_attachments(data: TableData) -> None: +def import_attachments(data: ImportedTableData) -> None: # Clean up the data in zerver_attachment that is not # relevant to our many-to-many import. fix_datetime_fields(data, "zerver_attachment") @@ -2146,7 +2148,7 @@ def import_attachments(data: TableData) -> None: ] # Create our table data for insert. - m2m_data: TableData = {m2m_table_name: m2m_rows} + m2m_data: ImportedTableData = {m2m_table_name: m2m_rows} convert_to_id_fields(m2m_data, m2m_table_name, parent_singular) convert_to_id_fields(m2m_data, m2m_table_name, child_singular) m2m_rows = m2m_data[m2m_table_name] @@ -2193,7 +2195,7 @@ def import_attachments(data: TableData) -> None: logging.info("Successfully imported M2M table %s", m2m_table_name) -def fix_attachments_data(attachment_data: TableData) -> None: +def fix_attachments_data(attachment_data: ImportedTableData) -> None: for attachment in attachment_data["zerver_attachment"]: attachment["path_id"] = path_maps["old_attachment_path_to_new_path"][attachment["path_id"]] @@ -2205,7 +2207,7 @@ def fix_attachments_data(attachment_data: TableData) -> None: attachment["content_type"] = guessed_content_type -def create_image_attachments(realm: Realm, attachment_data: TableData) -> None: +def create_image_attachments(realm: Realm, attachment_data: ImportedTableData) -> None: for attachment in attachment_data["zerver_attachment"]: if attachment["content_type"] not in THUMBNAIL_ACCEPT_IMAGE_TYPES: continue diff --git a/zerver/tests/test_import_export.py b/zerver/tests/test_import_export.py index 472be41c14..9dd2d9a0a0 100644 --- a/zerver/tests/test_import_export.py +++ b/zerver/tests/test_import_export.py @@ -849,11 +849,13 @@ class RealmImportExportTest(ExportFile): # Consented users: hamlet = self.example_user("hamlet") othello = self.example_user("othello") + cordelia = self.example_user("cordelia") # Iago will be non-consenting. iago = self.example_user("iago") do_change_user_setting(hamlet, "allow_private_data_export", True, acting_user=None) do_change_user_setting(othello, "allow_private_data_export", True, acting_user=None) + do_change_user_setting(cordelia, "allow_private_data_export", True, acting_user=None) do_change_user_setting(iago, "allow_private_data_export", False, acting_user=None) # Despite both hamlet and othello having consent enabled, in a public export @@ -865,6 +867,9 @@ class RealmImportExportTest(ExportFile): a_message.sending_client = private_client a_message.save() + # Verify that a group DM between consenting users is not exported + self.send_group_direct_message(hamlet, [othello, cordelia]) + # SavedSnippets are private content - so in a public export, despite # hamlet having consent enabled, such objects should not be exported. saved_snippet = do_create_saved_snippet("test", "test", hamlet) @@ -888,6 +893,9 @@ class RealmImportExportTest(ExportFile): exported_user_presence_ids = self.get_set(realm_data["zerver_userpresence"], "id") self.assertIn(iago_presence.id, exported_user_presence_ids) + exported_huddle_ids = self.get_set(realm_data["zerver_huddle"], "id") + self.assertEqual(exported_huddle_ids, set()) + def test_export_realm_with_member_consent(self) -> None: realm = Realm.objects.get(string_id="zulip")