export-search: Use background workers to download attachments.

This commit is contained in:
Alex Vandiver
2025-08-22 03:57:24 +00:00
committed by Tim Abbott
parent 7714ca3ff9
commit d15f2fb831

View File

@@ -1,5 +1,6 @@
import csv import csv
import os import os
import queue
import shutil import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from collections.abc import Iterator from collections.abc import Iterator
@@ -7,7 +8,8 @@ from datetime import datetime, timezone
from email.headerregistry import Address from email.headerregistry import Address
from functools import lru_cache, reduce from functools import lru_cache, reduce
from operator import or_ from operator import or_
from typing import Any from threading import Lock, Thread
from typing import Any, NoReturn, Union
import orjson import orjson
from django.core.management.base import CommandError from django.core.management.base import CommandError
@@ -17,20 +19,39 @@ from typing_extensions import override
from zerver.lib.management import ZulipBaseCommand from zerver.lib.management import ZulipBaseCommand
from zerver.lib.soft_deactivation import reactivate_user_if_soft_deactivated from zerver.lib.soft_deactivation import reactivate_user_if_soft_deactivated
from zerver.lib.upload import save_attachment_contents from zerver.lib.upload import save_attachment_contents
from zerver.models import AbstractUserMessage, Attachment, Message, Recipient, Stream, UserProfile from zerver.models import AbstractUserMessage, Message, Recipient, Stream, UserProfile
from zerver.models.recipients import get_direct_message_group, get_or_create_direct_message_group from zerver.models.recipients import get_direct_message_group, get_or_create_direct_message_group
from zerver.models.streams import get_stream from zerver.models.streams import get_stream
from zerver.models.users import get_user_by_delivery_email from zerver.models.users import get_user_by_delivery_email
check_lock = Lock()
download_queue: queue.Queue[str] = queue.Queue()
BATCH_SIZE = 1000 BATCH_SIZE = 1000
def write_attachment(base_path: str, attachment: Attachment) -> None: def write_attachment(base_path: str, path_id: str, file_lock: Union["Lock", None] = None) -> None:
dir_path_id = os.path.dirname(attachment.path_id) dir_path_id = os.path.dirname(path_id)
assert "../" not in dir_path_id assert "../" not in dir_path_id
os.makedirs(base_path + "/" + dir_path_id, exist_ok=True) os.makedirs(base_path + "/" + dir_path_id, exist_ok=True)
with open(base_path + "/" + attachment.path_id, "wb") as attachment_file: with open(base_path + "/" + path_id, "wb") as attachment_file:
save_attachment_contents(attachment.path_id, attachment_file) if file_lock:
file_lock.release()
save_attachment_contents(path_id, attachment_file)
def download_worker(base_path: str) -> NoReturn:
while True:
path_id = download_queue.get()
check_lock.acquire()
if os.path.exists(base_path + "/" + path_id):
check_lock.release()
download_queue.task_done()
continue
print(f"({download_queue.qsize()} Downloading {path_id}")
write_attachment(base_path, path_id, check_lock)
download_queue.task_done()
class Command(ZulipBaseCommand): class Command(ZulipBaseCommand):
@@ -57,6 +78,7 @@ This is most often used for legal compliance.
parser.add_argument( parser.add_argument(
"--force", action="store_true", help="Overwrite the output file if it exists already" "--force", action="store_true", help="Overwrite the output file if it exists already"
) )
parser.add_argument("--threads", default=5, type=int)
parser.add_argument( parser.add_argument(
"--file", "--file",
@@ -208,7 +230,6 @@ This is most often used for legal compliance.
channels = [get_stream(n.lstrip("#"), realm) for n in options["channel"]] channels = [get_stream(n.lstrip("#"), realm) for n in options["channel"]]
limits &= Q(recipient__in=[s.recipient_id for s in channels]) limits &= Q(recipient__in=[s.recipient_id for s in channels])
attachments_written: set[str] = set()
messages_query = ( messages_query = (
Message.objects.filter(limits, realm=realm) Message.objects.filter(limits, realm=realm)
.select_related("sender") .select_related("sender")
@@ -228,6 +249,12 @@ This is most often used for legal compliance.
if need_distinct: if need_distinct:
messages_query = messages_query.distinct("id") messages_query = messages_query.distinct("id")
if options["write_attachments"]:
for i in range(options["threads"]):
Thread(
target=download_worker, daemon=True, args=(options["write_attachments"],)
).start()
@lru_cache(maxsize=1000) @lru_cache(maxsize=1000)
def format_sender(full_name: str, delivery_email: str) -> str: def format_sender(full_name: str, delivery_email: str) -> str:
return str(Address(display_name=full_name, addr_spec=delivery_email)) return str(Address(display_name=full_name, addr_spec=delivery_email))
@@ -272,10 +299,7 @@ This is most often used for legal compliance.
attachments = message.attachment_set.all() attachments = message.attachment_set.all()
row["attachments"] = " ".join(a.path_id for a in attachments) row["attachments"] = " ".join(a.path_id for a in attachments)
for attachment in attachments: for attachment in attachments:
if attachment.path_id in attachments_written: download_queue.put(attachment.path_id)
continue
write_attachment(options["write_attachments"], attachment)
attachments_written.add(attachment.path_id)
else: else:
row["attachments"] = "" row["attachments"] = ""
return row return row
@@ -329,3 +353,4 @@ This is most often used for legal compliance.
for batch in chunked_results(): for batch in chunked_results():
print(".") print(".")
csvwriter.writerows(batch) csvwriter.writerows(batch)
download_queue.join()