diff --git a/zerver/tests/test_drafts.py b/zerver/tests/test_drafts.py index cf6b1e678c..cf3ce486dc 100644 --- a/zerver/tests/test_drafts.py +++ b/zerver/tests/test_drafts.py @@ -54,6 +54,14 @@ class DraftCreationTests(ZulipTestCase): # conditions. self.assertEqual(Draft.objects.count(), 0) + def test_require_enable_drafts_synchronization(self) -> None: + hamlet = self.example_user("hamlet") + hamlet.enable_drafts_synchronization = False + hamlet.save() + payload = {"drafts": "[]"} + resp = self.api_post(hamlet, "/api/v1/drafts", payload) + self.assert_json_error(resp, "User has disabled synchronizing drafts.") + def test_create_one_stream_draft_properly(self) -> None: hamlet = self.example_user("hamlet") visible_stream_name = self.get_streams(hamlet)[0] @@ -301,6 +309,13 @@ class DraftCreationTests(ZulipTestCase): class DraftEditTests(ZulipTestCase): + def test_require_enable_drafts_synchronization(self) -> None: + hamlet = self.example_user("hamlet") + hamlet.enable_drafts_synchronization = False + hamlet.save() + resp = self.api_patch(hamlet, "/api/v1/drafts/1", {"draft": {}}) + self.assert_json_error(resp, "User has disabled synchronizing drafts.") + def test_edit_draft_successfully(self) -> None: hamlet = self.example_user("hamlet") visible_streams = self.get_streams(hamlet) @@ -405,6 +420,13 @@ class DraftEditTests(ZulipTestCase): class DraftDeleteTests(ZulipTestCase): + def test_require_enable_drafts_synchronization(self) -> None: + hamlet = self.example_user("hamlet") + hamlet.enable_drafts_synchronization = False + hamlet.save() + resp = self.api_delete(hamlet, "/api/v1/drafts/1") + self.assert_json_error(resp, "User has disabled synchronizing drafts.") + def test_delete_draft_successfully(self) -> None: hamlet = self.example_user("hamlet") visible_streams = self.get_streams(hamlet) @@ -488,6 +510,13 @@ class DraftDeleteTests(ZulipTestCase): class DraftFetchTest(ZulipTestCase): + def test_require_enable_drafts_synchronization(self) -> None: + hamlet = self.example_user("hamlet") + hamlet.enable_drafts_synchronization = False + hamlet.save() + resp = self.api_get(hamlet, "/api/v1/drafts") + self.assert_json_error(resp, "User has disabled synchronizing drafts.") + def test_fetch_drafts(self) -> None: self.assertEqual(Draft.objects.count(), 0) diff --git a/zerver/views/drafts.py b/zerver/views/drafts.py index 4c2549ebc8..1af1dd960a 100644 --- a/zerver/views/drafts.py +++ b/zerver/views/drafts.py @@ -1,5 +1,6 @@ import time -from typing import Any, Dict, List, Set +from functools import wraps +from typing import Any, Dict, List, Set, cast from django.core.exceptions import ValidationError from django.http import HttpRequest, HttpResponse @@ -13,6 +14,7 @@ from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.streams import access_stream_by_id from zerver.lib.timestamp import timestamp_to_datetime +from zerver.lib.types import ViewFuncT from zerver.lib.validator import ( check_dict_only, check_float, @@ -86,12 +88,26 @@ def further_validated_draft_dict( } +def draft_endpoint(view_func: ViewFuncT) -> ViewFuncT: + @wraps(view_func) + def draft_view_func( + request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object + ) -> HttpResponse: + if not user_profile.enable_drafts_synchronization: + raise JsonableError(_("User has disabled synchronizing drafts.")) + return view_func(request, user_profile, *args, **kwargs) + + return cast(ViewFuncT, draft_view_func) # https://github.com/python/mypy/issues/1927 + + +@draft_endpoint def fetch_drafts(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: user_drafts = Draft.objects.filter(user_profile=user_profile).order_by("last_edit_time") draft_dicts = [draft.to_dict() for draft in user_drafts] return json_success({"count": user_drafts.count(), "drafts": draft_dicts}) +@draft_endpoint @has_request_variables def create_drafts( request: HttpRequest, @@ -118,6 +134,7 @@ def create_drafts( return json_success({"ids": draft_ids}) +@draft_endpoint @has_request_variables def edit_draft( request: HttpRequest, @@ -140,6 +157,7 @@ def edit_draft( return json_success() +@draft_endpoint def delete_draft(request: HttpRequest, user_profile: UserProfile, draft_id: int) -> HttpResponse: try: draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile)