mirror of
https://github.com/zulip/zulip.git
synced 2025-10-29 11:03:54 +00:00
export: Move all queries, when possible, to iterators.
This reduces overall memory usage for large exports.
This commit is contained in:
committed by
Tim Abbott
parent
0ffc0e810c
commit
755cb7d854
@@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from email.headerregistry import Address
|
from email.headerregistry import Address
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from itertools import islice
|
from itertools import chain, islice
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypedDict, TypeVar, cast
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ import orjson
|
|||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from django.db.models import Exists, Model, OuterRef, Q
|
from django.db.models import Exists, Model, OuterRef, Q, QuerySet
|
||||||
from django.forms.models import model_to_dict
|
from django.forms.models import model_to_dict
|
||||||
from django.utils.timezone import is_naive as timezone_is_naive
|
from django.utils.timezone import is_naive as timezone_is_naive
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
@@ -98,7 +98,7 @@ if TYPE_CHECKING:
|
|||||||
# Custom mypy types follow:
|
# Custom mypy types follow:
|
||||||
Record: TypeAlias = dict[str, Any]
|
Record: TypeAlias = dict[str, Any]
|
||||||
TableName = str
|
TableName = str
|
||||||
TableData: TypeAlias = dict[TableName, list[Record]]
|
TableData: TypeAlias = dict[TableName, Iterator[Record] | list[Record]]
|
||||||
Field = str
|
Field = str
|
||||||
Path = str
|
Path = str
|
||||||
Context: TypeAlias = dict[str, Any]
|
Context: TypeAlias = dict[str, Any]
|
||||||
@@ -108,11 +108,11 @@ SourceFilter: TypeAlias = Callable[[Record], bool]
|
|||||||
|
|
||||||
CustomFetch: TypeAlias = Callable[[TableData, Context], None]
|
CustomFetch: TypeAlias = Callable[[TableData, Context], None]
|
||||||
CustomReturnIds: TypeAlias = Callable[[TableData], set[int]]
|
CustomReturnIds: TypeAlias = Callable[[TableData], set[int]]
|
||||||
CustomProcessResults: TypeAlias = Callable[[list[Record], Context], list[Record]]
|
CustomProcessResults: TypeAlias = Callable[[Iterable[Record], Context], Iterator[Record]]
|
||||||
|
|
||||||
|
|
||||||
class MessagePartial(TypedDict):
|
class MessagePartial(TypedDict):
|
||||||
zerver_message: list[Record]
|
zerver_message: Iterable[Record]
|
||||||
zerver_userprofile_ids: list[int]
|
zerver_userprofile_ids: list[int]
|
||||||
realm_id: int
|
realm_id: int
|
||||||
|
|
||||||
@@ -516,12 +516,11 @@ def write_records_json_file(output_dir: str, records: Iterable[dict[str, Any]])
|
|||||||
logging.info("Finished writing %s", output_file)
|
logging.info("Finished writing %s", output_file)
|
||||||
|
|
||||||
|
|
||||||
def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
|
def make_raw(query: Iterable[Any], exclude: list[Field] | None = None) -> Iterator[Record]:
|
||||||
"""
|
"""
|
||||||
Takes a Django query and returns a JSONable list
|
Takes a Django query and returns a JSONable list
|
||||||
of dictionaries corresponding to the database rows.
|
of dictionaries corresponding to the database rows.
|
||||||
"""
|
"""
|
||||||
rows = []
|
|
||||||
for instance in query:
|
for instance in query:
|
||||||
data = model_to_dict(instance, exclude=exclude)
|
data = model_to_dict(instance, exclude=exclude)
|
||||||
"""
|
"""
|
||||||
@@ -536,20 +535,19 @@ def make_raw(query: Any, exclude: list[Field] | None = None) -> list[Record]:
|
|||||||
value = data[field.name]
|
value = data[field.name]
|
||||||
data[field.name] = [row.id for row in value]
|
data[field.name] = [row.id for row in value]
|
||||||
|
|
||||||
rows.append(data)
|
yield data
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
|
|
||||||
def floatify_datetime_fields(data: TableData, table: TableName) -> None:
|
def floatify_datetime_fields(item: Record, table: TableName) -> Record:
|
||||||
for item in data[table]:
|
updates = {}
|
||||||
for field in DATE_FIELDS[table]:
|
for field in DATE_FIELDS[table]:
|
||||||
dt = item[field]
|
dt = item[field]
|
||||||
if dt is None:
|
if dt is None:
|
||||||
continue
|
continue
|
||||||
assert isinstance(dt, datetime)
|
assert isinstance(dt, datetime)
|
||||||
assert not timezone_is_naive(dt)
|
assert not timezone_is_naive(dt)
|
||||||
item[field] = dt.timestamp()
|
updates[field] = dt.timestamp()
|
||||||
|
return {**item, **updates}
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -587,6 +585,7 @@ class Config:
|
|||||||
exclude: list[Field] | None = None,
|
exclude: list[Field] | None = None,
|
||||||
limit_to_consenting_users: bool | None = None,
|
limit_to_consenting_users: bool | None = None,
|
||||||
collect_client_ids: bool = False,
|
collect_client_ids: bool = False,
|
||||||
|
use_iterator: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert table or custom_tables
|
assert table or custom_tables
|
||||||
self.table = table
|
self.table = table
|
||||||
@@ -698,11 +697,17 @@ class Config:
|
|||||||
assert include_rows in ["user_profile_id__in", "user_id__in", "bot_profile_id__in"]
|
assert include_rows in ["user_profile_id__in", "user_id__in", "bot_profile_id__in"]
|
||||||
assert normal_parent is not None and normal_parent.table == "zerver_userprofile"
|
assert normal_parent is not None and normal_parent.table == "zerver_userprofile"
|
||||||
|
|
||||||
|
if self.collect_client_ids or self.is_seeded:
|
||||||
|
self.use_iterator = False
|
||||||
|
else:
|
||||||
|
self.use_iterator = use_iterator
|
||||||
|
|
||||||
def return_ids(self, response: TableData) -> set[int]:
|
def return_ids(self, response: TableData) -> set[int]:
|
||||||
if self.custom_return_ids is not None:
|
if self.custom_return_ids is not None:
|
||||||
return self.custom_return_ids(response)
|
return self.custom_return_ids(response)
|
||||||
else:
|
else:
|
||||||
assert self.table is not None
|
assert self.table is not None
|
||||||
|
assert not self.use_iterator, self.table
|
||||||
return {row["id"] for row in response[self.table]}
|
return {row["id"] for row in response[self.table]}
|
||||||
|
|
||||||
|
|
||||||
@@ -730,7 +735,8 @@ def export_from_config(
|
|||||||
for t in exported_tables:
|
for t in exported_tables:
|
||||||
logging.info("Exporting via export_from_config: %s", t)
|
logging.info("Exporting via export_from_config: %s", t)
|
||||||
|
|
||||||
rows = None
|
rows: Iterable[Any] | None = None
|
||||||
|
query: QuerySet[Any] | None = None
|
||||||
if config.is_seeded:
|
if config.is_seeded:
|
||||||
rows = [seed_object]
|
rows = [seed_object]
|
||||||
|
|
||||||
@@ -748,13 +754,13 @@ def export_from_config(
|
|||||||
# When we concat_and_destroy, we are working with
|
# When we concat_and_destroy, we are working with
|
||||||
# temporary "tables" that are lists of records that
|
# temporary "tables" that are lists of records that
|
||||||
# should already be ready to export.
|
# should already be ready to export.
|
||||||
data: list[Record] = []
|
|
||||||
for t in config.concat_and_destroy:
|
|
||||||
data += response[t]
|
|
||||||
del response[t]
|
|
||||||
logging.info("Deleted temporary %s", t)
|
|
||||||
assert table is not None
|
assert table is not None
|
||||||
response[table] = data
|
# We pop them off of the response and store them in a local
|
||||||
|
# which the iterable closes over
|
||||||
|
tables = {t: response.pop(t) for t in config.concat_and_destroy}
|
||||||
|
response[table] = chain.from_iterable(tables[t] for t in config.concat_and_destroy)
|
||||||
|
for t in config.concat_and_destroy:
|
||||||
|
logging.info("Deleted temporary %s", t)
|
||||||
|
|
||||||
elif config.normal_parent:
|
elif config.normal_parent:
|
||||||
# In this mode, our current model is figuratively Article,
|
# In this mode, our current model is figuratively Article,
|
||||||
@@ -818,7 +824,7 @@ def export_from_config(
|
|||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
try:
|
try:
|
||||||
query = model.objects.filter(**filter_params)
|
query = model.objects.filter(**filter_params).order_by("id")
|
||||||
except Exception:
|
except Exception:
|
||||||
print(
|
print(
|
||||||
f"""
|
f"""
|
||||||
@@ -833,8 +839,6 @@ def export_from_config(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
rows = list(query)
|
|
||||||
|
|
||||||
elif config.id_source:
|
elif config.id_source:
|
||||||
# In this mode, we are the figurative Blog, and we now
|
# In this mode, we are the figurative Blog, and we now
|
||||||
# need to look at the current response to get all the
|
# need to look at the current response to get all the
|
||||||
@@ -843,6 +847,8 @@ def export_from_config(
|
|||||||
assert model is not None
|
assert model is not None
|
||||||
# This will be a tuple of the form ('zerver_article', 'blog').
|
# This will be a tuple of the form ('zerver_article', 'blog').
|
||||||
(child_table, field) = config.id_source
|
(child_table, field) = config.id_source
|
||||||
|
assert config.virtual_parent is not None
|
||||||
|
assert not config.virtual_parent.use_iterator
|
||||||
child_rows = response[child_table]
|
child_rows = response[child_table]
|
||||||
if config.source_filter:
|
if config.source_filter:
|
||||||
child_rows = [r for r in child_rows if config.source_filter(r)]
|
child_rows = [r for r in child_rows if config.source_filter(r)]
|
||||||
@@ -851,12 +857,17 @@ def export_from_config(
|
|||||||
if config.filter_args:
|
if config.filter_args:
|
||||||
filter_params.update(config.filter_args)
|
filter_params.update(config.filter_args)
|
||||||
query = model.objects.filter(**filter_params)
|
query = model.objects.filter(**filter_params)
|
||||||
rows = list(query)
|
|
||||||
|
if query is not None:
|
||||||
|
rows = query.iterator()
|
||||||
|
|
||||||
if rows is not None:
|
if rows is not None:
|
||||||
assert table is not None # Hint for mypy
|
assert table is not None # Hint for mypy
|
||||||
response[table] = make_raw(rows, exclude=config.exclude)
|
response[table] = make_raw(rows, exclude=config.exclude)
|
||||||
if config.collect_client_ids and "collected_client_ids_set" in context:
|
if config.collect_client_ids and "collected_client_ids_set" in context:
|
||||||
|
# If we need to collect the client-ids, we can't just stream the results
|
||||||
|
response[table] = list(response[table])
|
||||||
|
|
||||||
model = cast(type[Model], model)
|
model = cast(type[Model], model)
|
||||||
assert issubclass(model, Model)
|
assert issubclass(model, Model)
|
||||||
client_id_field_name = get_fk_field_name(model, Client)
|
client_id_field_name = get_fk_field_name(model, Client)
|
||||||
@@ -873,8 +884,10 @@ def export_from_config(
|
|||||||
# of the exported data for the tables - e.g. to strip out private data.
|
# of the exported data for the tables - e.g. to strip out private data.
|
||||||
response[t] = custom_process_results(response[t], context)
|
response[t] = custom_process_results(response[t], context)
|
||||||
if t in DATE_FIELDS:
|
if t in DATE_FIELDS:
|
||||||
floatify_datetime_fields(response, t)
|
response[t] = (floatify_datetime_fields(r, t) for r in response[t])
|
||||||
|
|
||||||
|
if not config.use_iterator:
|
||||||
|
response[t] = list(response[t])
|
||||||
# Now walk our children. It's extremely important to respect
|
# Now walk our children. It's extremely important to respect
|
||||||
# the order of children here.
|
# the order of children here.
|
||||||
for child_config in config.children:
|
for child_config in config.children:
|
||||||
@@ -998,6 +1011,7 @@ def get_realm_config() -> Config:
|
|||||||
table="zerver_userprofile",
|
table="zerver_userprofile",
|
||||||
virtual_parent=realm_config,
|
virtual_parent=realm_config,
|
||||||
custom_fetch=custom_fetch_user_profile,
|
custom_fetch=custom_fetch_user_profile,
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_groups_config = Config(
|
user_groups_config = Config(
|
||||||
@@ -1006,6 +1020,7 @@ def get_realm_config() -> Config:
|
|||||||
normal_parent=realm_config,
|
normal_parent=realm_config,
|
||||||
include_rows="realm_id__in",
|
include_rows="realm_id__in",
|
||||||
exclude=["direct_members", "direct_subgroups"],
|
exclude=["direct_members", "direct_subgroups"],
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
Config(
|
Config(
|
||||||
@@ -1078,6 +1093,7 @@ def get_realm_config() -> Config:
|
|||||||
# It is just "glue" data for internal data model consistency purposes
|
# It is just "glue" data for internal data model consistency purposes
|
||||||
# with no user-specific information.
|
# with no user-specific information.
|
||||||
limit_to_consenting_users=False,
|
limit_to_consenting_users=False,
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
Config(
|
Config(
|
||||||
@@ -1092,6 +1108,7 @@ def get_realm_config() -> Config:
|
|||||||
model=Stream,
|
model=Stream,
|
||||||
normal_parent=realm_config,
|
normal_parent=realm_config,
|
||||||
include_rows="realm_id__in",
|
include_rows="realm_id__in",
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_recipient_config = Config(
|
stream_recipient_config = Config(
|
||||||
@@ -1100,6 +1117,7 @@ def get_realm_config() -> Config:
|
|||||||
normal_parent=stream_config,
|
normal_parent=stream_config,
|
||||||
include_rows="type_id__in",
|
include_rows="type_id__in",
|
||||||
filter_args={"type": Recipient.STREAM},
|
filter_args={"type": Recipient.STREAM},
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
Config(
|
Config(
|
||||||
@@ -1131,6 +1149,7 @@ def get_realm_config() -> Config:
|
|||||||
"_stream_recipient",
|
"_stream_recipient",
|
||||||
"_huddle_recipient",
|
"_huddle_recipient",
|
||||||
],
|
],
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
Config(
|
Config(
|
||||||
@@ -1157,11 +1176,12 @@ def get_realm_config() -> Config:
|
|||||||
|
|
||||||
|
|
||||||
def custom_process_subscription_in_realm_config(
|
def custom_process_subscription_in_realm_config(
|
||||||
subscriptions: list[Record], context: Context
|
subscriptions: Iterable[Record], context: Context
|
||||||
) -> list[Record]:
|
) -> Iterator[Record]:
|
||||||
export_type = context["export_type"]
|
export_type = context["export_type"]
|
||||||
if export_type == RealmExport.EXPORT_FULL_WITHOUT_CONSENT:
|
if export_type == RealmExport.EXPORT_FULL_WITHOUT_CONSENT:
|
||||||
return subscriptions
|
yield from subscriptions
|
||||||
|
return
|
||||||
|
|
||||||
exportable_user_ids_from_context = context["exportable_user_ids"]
|
exportable_user_ids_from_context = context["exportable_user_ids"]
|
||||||
if export_type == RealmExport.EXPORT_FULL_WITH_CONSENT:
|
if export_type == RealmExport.EXPORT_FULL_WITH_CONSENT:
|
||||||
@@ -1172,9 +1192,10 @@ def custom_process_subscription_in_realm_config(
|
|||||||
assert exportable_user_ids_from_context is None
|
assert exportable_user_ids_from_context is None
|
||||||
consented_user_ids = set()
|
consented_user_ids = set()
|
||||||
|
|
||||||
def scrub_subscription_if_needed(subscription: Record) -> Record:
|
for subscription in subscriptions:
|
||||||
if subscription["user_profile"] in consented_user_ids:
|
if subscription["user_profile"] in consented_user_ids:
|
||||||
return subscription
|
yield subscription
|
||||||
|
continue
|
||||||
# We create a replacement Subscription, setting only the essential fields,
|
# We create a replacement Subscription, setting only the essential fields,
|
||||||
# while allowing all the other ones to fall back to the defaults
|
# while allowing all the other ones to fall back to the defaults
|
||||||
# defined in the model.
|
# defined in the model.
|
||||||
@@ -1190,10 +1211,7 @@ def custom_process_subscription_in_realm_config(
|
|||||||
color=random.choice(STREAM_ASSIGNMENT_COLORS),
|
color=random.choice(STREAM_ASSIGNMENT_COLORS),
|
||||||
)
|
)
|
||||||
subscription_dict = model_to_dict(scrubbed_subscription)
|
subscription_dict = model_to_dict(scrubbed_subscription)
|
||||||
return subscription_dict
|
yield subscription_dict
|
||||||
|
|
||||||
processed_rows = map(scrub_subscription_if_needed, subscriptions)
|
|
||||||
return list(processed_rows)
|
|
||||||
|
|
||||||
|
|
||||||
def add_user_profile_child_configs(user_profile_config: Config) -> None:
|
def add_user_profile_child_configs(user_profile_config: Config) -> None:
|
||||||
@@ -1410,7 +1428,7 @@ def custom_fetch_user_profile(response: TableData, context: Context) -> None:
|
|||||||
|
|
||||||
def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) -> None:
|
def custom_fetch_user_profile_cross_realm(response: TableData, context: Context) -> None:
|
||||||
realm = context["realm"]
|
realm = context["realm"]
|
||||||
response["zerver_userprofile_crossrealm"] = []
|
crossrealm_bots = []
|
||||||
|
|
||||||
bot_name_to_default_email = {
|
bot_name_to_default_email = {
|
||||||
"NOTIFICATION_BOT": "notification-bot@zulip.com",
|
"NOTIFICATION_BOT": "notification-bot@zulip.com",
|
||||||
@@ -1432,13 +1450,14 @@ def custom_fetch_user_profile_cross_realm(response: TableData, context: Context)
|
|||||||
bot_user_id = get_system_bot(bot_email, internal_realm.id).id
|
bot_user_id = get_system_bot(bot_email, internal_realm.id).id
|
||||||
|
|
||||||
recipient_id = Recipient.objects.get(type_id=bot_user_id, type=Recipient.PERSONAL).id
|
recipient_id = Recipient.objects.get(type_id=bot_user_id, type=Recipient.PERSONAL).id
|
||||||
response["zerver_userprofile_crossrealm"].append(
|
crossrealm_bots.append(
|
||||||
dict(
|
dict(
|
||||||
email=bot_default_email,
|
email=bot_default_email,
|
||||||
id=bot_user_id,
|
id=bot_user_id,
|
||||||
recipient_id=recipient_id,
|
recipient_id=recipient_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
response["zerver_userprofile_crossrealm"] = crossrealm_bots
|
||||||
|
|
||||||
|
|
||||||
def fetch_attachment_data(
|
def fetch_attachment_data(
|
||||||
@@ -1448,10 +1467,21 @@ def fetch_attachment_data(
|
|||||||
Attachment.objects.filter(
|
Attachment.objects.filter(
|
||||||
Q(messages__in=message_ids) | Q(scheduled_messages__in=scheduled_message_ids),
|
Q(messages__in=message_ids) | Q(scheduled_messages__in=scheduled_message_ids),
|
||||||
realm_id=realm_id,
|
realm_id=realm_id,
|
||||||
).distinct()
|
|
||||||
)
|
)
|
||||||
response["zerver_attachment"] = make_raw(attachments)
|
.distinct("path_id")
|
||||||
floatify_datetime_fields(response, "zerver_attachment")
|
.order_by("path_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
def postprocess_attachment(row: Record) -> Record:
|
||||||
|
row = floatify_datetime_fields(row, "zerver_attachment")
|
||||||
|
filtered_message_ids = set(row["messages"]).intersection(message_ids)
|
||||||
|
row["messages"] = sorted(filtered_message_ids)
|
||||||
|
|
||||||
|
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
|
||||||
|
scheduled_message_ids
|
||||||
|
)
|
||||||
|
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
|
||||||
|
return row
|
||||||
|
|
||||||
"""
|
"""
|
||||||
We usually export most messages for the realm, but not
|
We usually export most messages for the realm, but not
|
||||||
@@ -1461,14 +1491,7 @@ def fetch_attachment_data(
|
|||||||
|
|
||||||
Same reasoning applies to scheduled_messages.
|
Same reasoning applies to scheduled_messages.
|
||||||
"""
|
"""
|
||||||
for row in response["zerver_attachment"]:
|
response["zerver_attachment"] = (postprocess_attachment(r) for r in make_raw(attachments))
|
||||||
filtered_message_ids = set(row["messages"]).intersection(message_ids)
|
|
||||||
row["messages"] = sorted(filtered_message_ids)
|
|
||||||
|
|
||||||
filtered_scheduled_message_ids = set(row["scheduled_messages"]).intersection(
|
|
||||||
scheduled_message_ids
|
|
||||||
)
|
|
||||||
row["scheduled_messages"] = sorted(filtered_scheduled_message_ids)
|
|
||||||
|
|
||||||
return attachments
|
return attachments
|
||||||
|
|
||||||
@@ -1480,18 +1503,17 @@ def custom_fetch_realm_audit_logs_for_user(response: TableData, context: Context
|
|||||||
"""
|
"""
|
||||||
user = context["user"]
|
user = context["user"]
|
||||||
query = RealmAuditLog.objects.filter(Q(modified_user_id=user.id) | Q(acting_user_id=user.id))
|
query = RealmAuditLog.objects.filter(Q(modified_user_id=user.id) | Q(acting_user_id=user.id))
|
||||||
rows = make_raw(list(query))
|
response["zerver_realmauditlog"] = make_raw(query.iterator())
|
||||||
response["zerver_realmauditlog"] = rows
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None:
|
def fetch_reaction_data(response: TableData, message_ids: set[int]) -> None:
|
||||||
query = Reaction.objects.filter(message_id__in=list(message_ids))
|
query = Reaction.objects.filter(message_id__in=list(message_ids))
|
||||||
response["zerver_reaction"] = make_raw(list(query))
|
response["zerver_reaction"] = make_raw(query.iterator())
|
||||||
|
|
||||||
|
|
||||||
def fetch_client_data(response: TableData, client_ids: set[int]) -> None:
|
def fetch_client_data(response: TableData, client_ids: set[int]) -> None:
|
||||||
query = Client.objects.filter(id__in=list(client_ids))
|
query = Client.objects.filter(id__in=list(client_ids))
|
||||||
response["zerver_client"] = make_raw(list(query))
|
response["zerver_client"] = make_raw(query.iterator())
|
||||||
|
|
||||||
|
|
||||||
def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None:
|
def custom_fetch_direct_message_groups(response: TableData, context: Context) -> None:
|
||||||
@@ -1509,7 +1531,9 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
|
|||||||
consented_user_ids = set()
|
consented_user_ids = set()
|
||||||
|
|
||||||
user_profile_ids = {
|
user_profile_ids = {
|
||||||
r["id"] for r in response["zerver_userprofile"] + response["zerver_userprofile_mirrordummy"]
|
r["id"]
|
||||||
|
for r in list(response["zerver_userprofile"])
|
||||||
|
+ list(response["zerver_userprofile_mirrordummy"])
|
||||||
}
|
}
|
||||||
|
|
||||||
recipient_filter = Q()
|
recipient_filter = Q()
|
||||||
@@ -1575,12 +1599,6 @@ def custom_fetch_direct_message_groups(response: TableData, context: Context) ->
|
|||||||
response["zerver_huddle"] = make_raw(
|
response["zerver_huddle"] = make_raw(
|
||||||
DirectMessageGroup.objects.filter(id__in=direct_message_group_ids)
|
DirectMessageGroup.objects.filter(id__in=direct_message_group_ids)
|
||||||
)
|
)
|
||||||
if export_type == RealmExport.EXPORT_PUBLIC and any(
|
|
||||||
response[t] for t in ["_huddle_recipient", "_huddle_subscription", "zerver_huddle"]
|
|
||||||
):
|
|
||||||
raise AssertionError(
|
|
||||||
"Public export should not result in exporting any data in _huddle tables"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def custom_fetch_scheduled_messages(response: TableData, context: Context) -> None:
|
def custom_fetch_scheduled_messages(response: TableData, context: Context) -> None:
|
||||||
@@ -1591,7 +1609,7 @@ def custom_fetch_scheduled_messages(response: TableData, context: Context) -> No
|
|||||||
exportable_scheduled_message_ids = context["exportable_scheduled_message_ids"]
|
exportable_scheduled_message_ids = context["exportable_scheduled_message_ids"]
|
||||||
|
|
||||||
query = ScheduledMessage.objects.filter(realm=realm, id__in=exportable_scheduled_message_ids)
|
query = ScheduledMessage.objects.filter(realm=realm, id__in=exportable_scheduled_message_ids)
|
||||||
rows = make_raw(list(query))
|
rows = make_raw(query)
|
||||||
|
|
||||||
response["zerver_scheduledmessage"] = rows
|
response["zerver_scheduledmessage"] = rows
|
||||||
|
|
||||||
@@ -1654,14 +1672,15 @@ def custom_fetch_realm_audit_logs_for_realm(response: TableData, context: Contex
|
|||||||
|
|
||||||
def custom_fetch_onboarding_usermessage(response: TableData, context: Context) -> None:
|
def custom_fetch_onboarding_usermessage(response: TableData, context: Context) -> None:
|
||||||
realm = context["realm"]
|
realm = context["realm"]
|
||||||
response["zerver_onboardingusermessage"] = []
|
onboarding = []
|
||||||
|
|
||||||
onboarding_usermessage_query = OnboardingUserMessage.objects.filter(realm=realm)
|
onboarding_usermessage_query = OnboardingUserMessage.objects.filter(realm=realm)
|
||||||
for onboarding_usermessage in onboarding_usermessage_query:
|
for onboarding_usermessage in onboarding_usermessage_query:
|
||||||
onboarding_usermessage_obj = model_to_dict(onboarding_usermessage)
|
onboarding_usermessage_obj = model_to_dict(onboarding_usermessage)
|
||||||
onboarding_usermessage_obj["flags_mask"] = onboarding_usermessage.flags.mask
|
onboarding_usermessage_obj["flags_mask"] = onboarding_usermessage.flags.mask
|
||||||
del onboarding_usermessage_obj["flags"]
|
del onboarding_usermessage_obj["flags"]
|
||||||
response["zerver_onboardingusermessage"].append(onboarding_usermessage_obj)
|
onboarding.append(onboarding_usermessage_obj)
|
||||||
|
response["zerver_onboardingusermessage"] = onboarding
|
||||||
|
|
||||||
|
|
||||||
def fetch_usermessages(
|
def fetch_usermessages(
|
||||||
@@ -1710,7 +1729,8 @@ def export_usermessages_batch(
|
|||||||
with open(input_path, "rb") as input_file:
|
with open(input_path, "rb") as input_file:
|
||||||
input_data: MessagePartial = orjson.loads(input_file.read())
|
input_data: MessagePartial = orjson.loads(input_file.read())
|
||||||
|
|
||||||
message_ids = {item["id"] for item in input_data["zerver_message"]}
|
messages = list(input_data["zerver_message"])
|
||||||
|
message_ids = {item["id"] for item in messages}
|
||||||
user_profile_ids = set(input_data["zerver_userprofile_ids"])
|
user_profile_ids = set(input_data["zerver_userprofile_ids"])
|
||||||
realm = Realm.objects.get(id=input_data["realm_id"])
|
realm = Realm.objects.get(id=input_data["realm_id"])
|
||||||
zerver_usermessage_data = fetch_usermessages(
|
zerver_usermessage_data = fetch_usermessages(
|
||||||
@@ -1723,7 +1743,7 @@ def export_usermessages_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_data: TableData = dict(
|
output_data: TableData = dict(
|
||||||
zerver_message=input_data["zerver_message"],
|
zerver_message=messages,
|
||||||
zerver_usermessage=zerver_usermessage_data,
|
zerver_usermessage=zerver_usermessage_data,
|
||||||
)
|
)
|
||||||
write_table_data(output_path, output_data)
|
write_table_data(output_path, output_data)
|
||||||
@@ -1764,9 +1784,9 @@ def export_partial_message_files(
|
|||||||
response["zerver_userprofile"],
|
response["zerver_userprofile"],
|
||||||
)
|
)
|
||||||
ids_of_our_possible_senders = get_ids(
|
ids_of_our_possible_senders = get_ids(
|
||||||
response["zerver_userprofile"]
|
list(response["zerver_userprofile"])
|
||||||
+ response["zerver_userprofile_mirrordummy"]
|
+ list(response["zerver_userprofile_mirrordummy"])
|
||||||
+ response["zerver_userprofile_crossrealm"]
|
+ list(response["zerver_userprofile_crossrealm"])
|
||||||
)
|
)
|
||||||
|
|
||||||
consented_user_ids: set[int] = set()
|
consented_user_ids: set[int] = set()
|
||||||
@@ -1913,7 +1933,9 @@ def write_message_partials(
|
|||||||
for message_id_chunk in message_id_chunks:
|
for message_id_chunk in message_id_chunks:
|
||||||
# Uses index: zerver_message_pkey
|
# Uses index: zerver_message_pkey
|
||||||
actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id")
|
actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id")
|
||||||
message_chunk = make_raw(actual_query)
|
message_chunk = [
|
||||||
|
floatify_datetime_fields(r, "zerver_message") for r in make_raw(actual_query.iterator())
|
||||||
|
]
|
||||||
|
|
||||||
for row in message_chunk:
|
for row in message_chunk:
|
||||||
collected_client_ids.add(row["sending_client"])
|
collected_client_ids.add(row["sending_client"])
|
||||||
@@ -1923,16 +1945,11 @@ def write_message_partials(
|
|||||||
message_filename += ".partial"
|
message_filename += ".partial"
|
||||||
logging.info("Fetched messages for %s", message_filename)
|
logging.info("Fetched messages for %s", message_filename)
|
||||||
|
|
||||||
# Clean up our messages.
|
|
||||||
table_data: TableData = {}
|
|
||||||
table_data["zerver_message"] = message_chunk
|
|
||||||
floatify_datetime_fields(table_data, "zerver_message")
|
|
||||||
|
|
||||||
# Build up our output for the .partial file, which needs
|
# Build up our output for the .partial file, which needs
|
||||||
# a list of user_profile_ids to search for (as well as
|
# a list of user_profile_ids to search for (as well as
|
||||||
# the realm id).
|
# the realm id).
|
||||||
output: MessagePartial = dict(
|
output: MessagePartial = dict(
|
||||||
zerver_message=table_data["zerver_message"],
|
zerver_message=message_chunk,
|
||||||
zerver_userprofile_ids=list(user_profile_ids),
|
zerver_userprofile_ids=list(user_profile_ids),
|
||||||
realm_id=realm.id,
|
realm_id=realm.id,
|
||||||
)
|
)
|
||||||
@@ -1945,7 +1962,7 @@ def write_message_partials(
|
|||||||
def export_uploads_and_avatars(
|
def export_uploads_and_avatars(
|
||||||
realm: Realm,
|
realm: Realm,
|
||||||
*,
|
*,
|
||||||
attachments: list[Attachment] | None = None,
|
attachments: Iterable[Attachment] | None = None,
|
||||||
user: UserProfile | None,
|
user: UserProfile | None,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -1974,7 +1991,7 @@ def export_uploads_and_avatars(
|
|||||||
else:
|
else:
|
||||||
handle_system_bots = False
|
handle_system_bots = False
|
||||||
users = [user]
|
users = [user]
|
||||||
attachments = list(Attachment.objects.filter(owner_id=user.id))
|
attachments = list(Attachment.objects.filter(owner_id=user.id).order_by("path_id"))
|
||||||
realm_emojis = list(RealmEmoji.objects.filter(author_id=user.id))
|
realm_emojis = list(RealmEmoji.objects.filter(author_id=user.id))
|
||||||
|
|
||||||
if settings.LOCAL_UPLOADS_DIR:
|
if settings.LOCAL_UPLOADS_DIR:
|
||||||
@@ -2164,7 +2181,6 @@ def export_files_from_s3(
|
|||||||
processing_emoji = flavor == "emoji"
|
processing_emoji = flavor == "emoji"
|
||||||
|
|
||||||
bucket = get_bucket(bucket_name)
|
bucket = get_bucket(bucket_name)
|
||||||
records = []
|
|
||||||
|
|
||||||
logging.info("Downloading %s files from %s", flavor, bucket_name)
|
logging.info("Downloading %s files from %s", flavor, bucket_name)
|
||||||
|
|
||||||
@@ -2175,8 +2191,11 @@ def export_files_from_s3(
|
|||||||
email_gateway_bot = get_system_bot(settings.EMAIL_GATEWAY_BOT, internal_realm.id)
|
email_gateway_bot = get_system_bot(settings.EMAIL_GATEWAY_BOT, internal_realm.id)
|
||||||
user_ids.add(email_gateway_bot.id)
|
user_ids.add(email_gateway_bot.id)
|
||||||
|
|
||||||
|
def iterate_attachments() -> Iterator[Record]:
|
||||||
count = 0
|
count = 0
|
||||||
for bkey in bucket.objects.filter(Prefix=object_prefix):
|
for bkey in bucket.objects.filter(Prefix=object_prefix):
|
||||||
|
# This is promised to be iterated in sorted filename order.
|
||||||
|
|
||||||
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
|
if valid_hashes is not None and bkey.Object().key not in valid_hashes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -2213,19 +2232,19 @@ def export_files_from_s3(
|
|||||||
record["path"] = key.key
|
record["path"] = key.key
|
||||||
_save_s3_object_to_file(key, output_dir, processing_uploads)
|
_save_s3_object_to_file(key, output_dir, processing_uploads)
|
||||||
|
|
||||||
records.append(record)
|
yield record
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if count % 100 == 0:
|
if count % 100 == 0:
|
||||||
logging.info("Finished %s", count)
|
logging.info("Finished %s", count)
|
||||||
|
|
||||||
write_records_json_file(output_dir, records)
|
write_records_json_file(output_dir, iterate_attachments())
|
||||||
|
|
||||||
|
|
||||||
def export_uploads_from_local(
|
def export_uploads_from_local(
|
||||||
realm: Realm, local_dir: Path, output_dir: Path, attachments: list[Attachment]
|
realm: Realm, local_dir: Path, output_dir: Path, attachments: Iterable[Attachment]
|
||||||
) -> None:
|
) -> None:
|
||||||
records = []
|
def iterate_attachments() -> Iterator[Record]:
|
||||||
for count, attachment in enumerate(attachments, 1):
|
for count, attachment in enumerate(attachments, 1):
|
||||||
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
||||||
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
||||||
@@ -2248,12 +2267,12 @@ def export_uploads_from_local(
|
|||||||
last_modified=stat.st_mtime,
|
last_modified=stat.st_mtime,
|
||||||
content_type=None,
|
content_type=None,
|
||||||
)
|
)
|
||||||
records.append(record)
|
yield record
|
||||||
|
|
||||||
if count % 100 == 0:
|
if count % 100 == 0:
|
||||||
logging.info("Finished %s", count)
|
logging.info("Finished %s", count)
|
||||||
|
|
||||||
write_records_json_file(output_dir, records)
|
write_records_json_file(output_dir, iterate_attachments())
|
||||||
|
|
||||||
|
|
||||||
def export_avatars_from_local(
|
def export_avatars_from_local(
|
||||||
@@ -2337,24 +2356,18 @@ def get_emoji_path(realm_emoji: RealmEmoji) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def export_emoji_from_local(
|
def export_emoji_from_local(
|
||||||
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: list[RealmEmoji]
|
realm: Realm, local_dir: Path, output_dir: Path, realm_emojis: Iterable[RealmEmoji]
|
||||||
) -> None:
|
) -> None:
|
||||||
records = []
|
def emoji_path_tuples() -> Iterator[tuple[RealmEmoji, str]]:
|
||||||
|
|
||||||
realm_emoji_helper_tuples: list[tuple[RealmEmoji, str]] = []
|
|
||||||
for realm_emoji in realm_emojis:
|
for realm_emoji in realm_emojis:
|
||||||
realm_emoji_path = get_emoji_path(realm_emoji)
|
realm_emoji_path = mark_sanitized(get_emoji_path(realm_emoji))
|
||||||
|
|
||||||
# Use 'mark_sanitized' to work around false positive caused by Pysa
|
yield (realm_emoji, realm_emoji_path)
|
||||||
# thinking that 'realm' (and thus 'attachment' and 'attachment.path_id')
|
yield (realm_emoji, realm_emoji_path + ".original")
|
||||||
# are user controlled
|
|
||||||
realm_emoji_path = mark_sanitized(realm_emoji_path)
|
|
||||||
|
|
||||||
realm_emoji_path_original = realm_emoji_path + ".original"
|
|
||||||
|
|
||||||
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path))
|
|
||||||
realm_emoji_helper_tuples.append((realm_emoji, realm_emoji_path_original))
|
|
||||||
|
|
||||||
|
def iterate_emoji(
|
||||||
|
realm_emoji_helper_tuples: Iterator[tuple[RealmEmoji, str]],
|
||||||
|
) -> Iterator[Record]:
|
||||||
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
|
for count, realm_emoji_helper_tuple in enumerate(realm_emoji_helper_tuples, 1):
|
||||||
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
|
realm_emoji_object, emoji_path = realm_emoji_helper_tuple
|
||||||
|
|
||||||
@@ -2365,9 +2378,7 @@ def export_emoji_from_local(
|
|||||||
shutil.copy2(local_path, output_path)
|
shutil.copy2(local_path, output_path)
|
||||||
# Realm emoji author is optional.
|
# Realm emoji author is optional.
|
||||||
author = realm_emoji_object.author
|
author = realm_emoji_object.author
|
||||||
author_id = None
|
author_id = author.id if author else None
|
||||||
if author:
|
|
||||||
author_id = author.id
|
|
||||||
record = dict(
|
record = dict(
|
||||||
realm_id=realm.id,
|
realm_id=realm.id,
|
||||||
author=author_id,
|
author=author_id,
|
||||||
@@ -2377,12 +2388,12 @@ def export_emoji_from_local(
|
|||||||
name=realm_emoji_object.name,
|
name=realm_emoji_object.name,
|
||||||
deactivated=realm_emoji_object.deactivated,
|
deactivated=realm_emoji_object.deactivated,
|
||||||
)
|
)
|
||||||
records.append(record)
|
yield record
|
||||||
|
|
||||||
if count % 100 == 0:
|
if count % 100 == 0:
|
||||||
logging.info("Finished %s", count)
|
logging.info("Finished %s", count)
|
||||||
|
|
||||||
write_records_json_file(output_dir, records)
|
write_records_json_file(output_dir, iterate_emoji(emoji_path_tuples()))
|
||||||
|
|
||||||
|
|
||||||
def do_write_stats_file_for_realm_export(output_dir: Path) -> dict[str, int | dict[str, int]]:
|
def do_write_stats_file_for_realm_export(output_dir: Path) -> dict[str, int | dict[str, int]]:
|
||||||
@@ -2502,17 +2513,13 @@ def do_export_realm(
|
|||||||
)
|
)
|
||||||
logging.info("%d messages were exported", len(message_ids))
|
logging.info("%d messages were exported", len(message_ids))
|
||||||
|
|
||||||
# zerver_reaction
|
fetch_reaction_data(response=response, message_ids=message_ids)
|
||||||
zerver_reaction: TableData = {}
|
|
||||||
fetch_reaction_data(response=zerver_reaction, message_ids=message_ids)
|
|
||||||
response.update(zerver_reaction)
|
|
||||||
|
|
||||||
zerver_client: TableData = {}
|
fetch_client_data(response=response, client_ids=collected_client_ids)
|
||||||
fetch_client_data(response=zerver_client, client_ids=collected_client_ids)
|
|
||||||
response.update(zerver_client)
|
|
||||||
|
|
||||||
# Override the "deactivated" flag on the realm
|
# Override the "deactivated" flag on the realm
|
||||||
if export_as_active is not None:
|
if export_as_active is not None:
|
||||||
|
assert isinstance(response["zerver_realm"], list)
|
||||||
response["zerver_realm"][0]["deactivated"] = not export_as_active
|
response["zerver_realm"][0]["deactivated"] = not export_as_active
|
||||||
|
|
||||||
response["import_source"] = "zulip" # type: ignore[assignment] # this is an extra info field, not TableData
|
response["import_source"] = "zulip" # type: ignore[assignment] # this is an extra info field, not TableData
|
||||||
@@ -2619,6 +2626,8 @@ def do_export_user(user_profile: UserProfile, output_dir: Path) -> None:
|
|||||||
export_file = os.path.join(output_dir, "user.json")
|
export_file = os.path.join(output_dir, "user.json")
|
||||||
write_table_data(output_file=export_file, data=response)
|
write_table_data(output_file=export_file, data=response)
|
||||||
|
|
||||||
|
# Double-check that this is a list, and not an iterator, so we can run over it again
|
||||||
|
assert isinstance(response["zerver_reaction"], list)
|
||||||
reaction_message_ids: set[int] = {row["message"] for row in response["zerver_reaction"]}
|
reaction_message_ids: set[int] = {row["message"] for row in response["zerver_reaction"]}
|
||||||
|
|
||||||
logging.info("Exporting messages")
|
logging.info("Exporting messages")
|
||||||
@@ -2661,6 +2670,7 @@ def get_single_user_config() -> Config:
|
|||||||
# Exports with consent are not relevant in the context of exporting
|
# Exports with consent are not relevant in the context of exporting
|
||||||
# a single user.
|
# a single user.
|
||||||
limit_to_consenting_users=False,
|
limit_to_consenting_users=False,
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# zerver_recipient
|
# zerver_recipient
|
||||||
@@ -2669,6 +2679,7 @@ def get_single_user_config() -> Config:
|
|||||||
model=Recipient,
|
model=Recipient,
|
||||||
virtual_parent=subscription_config,
|
virtual_parent=subscription_config,
|
||||||
id_source=("zerver_subscription", "recipient"),
|
id_source=("zerver_subscription", "recipient"),
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# zerver_stream
|
# zerver_stream
|
||||||
@@ -2713,6 +2724,7 @@ def get_single_user_config() -> Config:
|
|||||||
normal_parent=user_profile_config,
|
normal_parent=user_profile_config,
|
||||||
include_rows="user_profile_id__in",
|
include_rows="user_profile_id__in",
|
||||||
limit_to_consenting_users=False,
|
limit_to_consenting_users=False,
|
||||||
|
use_iterator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_user_profile_child_configs(user_profile_config)
|
add_user_profile_child_configs(user_profile_config)
|
||||||
@@ -2823,25 +2835,23 @@ def export_messages_single_user(
|
|||||||
.order_by("message_id")
|
.order_by("message_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
user_message_chunk = list(fat_query)
|
def process_row(user_message: UserMessage) -> Record:
|
||||||
|
|
||||||
message_chunk = []
|
|
||||||
for user_message in user_message_chunk:
|
|
||||||
item = model_to_dict(user_message.message)
|
item = model_to_dict(user_message.message)
|
||||||
item["flags"] = user_message.flags_list()
|
item["flags"] = user_message.flags_list()
|
||||||
item["flags_mask"] = user_message.flags.mask
|
item["flags_mask"] = user_message.flags.mask
|
||||||
# Add a few nice, human-readable details
|
# Add a few nice, human-readable details
|
||||||
item["sending_client_name"] = user_message.message.sending_client.name
|
item["sending_client_name"] = user_message.message.sending_client.name
|
||||||
item["recipient_name"] = get_recipient(user_message.message.recipient_id)
|
item["recipient_name"] = get_recipient(user_message.message.recipient_id)
|
||||||
message_chunk.append(item)
|
return floatify_datetime_fields(item, "zerver_message")
|
||||||
|
|
||||||
message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json")
|
message_filename = os.path.join(output_dir, f"messages-{dump_file_id:06}.json")
|
||||||
|
write_table_data(
|
||||||
|
message_filename,
|
||||||
|
{
|
||||||
|
"zerver_message": (process_row(um) for um in fat_query.iterator()),
|
||||||
|
},
|
||||||
|
)
|
||||||
logging.info("Fetched messages for %s", message_filename)
|
logging.info("Fetched messages for %s", message_filename)
|
||||||
|
|
||||||
output = {"zerver_message": message_chunk}
|
|
||||||
floatify_datetime_fields(output, "zerver_message")
|
|
||||||
|
|
||||||
write_table_data(message_filename, output)
|
|
||||||
dump_file_id += 1
|
dump_file_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import shutil
|
|||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from difflib import unified_diff
|
from difflib import unified_diff
|
||||||
from typing import Any
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
import bmemcached
|
import bmemcached
|
||||||
import orjson
|
import orjson
|
||||||
@@ -33,7 +33,7 @@ from zerver.actions.realm_settings import (
|
|||||||
from zerver.actions.user_settings import do_change_avatar_fields
|
from zerver.actions.user_settings import do_change_avatar_fields
|
||||||
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
|
from zerver.lib.avatar_hash import user_avatar_base_path_from_ids
|
||||||
from zerver.lib.bulk_create import bulk_set_users_or_streams_recipient_fields
|
from zerver.lib.bulk_create import bulk_set_users_or_streams_recipient_fields
|
||||||
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableData, TableName
|
from zerver.lib.export import DATE_FIELDS, Field, Path, Record, TableName
|
||||||
from zerver.lib.markdown import markdown_convert
|
from zerver.lib.markdown import markdown_convert
|
||||||
from zerver.lib.markdown import version as markdown_version
|
from zerver.lib.markdown import version as markdown_version
|
||||||
from zerver.lib.message import get_last_message_id
|
from zerver.lib.message import get_last_message_id
|
||||||
@@ -118,6 +118,8 @@ from zerver.models.recipients import get_direct_message_group_hash
|
|||||||
from zerver.models.users import get_system_bot, get_user_profile_by_id
|
from zerver.models.users import get_system_bot, get_user_profile_by_id
|
||||||
from zproject.backends import AUTH_BACKEND_NAME_MAP
|
from zproject.backends import AUTH_BACKEND_NAME_MAP
|
||||||
|
|
||||||
|
ImportedTableData: TypeAlias = dict[str, list[Record]]
|
||||||
|
|
||||||
realm_tables = [
|
realm_tables = [
|
||||||
("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"),
|
("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"),
|
||||||
("zerver_defaultstream", DefaultStream, "defaultstream"),
|
("zerver_defaultstream", DefaultStream, "defaultstream"),
|
||||||
@@ -208,7 +210,7 @@ message_id_to_attachments: dict[str, dict[int, list[str]]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def map_messages_to_attachments(data: TableData) -> None:
|
def map_messages_to_attachments(data: ImportedTableData) -> None:
|
||||||
for attachment in data["zerver_attachment"]:
|
for attachment in data["zerver_attachment"]:
|
||||||
for message_id in attachment["messages"]:
|
for message_id in attachment["messages"]:
|
||||||
message_id_to_attachments["zerver_message"][message_id].append(attachment["path_id"])
|
message_id_to_attachments["zerver_message"][message_id].append(attachment["path_id"])
|
||||||
@@ -231,14 +233,14 @@ def update_id_map(table: TableName, old_id: int, new_id: int) -> None:
|
|||||||
ID_MAP[table][old_id] = new_id
|
ID_MAP[table][old_id] = new_id
|
||||||
|
|
||||||
|
|
||||||
def fix_datetime_fields(data: TableData, table: TableName) -> None:
|
def fix_datetime_fields(data: ImportedTableData, table: TableName) -> None:
|
||||||
for item in data[table]:
|
for item in data[table]:
|
||||||
for field_name in DATE_FIELDS[table]:
|
for field_name in DATE_FIELDS[table]:
|
||||||
if item[field_name] is not None:
|
if item[field_name] is not None:
|
||||||
item[field_name] = datetime.fromtimestamp(item[field_name], tz=timezone.utc)
|
item[field_name] = datetime.fromtimestamp(item[field_name], tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def fix_upload_links(data: TableData, message_table: TableName) -> None:
|
def fix_upload_links(data: ImportedTableData, message_table: TableName) -> None:
|
||||||
"""
|
"""
|
||||||
Because the URLs for uploaded files encode the realm ID of the
|
Because the URLs for uploaded files encode the realm ID of the
|
||||||
organization being imported (which is only determined at import
|
organization being imported (which is only determined at import
|
||||||
@@ -260,7 +262,7 @@ def fix_upload_links(data: TableData, message_table: TableName) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def fix_stream_permission_group_settings(
|
def fix_stream_permission_group_settings(
|
||||||
data: TableData, system_groups_name_dict: dict[str, NamedUserGroup]
|
data: ImportedTableData, system_groups_name_dict: dict[str, NamedUserGroup]
|
||||||
) -> None:
|
) -> None:
|
||||||
table = get_db_table(Stream)
|
table = get_db_table(Stream)
|
||||||
for stream in data[table]:
|
for stream in data[table]:
|
||||||
@@ -289,7 +291,7 @@ def fix_stream_permission_group_settings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_subscription_events(data: TableData, realm_id: int) -> None:
|
def create_subscription_events(data: ImportedTableData, realm_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
When the export data doesn't contain the table `zerver_realmauditlog`,
|
When the export data doesn't contain the table `zerver_realmauditlog`,
|
||||||
this function creates RealmAuditLog objects for `subscription_created`
|
this function creates RealmAuditLog objects for `subscription_created`
|
||||||
@@ -332,7 +334,7 @@ def create_subscription_events(data: TableData, realm_id: int) -> None:
|
|||||||
RealmAuditLog.objects.bulk_create(all_subscription_logs)
|
RealmAuditLog.objects.bulk_create(all_subscription_logs)
|
||||||
|
|
||||||
|
|
||||||
def fix_service_tokens(data: TableData, table: TableName) -> None:
|
def fix_service_tokens(data: ImportedTableData, table: TableName) -> None:
|
||||||
"""
|
"""
|
||||||
The tokens in the services are created by 'generate_api_key'.
|
The tokens in the services are created by 'generate_api_key'.
|
||||||
As the tokens are unique, they should be re-created for the imports.
|
As the tokens are unique, they should be re-created for the imports.
|
||||||
@@ -341,7 +343,7 @@ def fix_service_tokens(data: TableData, table: TableName) -> None:
|
|||||||
item["token"] = generate_api_key()
|
item["token"] = generate_api_key()
|
||||||
|
|
||||||
|
|
||||||
def process_direct_message_group_hash(data: TableData, table: TableName) -> None:
|
def process_direct_message_group_hash(data: ImportedTableData, table: TableName) -> None:
|
||||||
"""
|
"""
|
||||||
Build new direct message group hashes with the updated ids of the users
|
Build new direct message group hashes with the updated ids of the users
|
||||||
"""
|
"""
|
||||||
@@ -350,7 +352,7 @@ def process_direct_message_group_hash(data: TableData, table: TableName) -> None
|
|||||||
direct_message_group["huddle_hash"] = get_direct_message_group_hash(user_id_list)
|
direct_message_group["huddle_hash"] = get_direct_message_group_hash(user_id_list)
|
||||||
|
|
||||||
|
|
||||||
def get_direct_message_groups_from_subscription(data: TableData, table: TableName) -> None:
|
def get_direct_message_groups_from_subscription(data: ImportedTableData, table: TableName) -> None:
|
||||||
"""
|
"""
|
||||||
Extract the IDs of the user_profiles involved in a direct message group from
|
Extract the IDs of the user_profiles involved in a direct message group from
|
||||||
the subscription object
|
the subscription object
|
||||||
@@ -369,7 +371,7 @@ def get_direct_message_groups_from_subscription(data: TableData, table: TableNam
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fix_customprofilefield(data: TableData) -> None:
|
def fix_customprofilefield(data: ImportedTableData) -> None:
|
||||||
"""
|
"""
|
||||||
In CustomProfileField with 'field_type' like 'USER', the IDs need to be
|
In CustomProfileField with 'field_type' like 'USER', the IDs need to be
|
||||||
re-mapped.
|
re-mapped.
|
||||||
@@ -530,7 +532,7 @@ def fix_message_edit_history(
|
|||||||
message["edit_history"] = orjson.dumps(edit_history).decode()
|
message["edit_history"] = orjson.dumps(edit_history).decode()
|
||||||
|
|
||||||
|
|
||||||
def current_table_ids(data: TableData, table: TableName) -> list[int]:
|
def current_table_ids(data: ImportedTableData, table: TableName) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Returns the ids present in the current table
|
Returns the ids present in the current table
|
||||||
"""
|
"""
|
||||||
@@ -560,7 +562,7 @@ def allocate_ids(model_class: Any, count: int) -> list[int]:
|
|||||||
return [item[0] for item in query]
|
return [item[0] for item in query]
|
||||||
|
|
||||||
|
|
||||||
def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None:
|
def convert_to_id_fields(data: ImportedTableData, table: TableName, field_name: Field) -> None:
|
||||||
"""
|
"""
|
||||||
When Django gives us dict objects via model_to_dict, the foreign
|
When Django gives us dict objects via model_to_dict, the foreign
|
||||||
key fields are `foo`, but we want `foo_id` for the bulk insert.
|
key fields are `foo`, but we want `foo_id` for the bulk insert.
|
||||||
@@ -574,7 +576,7 @@ def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -
|
|||||||
|
|
||||||
|
|
||||||
def re_map_foreign_keys(
|
def re_map_foreign_keys(
|
||||||
data: TableData,
|
data: ImportedTableData,
|
||||||
table: TableName,
|
table: TableName,
|
||||||
field_name: Field,
|
field_name: Field,
|
||||||
related_table: TableName,
|
related_table: TableName,
|
||||||
@@ -585,7 +587,7 @@ def re_map_foreign_keys(
|
|||||||
"""
|
"""
|
||||||
This is a wrapper function for all the realm data tables
|
This is a wrapper function for all the realm data tables
|
||||||
and only avatar and attachment records need to be passed through the internal function
|
and only avatar and attachment records need to be passed through the internal function
|
||||||
because of the difference in data format (TableData corresponding to realm data tables
|
because of the difference in data format (ImportedTableData corresponding to realm data tables
|
||||||
and List[Record] corresponding to the avatar and attachment records)
|
and List[Record] corresponding to the avatar and attachment records)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -653,7 +655,7 @@ def re_map_foreign_keys_internal(
|
|||||||
item[field_name] = new_id
|
item[field_name] = new_id
|
||||||
|
|
||||||
|
|
||||||
def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
|
def re_map_realm_emoji_codes(data: ImportedTableData, *, table_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Some tables, including Reaction and UserStatus, contain a form of
|
Some tables, including Reaction and UserStatus, contain a form of
|
||||||
foreign key reference to the RealmEmoji table in the form of
|
foreign key reference to the RealmEmoji table in the form of
|
||||||
@@ -683,7 +685,7 @@ def re_map_realm_emoji_codes(data: TableData, *, table_name: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def re_map_foreign_keys_many_to_many(
|
def re_map_foreign_keys_many_to_many(
|
||||||
data: TableData,
|
data: ImportedTableData,
|
||||||
table: TableName,
|
table: TableName,
|
||||||
field_name: Field,
|
field_name: Field,
|
||||||
related_table: TableName,
|
related_table: TableName,
|
||||||
@@ -733,13 +735,13 @@ def re_map_foreign_keys_many_to_many_internal(
|
|||||||
return new_id_list
|
return new_id_list
|
||||||
|
|
||||||
|
|
||||||
def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None:
|
def fix_bitfield_keys(data: ImportedTableData, table: TableName, field_name: Field) -> None:
|
||||||
for item in data[table]:
|
for item in data[table]:
|
||||||
item[field_name] = item[field_name + "_mask"]
|
item[field_name] = item[field_name + "_mask"]
|
||||||
del item[field_name + "_mask"]
|
del item[field_name + "_mask"]
|
||||||
|
|
||||||
|
|
||||||
def remove_denormalized_recipient_column_from_data(data: TableData) -> None:
|
def remove_denormalized_recipient_column_from_data(data: ImportedTableData) -> None:
|
||||||
"""
|
"""
|
||||||
The recipient column shouldn't be imported, we'll set the correct values
|
The recipient column shouldn't be imported, we'll set the correct values
|
||||||
when Recipient table gets imported.
|
when Recipient table gets imported.
|
||||||
@@ -762,7 +764,7 @@ def get_db_table(model_class: Any) -> str:
|
|||||||
return model_class._meta.db_table
|
return model_class._meta.db_table
|
||||||
|
|
||||||
|
|
||||||
def update_model_ids(model: Any, data: TableData, related_table: TableName) -> None:
|
def update_model_ids(model: Any, data: ImportedTableData, related_table: TableName) -> None:
|
||||||
table = get_db_table(model)
|
table = get_db_table(model)
|
||||||
|
|
||||||
# Important: remapping usermessage rows is
|
# Important: remapping usermessage rows is
|
||||||
@@ -777,7 +779,7 @@ def update_model_ids(model: Any, data: TableData, related_table: TableName) -> N
|
|||||||
re_map_foreign_keys(data, table, "id", related_table=related_table, id_field=True)
|
re_map_foreign_keys(data, table, "id", related_table=related_table, id_field=True)
|
||||||
|
|
||||||
|
|
||||||
def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
|
def bulk_import_user_message_data(data: ImportedTableData, dump_file_id: int) -> None:
|
||||||
model = UserMessage
|
model = UserMessage
|
||||||
table = "zerver_usermessage"
|
table = "zerver_usermessage"
|
||||||
lst = data[table]
|
lst = data[table]
|
||||||
@@ -810,7 +812,7 @@ def bulk_import_user_message_data(data: TableData, dump_file_id: int) -> None:
|
|||||||
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
||||||
|
|
||||||
|
|
||||||
def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = None) -> None:
|
def bulk_import_model(data: ImportedTableData, model: Any, dump_file_id: str | None = None) -> None:
|
||||||
table = get_db_table(model)
|
table = get_db_table(model)
|
||||||
# TODO, deprecate dump_file_id
|
# TODO, deprecate dump_file_id
|
||||||
model.objects.bulk_create(model(**item) for item in data[table])
|
model.objects.bulk_create(model(**item) for item in data[table])
|
||||||
@@ -820,7 +822,7 @@ def bulk_import_model(data: TableData, model: Any, dump_file_id: str | None = No
|
|||||||
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
logging.info("Successfully imported %s from %s[%s].", model, table, dump_file_id)
|
||||||
|
|
||||||
|
|
||||||
def bulk_import_named_user_groups(data: TableData) -> None:
|
def bulk_import_named_user_groups(data: ImportedTableData) -> None:
|
||||||
vals = [
|
vals = [
|
||||||
(
|
(
|
||||||
group["usergroup_ptr_id"],
|
group["usergroup_ptr_id"],
|
||||||
@@ -854,7 +856,7 @@ def bulk_import_named_user_groups(data: TableData) -> None:
|
|||||||
# correctly import multiple realms into the same server, we need to
|
# correctly import multiple realms into the same server, we need to
|
||||||
# check if a Client object already exists, and so we need to support
|
# check if a Client object already exists, and so we need to support
|
||||||
# remap all Client IDs to the values in the new DB.
|
# remap all Client IDs to the values in the new DB.
|
||||||
def bulk_import_client(data: TableData, model: Any, table: TableName) -> None:
|
def bulk_import_client(data: ImportedTableData, model: Any, table: TableName) -> None:
|
||||||
for item in data[table]:
|
for item in data[table]:
|
||||||
try:
|
try:
|
||||||
client = Client.objects.get(name=item["name"])
|
client = Client.objects.get(name=item["name"])
|
||||||
@@ -880,7 +882,7 @@ def set_subscriber_count_for_channels(realm: Realm) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def fix_subscriptions_is_user_active_column(
|
def fix_subscriptions_is_user_active_column(
|
||||||
data: TableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
|
data: ImportedTableData, user_profiles: list[UserProfile], crossrealm_user_ids: set[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
table = get_db_table(Subscription)
|
table = get_db_table(Subscription)
|
||||||
user_id_to_active_status = {user.id: user.is_active for user in user_profiles}
|
user_id_to_active_status = {user.id: user.is_active for user in user_profiles}
|
||||||
@@ -1156,7 +1158,7 @@ def import_uploads(
|
|||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
def disable_restricted_authentication_methods(data: TableData) -> None:
|
def disable_restricted_authentication_methods(data: ImportedTableData) -> None:
|
||||||
"""
|
"""
|
||||||
Should run only with settings.BILLING_ENABLED. Ensures that we only
|
Should run only with settings.BILLING_ENABLED. Ensures that we only
|
||||||
enable authentication methods that are available without needing a plan.
|
enable authentication methods that are available without needing a plan.
|
||||||
@@ -2105,7 +2107,7 @@ def import_message_data(realm: Realm, sender_map: dict[int, Record], import_dir:
|
|||||||
dump_file_id += 1
|
dump_file_id += 1
|
||||||
|
|
||||||
|
|
||||||
def import_attachments(data: TableData) -> None:
|
def import_attachments(data: ImportedTableData) -> None:
|
||||||
# Clean up the data in zerver_attachment that is not
|
# Clean up the data in zerver_attachment that is not
|
||||||
# relevant to our many-to-many import.
|
# relevant to our many-to-many import.
|
||||||
fix_datetime_fields(data, "zerver_attachment")
|
fix_datetime_fields(data, "zerver_attachment")
|
||||||
@@ -2146,7 +2148,7 @@ def import_attachments(data: TableData) -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create our table data for insert.
|
# Create our table data for insert.
|
||||||
m2m_data: TableData = {m2m_table_name: m2m_rows}
|
m2m_data: ImportedTableData = {m2m_table_name: m2m_rows}
|
||||||
convert_to_id_fields(m2m_data, m2m_table_name, parent_singular)
|
convert_to_id_fields(m2m_data, m2m_table_name, parent_singular)
|
||||||
convert_to_id_fields(m2m_data, m2m_table_name, child_singular)
|
convert_to_id_fields(m2m_data, m2m_table_name, child_singular)
|
||||||
m2m_rows = m2m_data[m2m_table_name]
|
m2m_rows = m2m_data[m2m_table_name]
|
||||||
@@ -2193,7 +2195,7 @@ def import_attachments(data: TableData) -> None:
|
|||||||
logging.info("Successfully imported M2M table %s", m2m_table_name)
|
logging.info("Successfully imported M2M table %s", m2m_table_name)
|
||||||
|
|
||||||
|
|
||||||
def fix_attachments_data(attachment_data: TableData) -> None:
|
def fix_attachments_data(attachment_data: ImportedTableData) -> None:
|
||||||
for attachment in attachment_data["zerver_attachment"]:
|
for attachment in attachment_data["zerver_attachment"]:
|
||||||
attachment["path_id"] = path_maps["old_attachment_path_to_new_path"][attachment["path_id"]]
|
attachment["path_id"] = path_maps["old_attachment_path_to_new_path"][attachment["path_id"]]
|
||||||
|
|
||||||
@@ -2205,7 +2207,7 @@ def fix_attachments_data(attachment_data: TableData) -> None:
|
|||||||
attachment["content_type"] = guessed_content_type
|
attachment["content_type"] = guessed_content_type
|
||||||
|
|
||||||
|
|
||||||
def create_image_attachments(realm: Realm, attachment_data: TableData) -> None:
|
def create_image_attachments(realm: Realm, attachment_data: ImportedTableData) -> None:
|
||||||
for attachment in attachment_data["zerver_attachment"]:
|
for attachment in attachment_data["zerver_attachment"]:
|
||||||
if attachment["content_type"] not in THUMBNAIL_ACCEPT_IMAGE_TYPES:
|
if attachment["content_type"] not in THUMBNAIL_ACCEPT_IMAGE_TYPES:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -849,11 +849,13 @@ class RealmImportExportTest(ExportFile):
|
|||||||
# Consented users:
|
# Consented users:
|
||||||
hamlet = self.example_user("hamlet")
|
hamlet = self.example_user("hamlet")
|
||||||
othello = self.example_user("othello")
|
othello = self.example_user("othello")
|
||||||
|
cordelia = self.example_user("cordelia")
|
||||||
# Iago will be non-consenting.
|
# Iago will be non-consenting.
|
||||||
iago = self.example_user("iago")
|
iago = self.example_user("iago")
|
||||||
|
|
||||||
do_change_user_setting(hamlet, "allow_private_data_export", True, acting_user=None)
|
do_change_user_setting(hamlet, "allow_private_data_export", True, acting_user=None)
|
||||||
do_change_user_setting(othello, "allow_private_data_export", True, acting_user=None)
|
do_change_user_setting(othello, "allow_private_data_export", True, acting_user=None)
|
||||||
|
do_change_user_setting(cordelia, "allow_private_data_export", True, acting_user=None)
|
||||||
do_change_user_setting(iago, "allow_private_data_export", False, acting_user=None)
|
do_change_user_setting(iago, "allow_private_data_export", False, acting_user=None)
|
||||||
|
|
||||||
# Despite both hamlet and othello having consent enabled, in a public export
|
# Despite both hamlet and othello having consent enabled, in a public export
|
||||||
@@ -865,6 +867,9 @@ class RealmImportExportTest(ExportFile):
|
|||||||
a_message.sending_client = private_client
|
a_message.sending_client = private_client
|
||||||
a_message.save()
|
a_message.save()
|
||||||
|
|
||||||
|
# Verify that a group DM between consenting users is not exported
|
||||||
|
self.send_group_direct_message(hamlet, [othello, cordelia])
|
||||||
|
|
||||||
# SavedSnippets are private content - so in a public export, despite
|
# SavedSnippets are private content - so in a public export, despite
|
||||||
# hamlet having consent enabled, such objects should not be exported.
|
# hamlet having consent enabled, such objects should not be exported.
|
||||||
saved_snippet = do_create_saved_snippet("test", "test", hamlet)
|
saved_snippet = do_create_saved_snippet("test", "test", hamlet)
|
||||||
@@ -888,6 +893,9 @@ class RealmImportExportTest(ExportFile):
|
|||||||
exported_user_presence_ids = self.get_set(realm_data["zerver_userpresence"], "id")
|
exported_user_presence_ids = self.get_set(realm_data["zerver_userpresence"], "id")
|
||||||
self.assertIn(iago_presence.id, exported_user_presence_ids)
|
self.assertIn(iago_presence.id, exported_user_presence_ids)
|
||||||
|
|
||||||
|
exported_huddle_ids = self.get_set(realm_data["zerver_huddle"], "id")
|
||||||
|
self.assertEqual(exported_huddle_ids, set())
|
||||||
|
|
||||||
def test_export_realm_with_member_consent(self) -> None:
|
def test_export_realm_with_member_consent(self) -> None:
|
||||||
realm = Realm.objects.get(string_id="zulip")
|
realm = Realm.objects.get(string_id="zulip")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user