From 8b9516fb0b88f7a9dc91956e22d41e8acd3cd99c Mon Sep 17 00:00:00 2001 From: Mateusz Mandera Date: Wed, 12 Mar 2025 00:11:42 +0800 Subject: [PATCH] export: Only export Client objects needed by the data being exported. We shouldn't export the entire Client table - it includes Clients for all the realms on the server, completely unrelated to the realm we're exporting. Since these contain parts of the UserAgents used by the users, we should treat these as private data and only export the Clients that the specific data we're exporting "knows" about. --- zerver/lib/export.py | 63 +++++++++++++++++++++++++++++++++----------- zerver/lib/utils.py | 27 +++++++++++++++++++ 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 1438785a62..8b6cb33a41 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -17,13 +17,13 @@ from collections.abc import Callable, Iterable, Mapping from contextlib import suppress from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, cast from urllib.parse import urlsplit import orjson from django.apps import apps from django.conf import settings -from django.db.models import Exists, OuterRef, Q +from django.db.models import Exists, Model, OuterRef, Q from django.forms.models import model_to_dict from django.utils.timezone import is_naive as timezone_is_naive from django.utils.timezone import now as timezone_now @@ -37,6 +37,7 @@ from zerver.lib.migration_status import MigrationStatusJson, parse_migration_sta from zerver.lib.pysa import mark_sanitized from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.upload.s3 import get_bucket +from zerver.lib.utils import get_fk_field_name from zerver.models import ( AlertWord, Attachment, @@ -511,10 +512,10 @@ class Config: id_source: IdSource | None = None, source_filter: SourceFilter | None = None, include_rows: Field | None = None, - use_all: bool = False, is_seeded: bool = False, exclude: list[Field] | None = None, limit_to_consenting_users: bool | None = None, + collect_client_ids: bool = False, ) -> None: assert table or custom_tables self.table = table @@ -523,7 +524,6 @@ class Config: self.virtual_parent = virtual_parent self.filter_args = filter_args self.include_rows = include_rows - self.use_all = use_all self.is_seeded = is_seeded self.exclude = exclude self.custom_fetch = custom_fetch @@ -533,6 +533,7 @@ class Config: self.id_source = id_source self.source_filter = source_filter self.limit_to_consenting_users = limit_to_consenting_users + self.collect_client_ids = collect_client_ids self.children: list[Config] = [] if self.include_rows: @@ -549,6 +550,15 @@ class Config: """ ) + if self.collect_client_ids: + raise AssertionError( + """ + If you're using custom_fetch with collect_client_ids, you need to + extend the related logic to handle how to collect Client ids with your + customer fetcher. + """ + ) + if normal_parent is not None: self.parent: Config | None = normal_parent else: @@ -674,11 +684,6 @@ def export_from_config( assert table is not None response[table] = data - elif config.use_all: - assert model is not None - query = model.objects.all() - rows = list(query) - elif config.normal_parent: # In this mode, our current model is figuratively Article, # and normal_parent is figuratively Blog, and @@ -779,6 +784,14 @@ def export_from_config( if rows is not None: assert table is not None # Hint for mypy response[table] = make_raw(rows, exclude=config.exclude) + if config.collect_client_ids and "collected_client_ids_set" in context: + model = cast(type[Model], model) + assert issubclass(model, Model) + client_id_field_name = get_fk_field_name(model, Client) + assert client_id_field_name is not None + context["collected_client_ids_set"].update( + {row[client_id_field_name] for row in response[table]} + ) # Post-process rows for t in exported_tables: @@ -879,13 +892,6 @@ def get_realm_config() -> Config: include_rows="realm_id__in", ) - Config( - table="zerver_client", - model=Client, - virtual_parent=realm_config, - use_all=True, - ) - Config( table="zerver_realmuserdefault", model=RealmUserDefault, @@ -1127,6 +1133,7 @@ def add_user_profile_child_configs(user_profile_config: Config) -> None: normal_parent=user_profile_config, include_rows="user_profile_id__in", limit_to_consenting_users=True, + collect_client_ids=True, ) Config( @@ -1156,6 +1163,7 @@ def add_user_profile_child_configs(user_profile_config: Config) -> None: normal_parent=user_profile_config, include_rows="user_profile_id__in", limit_to_consenting_users=True, + collect_client_ids=True, ) Config( @@ -1291,6 +1299,11 @@ def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None: response["zerver_reaction"] = make_raw(list(query)) +def fetch_client_data(response: TableData, client_ids: set[int]) -> None: + query = Client.objects.filter(id__in=list(client_ids)) + response["zerver_client"] = make_raw(list(query)) + + def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None: realm = context["realm"] user_profile_ids = { @@ -1451,6 +1464,7 @@ def export_partial_message_files( realm: Realm, response: TableData, export_type: int, + collected_client_ids: set[int], chunk_size: int = MESSAGE_BATCH_CHUNK_SIZE, output_dir: Path | None = None, ) -> set[int]: @@ -1608,6 +1622,7 @@ def export_partial_message_files( message_id_chunks=message_id_chunks, output_dir=output_dir, user_profile_ids=user_ids_for_us, + collected_client_ids=collected_client_ids, ) return all_message_ids @@ -1619,6 +1634,7 @@ def write_message_partials( message_id_chunks: list[list[int]], output_dir: Path, user_profile_ids: set[int], + collected_client_ids: set[int], ) -> None: dump_file_id = 1 @@ -1627,6 +1643,9 @@ def write_message_partials( actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id") message_chunk = make_raw(actual_query) + for row in message_chunk: + collected_client_ids.add(row["sending_client"]) + # Figure out the name of our shard file. message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json") message_filename += ".partial" @@ -2164,6 +2183,12 @@ def do_export_realm( create_soft_link(source=output_dir, in_progress=True) exportable_scheduled_message_ids = get_exportable_scheduled_message_ids(realm, export_type) + collected_client_ids = set( + ScheduledMessage.objects.filter(id__in=exportable_scheduled_message_ids) + .order_by("sending_client_id") + .distinct("sending_client_id") + .values_list("sending_client_id", flat=True) + ) logging.info("Exporting data from get_realm_config()...") export_from_config( @@ -2175,6 +2200,7 @@ def do_export_realm( export_type=export_type, exportable_user_ids=exportable_user_ids, exportable_scheduled_message_ids=exportable_scheduled_message_ids, + collected_client_ids_set=collected_client_ids, ), ) logging.info("...DONE with get_realm_config() data") @@ -2192,6 +2218,7 @@ def do_export_realm( response, export_type=export_type, output_dir=output_dir, + collected_client_ids=collected_client_ids, ) logging.info("%d messages were exported", len(message_ids)) @@ -2200,6 +2227,10 @@ def do_export_realm( fetch_reaction_data(response=zerver_reaction, message_ids=message_ids) response.update(zerver_reaction) + zerver_client: TableData = {} + fetch_client_data(response=zerver_client, client_ids=collected_client_ids) + response.update(zerver_client) + # Override the "deactivated" flag on the realm if export_as_active is not None: response["zerver_realm"][0]["deactivated"] = not export_as_active diff --git a/zerver/lib/utils.py b/zerver/lib/utils.py index ea16c7d0e0..3d27b68be8 100644 --- a/zerver/lib/utils.py +++ b/zerver/lib/utils.py @@ -3,6 +3,8 @@ import secrets from collections.abc import Callable from typing import TypeVar +from django.db import models + T = TypeVar("T") @@ -41,3 +43,28 @@ def optional_bytes_to_mib(value: int | None) -> int | None: return None else: return value >> 20 + + +def get_fk_field_name(model: type[models.Model], related_model: type[models.Model]) -> str | None: + """ + Get the name of the foreign key field in model, pointing the related_model table. + Returns None if there is no such field. + + Example usage: + + get_fk_field_name(UserProfile, Realm) returns "realm" + """ + + fields = model._meta.get_fields() + foreign_key_fields_to_related_model = [ + field + for field in fields + if hasattr(field, "related_model") and field.related_model == related_model + ] + + if len(foreign_key_fields_to_related_model) == 0: + return None + + assert len(foreign_key_fields_to_related_model) == 1 + + return foreign_key_fields_to_related_model[0].name