mirror of
https://github.com/zulip/zulip.git
synced 2025-10-23 04:52:12 +00:00
export: Move all queries, when possible, to iterators.
This reduces overall memory usage for large exports.
This commit is contained in:
committed by
Tim Abbott
parent
0ffc0e810c
commit
755cb7d854
@@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||
from datetime import datetime
|
||||
from email.headerregistry import Address
|
||||
from functools import cache
|
||||
from itertools import islice
|
||||
from itertools import chain, islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, TypeVar, cast
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
@@ -27,7 +27,7 @@ import orjson
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
from django.db.models import Exists, Model, OuterRef, Q
|
||||
from django.db.models import Exists, Model, OuterRef, Q, QuerySet
|
||||
from django.forms.models import model_to_dict
|
||||
from django.utils.timezone import is_naive as timezone_is_naive
|
||||
from django.utils.timezone import now as timezone_now
|
||||
@@ -98,7 +98,7 @@ if TYPE_CHECKING:
|
||||
# Custom mypy types follow:
|
||||
Record: TypeAlias = dict[str, Any]
|
||||
TableName = str
|
||||
TableData: TypeAlias = dict[TableName, list[Record]]
|
||||
TableData: TypeAlias = dict[TableName, Iterator[Record] | list[Record]]
|
||||
Field = str
|
||||
Path = str
|
||||
Context: TypeAlias = dict[str, Any]
|
||||
@@ -108,11 +108,11 @@ SourceFilter: TypeAlias = Callable[[Record], bool]
|
||||
|
||||
CustomFetch: TypeAlias = Callable[[TableData, Context], None]
|
||||
CustomReturnIds: TypeAlias = Callable[[TableData], set[int]]
|
||||
CustomProcessResults: TypeAlias = Callable[[list[Record], Context], list[Record]]
|
||||
CustomProcessResults: TypeAlias = Callable[[Iterable[Record], Context], Iterator[Record]]
|
||||
|
||||
|
||||
class MessagePartial(TypedDict):
|
||||
zerver_message: list[Record]
|
||||
zerver_message: Iterable[Record]
|
||||
zerver_userprofile_ids: list[int]
|
||||
realm_id: int
|
||||
|
||||
@@ -516,12 +516,11 @@ def write_records_json_file(output_dir: str, records: Iterable[dict[str, Any]])
|
||||
logging.info("Finished writing %s", output_file)
|
||||
|
||||
|
||||
def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
|
||||
def make_raw(query: Iterable[Any], exclude: list[Field] | None = None) -> Iterator[Record]:
|
||||
"""
|
||||
Takes a Django query and returns a JSONable list
|
||||
of dictionaries corresponding to the database rows.
|
||||
"""
|
||||
rows = []
|
||||
for instance in query:
|
||||
data = model_to_dict(instance, exclude=exclude)
|
||||
"""
|
||||
@@ -536,20 +535,19 @@ def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
|
||||
value = data[field.name]
|
||||
data[field.name] = [row.id for row in value]
|
||||
|
||||
rows.append(data)
|
||||
|
||||
return rows
|
||||
yield data
|
||||
|
||||
|
||||
def floatify_datetime_fields(data: TableData, table: TableName) -> None:
|
||||
for item in data[table]:
|
||||
for field in DATE_FIELDS[table]:
|
||||
dt = item[field]
|
||||
if dt is None:
|
||||
continue
|
||||
assert isinstance(dt, datetime)
|
||||
assert not timezone_is_naive(dt)
|
||||
item[field] = dt.timestamp()
|
||||
def floatify_datetime_fields(item: Record, table: TableName) -> Record:
|
||||
updates = {}
|
||||
for field in DATE_FIELDS[table]:
|
||||
dt = item[field]
|
||||
if dt is None:
|
||||
continue
|
||||
assert isinstance(dt, datetime)
|
||||
assert not timezone_is_naive(dt)
|
||||
updates[field] = dt.timestamp()
|
||||
return {**item, **updates}
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -587,6 +585,7 @@ class Config:
|
||||
exclude: list[Field] | None = None,
|
||||
limit_to_consenting_users: bool | None = None,
|
||||
collect_client_ids: bool = False,
|
||||
use_iterator: bool = True,
|
||||
) -> None:
|
||||
assert table or custom_tables
|
||||
self.table = table
|
||||
@@ -698,11 +697,17 @@ class Config:
|
||||
assert include_rows in ["user_profile_id__in", "user_id__in", "bot_profile_id__in"]
|
||||
assert normal_parent is not None and normal_parent.table == "zerver_userprofile"
|
||||
|
||||
if self.collect_client_ids or self.is_seeded:
|
||||
self.use_iterator = False
|
||||
else:
|
||||
self.use_iterator = use_iterator
|
||||
|
||||
def return_ids(self, response: TableData) -> set[int]:
|
||||
if self.custom_return_ids is not None:
|
||||
return self.custom_return_ids(response)
|
||||
else:
|
||||
assert self.table is not None
|
||||
assert not self.use_iterator, self.table
|
||||
return {row["id"] for row in response[self.table]}
|
||||
|
||||
|
||||
@@ -730,7 +735,8 @@ def export_from_config(
|
||||
for t in exported_tables:
|
||||
logging.info("Exporting via export_from_config: %s", t)
|
||||
|
||||
rows = None
|
||||
rows: Iterable[Any] | None = None
|
||||
query: QuerySet[Any] | None = None
|
||||
if config.is_seeded:
|
||||
rows = [seed_object]
|
||||
|
||||
@@ -748,13 +754,13 @@ def export_from_config(
|
||||
# When we concat_and_destroy, we are working with
|
||||
# temporary "tables" that are lists of records that
|
||||
# should already be ready to export.
|
||||
data: list[Record] = []
|
||||
for t in config.concat_and_destroy:
|
||||
data += response[t]
|
||||
del response[t]
|
||||
logging.info("Deleted temporary %s", t)
|
||||
assert table is not None
|
||||
response[table] = data
|
||||
# We pop them off of the response and store them in a local
|
||||
# which the iterable closes over
|
||||
tables = {t: response.pop(t) for t in config.concat_and_destroy}
|
||||
response[table] = chain.from_iterable(tables[t] for t in config.concat_and_destroy)
|
||||
for t in config.concat_and_destroy:
|
||||
logging.info("Deleted temporary %s", t)
|
||||
|
||||
elif config.normal_parent:
|
||||
# In this mode, our current model is figuratively Article,
|
||||
@@ -818,7 +824,7 @@ def export_from_config(
|
||||
|
||||
assert model is not None
|
||||
try:
|
||||
query = model.objects.filter(**filter_params)
|
||||
query = model.objects.filter(**filter_params).order_by("id")
|
||||
except Exception:
|
||||
print(
|
||||
f"""
|
||||
@@ -833,8 +839,6 @@ def export_from_config(
|
||||
)
|
||||
raise
|
||||
|
||||
rows = list(query)
|
||||
|
||||
elif config.id_source:
|
||||
# In this mode, we are the figurative Blog, and we now
|
||||
# need to look at the current response to get all the
|
||||
@@ -843,6 +847,8 @@ def export_from_config(
|
||||
assert model is not None
|
||||
# This will be a tuple of the form ('zerver_article', 'blog').
|
||||
(child_table, field) = config.id_source
|
||||
assert config.virtual_parent is not None
|
||||
assert not config.virtual_parent.use_iterator
|
||||
child_rows = response[child_table]
|
||||
if config.source_filter:
|
||||
child_rows = [r for r in child_rows if config.source_filter(r)]
|
||||
@@ -851,12 +857,17 @@ def export_from_config(
|
||||
if config.filter_args:
|
||||
filter_params.update(config.filter_args)
|
||||
query = model.objects.filter(**filter_params)
|
||||
rows = list(query)
|
||||
|
||||
if query is not None:
|
||||
rows = query.iterator()
|
||||
|
||||
if rows is not None:
|
||||
assert table is not None # Hint for mypy
|
||||
response[table] = make_raw(rows, exclude=config.exclude)
|
||||
if config.collect_client_ids and "collected_client_ids_set" in context:
|
||||
# If we need to collect the client-ids, we can't just stream the results
|
||||
response[table] = list(response[table])
|
||||
|
||||
model = cast(type[Model], model)
|
||||
assert issubclass(model, Model)
|
||||
client_id_field_name = get_fk_field_name(model, Client)
|
||||
@@ -873,8 +884,10 @@ def export_from_config(
|
||||
# of the exported data for the tables - e.g. to strip out private data.
|
||||
response[t] = custom_process_results(response[t], context)
|
||||
if t in DATE_FIELDS:
|
||||
floatify_datetime_fields(response, t)
|
||||
response[t] = (floatify_datetime_fields(r, t) for r in response[t])
|
||||
|
||||
if not config.use_iterator:
|
||||
response[t] = list(response[t])
|
||||
# Now walk our children. It's extremely important to respect
|
||||
# the order of children here.
|
||||
for child_config in config.children:
|
||||
@@ -998,6 +1011,7 @@ def get_realm_config() -> Config:
|
||||
table="zerver_userprofile",
|
||||
virtual_parent=realm_config,
|
||||
custom_fetch=custom_fetch_user_profile,
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
user_groups_config = Config(
|
||||
@@ -1006,6 +1020,7 @@ def get_realm_config() -> Config:
|
||||
normal_parent=realm_config,
|
||||
include_rows="realm_id__in",
|
||||
exclude=["direct_members", "direct_subgroups"],
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
Config(
|
||||
@@ -1078,6 +1093,7 @@ def get_realm_config() -> Config:
|
||||
# It is just "glue" data for internal data model consistency purposes
|
||||
# with no user-specific information.
|
||||
limit_to_consenting_users=False,
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
Config(
|
||||
@@ -1092,6 +1108,7 @@ def get_realm_config() -> Config:
|
||||
model=Stream,
|
||||
normal_parent=realm_config,
|
||||
include_rows="realm_id__in",
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
stream_recipient_config = Config(
|
||||
@@ -1100,6 +1117,7 @@ def get_realm_config() -> Config:
|
||||
normal_parent=stream_config,
|
||||
include_rows="type_id__in",
|
||||
filter_args={"type": Recipient.STREAM},
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
Config(
|
||||
@@ -1131,6 +1149,7 @@ def get_realm_config() -> Config:
|
||||
"_stream_recipient",
|
||||
"_huddle_recipient",
|
||||
],
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
Config(
|
||||
@@ -1157,11 +1176,12 @@ def get_realm_config() -> Config:
|
||||
|
||||
|
||||
def custom_process_subscription_in_realm_config(
|
||||
subscriptions: list[Record], context: Context
|
||||
) -> list[Record]:
|
||||
subscriptions: Iterable[Record], context: Context
|
||||
) -> Iterator[Record]:
|
||||
export_type = context["export_type"]
|
||||
if export_type == RealmExport.EXPORT_FULL_WITHOUT_CONSENT:
|
||||
return subscriptions
|
||||
yield from subscriptions
|
||||
return
|
||||
|
||||
exportable_user_ids_from_context = context["exportable_user_ids"]
|
||||
if export_type == RealmExport.EXPORT_FULL_WITH_CONSENT:
|
||||
@@ -1172,9 +1192,10 @@ def custom_process_subscription_in_realm_config(
|
||||
assert exportable_user_ids_from_context is None
|
||||
consented_user_ids = set()
|
||||
|
||||
def scrub_subscription_if_needed(subscription: Record) -> Record:
|
||||
for subscription in subscriptions:
|
||||
if subscription["user_profile"] in consented_user_ids:
|
||||
return subscription
|
||||
yield subscription
|
||||
continue
|
||||
# We create a replacement Subscription, setting only the essential fields,
|
||||
# while allowing all the other ones to fall back to the defaults
|
||||
# defined in the model.
|
||||
@@ -1190,10 +1211,7 @@ def custom_process_subscription_in_realm_config(
|
||||
color=random.choice(STREAM_ASSIGNMENT_COLORS),
|
||||
)
|
||||
subscription_dict = model_to_dict(scrubbed_subscription)
|
||||
return subscription_dict
|
||||
|
||||
processed_rows = map(scrub_subscription_if_needed, subscriptions)
|
||||
return list(processed_rows)
|
||||
yield subscription_dict
|
||||
|
||||
|
||||
def add_user_profile_child_configs(user_profile_config: Config) -> None:
|
||||
@@ -1410,7 +1428,7 @@ def custom_fetch_user_profile(response: TableData, context: Context) -> None:
|
||||
|
||||
def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) -> None:
|
||||
realm = context["realm"]
|
||||
response["zerver_userprofile_crossrealm"] = []
|
||||
crossrealm_bots = []
|
||||
|
||||
bot_name_to_default_email = {
|
||||
"NOTIFICATION_BOT": "notification-bot@zulip.com",
|
||||
@@ -1432,13 +1450,14 @@ def custom_fetch_user_profile_cross_realm(response: TableData, context: Context)
|
||||
bot_user_id = get_system_bot(bot_email, internal_realm.id).id
|
||||
|
||||
recipient_id = Recipient.objects.get(type_id=bot_user_id, type=Recipient.PERSONAL).id
|
||||
response["zerver_userprofile_crossrealm"].append(
|
||||
crossrealm_bots.append(
|
||||
dict(
|
||||
email=bot_default_email,
|
||||
id=bot_user_id,
|
||||
recipient_id=recipient_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
response["zerver_userprofile_crossrealm"] = crossrealm_bots
|
||||
|
||||
|
||||
def fetch_attachment_data(
|
||||
@@ -1448,10 +1467,21 @@ def fetch_attachment_data(
|
||||
Attachment.objects.filter(
|
||||
Q(messages__in=message_ids) | Q(scheduled_messages__in=scheduled_message_ids),
|
||||
realm_id=realm_id,
|
||||
).distinct()
|
||||
)
|
||||
.distinct("path_id")
|
||||
.order_by("path_id")
|
||||
)
|
||||
response["zerver_attachment"] = make_raw(attachments)
|
||||
floatify_datetime_fields(response, "zerver_attachment")
|
||||
|
||||
def postprocess_attachment(row: Record) -> Record:
|
||||
row = floatify_datetime_fields(row, "zerver_attachment")
|
||||
filtered_message_ids = set(row["messages"]).intersection(message_ids)
|
||||
row["messages"] = sorted(filtered_message_ids)
|
||||
|
||||
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
|
||||
scheduled_message_ids
|
||||
)
|
||||
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
|
||||
return row
|
||||
|
||||
"""
|
||||
We usually export most messages for the realm, but not
|
||||
@@ -1461,14 +1491,7 @@ def fetch_attachment_data(
|
||||
|
||||
Same reasoning applies to scheduled_messages.
|
||||
"""
|
||||
for row in response["zerver_attachment"]:
|
||||
filtered_message_ids = set(row["messages"]).intersection(message_ids)
|
||||
row["messages"] = sorted(filtered_message_ids)
|
||||
|
||||
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
|
||||
scheduled_message_ids
|
||||
)
|
||||
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
|
||||
response["zerver_attachment"] = (postprocess_attachment(r) for r in make_raw(attachments))
|
||||
|
||||
return attachments
|
||||
|
||||
@@ -1480,18 +1503,17 @@ def custom_fetch_realm_audit_logs_for_user(response: TableData, context: Context
|
||||
"""
|
||||
user = context["user"]
|
||||
query = RealmAuditLog.objects.filter(Q(modified_user_id=user.id) | Q(acting_user_id=user.id))
|
||||
rows = make_raw(list(query))
|
||||
response["zerver_realmauditlog"] = rows
|
||||
response["zerver_realmauditlog"] = make_raw(query.iterator())
|
||||
|
||||
|
||||
def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None:
|
||||
query = Reaction.objects.filter(message_id__in=list(message_ids))
|
||||
response["zerver_reaction"] = make_raw(list(query))
|
||||
response["zerver_reaction"] = make_raw(query.iterator())
|
||||
|
||||
|
||||
def fetch_client_data(response: TableData, client_ids: set[int]) -> None:
|
||||
query = Client.objects.filter(id__in=list(client_ids))
|
||||
response["zerver_client"] = make_raw(list(query))
|
||||
response["zerver_client"] = make_raw(query.iterator())
|
||||
|
||||
|
||||
def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None:
|
||||
@@ -1509,7 +1531,9 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
|
||||
consented_user_ids = set()
|
||||
|
||||
user_profile_ids = {
|
||||
r["id"] for r in response["zerver_userprofile"] + response["zerver_userprofile_mirrordummy"]
|
||||
r["id"]
|
||||
for r in list(response["zerver_userprofile"])
|
||||
+ list(response["zerver_userprofile_mirrordummy"])
|
||||
}
|
||||
|
||||
recipient_filter = Q()
|
||||
@@ -1575,12 +1599,6 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
|
||||
response["zerver_huddle"] = make_raw(
|
||||
DirectMessageGroup.objects.filter(id__in=direct_message_group_ids)
|
||||
)
|
||||
if export_type == RealmExport.EXPORT_PUBLIC and any(
|
||||
response[t] for t in ["_huddle_recipient", "_huddle_subscription", "zerver_huddle"]
|
||||
):
|
||||
raise AssertionError(
|
||||
"Public export should not result in exporting any data in _huddle tables"
|
||||
)
|
||||
|
||||
|
||||
def custom_fetch_scheduled_messages(response: TableData, context: Context) -> None:
|
||||
@@ -1591,7 +1609,7 @@ def custom_fetch_scheduled_messages(response: TableData, context: Context) -> No
|
||||
exportable_scheduled_message_ids = context["exportable_scheduled_message_ids"]
|
||||
|
||||
query = ScheduledMessage.objects.filter(realm=realm, id__in=exportable_scheduled_message_ids)
|
||||
rows = make_raw(list(query))
|
||||
rows = make_raw(query)
|
||||
|
||||
response["zerver_scheduledmessage"] = rows
|
||||
|
||||
@@ -1654,14 +1672,15 @@ def custom_fetch_realm_audit_logs_for_realm(response: TableData, context: Contex
|
||||
|
||||
def custom_fetch_onboarding_usermessage(response: TableData, context: Context) -> None:
|
||||
realm = context["realm"]
|
||||
response["zerver_onboardingusermessage"] = []
|
||||
onboarding = []
|
||||
|
||||
onboarding_usermessage_query = OnboardingUserMessage.objects.filter(realm=realm)
|
||||
for onboarding_usermessage in onboarding_usermessage_query:
|
||||
onboarding_usermessage_obj = model_to_dict(onboarding_usermessage)
|
||||
onboarding_usermessage_obj["flags_mask"] = onboarding_usermessage.flags.mask
|
||||
del onboarding_usermessage_obj["flags"]
|
||||
response["zerver_onboardingusermessage"].append(onboarding_usermessage_obj)
|
||||
onboarding.append(onboarding_usermessage_obj)
|
||||
response["zerver_onboardingusermessage"] = onboarding
|
||||
|
||||
|
||||
def fetch_usermessages(
|
||||
@@ -1710,7 +1729,8 @@ def export_usermessages_batch(
|
||||
with open(input_path, "rb") as input_file:
|
||||
input_data: MessagePartial = orjson.loads(input_file.read())
|
||||
|
||||
message_ids = {item["id"] for item in input_data["zerver_message"]}
|
||||
messages = list(input_data["zerver_message"])
|
||||
message_ids = {item["id"] for item in messages}
|
||||
user_profile_ids = set(input_data["zerver_userprofile_ids"])
|
||||
realm = Realm.objects.get(id=input_data["realm_id"])
|
||||
zerver_usermessage_data = fetch_usermessages(
|
||||
@@ -1723,7 +1743,7 @@ def export_usermessages_batch(
|
||||
)
|
||||
|
||||
output_data: TableData = dict(
|
||||
zerver_message=input_data["zerver_message"],
|
||||
zerver_message=messages,
|
||||
zerver_usermessage=zerver_usermessage_data,
|
||||
)
|
||||
write_table_data(output_path, output_data)
|
||||
@@ -1764,9 +1784,9 @@ def export_partial_message_files(
|
||||
response["zerver_userprofile"],
|
||||
)
|
||||
ids_of_our_possible_senders = get_ids(
|
||||
response["zerver_userprofile"]
|
||||
+ response["zerver_userprofile_mirrordummy"]
|
||||
+ response["zerver_userprofile_crossrealm"]
|
||||
list(response["zerver_userprofile"])
|
||||
+ list(response["zerver_userprofile_mirrordummy"])
|
||||
+ list(response["zerver_userprofile_crossrealm"])
|
||||
)
|
||||
|
||||
consented_user_ids: set[int] = set()
|
||||
@@ -1913,7 +1933,9 @@ def write_message_partials(
|
||||
for message_id_chunk in message_id_chunks:
|
||||
# Uses index: zerver_message_pkey
|
||||
actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id")
|
||||
message_chunk = make_raw(actual_query)
|
||||
message_chunk = [
|
||||
floatify_datetime_fields(r, "zerver_message") for r in make_raw(actual_query.iterator())
|
||||
]
|
||||
|
||||
for row in message_chunk:
|
||||
collected_client_ids.add(row["sending_client"])
|
||||
@@ -1923,16 +1945,11 @@ def write_message_partials(
|
||||
message_filename += ".partial"
|
||||
logging.info("Fetched messages for %s", message_filename)
|
||||
|
||||
# Clean up our messages.
|
||||
table_data: TableData = {}
|
||||
table_data["zerver_message"] = message_chunk
|
||||
floatify_datetime_fields(table_data, "zerver_message")
|
||||
|
||||
# Build up our output for the .partial file, which needs
|
||||
# a list of user_profile_ids to search for (as well as
|
||||
# the realm id).
|
||||
output: MessagePartial = dict(
|
||||
zerver_message=table_data["zerver_message"],
|
||||
zerver_message=message_chunk,
|
||||
zerver_userprofile_ids=list(user_profile_ids),
|
||||
realm_id=realm.id,
|
||||
)
|
||||
@@ -1945,7 +1962,7 @@ def write_message_partials(
|
||||
def export_uploads_and_avatars(
|
||||
realm: Realm,
|
||||
*,
|
||||
attachments: list[Attachment] | None = None,
|
||||
attachments: Iterable[Attachment] | None = None,
|
||||
user: UserProfile | None,
|
||||
output_dir: Path,
|
||||
) -> None:
|
||||
@@ -1974,7 +1991,7 @@ def export_uploads_and_avatars(
|
||||
else:
|
||||
handle_system_bots = False
|
||||
users = [user]
|
||||
attachments = list(Attachment.objects.filter(owner_id=user.id))
|
||||
attachments = list(Attachment.objects.filter(owner_id=user.id).order_by("path_id"))
|
||||
realm_emojis = list(RealmEmoji.objects.filter(author_id=user.id))
|
||||
|
||||
if settings.LOCAL_UPLOADS_DIR:
|
||||
@@ -2164,7 +2181,6 @@ def export_files_from_s3(
|
||||
processing_emoji = flavor == "emoji"
|
||||
|
||||
bucket = get_bucket(bucket_name)
|
||||
records = []
|
||||
|
||||
logging.info("Downloading %s files from %s", flavor, bucket_name)
|
||||
|
||||
@@ -2175,85 +2191,88 @@ def export_files_from_s3(
|
||||
email_gateway_bot = get_system_bot(settings.EMAIL_GATEWAY_BOT, internal_realm.id)
|
||||
user_ids.add(email_gateway_bot.id)
|
||||
|
||||
count = 0
|
||||
for bkey in bucket.objects.filter(Prefix=object_prefix):
|
||||
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
|
||||
continue
|
||||
def iterate_attachments() -> Iterator[Record]:
|
||||
count = 0
|
||||
for bkey in bucket.objects.filter(Prefix=object_prefix):
|
||||
# This is promised to be iterated in sorted filename order.
|
||||
|
||||
key = bucket.Object(bkey.key)
|
||||
|
||||
"""
|
||||
For very old realms we may not have proper metadata. If you really need
|
||||
an export to bypass these checks, flip the following flag.
|
||||
"""
|
||||
checking_metadata = True
|
||||
if checking_metadata:
|
||||
if "realm_id" not in key.metadata:
|
||||
raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}")
|
||||
|
||||
if "user_profile_id" not in key.metadata:
|
||||
raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}")
|
||||
|
||||
if int(key.metadata["user_profile_id"]) not in user_ids:
|
||||
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
|
||||
continue
|
||||
|
||||
# This can happen if an email address has moved realms
|
||||
if key.metadata["realm_id"] != str(realm.id):
|
||||
if email_gateway_bot is None or key.metadata["user_profile_id"] != str(
|
||||
email_gateway_bot.id
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}"
|
||||
)
|
||||
# Email gateway bot sends messages, potentially including attachments, cross-realm.
|
||||
print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}")
|
||||
key = bucket.Object(bkey.key)
|
||||
|
||||
record = _get_exported_s3_record(bucket_name, key, processing_emoji)
|
||||
"""
|
||||
For very old realms we may not have proper metadata. If you really need
|
||||
an export to bypass these checks, flip the following flag.
|
||||
"""
|
||||
checking_metadata = True
|
||||
if checking_metadata:
|
||||
if "realm_id" not in key.metadata:
|
||||
raise AssertionError(f"Missing realm_id in key metadata: {key.metadata}")
|
||||
|
||||
record["path"] = key.key
|
||||
_save_s3_object_to_file(key, output_dir, processing_uploads)
|
||||
if "user_profile_id" not in key.metadata:
|
||||
raise AssertionError(f"Missing user_profile_id in key metadata: {key.metadata}")
|
||||
|
||||
records.append(record)
|
||||
count += 1
|
||||
if int(key.metadata["user_profile_id"]) not in user_ids:
|
||||
continue
|
||||
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
# This can happen if an email address has moved realms
|
||||
if key.metadata["realm_id"] != str(realm.id):
|
||||
if email_gateway_bot is None or key.metadata["user_profile_id"] != str(
|
||||
email_gateway_bot.id
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Key metadata problem: {key.key} / {key.metadata} / {realm.id}"
|
||||
)
|
||||
# Email gateway bot sends messages, potentially including attachments, cross-realm.
|
||||
print(f"File uploaded by email gateway bot: {key.key} / {key.metadata}")
|
||||
|
||||
write_records_json_file(output_dir, records)
|
||||
record = _get_exported_s3_record(bucket_name, key, processing_emoji)
|
||||
|
||||
record["path"] = key.key
|
||||
_save_s3_object_to_file(key, output_dir, processing_uploads)
|
||||
|
||||
yield record
|
||||
count += 1
|
||||
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
|
||||
write_records_json_file(output_dir, iterate_attachments())
|
||||
|
||||
|
||||
def export_uploads_from_local(
|
||||
realm: Realm, local_dir: Path, output_dir: Path, attachments: list[Attachment]
|
||||
realm: Realm, local_dir: Path, output_dir: Path, attachments: Iterable[Attachment]
|
||||
) -> None:
|
||||
records = []
|
||||
for count, attachment in enumerate(attachments, 1):
|
||||
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
||||
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
||||
# are user controlled
|
||||
path_id = mark_sanitized(attachment.path_id)
|
||||
def iterate_attachments() -> Iterator[Record]:
|
||||
for count, attachment in enumerate(attachments, 1):
|
||||
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
||||
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
||||
# are user controlled
|
||||
path_id = mark_sanitized(attachment.path_id)
|
||||
|
||||
local_path = os.path.join(local_dir, path_id)
|
||||
output_path = os.path.join(output_dir, path_id)
|
||||
local_path = os.path.join(local_dir, path_id)
|
||||
output_path = os.path.join(output_dir, path_id)
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
shutil.copy2(local_path, output_path)
|
||||
stat = os.stat(local_path)
|
||||
record = dict(
|
||||
realm_id=attachment.realm_id,
|
||||
user_profile_id=attachment.owner.id,
|
||||
user_profile_email=attachment.owner.email,
|
||||
s3_path=path_id,
|
||||
path=path_id,
|
||||
size=stat.st_size,
|
||||
last_modified=stat.st_mtime,
|
||||
content_type=None,
|
||||
)
|
||||
records.append(record)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
shutil.copy2(local_path, output_path)
|
||||
stat = os.stat(local_path)
|
||||
record = dict(
|
||||
realm_id=attachment.realm_id,
|
||||
user_profile_id=attachment.owner.id,
|
||||
user_profile_email=attachment.owner.email,
|
||||
s3_path=path_id,
|
||||
path=path_id,
|
||||
size=stat.st_size,
|
||||
last_modified=stat.st_mtime,
|
||||
content_type=None,
|
||||
)
|
||||
yield record
|
||||
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
|
||||
write_records_json_file(output_dir, records)
|
||||
write_records_json_file(output_dir, iterate_attachments())
|
||||
|
||||
|
||||
def export_avatars_from_local(
|
||||
@@ -2337,52 +2356,44 @@ def get_emoji_path(realm_emoji: RealmEmoji) -> str:
|
||||
|
||||
|
||||
def export_emoji_from_local(
|
||||
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: list[RealmEmoji]
|
||||
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: Iterable[RealmEmoji]
|
||||
) -> None:
|
||||
records = []
|
||||
def emoji_path_tuples() -> Iterator[tuple[RealmEmoji, str]]:
|
||||
for realm_emoji in realm_emojis:
|
||||
realm_emoji_path = mark_sanitized(get_emoji_path(realm_emoji))
|
||||
|
||||
realm_emoji_helper_tuples: list[tuple[RealmEmoji, str]] = []
|
||||
for realm_emoji in realm_emojis:
|
||||
realm_emoji_path = get_emoji_path(realm_emoji)
|
||||
yield (realm_emoji, realm_emoji_path)
|
||||
yield (realm_emoji, realm_emoji_path + ".original")
|
||||
|
||||
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
||||
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
||||
# are user controlled
|
||||
realm_emoji_path = mark_sanitized(realm_emoji_path)
|
||||
def iterate_emoji(
|
||||
realm_emoji_helper_tuples: Iterator[tuple[RealmEmoji, str]],
|
||||
) -> Iterator[Record]:
|
||||
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
|
||||
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
|
||||
|
||||
realm_emoji_path_original = realm_emoji_path + ".original"
|
||||
local_path = os.path.join(local_dir, emoji_path)
|
||||
output_path = os.path.join(output_dir, emoji_path)
|
||||
|
||||
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path))
|
||||
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path_original))
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
shutil.copy2(local_path, output_path)
|
||||
# Realm emoji author is optional.
|
||||
author = realm_emoji_object.author
|
||||
author_id = author.id if author else None
|
||||
record = dict(
|
||||
realm_id=realm.id,
|
||||
author=author_id,
|
||||
path=emoji_path,
|
||||
s3_path=emoji_path,
|
||||
file_name=realm_emoji_object.file_name,
|
||||
name=realm_emoji_object.name,
|
||||
deactivated=realm_emoji_object.deactivated,
|
||||
)
|
||||
yield record
|
||||
|
||||
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
|
||||
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
|
||||
local_path = os.path.join(local_dir, emoji_path)
|
||||
output_path = os.path.join(output_dir, emoji_path)
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
shutil.copy2(local_path, output_path)
|
||||
# Realm emoji author is optional.
|
||||
author = realm_emoji_object.author
|
||||
author_id = None
|
||||
if author:
|
||||
author_id = author.id
|
||||
record = dict(
|
||||
realm_id=realm.id,
|
||||
author=author_id,
|
||||
path=emoji_path,
|
||||
s3_path=emoji_path,
|
||||
file_name=realm_emoji_object.file_name,
|
||||
name=realm_emoji_object.name,
|
||||
deactivated=realm_emoji_object.deactivated,
|
||||
)
|
||||
records.append(record)
|
||||
|
||||
if count % 100 == 0:
|
||||
logging.info("Finished %s", count)
|
||||
|
||||
write_records_json_file(output_dir, records)
|
||||
write_records_json_file(output_dir, iterate_emoji(emoji_path_tuples()))
|
||||
|
||||
|
||||
def do_write_stats_file_for_realm_export(output_dir: Path) -> dict[str, int | dict[str, int]]:
|
||||
@@ -2502,17 +2513,13 @@ def do_export_realm(
|
||||
)
|
||||
logging.info("%d messages were exported", len(message_ids))
|
||||
|
||||
# zerver_reaction
|
||||
zerver_reaction: TableData = {}
|
||||
fetch_reaction_data(response=zerver_reaction, message_ids=message_ids)
|
||||
response.update(zerver_reaction)
|
||||
fetch_reaction_data(response=response, message_ids=message_ids)
|
||||
|
||||
zerver_client: TableData = {}
|
||||
fetch_client_data(response=zerver_client, client_ids=collected_client_ids)
|
||||
response.update(zerver_client)
|
||||
fetch_client_data(response=response, client_ids=collected_client_ids)
|
||||
|
||||
# Override the "deactivated" flag on the realm
|
||||
if export_as_active is not None:
|
||||
assert isinstance(response["zerver_realm"], list)
|
||||
response["zerver_realm"][0]["deactivated"] = not export_as_active
|
||||
|
||||
response["import_source"] = "zulip" # type: ignore[assignment] # this is an extra info field, not TableData
|
||||
@@ -2619,6 +2626,8 @@ def do_export_user(user_profile: UserProfile, output_dir: Path) -> None:
|
||||
export_file = os.path.join(output_dir, "user.json")
|
||||
write_table_data(output_file=export_file, data=response)
|
||||
|
||||
# Double-check that this is a list, and not an iterator, so we can run over it again
|
||||
assert isinstance(response["zerver_reaction"], list)
|
||||
reaction_message_ids: set[int] = {row["message"] for row in response["zerver_reaction"]}
|
||||
|
||||
logging.info("Exporting messages")
|
||||
@@ -2661,6 +2670,7 @@ def get_single_user_config() -> Config:
|
||||
# Exports with consent are not relevant in the context of exporting
|
||||
# a single user.
|
||||
limit_to_consenting_users=False,
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
# zerver_recipient
|
||||
@@ -2669,6 +2679,7 @@ def get_single_user_config() -> Config:
|
||||
model=Recipient,
|
||||
virtual_parent=subscription_config,
|
||||
id_source=("zerver_subscription", "recipient"),
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
# zerver_stream
|
||||
@@ -2713,6 +2724,7 @@ def get_single_user_config() -> Config:
|
||||
normal_parent=user_profile_config,
|
||||
include_rows="user_profile_id__in",
|
||||
limit_to_consenting_users=False,
|
||||
use_iterator=False,
|
||||
)
|
||||
|
||||
add_user_profile_child_configs(user_profile_config)
|
||||
@@ -2823,25 +2835,23 @@ def export_messages_single_user(
|
||||
.order_by("message_id")
|
||||
)
|
||||
|
||||
user_message_chunk = list(fat_query)
|
||||
|
||||
message_chunk = []
|
||||
for user_message in user_message_chunk:
|
||||
def process_row(user_message: UserMessage) -> Record:
|
||||
item = model_to_dict(user_message.message)
|
||||
item["flags"] = user_message.flags_list()
|
||||
item["flags_mask"] = user_message.flags.mask
|
||||
# Add a few nice, human-readable details
|
||||
item["sending_client_name"] = user_message.message.sending_client.name
|
||||
item["recipient_name"] = get_recipient(user_message.message.recipient_id)
|
||||
message_chunk.append(item)
|
||||
return floatify_datetime_fields(item, "zerver_message")
|
||||
|
||||
message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json")
|
||||
write_table_data(
|
||||
message_filename,
|
||||
{
|
||||
"zerver_message": (process_row(um) for um in fat_query.iterator()),
|
||||
},
|
||||
)
|
||||
logging.info("Fetched messages for %s", message_filename)
|
||||
|
||||
output = {"zerver_message": message_chunk}
|
||||
floatify_datetime_fields(output, "zerver_message")
|
||||
|
||||
write_table_data(message_filename, output)
|
||||
dump_file_id += 1
|
||||
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import shutil
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from difflib import unified_diff
|
||||
from typing import Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import bmemcached
|
||||
import orjson
|
||||
@@ -33,7 +33,7 @@ from zerver.actions.realm_settings import (
|
||||
from zerver.actions.user_settings import do_change_avatar_fields
|
||||
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
|
||||
from zerver.lib.bulk_create import bulk_set_users_or_streams_recipient_fields
|
||||
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableData, TableName
|
||||
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableName
|
||||
from zerver.lib.markdown import markdown_convert
|
||||
from zerver.lib.markdown import version as markdown_version
|
||||
from zerver.lib.message import get_last_message_id
|
||||
@@ -118,6 +118,8 @@ from zerver.models.recipients import get_direct_message_group_hash
|
||||
from zerver.models.users import get_system_bot, get_user_profile_by_id
|
||||
from zproject.backends import AUTH_BACKEND_NAME_MAP
|
||||
|
||||
ImportedTableData: TypeAlias = dict[str, list[Record]]
|
||||
|
||||
realm_tables = [
|
||||
("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"),
|
||||
("zerver_defaultstream", DefaultStream, "defaultstream"),
|
||||
@@ -208,7 +210,7 @@ message_id_to_attachments: dict[str, dict[int, list[str]]] = {
|
||||
}
|
||||
|
||||
|
||||
def map_messages_to_attachments(data: TableData) -> None:
|
||||
def map_messages_to_attachments(data: ImportedTableData) -> None:
|
||||
for attachment in data["zerver_attachment"]:
|
||||
for message_id in attachment["messages"]:
|
||||
message_id_to_attachments["zerver_message"][message_id].append(attachment["path_id"])
|
||||
@@ -231,14 +233,14 @@ def update_id_map(table: TableName, old_id: int, new_id: int) -> None:
|
||||
ID_MAP[table][old_id] = new_id
|
||||
|
||||
|
||||
def fix_datetime_fields(data: TableData, table: TableName) -> None:
|
||||
def fix_datetime_fields(data: ImportedTableData, table: TableName) -> None:
|
||||
for item in data[table]:
|
||||
for field_name in DATE_FIELDS[table]:
|
||||
if item[field_name] is not None:
|
||||
item[field_name] = datetime.fromtimestamp(item[field_name], tz=timezone.utc)
|
||||
|
||||
|
||||
def fix_upload_links(data: TableData, message_table: TableName) -> None:
|
||||
def fix_upload_links(data: ImportedTableData, message_table: TableName) -> None:
|
||||
"""
|
||||
Because the URLs for uploaded files encode the realm ID of the
|
||||
organization being imported (which is only determined at import
|
||||
@@ -260,7 +262,7 @@ def fix_upload_links(data: TableData, message_table: TableName) -> None:
|
||||
|
||||
|
||||
def fix_stream_permission_group_settings(
|
||||
data: TableData, system_groups_name_dict: dict[str, NamedUserGroup]
|
||||
data: ImportedTableData, system_groups_name_dict: dict[str, NamedUserGroup]
|
||||
) -> None:
|
||||
table = get_db_table(Stream)
|
||||
for stream in data[table]:
|
||||
@@ -289,7 +291,7 @@ def fix_stream_permission_group_settings(
|
||||
)
|
||||
|
||||
|
||||
def create_subscription_events(data: TableData, realm_id: int) -> None:
|
||||
def create_subscription_events(data: ImportedTableData, realm_id: int) -> None:
|
||||
"""
|
||||
When the export data doesn't contain the table `zerver_realmauditlog`,
|
||||
this function creates RealmAuditLog objects for `subscription_created`
|
||||
@@ -332,7 +334,7 @@ def create_subscription_events(data: TableData, realm_id: int) -> None:
|
||||
RealmAuditLog.objects.bulk_create(all_subscription_logs)
|
||||
|
||||
|
||||
def fix_service_tokens(data: TableData, table: TableName) -> None:
|
||||
def fix_service_tokens(data: ImportedTableData, table: TableName) -> None:
|
||||
"""
|
||||
The tokens in the services are created by 'generate_api_key'.
|
||||
As the tokens are unique, they should be re-created for the imports.
|
||||
@@ -341,7 +343,7 @@ def fix_service_tokens(data: TableData, table: TableName) -> None:
|
||||
item["token"] = generate_api_key()
|
||||
|
||||
|
||||
def process_direct_message_group_hash(data: TableData, table: TableName) -> None:
|
||||
def process_direct_message_group_hash(data: ImportedTableData, table: TableName) -> None:
|
||||
"""
|
||||
Build new direct message group hashes with the updated ids of the users
|
||||
"""
|
||||
@@ -350,7 +352,7 @@ def process_direct_message_group_hash(data: TableData, table: TableName) -> None
|
||||
direct_message_group["huddle_hash"] = get_direct_message_group_hash(user_id_list)
|
||||
|
||||
|
||||
def get_direct_message_groups_from_subscription(data: TableData, table: TableName) -> None:
|
||||
def get_direct_message_groups_from_subscription(data: ImportedTableData, table: TableName) -> None:
|
||||
"""
|
||||
Extract the IDs of the user_profiles involved in a direct message group from
|
||||
the subscription object
|
||||
@@ -369,7 +371,7 @@ def get_direct_message_groups_from_subscription(data: TableData, table: TableNam
|
||||
)
|
||||
|
||||
|
||||
def fix_customprofilefield(data: TableData) -> None:
|
||||
def fix_customprofilefield(data: ImportedTableData) -> None:
|
||||
"""
|
||||
In CustomProfileField with 'field_type' like 'USER', the IDs need to be
|
||||
re-mapped.
|
||||
@@ -530,7 +532,7 @@ def fix_message_edit_history(
|
||||
message["edit_history"] = orjson.dumps(edit_history).decode()
|
||||
|
||||
|
||||
def current_table_ids(data: TableData, table: TableName) -> list[int]:
|
||||
def current_table_ids(data: ImportedTableData, table: TableName) -> list[int]:
|
||||
"""
|
||||
Returns the ids present in the current table
|
||||
"""
|
||||
@@ -560,7 +562,7 @@ def allocate_ids(model_class: Any, count: int) -> list[int]:
|
||||
return [item[0] for item in query]
|
||||
|
||||
|
||||
def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None:
|
||||
def convert_to_id_fields(data: ImportedTableData, table: TableName, field_name: Field) -> None:
|
||||
"""
|
||||
When Django gives us dict objects via model_to_dict, the foreign
|
||||
key fields are `foo`, but we want `foo_id` for the bulk insert.
|
||||
@@ -574,7 +576,7 @@ def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -
|
||||
|
||||
|
||||
def re_map_foreign_keys(
|
||||
data: TableData,
|
||||
data: ImportedTableData,
|
||||
table: TableName,
|
||||
field_name: Field,
|
||||
related_table: TableName,
|
||||
@@ -585,7 +587,7 @@ def re_map_foreign_keys(
|
||||
"""
|
||||
This is a wrapper function for all the realm data tables
|
||||
and only avatar and attachment records need to be passed through the internal function
|
||||
because of the difference in data format (TableData corresponding to realm data tables
|
||||
because of the difference in data format (ImportedTableData corresponding to realm data tables
|
||||
and List[Record] corresponding to the avatar and attachment records)
|
||||
"""
|
||||
|
||||
@@ -653,7 +655,7 @@ def re_map_foreign_keys_internal(
|
||||
item[field_name] = new_id
|
||||
|
||||
|
||||
def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
|
||||
def re_map_realm_emoji_codes(data: ImportedTableData, *, table_name: str) -> None:
|
||||
"""
|
||||
Some tables, including Reaction and UserStatus, contain a form of
|
||||
foreign key reference to the RealmEmoji table in the form of
|
||||
@@ -683,7 +685,7 @@ def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
|
||||
|
||||
|
||||
def re_map_foreign_keys_many_to_many(
|
||||
data: TableData,
|
||||
data: ImportedTableData,
|
||||
table: TableName,
|
||||
field_name: Field,
|
||||
related_table: TableName,
|
||||
@@ -733,13 +735,13 @@ def re_map_foreign_keys_many_to_many_internal(
|
||||
return new_id_list
|
||||
|
||||
|
||||
def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None:
|
||||
def fix_bitfield_keys(data: ImportedTableData, table: TableName, field_name: Field) -> None:
|
||||
for item in data[table]:
|
||||
item[field_name] = item[field_name + "_mask"]
|
||||
del item[field_name + "_mask"]
|
||||
|
||||
|
||||
def remove_denormalized_recipient_column_from_data(data: TableData) -> None:
|
||||
def remove_denormalized_recipient_column_from_data(data: ImportedTableData) -> None:
|
||||
"""
|
||||
The recipient column shouldn't be imported, we'll set the correct values
|
||||
when Recipient table gets imported.
|
||||
@@ -762,7 +764,7 @@ def get_db_table(model_class: Any) -> str:
|
||||
return model_class._meta.db_table
|
||||
|
||||
|
||||
def update_model_ids(model: Any, data: TableData, related_table: TableName) -> None:
|
||||
def update_model_ids(model: Any, data: ImportedTableData, related_table: TableName) -> None:
|
||||
table = get_db_table(model)
|
||||
|
||||
# Important: remapping usermessage rows is
|
||||
@@ -777,7 +779,7 @@ def update_model_ids(model: Any, data: TableData, related_table: TableName) -> N
|
||||
re_map_foreign_keys(data, table, "id", related_table=related_table, id_field=True)
|
||||
|
||||
|
||||
def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
|
||||
def bulk_import_user_message_data(data: ImportedTableData, dump_file_id: int) -> None:
|
||||
model = UserMessage
|
||||
table = "zerver_usermessage"
|
||||
lst = data[table]
|
||||
@@ -810,7 +812,7 @@ def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
|
||||
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
||||
|
||||
|
||||
def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = None) -> None:
|
||||
def bulk_import_model(data: ImportedTableData, model: Any, dump_file_id: str | None = None) -> None:
|
||||
table = get_db_table(model)
|
||||
# TODO, deprecate dump_file_id
|
||||
model.objects.bulk_create(model(**item) for item in data[table])
|
||||
@@ -820,7 +822,7 @@ def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = No
|
||||
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
||||
|
||||
|
||||
def bulk_import_named_user_groups(data: TableData) -> None:
|
||||
def bulk_import_named_user_groups(data: ImportedTableData) -> None:
|
||||
vals = [
|
||||
(
|
||||
group["usergroup_ptr_id"],
|
||||
@@ -854,7 +856,7 @@ def bulk_import_named_user_groups(data: TableData) -> None:
|
||||
# correctly import multiple realms into the same server, we need to
|
||||
# check if a Client object already exists, and so we need to support
|
||||
# remap all Client IDs to the values in the new DB.
|
||||
def bulk_import_client(data: TableData, model: Any, table: TableName) -> None:
|
||||
def bulk_import_client(data: ImportedTableData, model: Any, table: TableName) -> None:
|
||||
for item in data[table]:
|
||||
try:
|
||||
client = Client.objects.get(name=item["name"])
|
||||
@@ -880,7 +882,7 @@ def set_subscriber_count_for_channels(realm: Realm) -> None:
|
||||
|
||||
|
||||
def fix_subscriptions_is_user_active_column(
|
||||
data: TableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
|
||||
data: ImportedTableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
|
||||
) -> None:
|
||||
table = get_db_table(Subscription)
|
||||
user_id_to_active_status = {user.id: user.is_active for user in user_profiles}
|
||||
@@ -1156,7 +1158,7 @@ def import_uploads(
|
||||
future.result()
|
||||
|
||||
|
||||
def disable_restricted_authentication_methods(data: TableData) -> None:
|
||||
def disable_restricted_authentication_methods(data: ImportedTableData) -> None:
|
||||
"""
|
||||
Should run only with settings.BILLING_ENABLED. Ensures that we only
|
||||
enable authentication methods that are available without needing a plan.
|
||||
@@ -2105,7 +2107,7 @@ def import_message_data(realm: Realm, sender_map: dict[int, Record], import_dir:
|
||||
dump_file_id += 1
|
||||
|
||||
|
||||
def import_attachments(data: TableData) -> None:
|
||||
def import_attachments(data: ImportedTableData) -> None:
|
||||
# Clean up the data in zerver_attachment that is not
|
||||
# relevant to our many-to-many import.
|
||||
fix_datetime_fields(data, "zerver_attachment")
|
||||
@@ -2146,7 +2148,7 @@ def import_attachments(data: TableData) -> None:
|
||||
]
|
||||
|
||||
# Create our table data for insert.
|
||||
m2m_data: TableData = {m2m_table_name: m2m_rows}
|
||||
m2m_data: ImportedTableData = {m2m_table_name: m2m_rows}
|
||||
convert_to_id_fields(m2m_data, m2m_table_name, parent_singular)
|
||||
convert_to_id_fields(m2m_data, m2m_table_name, child_singular)
|
||||
m2m_rows = m2m_data[m2m_table_name]
|
||||
@@ -2193,7 +2195,7 @@ def import_attachments(data: TableData) -> None:
|
||||
logging.info("Successfully imported M2M table %s", m2m_table_name)
|
||||
|
||||
|
||||
def fix_attachments_data(attachment_data: TableData) -> None:
|
||||
def fix_attachments_data(attachment_data: ImportedTableData) -> None:
|
||||
for attachment in attachment_data["zerver_attachment"]:
|
||||
attachment["path_id"] = path_maps["old_attachment_path_to_new_path"][attachment["path_id"]]
|
||||
|
||||
@@ -2205,7 +2207,7 @@ def fix_attachments_data(attachment_data: TableData) -> None:
|
||||
attachment["content_type"] = guessed_content_type
|
||||
|
||||
|
||||
def create_image_attachments(realm: Realm, attachment_data: TableData) -> None:
|
||||
def create_image_attachments(realm: Realm, attachment_data: ImportedTableData) -> None:
|
||||
for attachment in attachment_data["zerver_attachment"]:
|
||||
if attachment["content_type"] not in THUMBNAIL_ACCEPT_IMAGE_TYPES:
|
||||
continue
|
||||
|
@@ -849,11 +849,13 @@ class RealmImportExportTest(ExportFile):
|
||||
# Consented users:
|
||||
hamlet = self.example_user("hamlet")
|
||||
othello = self.example_user("othello")
|
||||
cordelia = self.example_user("cordelia")
|
||||
# Iago will be non-consenting.
|
||||
iago = self.example_user("iago")
|
||||
|
||||
do_change_user_setting(hamlet, "allow_private_data_export", True, acting_user=None)
|
||||
do_change_user_setting(othello, "allow_private_data_export", True, acting_user=None)
|
||||
do_change_user_setting(cordelia, "allow_private_data_export", True, acting_user=None)
|
||||
do_change_user_setting(iago, "allow_private_data_export", False, acting_user=None)
|
||||
|
||||
# Despite both hamlet and othello having consent enabled, in a public export
|
||||
@@ -865,6 +867,9 @@ class RealmImportExportTest(ExportFile):
|
||||
a_message.sending_client = private_client
|
||||
a_message.save()
|
||||
|
||||
# Verify that a group DM between consenting users is not exported
|
||||
self.send_group_direct_message(hamlet, [othello, cordelia])
|
||||
|
||||
# SavedSnippets are private content - so in a public export, despite
|
||||
# hamlet having consent enabled, such objects should not be exported.
|
||||
saved_snippet = do_create_saved_snippet("test", "test", hamlet)
|
||||
@@ -888,6 +893,9 @@ class RealmImportExportTest(ExportFile):
|
||||
exported_user_presence_ids = self.get_set(realm_data["zerver_userpresence"], "id")
|
||||
self.assertIn(iago_presence.id, exported_user_presence_ids)
|
||||
|
||||
exported_huddle_ids = self.get_set(realm_data["zerver_huddle"], "id")
|
||||
self.assertEqual(exported_huddle_ids, set())
|
||||
|
||||
def test_export_realm_with_member_consent(self) -> None:
|
||||
realm = Realm.objects.get(string_id="zulip")
|
||||
|
||||
|
Reference in New Issue
Block a user