diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 1438785a62..8b6cb33a41 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -17,13 +17,13 @@ from collections.abc import Callable, Iterable, Mapping from contextlib import suppress from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, cast from urllib.parse import urlsplit import orjson from django.apps import apps from django.conf import settings -from django.db.models import Exists, OuterRef, Q +from django.db.models import Exists, Model, OuterRef, Q from django.forms.models import model_to_dict from django.utils.timezone import is_naive as timezone_is_naive from django.utils.timezone import now as timezone_now @@ -37,6 +37,7 @@ from zerver.lib.migration_status import MigrationStatusJson, parse_migration_sta from zerver.lib.pysa import mark_sanitized from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.upload.s3 import get_bucket +from zerver.lib.utils import get_fk_field_name from zerver.models import ( AlertWord, Attachment, @@ -511,10 +512,10 @@ class Config: id_source: IdSource | None = None, source_filter: SourceFilter | None = None, include_rows: Field | None = None, - use_all: bool = False, is_seeded: bool = False, exclude: list[Field] | None = None, limit_to_consenting_users: bool | None = None, + collect_client_ids: bool = False, ) -> None: assert table or custom_tables self.table = table @@ -523,7 +524,6 @@ class Config: self.virtual_parent = virtual_parent self.filter_args = filter_args self.include_rows = include_rows - self.use_all = use_all self.is_seeded = is_seeded self.exclude = exclude self.custom_fetch = custom_fetch @@ -533,6 +533,7 @@ class Config: self.id_source = id_source self.source_filter = source_filter self.limit_to_consenting_users = limit_to_consenting_users + self.collect_client_ids = collect_client_ids self.children: list[Config] = [] if self.include_rows: @@ -549,6 +550,15 @@ class Config: """ ) + if self.collect_client_ids: + raise AssertionError( + """ + If you're using custom_fetch with collect_client_ids, you need to + extend the related logic to handle how to collect Client ids with your + customer fetcher. + """ + ) + if normal_parent is not None: self.parent: Config | None = normal_parent else: @@ -674,11 +684,6 @@ def export_from_config( assert table is not None response[table] = data - elif config.use_all: - assert model is not None - query = model.objects.all() - rows = list(query) - elif config.normal_parent: # In this mode, our current model is figuratively Article, # and normal_parent is figuratively Blog, and @@ -779,6 +784,14 @@ def export_from_config( if rows is not None: assert table is not None # Hint for mypy response[table] = make_raw(rows, exclude=config.exclude) + if config.collect_client_ids and "collected_client_ids_set" in context: + model = cast(type[Model], model) + assert issubclass(model, Model) + client_id_field_name = get_fk_field_name(model, Client) + assert client_id_field_name is not None + context["collected_client_ids_set"].update( + {row[client_id_field_name] for row in response[table]} + ) # Post-process rows for t in exported_tables: @@ -879,13 +892,6 @@ def get_realm_config() -> Config: include_rows="realm_id__in", ) - Config( - table="zerver_client", - model=Client, - virtual_parent=realm_config, - use_all=True, - ) - Config( table="zerver_realmuserdefault", model=RealmUserDefault, @@ -1127,6 +1133,7 @@ def add_user_profile_child_configs(user_profile_config: Config) -> None: normal_parent=user_profile_config, include_rows="user_profile_id__in", limit_to_consenting_users=True, + collect_client_ids=True, ) Config( @@ -1156,6 +1163,7 @@ def add_user_profile_child_configs(user_profile_config: Config) -> None: normal_parent=user_profile_config, include_rows="user_profile_id__in", limit_to_consenting_users=True, + collect_client_ids=True, ) Config( @@ -1291,6 +1299,11 @@ def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None: response["zerver_reaction"] = make_raw(list(query)) +def fetch_client_data(response: TableData, client_ids: set[int]) -> None: + query = Client.objects.filter(id__in=list(client_ids)) + response["zerver_client"] = make_raw(list(query)) + + def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None: realm = context["realm"] user_profile_ids = { @@ -1451,6 +1464,7 @@ def export_partial_message_files( realm: Realm, response: TableData, export_type: int, + collected_client_ids: set[int], chunk_size: int = MESSAGE_BATCH_CHUNK_SIZE, output_dir: Path | None = None, ) -> set[int]: @@ -1608,6 +1622,7 @@ def export_partial_message_files( message_id_chunks=message_id_chunks, output_dir=output_dir, user_profile_ids=user_ids_for_us, + collected_client_ids=collected_client_ids, ) return all_message_ids @@ -1619,6 +1634,7 @@ def write_message_partials( message_id_chunks: list[list[int]], output_dir: Path, user_profile_ids: set[int], + collected_client_ids: set[int], ) -> None: dump_file_id = 1 @@ -1627,6 +1643,9 @@ def write_message_partials( actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id") message_chunk = make_raw(actual_query) + for row in message_chunk: + collected_client_ids.add(row["sending_client"]) + # Figure out the name of our shard file. message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json") message_filename += ".partial" @@ -2164,6 +2183,12 @@ def do_export_realm( create_soft_link(source=output_dir, in_progress=True) exportable_scheduled_message_ids = get_exportable_scheduled_message_ids(realm, export_type) + collected_client_ids = set( + ScheduledMessage.objects.filter(id__in=exportable_scheduled_message_ids) + .order_by("sending_client_id") + .distinct("sending_client_id") + .values_list("sending_client_id", flat=True) + ) logging.info("Exporting data from get_realm_config()...") export_from_config( @@ -2175,6 +2200,7 @@ def do_export_realm( export_type=export_type, exportable_user_ids=exportable_user_ids, exportable_scheduled_message_ids=exportable_scheduled_message_ids, + collected_client_ids_set=collected_client_ids, ), ) logging.info("...DONE with get_realm_config() data") @@ -2192,6 +2218,7 @@ def do_export_realm( response, export_type=export_type, output_dir=output_dir, + collected_client_ids=collected_client_ids, ) logging.info("%d messages were exported", len(message_ids)) @@ -2200,6 +2227,10 @@ def do_export_realm( fetch_reaction_data(response=zerver_reaction, message_ids=message_ids) response.update(zerver_reaction) + zerver_client: TableData = {} + fetch_client_data(response=zerver_client, client_ids=collected_client_ids) + response.update(zerver_client) + # Override the "deactivated" flag on the realm if export_as_active is not None: response["zerver_realm"][0]["deactivated"] = not export_as_active diff --git a/zerver/lib/utils.py b/zerver/lib/utils.py index ea16c7d0e0..3d27b68be8 100644 --- a/zerver/lib/utils.py +++ b/zerver/lib/utils.py @@ -3,6 +3,8 @@ import secrets from collections.abc import Callable from typing import TypeVar +from django.db import models + T = TypeVar("T") @@ -41,3 +43,28 @@ def optional_bytes_to_mib(value: int | None) -> int | None: return None else: return value >> 20 + + +def get_fk_field_name(model: type[models.Model], related_model: type[models.Model]) -> str | None: + """ + Get the name of the foreign key field in model, pointing the related_model table. + Returns None if there is no such field. + + Example usage: + + get_fk_field_name(UserProfile, Realm) returns "realm" + """ + + fields = model._meta.get_fields() + foreign_key_fields_to_related_model = [ + field + for field in fields + if hasattr(field, "related_model") and field.related_model == related_model + ] + + if len(foreign_key_fields_to_related_model) == 0: + return None + + assert len(foreign_key_fields_to_related_model) == 1 + + return foreign_key_fields_to_related_model[0].name