export: Move all queries, when possible, to iterators.

This reduces overall memory usage for large exports.
This commit is contained in:
Alex Vandiver
2025-10-01 15:03:15 +00:00
committed by Tim Abbott
parent 0ffc0e810c
commit 755cb7d854
3 changed files with 252 additions and 232 deletions

View File

@@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping
from datetime import datetime
from email.headerregistry import Address
from functools import cache
from itertools import islice
from itertools import chain, islice
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, TypeVar, cast
from urllib.parse import urlsplit
@@ -27,7 +27,7 @@ import orjson
from django.apps import apps
from django.conf import settings
from django.db import connection
from django.db.models import Exists, Model, OuterRef, Q
from django.db.models import Exists, Model, OuterRef, Q, QuerySet
from django.forms.models import model_to_dict
from django.utils.timezone import is_naive as timezone_is_naive
from django.utils.timezone import now as timezone_now
@@ -98,7 +98,7 @@ if TYPE_CHECKING:
# Custom mypy types follow:
Record: TypeAlias = dict[str, Any]
TableName = str
TableData: TypeAlias = dict[TableName, list[Record]]
TableData: TypeAlias = dict[TableName, Iterator[Record] | list[Record]]
Field = str
Path = str
Context: TypeAlias = dict[str, Any]
@@ -108,11 +108,11 @@ SourceFilter: TypeAlias = Callable[[Record], bool]
CustomFetch: TypeAlias = Callable[[TableData, Context], None]
CustomReturnIds: TypeAlias = Callable[[TableData], set[int]]
CustomProcessResults: TypeAlias = Callable[[list[Record], Context], list[Record]]
CustomProcessResults: TypeAlias = Callable[[Iterable[Record], Context], Iterator[Record]]
class MessagePartial(TypedDict):
zerver_message: list[Record]
zerver_message: Iterable[Record]
zerver_userprofile_ids: list[int]
realm_id: int
@@ -516,12 +516,11 @@ def write_records_json_file(output_dir: str, records: Iterable[dict[str, Any]])
logging.info("Finished writing %s", output_file)
def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
def make_raw(query: Iterable[Any], exclude: list[Field] | None = None) -> Iterator[Record]:
"""
Takes a Django query and returns a JSONable list
of dictionaries corresponding to the database rows.
"""
rows = []
for instance in query:
data = model_to_dict(instance, exclude=exclude)
"""
@@ -536,20 +535,19 @@ def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
value = data[field.name]
data[field.name] = [row.id for row in value]
rows.append(data)
return rows
yield data
def floatify_datetime_fields(data: TableData, table: TableName) -> None:
for item in data[table]:
for field in DATE_FIELDS[table]:
dt = item[field]
if dt is None:
continue
assert isinstance(dt, datetime)
assert not timezone_is_naive(dt)
item[field] = dt.timestamp()
def floatify_datetime_fields(item: Record, table: TableName) -> Record:
updates = {}
for field in DATE_FIELDS[table]:
dt = item[field]
if dt is None:
continue
assert isinstance(dt, datetime)
assert not timezone_is_naive(dt)
updates[field] = dt.timestamp()
return {**item, **updates}
class Config:
@@ -587,6 +585,7 @@ class Config:
exclude: list[Field] | None = None,
limit_to_consenting_users: bool | None = None,
collect_client_ids: bool = False,
use_iterator: bool = True,
) -> None:
assert table or custom_tables
self.table = table
@@ -698,11 +697,17 @@ class Config:
assert include_rows in ["user_profile_id__in", "user_id__in", "bot_profile_id__in"]
assert normal_parent is not None and normal_parent.table == "zerver_userprofile"
if self.collect_client_ids or self.is_seeded:
self.use_iterator = False
else:
self.use_iterator = use_iterator
def return_ids(self, response: TableData) -> set[int]:
if self.custom_return_ids is not None:
return self.custom_return_ids(response)
else:
assert self.table is not None
assert not self.use_iterator, self.table
return {row["id"] for row in response[self.table]}
@@ -730,7 +735,8 @@ def export_from_config(
for t in exported_tables:
logging.info("Exporting via export_from_config: %s", t)
rows = None
rows: Iterable[Any] | None = None
query: QuerySet[Any] | None = None
if config.is_seeded:
rows = [seed_object]
@@ -748,13 +754,13 @@ def export_from_config(
# When we concat_and_destroy, we are working with
# temporary "tables" that are lists of records that
# should already be ready to export.
data: list[Record] = []
for t in config.concat_and_destroy:
data += response[t]
del response[t]
logging.info("Deleted temporary %s", t)
assert table is not None
response[table] = data
# We pop them off of the response and store them in a local
# which the iterable closes over
tables = {t: response.pop(t) for t in config.concat_and_destroy}
response[table] = chain.from_iterable(tables[t] for t in config.concat_and_destroy)
for t in config.concat_and_destroy:
logging.info("Deleted temporary %s", t)
elif config.normal_parent:
# In this mode, our current model is figuratively Article,
@@ -818,7 +824,7 @@ def export_from_config(
assert model is not None
try:
query = model.objects.filter(**filter_params)
query = model.objects.filter(**filter_params).order_by("id")
except Exception:
print(
f"""
@@ -833,8 +839,6 @@ def export_from_config(
)
raise
rows = list(query)
elif config.id_source:
# In this mode, we are the figurative Blog, and we now
# need to look at the current response to get all the
@@ -843,6 +847,8 @@ def export_from_config(
assert model is not None
# This will be a tuple of the form ('zerver_article', 'blog').
(child_table, field) = config.id_source
assert config.virtual_parent is not None
assert not config.virtual_parent.use_iterator
child_rows = response[child_table]
if config.source_filter:
child_rows = [r for r in child_rows if config.source_filter(r)]
@@ -851,12 +857,17 @@ def export_from_config(
if config.filter_args:
filter_params.update(config.filter_args)
query = model.objects.filter(**filter_params)
rows = list(query)
if query is not None:
rows = query.iterator()
if rows is not None:
assert table is not None # Hint for mypy
response[table] = make_raw(rows, exclude=config.exclude)
if config.collect_client_ids and "collected_client_ids_set" in context:
# If we need to collect the client-ids, we can't just stream the results
response[table] = list(response[table])
model = cast(type[Model], model)
assert issubclass(model, Model)
client_id_field_name = get_fk_field_name(model, Client)
@@ -873,8 +884,10 @@ def export_from_config(
# of the exported data for the tables - e.g. to strip out private data.
response[t] = custom_process_results(response[t], context)
if t in DATE_FIELDS:
floatify_datetime_fields(response, t)
response[t] = (floatify_datetime_fields(r, t) for r in response[t])
if not config.use_iterator:
response[t] = list(response[t])
# Now walk our children. It's extremely important to respect
# the order of children here.
for child_config in config.children:
@@ -998,6 +1011,7 @@ def get_realm_config() -> Config:
table="zerver_userprofile",
virtual_parent=realm_config,
custom_fetch=custom_fetch_user_profile,
use_iterator=False,
)
user_groups_config = Config(
@@ -1006,6 +1020,7 @@ def get_realm_config() -> Config:
normal_parent=realm_config,
include_rows="realm_id__in",
exclude=["direct_members", "direct_subgroups"],
use_iterator=False,
)
Config(
@@ -1078,6 +1093,7 @@ def get_realm_config() -> Config:
# It is just "glue" data for internal data model consistency purposes
# with no user-specific information.
limit_to_consenting_users=False,
use_iterator=False,
)
Config(
@@ -1092,6 +1108,7 @@ def get_realm_config() -> Config:
model=Stream,
normal_parent=realm_config,
include_rows="realm_id__in",
use_iterator=False,
)
stream_recipient_config = Config(
@@ -1100,6 +1117,7 @@ def get_realm_config() -> Config:
normal_parent=stream_config,
include_rows="type_id__in",
filter_args={"type": Recipient.STREAM},
use_iterator=False,
)
Config(
@@ -1131,6 +1149,7 @@ def get_realm_config() -> Config:
"_stream_recipient",
"_huddle_recipient",
],
use_iterator=False,
)
Config(
@@ -1157,11 +1176,12 @@ def get_realm_config() -> Config:
def custom_process_subscription_in_realm_config(
subscriptions: list[Record], context: Context
) -> list[Record]:
subscriptions: Iterable[Record], context: Context
) -> Iterator[Record]:
export_type = context["export_type"]
if export_type == RealmExport.EXPORT_FULL_WITHOUT_CONSENT:
return subscriptions
yield from subscriptions
return
exportable_user_ids_from_context = context["exportable_user_ids"]
if export_type == RealmExport.EXPORT_FULL_WITH_CONSENT:
@@ -1172,9 +1192,10 @@ def custom_process_subscription_in_realm_config(
assert exportable_user_ids_from_context is None
consented_user_ids = set()
def scrub_subscription_if_needed(subscription: Record) -> Record:
for subscription in subscriptions:
if subscription["user_profile"] in consented_user_ids:
return subscription
yield subscription
continue
# We create a replacement Subscription, setting only the essential fields,
# while allowing all the other ones to fall back to the defaults
# defined in the model.
@@ -1190,10 +1211,7 @@ def custom_process_subscription_in_realm_config(
color=random.choice(STREAM_ASSIGNMENT_COLORS),
)
subscription_dict = model_to_dict(scrubbed_subscription)
return subscription_dict
processed_rows = map(scrub_subscription_if_needed, subscriptions)
return list(processed_rows)
yield subscription_dict
def add_user_profile_child_configs(user_profile_config: Config) -> None:
@@ -1410,7 +1428,7 @@ def custom_fetch_user_profile(response: TableData, context: Context) -> None:
def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) -> None:
realm = context["realm"]
response["zerver_userprofile_crossrealm"] = []
crossrealm_bots = []
bot_name_to_default_email = {
"NOTIFICATION_BOT": "notification-bot@zulip.com",
@@ -1432,13 +1450,14 @@ def custom_fetch_user_profile_cross_realm(response: TableData, context: Context)
bot_user_id = get_system_bot(bot_email, internal_realm.id).id
recipient_id = Recipient.objects.get(type_id=bot_user_id, type=Recipient.PERSONAL).id
response["zerver_userprofile_crossrealm"].append(
crossrealm_bots.append(
dict(
email=bot_default_email,
id=bot_user_id,
recipient_id=recipient_id,
)
),
)
response["zerver_userprofile_crossrealm"] = crossrealm_bots
def fetch_attachment_data(
@@ -1448,10 +1467,21 @@ def fetch_attachment_data(
Attachment.objects.filter(
Q(messages__in=message_ids) | Q(scheduled_messages__in=scheduled_message_ids),
realm_id=realm_id,
).distinct()
)
.distinct("path_id")
.order_by("path_id")
)
response["zerver_attachment"] = make_raw(attachments)
floatify_datetime_fields(response, "zerver_attachment")
def postprocess_attachment(row: Record) -> Record:
row = floatify_datetime_fields(row, "zerver_attachment")
filtered_message_ids = set(row["messages"]).intersection(message_ids)
row["messages"] = sorted(filtered_message_ids)
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
scheduled_message_ids
)
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
return row
"""
We usually export most messages for the realm, but not
@@ -1461,14 +1491,7 @@ def fetch_attachment_data(
Same reasoning applies to scheduled_messages.
"""
for row in response["zerver_attachment"]:
filtered_message_ids = set(row["messages"]).intersection(message_ids)
row["messages"] = sorted(filtered_message_ids)
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
scheduled_message_ids
)
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
response["zerver_attachment"] = (postprocess_attachment(r) for r in make_raw(attachments))
return attachments
@@ -1480,18 +1503,17 @@ def custom_fetch_realm_audit_logs_for_user(response: TableData, context: Context
"""
user = context["user"]
query = RealmAuditLog.objects.filter(Q(modified_user_id=user.id) | Q(acting_user_id=user.id))
rows = make_raw(list(query))
response["zerver_realmauditlog"] = rows
response["zerver_realmauditlog"] = make_raw(query.iterator())
def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None:
query = Reaction.objects.filter(message_id__in=list(message_ids))
response["zerver_reaction"] = make_raw(list(query))
response["zerver_reaction"] = make_raw(query.iterator())
def fetch_client_data(response: TableData, client_ids: set[int]) -> None:
query = Client.objects.filter(id__in=list(client_ids))
response["zerver_client"] = make_raw(list(query))
response["zerver_client"] = make_raw(query.iterator())
def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None:
@@ -1509,7 +1531,9 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
consented_user_ids = set()
user_profile_ids = {
r["id"] for r in response["zerver_userprofile"] + response["zerver_userprofile_mirrordummy"]
r["id"]
for r in list(response["zerver_userprofile"])
+ list(response["zerver_userprofile_mirrordummy"])
}
recipient_filter = Q()
@@ -1575,12 +1599,6 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
response["zerver_huddle"] = make_raw(
DirectMessageGroup.objects.filter(id__in=direct_message_group_ids)
)
if export_type == RealmExport.EXPORT_PUBLIC and any(
response[t] for t in ["_huddle_recipient", "_huddle_subscription", "zerver_huddle"]
):
raise AssertionError(
"Public export should not result in exporting any data in _huddle tables"
)
def custom_fetch_scheduled_messages(response: TableData, context: Context) -> None:
@@ -1591,7 +1609,7 @@ def custom_fetch_scheduled_messages(response: TableData, context: Context) -> No
exportable_scheduled_message_ids = context["exportable_scheduled_message_ids"]
query = ScheduledMessage.objects.filter(realm=realm, id__in=exportable_scheduled_message_ids)
rows = make_raw(list(query))
rows = make_raw(query)
response["zerver_scheduledmessage"] = rows
@@ -1654,14 +1672,15 @@ def custom_fetch_realm_audit_logs_for_realm(response: TableData, context: Contex
def custom_fetch_onboarding_usermessage(response: TableData, context: Context) -> None:
realm = context["realm"]
response["zerver_onboardingusermessage"] = []
onboarding = []
onboarding_usermessage_query = OnboardingUserMessage.objects.filter(realm=realm)
for onboarding_usermessage in onboarding_usermessage_query:
onboarding_usermessage_obj = model_to_dict(onboarding_usermessage)
onboarding_usermessage_obj["flags_mask"] = onboarding_usermessage.flags.mask
del onboarding_usermessage_obj["flags"]
response["zerver_onboardingusermessage"].append(onboarding_usermessage_obj)
onboarding.append(onboarding_usermessage_obj)
response["zerver_onboardingusermessage"] = onboarding
def fetch_usermessages(
@@ -1710,7 +1729,8 @@ def export_usermessages_batch(
with open(input_path, "rb") as input_file:
input_data: MessagePartial = orjson.loads(input_file.read())
message_ids = {item["id"] for item in input_data["zerver_message"]}
messages = list(input_data["zerver_message"])
message_ids = {item["id"] for item in messages}
user_profile_ids = set(input_data["zerver_userprofile_ids"])
realm = Realm.objects.get(id=input_data["realm_id"])
zerver_usermessage_data = fetch_usermessages(
@@ -1723,7 +1743,7 @@ def export_usermessages_batch(
)
output_data: TableData = dict(
zerver_message=input_data["zerver_message"],
zerver_message=messages,
zerver_usermessage=zerver_usermessage_data,
)
write_table_data(output_path, output_data)
@@ -1764,9 +1784,9 @@ def export_partial_message_files(
response["zerver_userprofile"],
)
ids_of_our_possible_senders = get_ids(
response["zerver_userprofile"]
+ response["zerver_userprofile_mirrordummy"]
+ response["zerver_userprofile_crossrealm"]
list(response["zerver_userprofile"])
+ list(response["zerver_userprofile_mirrordummy"])
+ list(response["zerver_userprofile_crossrealm"])
)
consented_user_ids: set[int] = set()
@@ -1913,7 +1933,9 @@ def write_message_partials(
for message_id_chunk in message_id_chunks:
# Uses index: zerver_message_pkey
actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id")
message_chunk = make_raw(actual_query)
message_chunk = [
floatify_datetime_fields(r, "zerver_message") for r in make_raw(actual_query.iterator())
]
for row in message_chunk:
collected_client_ids.add(row["sending_client"])
@@ -1923,16 +1945,11 @@ def write_message_partials(
message_filename += ".partial"
logging.info("Fetched messages for %s", message_filename)
# Clean up our messages.
table_data: TableData = {}
table_data["zerver_message"] = message_chunk
floatify_datetime_fields(table_data, "zerver_message")
# Build up our output for the .partial file, which needs
# a list of user_profile_ids to search for (as well as
# the realm id).
output: MessagePartial = dict(
zerver_message=table_data["zerver_message"],
zerver_message=message_chunk,
zerver_userprofile_ids=list(user_profile_ids),
realm_id=realm.id,
)
@@ -1945,7 +1962,7 @@ def write_message_partials(
def export_uploads_and_avatars(
realm: Realm,
*,
attachments: list[Attachment] | None = None,
attachments: Iterable[Attachment] | None = None,
user: UserProfile | None,
output_dir: Path,
) -> None:
@@ -1974,7 +1991,7 @@ def export_uploads_and_avatars(
else:
handle_system_bots = False
users = [user]
attachments = list(Attachment.objects.filter(owner_id=user.id))
attachments = list(Attachment.objects.filter(owner_id=user.id).order_by("path_id"))
realm_emojis = list(RealmEmoji.objects.filter(author_id=user.id))
if settings.LOCAL_UPLOADS_DIR:
@@ -2164,7 +2181,6 @@ def export_files_from_s3(
processing_emoji = flavor == "emoji"
bucket = get_bucket(bucket_name)
records = []
logging.info("Downloading %s files from %s", flavor, bucket_name)
@@ -2175,85 +2191,88 @@ def export_files_from_s3(
email_gateway_bot = get_system_bot(settings.EMAIL_GATEWAY_BOT, internal_realm.id)
user_ids.add(email_gateway_bot.id)
count = 0
for bkey in bucket.objects.filter(Prefix=object_prefix):
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
continue
def iterate_attachments() -> Iterator[Record]:
count = 0
for bkey in bucket.objects.filter(Prefix=object_prefix):
# This is promised to be iterated in sorted filename order.
key = bucket.Object(bkey.key)
"""
For very old realms we may not have proper metadata. If you really need
an export to bypass these checks, flip the following flag.
"""
checking_metadata = True
if checking_metadata:
if "realm_id" not in key.metadata:
raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}")
if "user_profile_id" not in key.metadata:
raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}")
if int(key.metadata["user_profile_id"]) not in user_ids:
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
continue
# This can happen if an email address has moved realms
if key.metadata["realm_id"] != str(realm.id):
if email_gateway_bot is None or key.metadata["user_profile_id"] != str(
email_gateway_bot.id
):
raise AssertionError(
f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}"
)
# Email gateway bot sends messages, potentially including attachments, cross-realm.
print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}")
key = bucket.Object(bkey.key)
record = _get_exported_s3_record(bucket_name, key, processing_emoji)
"""
For very old realms we may not have proper metadata. If you really need
an export to bypass these checks, flip the following flag.
"""
checking_metadata = True
if checking_metadata:
if "realm_id" not in key.metadata:
raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}")
record["path"] = key.key
_save_s3_object_to_file(key, output_dir, processing_uploads)
if "user_profile_id" not in key.metadata:
raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}")
records.append(record)
count += 1
if int(key.metadata["user_profile_id"]) not in user_ids:
continue
if count % 100 == 0:
logging.info("Finished %s", count)
# This can happen if an email address has moved realms
if key.metadata["realm_id"] != str(realm.id):
if email_gateway_bot is None or key.metadata["user_profile_id"] != str(
email_gateway_bot.id
):
raise AssertionError(
f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}"
)
# Email gateway bot sends messages, potentially including attachments, cross-realm.
print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}")
write_records_json_file(output_dir, records)
record = _get_exported_s3_record(bucket_name, key, processing_emoji)
record["path"] = key.key
_save_s3_object_to_file(key, output_dir, processing_uploads)
yield record
count += 1
if count % 100 == 0:
logging.info("Finished %s", count)
write_records_json_file(output_dir, iterate_attachments())
def export_uploads_from_local(
realm: Realm, local_dir: Path, output_dir: Path, attachments: list[Attachment]
realm: Realm, local_dir: Path, output_dir: Path, attachments: Iterable[Attachment]
) -> None:
records = []
for count, attachment in enumerate(attachments, 1):
# Use 'mark_sanitized' to work around false positive caused by Pysa
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
# are user controlled
path_id = mark_sanitized(attachment.path_id)
def iterate_attachments() -> Iterator[Record]:
for count, attachment in enumerate(attachments, 1):
# Use 'mark_sanitized' to work around false positive caused by Pysa
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
# are user controlled
path_id = mark_sanitized(attachment.path_id)
local_path = os.path.join(local_dir, path_id)
output_path = os.path.join(output_dir, path_id)
local_path = os.path.join(local_dir, path_id)
output_path = os.path.join(output_dir, path_id)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
shutil.copy2(local_path, output_path)
stat = os.stat(local_path)
record = dict(
realm_id=attachment.realm_id,
user_profile_id=attachment.owner.id,
user_profile_email=attachment.owner.email,
s3_path=path_id,
path=path_id,
size=stat.st_size,
last_modified=stat.st_mtime,
content_type=None,
)
records.append(record)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
shutil.copy2(local_path, output_path)
stat = os.stat(local_path)
record = dict(
realm_id=attachment.realm_id,
user_profile_id=attachment.owner.id,
user_profile_email=attachment.owner.email,
s3_path=path_id,
path=path_id,
size=stat.st_size,
last_modified=stat.st_mtime,
content_type=None,
)
yield record
if count % 100 == 0:
logging.info("Finished %s", count)
if count % 100 == 0:
logging.info("Finished %s", count)
write_records_json_file(output_dir, records)
write_records_json_file(output_dir, iterate_attachments())
def export_avatars_from_local(
@@ -2337,52 +2356,44 @@ def get_emoji_path(realm_emoji: RealmEmoji) -> str:
def export_emoji_from_local(
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: list[RealmEmoji]
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: Iterable[RealmEmoji]
) -> None:
records = []
def emoji_path_tuples() -> Iterator[tuple[RealmEmoji, str]]:
for realm_emoji in realm_emojis:
realm_emoji_path = mark_sanitized(get_emoji_path(realm_emoji))
realm_emoji_helper_tuples: list[tuple[RealmEmoji, str]] = []
for realm_emoji in realm_emojis:
realm_emoji_path = get_emoji_path(realm_emoji)
yield (realm_emoji, realm_emoji_path)
yield (realm_emoji, realm_emoji_path + ".original")
# Use 'mark_sanitized' to work around false positive caused by Pysa
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
# are user controlled
realm_emoji_path = mark_sanitized(realm_emoji_path)
def iterate_emoji(
realm_emoji_helper_tuples: Iterator[tuple[RealmEmoji, str]],
) -> Iterator[Record]:
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
realm_emoji_path_original = realm_emoji_path + ".original"
local_path = os.path.join(local_dir, emoji_path)
output_path = os.path.join(output_dir, emoji_path)
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path))
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path_original))
os.makedirs(os.path.dirname(output_path), exist_ok=True)
shutil.copy2(local_path, output_path)
# Realm emoji author is optional.
author = realm_emoji_object.author
author_id = author.id if author else None
record = dict(
realm_id=realm.id,
author=author_id,
path=emoji_path,
s3_path=emoji_path,
file_name=realm_emoji_object.file_name,
name=realm_emoji_object.name,
deactivated=realm_emoji_object.deactivated,
)
yield record
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
if count % 100 == 0:
logging.info("Finished %s", count)
local_path = os.path.join(local_dir, emoji_path)
output_path = os.path.join(output_dir, emoji_path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
shutil.copy2(local_path, output_path)
# Realm emoji author is optional.
author = realm_emoji_object.author
author_id = None
if author:
author_id = author.id
record = dict(
realm_id=realm.id,
author=author_id,
path=emoji_path,
s3_path=emoji_path,
file_name=realm_emoji_object.file_name,
name=realm_emoji_object.name,
deactivated=realm_emoji_object.deactivated,
)
records.append(record)
if count % 100 == 0:
logging.info("Finished %s", count)
write_records_json_file(output_dir, records)
write_records_json_file(output_dir, iterate_emoji(emoji_path_tuples()))
def do_write_stats_file_for_realm_export(output_dir: Path) -> dict[str, int | dict[str, int]]:
@@ -2502,17 +2513,13 @@ def do_export_realm(
)
logging.info("%d messages were exported", len(message_ids))
# zerver_reaction
zerver_reaction: TableData = {}
fetch_reaction_data(response=zerver_reaction, message_ids=message_ids)
response.update(zerver_reaction)
fetch_reaction_data(response=response, message_ids=message_ids)
zerver_client: TableData = {}
fetch_client_data(response=zerver_client, client_ids=collected_client_ids)
response.update(zerver_client)
fetch_client_data(response=response, client_ids=collected_client_ids)
# Override the "deactivated" flag on the realm
if export_as_active is not None:
assert isinstance(response["zerver_realm"], list)
response["zerver_realm"][0]["deactivated"] = not export_as_active
response["import_source"] = "zulip" # type: ignore[assignment] # this is an extra info field, not TableData
@@ -2619,6 +2626,8 @@ def do_export_user(user_profile: UserProfile, output_dir: Path) -> None:
export_file = os.path.join(output_dir, "user.json")
write_table_data(output_file=export_file, data=response)
# Double-check that this is a list, and not an iterator, so we can run over it again
assert isinstance(response["zerver_reaction"], list)
reaction_message_ids: set[int] = {row["message"] for row in response["zerver_reaction"]}
logging.info("Exporting messages")
@@ -2661,6 +2670,7 @@ def get_single_user_config() -> Config:
# Exports with consent are not relevant in the context of exporting
# a single user.
limit_to_consenting_users=False,
use_iterator=False,
)
# zerver_recipient
@@ -2669,6 +2679,7 @@ def get_single_user_config() -> Config:
model=Recipient,
virtual_parent=subscription_config,
id_source=("zerver_subscription", "recipient"),
use_iterator=False,
)
# zerver_stream
@@ -2713,6 +2724,7 @@ def get_single_user_config() -> Config:
normal_parent=user_profile_config,
include_rows="user_profile_id__in",
limit_to_consenting_users=False,
use_iterator=False,
)
add_user_profile_child_configs(user_profile_config)
@@ -2823,25 +2835,23 @@ def export_messages_single_user(
.order_by("message_id")
)
user_message_chunk = list(fat_query)
message_chunk = []
for user_message in user_message_chunk:
def process_row(user_message: UserMessage) -> Record:
item = model_to_dict(user_message.message)
item["flags"] = user_message.flags_list()
item["flags_mask"] = user_message.flags.mask
# Add a few nice, human-readable details
item["sending_client_name"] = user_message.message.sending_client.name
item["recipient_name"] = get_recipient(user_message.message.recipient_id)
message_chunk.append(item)
return floatify_datetime_fields(item, "zerver_message")
message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json")
write_table_data(
message_filename,
{
"zerver_message": (process_row(um) for um in fat_query.iterator()),
},
)
logging.info("Fetched messages for %s", message_filename)
output = {"zerver_message": message_chunk}
floatify_datetime_fields(output, "zerver_message")
write_table_data(message_filename, output)
dump_file_id += 1

View File

@@ -5,7 +5,7 @@ import shutil
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from difflib import unified_diff
from typing import Any
from typing import Any, TypeAlias
import bmemcached
import orjson
@@ -33,7 +33,7 @@ from zerver.actions.realm_settings import (
from zerver.actions.user_settings import do_change_avatar_fields
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
from zerver.lib.bulk_create import bulk_set_users_or_streams_recipient_fields
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableData, TableName
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableName
from zerver.lib.markdown import markdown_convert
from zerver.lib.markdown import version as markdown_version
from zerver.lib.message import get_last_message_id
@@ -118,6 +118,8 @@ from zerver.models.recipients import get_direct_message_group_hash
from zerver.models.users import get_system_bot, get_user_profile_by_id
from zproject.backends import AUTH_BACKEND_NAME_MAP
ImportedTableData: TypeAlias = dict[str, list[Record]]
realm_tables = [
("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"),
("zerver_defaultstream", DefaultStream, "defaultstream"),
@@ -208,7 +210,7 @@ message_id_to_attachments: dict[str, dict[int, list[str]]] = {
}
def map_messages_to_attachments(data: TableData) -> None:
def map_messages_to_attachments(data: ImportedTableData) -> None:
for attachment in data["zerver_attachment"]:
for message_id in attachment["messages"]:
message_id_to_attachments["zerver_message"][message_id].append(attachment["path_id"])
@@ -231,14 +233,14 @@ def update_id_map(table: TableName, old_id: int, new_id: int) -> None:
ID_MAP[table][old_id] = new_id
def fix_datetime_fields(data: TableData, table: TableName) -> None:
def fix_datetime_fields(data: ImportedTableData, table: TableName) -> None:
for item in data[table]:
for field_name in DATE_FIELDS[table]:
if item[field_name] is not None:
item[field_name] = datetime.fromtimestamp(item[field_name], tz=timezone.utc)
def fix_upload_links(data: TableData, message_table: TableName) -> None:
def fix_upload_links(data: ImportedTableData, message_table: TableName) -> None:
"""
Because the URLs for uploaded files encode the realm ID of the
organization being imported (which is only determined at import
@@ -260,7 +262,7 @@ def fix_upload_links(data: TableData, message_table: TableName) -> None:
def fix_stream_permission_group_settings(
data: TableData, system_groups_name_dict: dict[str, NamedUserGroup]
data: ImportedTableData, system_groups_name_dict: dict[str, NamedUserGroup]
) -> None:
table = get_db_table(Stream)
for stream in data[table]:
@@ -289,7 +291,7 @@ def fix_stream_permission_group_settings(
)
def create_subscription_events(data: TableData, realm_id: int) -> None:
def create_subscription_events(data: ImportedTableData, realm_id: int) -> None:
"""
When the export data doesn't contain the table `zerver_realmauditlog`,
this function creates RealmAuditLog objects for `subscription_created`
@@ -332,7 +334,7 @@ def create_subscription_events(data: TableData, realm_id: int) -> None:
RealmAuditLog.objects.bulk_create(all_subscription_logs)
def fix_service_tokens(data: TableData, table: TableName) -> None:
def fix_service_tokens(data: ImportedTableData, table: TableName) -> None:
"""
The tokens in the services are created by 'generate_api_key'.
As the tokens are unique, they should be re-created for the imports.
@@ -341,7 +343,7 @@ def fix_service_tokens(data: TableData, table: TableName) -> None:
item["token"] = generate_api_key()
def process_direct_message_group_hash(data: TableData, table: TableName) -> None:
def process_direct_message_group_hash(data: ImportedTableData, table: TableName) -> None:
"""
Build new direct message group hashes with the updated ids of the users
"""
@@ -350,7 +352,7 @@ def process_direct_message_group_hash(data: TableData, table: TableName) -> None
direct_message_group["huddle_hash"] = get_direct_message_group_hash(user_id_list)
def get_direct_message_groups_from_subscription(data: TableData, table: TableName) -> None:
def get_direct_message_groups_from_subscription(data: ImportedTableData, table: TableName) -> None:
"""
Extract the IDs of the user_profiles involved in a direct message group from
the subscription object
@@ -369,7 +371,7 @@ def get_direct_message_groups_from_subscription(data: TableData, table: TableNam
)
def fix_customprofilefield(data: TableData) -> None:
def fix_customprofilefield(data: ImportedTableData) -> None:
"""
In CustomProfileField with 'field_type' like 'USER', the IDs need to be
re-mapped.
@@ -530,7 +532,7 @@ def fix_message_edit_history(
message["edit_history"] = orjson.dumps(edit_history).decode()
def current_table_ids(data: TableData, table: TableName) -> list[int]:
def current_table_ids(data: ImportedTableData, table: TableName) -> list[int]:
"""
Returns the ids present in the current table
"""
@@ -560,7 +562,7 @@ def allocate_ids(model_class: Any, count: int) -> list[int]:
return [item[0] for item in query]
def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None:
def convert_to_id_fields(data: ImportedTableData, table: TableName, field_name: Field) -> None:
"""
When Django gives us dict objects via model_to_dict, the foreign
key fields are `foo`, but we want `foo_id` for the bulk insert.
@@ -574,7 +576,7 @@ def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -
def re_map_foreign_keys(
data: TableData,
data: ImportedTableData,
table: TableName,
field_name: Field,
related_table: TableName,
@@ -585,7 +587,7 @@ def re_map_foreign_keys(
"""
This is a wrapper function for all the realm data tables
and only avatar and attachment records need to be passed through the internal function
because of the difference in data format (TableData corresponding to realm data tables
because of the difference in data format (ImportedTableData corresponding to realm data tables
and List[Record] corresponding to the avatar and attachment records)
"""
@@ -653,7 +655,7 @@ def re_map_foreign_keys_internal(
item[field_name] = new_id
def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
def re_map_realm_emoji_codes(data: ImportedTableData, *, table_name: str) -> None:
"""
Some tables, including Reaction and UserStatus, contain a form of
foreign key reference to the RealmEmoji table in the form of
@@ -683,7 +685,7 @@ def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
def re_map_foreign_keys_many_to_many(
data: TableData,
data: ImportedTableData,
table: TableName,
field_name: Field,
related_table: TableName,
@@ -733,13 +735,13 @@ def re_map_foreign_keys_many_to_many_internal(
return new_id_list
def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None:
def fix_bitfield_keys(data: ImportedTableData, table: TableName, field_name: Field) -> None:
for item in data[table]:
item[field_name] = item[field_name + "_mask"]
del item[field_name + "_mask"]
def remove_denormalized_recipient_column_from_data(data: TableData) -> None:
def remove_denormalized_recipient_column_from_data(data: ImportedTableData) -> None:
"""
The recipient column shouldn't be imported, we'll set the correct values
when Recipient table gets imported.
@@ -762,7 +764,7 @@ def get_db_table(model_class: Any) -> str:
return model_class._meta.db_table
def update_model_ids(model: Any, data: TableData, related_table: TableName) -> None:
def update_model_ids(model: Any, data: ImportedTableData, related_table: TableName) -> None:
table = get_db_table(model)
# Important: remapping usermessage rows is
@@ -777,7 +779,7 @@ def update_model_ids(model: Any, data: TableData, related_table: TableName) -> N
re_map_foreign_keys(data, table, "id", related_table=related_table, id_field=True)
def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
def bulk_import_user_message_data(data: ImportedTableData, dump_file_id: int) -> None:
model = UserMessage
table = "zerver_usermessage"
lst = data[table]
@@ -810,7 +812,7 @@ def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = None) -> None:
def bulk_import_model(data: ImportedTableData, model: Any, dump_file_id: str | None = None) -> None:
table = get_db_table(model)
# TODO, deprecate dump_file_id
model.objects.bulk_create(model(**item) for item in data[table])
@@ -820,7 +822,7 @@ def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = No
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
def bulk_import_named_user_groups(data: TableData) -> None:
def bulk_import_named_user_groups(data: ImportedTableData) -> None:
vals = [
(
group["usergroup_ptr_id"],
@@ -854,7 +856,7 @@ def bulk_import_named_user_groups(data: TableData) -> None:
# correctly import multiple realms into the same server, we need to
# check if a Client object already exists, and so we need to support
# remap all Client IDs to the values in the new DB.
def bulk_import_client(data: TableData, model: Any, table: TableName) -> None:
def bulk_import_client(data: ImportedTableData, model: Any, table: TableName) -> None:
for item in data[table]:
try:
client = Client.objects.get(name=item["name"])
@@ -880,7 +882,7 @@ def set_subscriber_count_for_channels(realm: Realm) -> None:
def fix_subscriptions_is_user_active_column(
data: TableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
data: ImportedTableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
) -> None:
table = get_db_table(Subscription)
user_id_to_active_status = {user.id: user.is_active for user in user_profiles}
@@ -1156,7 +1158,7 @@ def import_uploads(
future.result()
def disable_restricted_authentication_methods(data: TableData) -> None:
def disable_restricted_authentication_methods(data: ImportedTableData) -> None:
"""
Should run only with settings.BILLING_ENABLED. Ensures that we only
enable authentication methods that are available without needing a plan.
@@ -2105,7 +2107,7 @@ def import_message_data(realm: Realm, sender_map: dict[int, Record], import_dir:
dump_file_id += 1
def import_attachments(data: TableData) -> None:
def import_attachments(data: ImportedTableData) -> None:
# Clean up the data in zerver_attachment that is not
# relevant to our many-to-many import.
fix_datetime_fields(data, "zerver_attachment")
@@ -2146,7 +2148,7 @@ def import_attachments(data: TableData) -> None:
]
# Create our table data for insert.
m2m_data: TableData = {m2m_table_name: m2m_rows}
m2m_data: ImportedTableData = {m2m_table_name: m2m_rows}
convert_to_id_fields(m2m_data, m2m_table_name, parent_singular)
convert_to_id_fields(m2m_data, m2m_table_name, child_singular)
m2m_rows = m2m_data[m2m_table_name]
@@ -2193,7 +2195,7 @@ def import_attachments(data: TableData) -> None:
logging.info("Successfully imported M2M table %s", m2m_table_name)
def fix_attachments_data(attachment_data: TableData) -> None:
def fix_attachments_data(attachment_data: ImportedTableData) -> None:
for attachment in attachment_data["zerver_attachment"]:
attachment["path_id"] = path_maps["old_attachment_path_to_new_path"][attachment["path_id"]]
@@ -2205,7 +2207,7 @@ def fix_attachments_data(attachment_data: TableData) -> None:
attachment["content_type"] = guessed_content_type
def create_image_attachments(realm: Realm, attachment_data: TableData) -> None:
def create_image_attachments(realm: Realm, attachment_data: ImportedTableData) -> None:
for attachment in attachment_data["zerver_attachment"]:
if attachment["content_type"] not in THUMBNAIL_ACCEPT_IMAGE_TYPES:
continue

View File

@@ -849,11 +849,13 @@ class RealmImportExportTest(ExportFile):
# Consented users:
hamlet = self.example_user("hamlet")
othello = self.example_user("othello")
cordelia = self.example_user("cordelia")
# Iago will be non-consenting.
iago = self.example_user("iago")
do_change_user_setting(hamlet, "allow_private_data_export", True, acting_user=None)
do_change_user_setting(othello, "allow_private_data_export", True, acting_user=None)
do_change_user_setting(cordelia, "allow_private_data_export", True, acting_user=None)
do_change_user_setting(iago, "allow_private_data_export", False, acting_user=None)
# Despite both hamlet and othello having consent enabled, in a public export
@@ -865,6 +867,9 @@ class RealmImportExportTest(ExportFile):
a_message.sending_client = private_client
a_message.save()
# Verify that a group DM between consenting users is not exported
self.send_group_direct_message(hamlet, [othello, cordelia])
# SavedSnippets are private content - so in a public export, despite
# hamlet having consent enabled, such objects should not be exported.
saved_snippet = do_create_saved_snippet("test", "test", hamlet)
@@ -888,6 +893,9 @@ class RealmImportExportTest(ExportFile):
exported_user_presence_ids = self.get_set(realm_data["zerver_userpresence"], "id")
self.assertIn(iago_presence.id, exported_user_presence_ids)
exported_huddle_ids = self.get_set(realm_data["zerver_huddle"], "id")
self.assertEqual(exported_huddle_ids, set())
def test_export_realm_with_member_consent(self) -> None:
realm = Realm.objects.get(string_id="zulip")