diff --git a/zerver/lib/upload.py b/zerver/lib/upload.py index b5c8d731b6..fac36e531a 100644 --- a/zerver/lib/upload.py +++ b/zerver/lib/upload.py @@ -385,7 +385,7 @@ def get_file_info(request: HttpRequest, user_file: File) -> Tuple[str, int, Opti return uploaded_file_name, uploaded_file_size, content_type -def get_signed_upload_url(path: str) -> str: +def get_signed_upload_url(path: str, download: bool = False) -> str: client = boto3.client( "s3", aws_access_key_id=settings.S3_KEY, @@ -393,9 +393,16 @@ def get_signed_upload_url(path: str) -> str: region_name=settings.S3_REGION, endpoint_url=settings.S3_ENDPOINT_URL, ) + params = { + "Bucket": settings.S3_AUTH_UPLOADS_BUCKET, + "Key": path, + } + if download: + params["ResponseContentDisposition"] = "attachment" + return client.generate_presigned_url( ClientMethod="get_object", - Params={"Bucket": settings.S3_AUTH_UPLOADS_BUCKET, "Key": path}, + Params=params, ExpiresIn=SIGNED_UPLOAD_URL_DURATION, HttpMethod="GET", ) diff --git a/zerver/tests/test_upload.py b/zerver/tests/test_upload.py index 811cc1bf9c..8db2138e7a 100644 --- a/zerver/tests/test_upload.py +++ b/zerver/tests/test_upload.py @@ -210,6 +210,12 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase): # requests; they will be first authenticated and redirected self.assert_streaming_content(self.client_get(uri), b"zulip!") + # Check the download endpoint + download_uri = uri.replace("/user_uploads/", "/user_uploads/download/") + result = self.client_get(download_uri) + self.assert_streaming_content(result, b"zulip!") + self.assertIn("attachment;", result.headers["Content-Disposition"]) + # check if DB has attachment marked as unclaimed entry = Attachment.objects.get(file_name="zulip.txt") self.assertEqual(entry.is_claimed(), False) @@ -815,7 +821,10 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase): def test_serve_local(self) -> None: def check_xsend_links( - name: str, name_str_for_test: str, content_disposition: str = "" + name: str, + name_str_for_test: str, + content_disposition: str = "", + download: bool = False, ) -> None: with self.settings(SENDFILE_BACKEND="django_sendfile.backends.nginx"): _get_sendfile.cache_clear() # To clearout cached version of backend from djangosendfile @@ -826,6 +835,8 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase): uri = result.json()["uri"] fp_path_id = re.sub("/user_uploads/", "", uri) fp_path = os.path.split(fp_path_id)[0] + if download: + uri = uri.replace("/user_uploads/", "/user_uploads/download/") response = self.client_get(uri) _get_sendfile.cache_clear() assert settings.LOCAL_UPLOADS_DIR is not None @@ -852,6 +863,9 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase): check_xsend_links("zulip.html", "zulip.html", 'filename="zulip.html"') check_xsend_links("zulip.sh", "zulip.sh", 'filename="zulip.sh"') check_xsend_links("zulip.jpeg", "zulip.jpeg") + check_xsend_links( + "zulip.jpeg", "zulip.jpeg", download=True, content_disposition='filename="zulip.jpeg"' + ) check_xsend_links("áéБД.pdf", "%C3%A1%C3%A9%D0%91%D0%94.pdf") check_xsend_links("zulip", "zulip", 'filename="zulip"') @@ -1935,6 +1949,15 @@ class S3Test(ZulipTestCase): key = path[1:] self.assertEqual(b"zulip!", bucket.Object(key).get()["Body"].read()) + # Check the download endpoint + download_uri = uri.replace("/user_uploads/", "/user_uploads/download/") + response = self.client_get(download_uri) + redirect_url = response["Location"] + path = urllib.parse.urlparse(redirect_url).path + assert path.startswith("/") + key = path[1:] + self.assertEqual(b"zulip!", bucket.Object(key).get()["Body"].read()) + # Now try the endpoint that's supposed to return a temporary URL for access # to the file. result = self.client_get("/json" + uri) diff --git a/zerver/views/upload.py b/zerver/views/upload.py index fac2b4fbfb..d018002497 100644 --- a/zerver/views/upload.py +++ b/zerver/views/upload.py @@ -21,15 +21,19 @@ from zerver.lib.upload import ( from zerver.models import UserProfile, validate_attachment_request -def serve_s3(request: HttpRequest, url_path: str, url_only: bool) -> HttpResponse: - url = get_signed_upload_url(url_path) +def serve_s3( + request: HttpRequest, url_path: str, url_only: bool, download: bool = False +) -> HttpResponse: + url = get_signed_upload_url(url_path, download=download) if url_only: return json_success(request, data=dict(url=url)) return redirect(url) -def serve_local(request: HttpRequest, path_id: str, url_only: bool) -> HttpResponse: +def serve_local( + request: HttpRequest, path_id: str, url_only: bool, download: bool = False +) -> HttpResponse: local_path = get_local_file_path(path_id) if local_path is None: return HttpResponseNotFound("

File not found

") @@ -56,7 +60,7 @@ def serve_local(request: HttpRequest, path_id: str, url_only: bool) -> HttpRespo # and filename, see the below docs: # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition mimetype, encoding = guess_type(local_path) - attachment = mimetype not in INLINE_MIME_TYPES + attachment = download or mimetype not in INLINE_MIME_TYPES response = sendfile( request, local_path, attachment=attachment, mimetype=mimetype, encoding=encoding @@ -65,6 +69,12 @@ def serve_local(request: HttpRequest, path_id: str, url_only: bool) -> HttpRespo return response +def serve_file_download_backend( + request: HttpRequest, user_profile: UserProfile, realm_id_str: str, filename: str +) -> HttpRequest: + return serve_file(request, user_profile, realm_id_str, filename, url_only=False, download=True) + + def serve_file_backend( request: HttpRequest, user_profile: UserProfile, realm_id_str: str, filename: str ) -> HttpResponse: @@ -88,6 +98,7 @@ def serve_file( realm_id_str: str, filename: str, url_only: bool = False, + download: bool = False, ) -> HttpResponse: path_id = f"{realm_id_str}/{filename}" is_authorized = validate_attachment_request(user_profile, path_id) @@ -97,9 +108,9 @@ def serve_file( if not is_authorized: return HttpResponseForbidden(_("

You are not authorized to view this file.

")) if settings.LOCAL_UPLOADS_DIR is not None: - return serve_local(request, path_id, url_only) + return serve_local(request, path_id, url_only, download=download) - return serve_s3(request, path_id, url_only) + return serve_s3(request, path_id, url_only, download=download) def serve_local_file_unauthed(request: HttpRequest, token: str, filename: str) -> HttpResponse: diff --git a/zproject/urls.py b/zproject/urls.py index 0049be3578..3f41ba2be1 100644 --- a/zproject/urls.py +++ b/zproject/urls.py @@ -167,6 +167,7 @@ from zerver.views.typing import send_notification_backend from zerver.views.unsubscribe import email_unsubscribe from zerver.views.upload import ( serve_file_backend, + serve_file_download_backend, serve_file_url_backend, serve_local_file_unauthed, upload_file_backend, @@ -669,6 +670,10 @@ urls += [ serve_local_file_unauthed, name="local_file_unauthed", ), + rest_path( + "user_uploads/download//", + GET=(serve_file_download_backend, {"override_api_url_scheme"}), + ), rest_path( "user_uploads//", GET=(serve_file_backend, {"override_api_url_scheme"}),