mirror of
https://github.com/zulip/zulip.git
synced 2025-11-22 15:31:20 +00:00
drafts: Migrate drafts to use @typed_endpoint.
This demonstrates the use of BaseModel to replace a check_dict_only validator. We also add support to referring to $defs in the OpenAPI tests. In the future, we can descend down each object instead of mapping them to dict for more accurate checks.
This commit is contained in:
committed by
Tim Abbott
parent
4701f290f7
commit
910f69465c
@@ -1,11 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Set
|
from typing import Any, Callable, Dict, List, Literal, Union
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing_extensions import Annotated, Concatenate, ParamSpec
|
||||||
|
|
||||||
from zerver.lib.addressee import get_user_profiles_by_ids
|
from zerver.lib.addressee import get_user_profiles_by_ids
|
||||||
from zerver.lib.exceptions import JsonableError, ResourceNotFoundError
|
from zerver.lib.exceptions import JsonableError, ResourceNotFoundError
|
||||||
@@ -13,48 +14,36 @@ from zerver.lib.message import normalize_body, truncate_topic
|
|||||||
from zerver.lib.recipient_users import recipient_for_user_profiles
|
from zerver.lib.recipient_users import recipient_for_user_profiles
|
||||||
from zerver.lib.streams import access_stream_by_id
|
from zerver.lib.streams import access_stream_by_id
|
||||||
from zerver.lib.timestamp import timestamp_to_datetime
|
from zerver.lib.timestamp import timestamp_to_datetime
|
||||||
from zerver.lib.validator import (
|
from zerver.lib.typed_endpoint import RequiredStringConstraint
|
||||||
check_dict_only,
|
|
||||||
check_float,
|
|
||||||
check_int,
|
|
||||||
check_list,
|
|
||||||
check_required_string,
|
|
||||||
check_string,
|
|
||||||
check_string_in,
|
|
||||||
check_union,
|
|
||||||
)
|
|
||||||
from zerver.models import Draft, UserProfile
|
from zerver.models import Draft, UserProfile
|
||||||
from zerver.tornado.django_api import send_event
|
from zerver.tornado.django_api import send_event
|
||||||
|
|
||||||
ParamT = ParamSpec("ParamT")
|
ParamT = ParamSpec("ParamT")
|
||||||
VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"}
|
|
||||||
|
|
||||||
# A validator to verify if the structure (syntax) of a dictionary
|
|
||||||
# meets the requirements to be a draft dictionary:
|
class DraftData(BaseModel):
|
||||||
draft_dict_validator = check_dict_only(
|
model_config = ConfigDict(extra="forbid")
|
||||||
required_keys=[
|
|
||||||
("type", check_string_in(VALID_DRAFT_TYPES)),
|
type: Literal["private", "stream", ""]
|
||||||
("to", check_list(check_int)), # The ID of the stream to send to, or a list of user IDs.
|
to: List[int]
|
||||||
("topic", check_string), # This string can simply be empty for private type messages.
|
topic: str
|
||||||
("content", check_required_string),
|
content: Annotated[str, RequiredStringConstraint()]
|
||||||
],
|
timestamp: Union[int, float, None] = None
|
||||||
optional_keys=[
|
|
||||||
("timestamp", check_union([check_int, check_float])), # A Unix timestamp.
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def further_validated_draft_dict(
|
def further_validated_draft_dict(
|
||||||
draft_dict: Dict[str, Any], user_profile: UserProfile
|
draft_dict: DraftData, user_profile: UserProfile
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Take a draft_dict that was already validated by draft_dict_validator then
|
"""Take a draft_dict that was already validated by draft_dict_validator then
|
||||||
further sanitize, validate, and transform it. Ultimately return this "further
|
further sanitize, validate, and transform it. Ultimately return this "further
|
||||||
validated" draft dict. It will have a slightly different set of keys the values
|
validated" draft dict. It will have a slightly different set of keys the values
|
||||||
for which can be used to directly create a Draft object."""
|
for which can be used to directly create a Draft object."""
|
||||||
|
|
||||||
content = normalize_body(draft_dict["content"])
|
content = normalize_body(draft_dict.content)
|
||||||
|
|
||||||
timestamp = draft_dict.get("timestamp", time.time())
|
timestamp = draft_dict.timestamp
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = time.time()
|
||||||
timestamp = round(timestamp, 6)
|
timestamp = round(timestamp, 6)
|
||||||
if timestamp < 0:
|
if timestamp < 0:
|
||||||
# While it's not exactly an invalid timestamp, it's not something
|
# While it's not exactly an invalid timestamp, it's not something
|
||||||
@@ -64,16 +53,16 @@ def further_validated_draft_dict(
|
|||||||
|
|
||||||
topic = ""
|
topic = ""
|
||||||
recipient_id = None
|
recipient_id = None
|
||||||
to = draft_dict["to"]
|
to = draft_dict.to
|
||||||
if draft_dict["type"] == "stream":
|
if draft_dict.type == "stream":
|
||||||
topic = truncate_topic(draft_dict["topic"])
|
topic = truncate_topic(draft_dict.topic)
|
||||||
if "\0" in topic:
|
if "\0" in topic:
|
||||||
raise JsonableError(_("Topic must not contain null bytes"))
|
raise JsonableError(_("Topic must not contain null bytes"))
|
||||||
if len(to) != 1:
|
if len(to) != 1:
|
||||||
raise JsonableError(_("Must specify exactly 1 stream ID for stream messages"))
|
raise JsonableError(_("Must specify exactly 1 stream ID for stream messages"))
|
||||||
stream, sub = access_stream_by_id(user_profile, to[0])
|
stream, sub = access_stream_by_id(user_profile, to[0])
|
||||||
recipient_id = stream.recipient_id
|
recipient_id = stream.recipient_id
|
||||||
elif draft_dict["type"] == "private" and len(to) != 0:
|
elif draft_dict.type == "private" and len(to) != 0:
|
||||||
to_users = get_user_profiles_by_ids(set(to), user_profile.realm)
|
to_users = get_user_profiles_by_ids(set(to), user_profile.realm)
|
||||||
try:
|
try:
|
||||||
recipient_id = recipient_for_user_profiles(to_users, False, None, user_profile).id
|
recipient_id = recipient_for_user_profiles(to_users, False, None, user_profile).id
|
||||||
@@ -106,14 +95,14 @@ def draft_endpoint(
|
|||||||
return draft_view_func
|
return draft_view_func
|
||||||
|
|
||||||
|
|
||||||
def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]:
|
def do_create_drafts(drafts: List[DraftData], user_profile: UserProfile) -> List[Draft]:
|
||||||
"""Create drafts in bulk for a given user based on the draft dicts. Since
|
"""Create drafts in bulk for a given user based on the draft dicts. Since
|
||||||
currently, the only place this method is being used (apart from tests) is from
|
currently, the only place this method is being used (apart from tests) is from
|
||||||
the create_draft view, we assume that the drafts_dicts are syntactically valid
|
the create_draft view, we assume that the drafts_dicts are syntactically valid
|
||||||
(i.e. they satisfy the draft_dict_validator)."""
|
(i.e. they satisfy the draft_dict_validator)."""
|
||||||
draft_objects = []
|
draft_objects = []
|
||||||
for draft_dict in draft_dicts:
|
for draft in drafts:
|
||||||
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
|
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
|
||||||
draft_objects.append(
|
draft_objects.append(
|
||||||
Draft(
|
Draft(
|
||||||
user_profile=user_profile,
|
user_profile=user_profile,
|
||||||
@@ -136,7 +125,7 @@ def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfil
|
|||||||
return created_draft_objects
|
return created_draft_objects
|
||||||
|
|
||||||
|
|
||||||
def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserProfile) -> None:
|
def do_edit_draft(draft_id: int, draft: DraftData, user_profile: UserProfile) -> None:
|
||||||
"""Edit/update a single draft for a given user. Since the only place this method is being
|
"""Edit/update a single draft for a given user. Since the only place this method is being
|
||||||
used from (apart from tests) is the edit_draft view, we assume that the drafts_dict is
|
used from (apart from tests) is the edit_draft view, we assume that the drafts_dict is
|
||||||
syntactically valid (i.e. it satisfies the draft_dict_validator)."""
|
syntactically valid (i.e. it satisfies the draft_dict_validator)."""
|
||||||
@@ -144,7 +133,7 @@ def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserP
|
|||||||
draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile)
|
draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile)
|
||||||
except Draft.DoesNotExist:
|
except Draft.DoesNotExist:
|
||||||
raise ResourceNotFoundError(_("Draft does not exist"))
|
raise ResourceNotFoundError(_("Draft does not exist"))
|
||||||
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
|
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
|
||||||
draft_object.content = valid_draft_dict["content"]
|
draft_object.content = valid_draft_dict["content"]
|
||||||
draft_object.topic = valid_draft_dict["topic"]
|
draft_object.topic = valid_draft_dict["topic"]
|
||||||
draft_object.recipient_id = valid_draft_dict["recipient_id"]
|
draft_object.recipient_id = valid_draft_dict["recipient_id"]
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ from zerver.actions.users import (
|
|||||||
do_update_outgoing_webhook_service,
|
do_update_outgoing_webhook_service,
|
||||||
)
|
)
|
||||||
from zerver.actions.video_calls import do_set_zoom_token
|
from zerver.actions.video_calls import do_set_zoom_token
|
||||||
from zerver.lib.drafts import do_create_drafts, do_delete_draft, do_edit_draft
|
from zerver.lib.drafts import DraftData, do_create_drafts, do_delete_draft, do_edit_draft
|
||||||
from zerver.lib.event_schema import (
|
from zerver.lib.event_schema import (
|
||||||
check_alert_words,
|
check_alert_words,
|
||||||
check_attachment_add,
|
check_attachment_add,
|
||||||
@@ -3497,39 +3497,39 @@ class DraftActionTest(BaseAction):
|
|||||||
|
|
||||||
def test_draft_create_event(self) -> None:
|
def test_draft_create_event(self) -> None:
|
||||||
self.do_enable_drafts_synchronization(self.user_profile)
|
self.do_enable_drafts_synchronization(self.user_profile)
|
||||||
dummy_draft = {
|
dummy_draft = DraftData(
|
||||||
"type": "draft",
|
type="",
|
||||||
"to": "",
|
to=[],
|
||||||
"topic": "",
|
topic="",
|
||||||
"content": "Sample draft content",
|
content="Sample draft content",
|
||||||
"timestamp": 1596820995,
|
timestamp=1596820995,
|
||||||
}
|
)
|
||||||
action = lambda: do_create_drafts([dummy_draft], self.user_profile)
|
action = lambda: do_create_drafts([dummy_draft], self.user_profile)
|
||||||
self.verify_action(action)
|
self.verify_action(action)
|
||||||
|
|
||||||
def test_draft_edit_event(self) -> None:
|
def test_draft_edit_event(self) -> None:
|
||||||
self.do_enable_drafts_synchronization(self.user_profile)
|
self.do_enable_drafts_synchronization(self.user_profile)
|
||||||
dummy_draft = {
|
dummy_draft = DraftData(
|
||||||
"type": "draft",
|
type="",
|
||||||
"to": "",
|
to=[],
|
||||||
"topic": "",
|
topic="",
|
||||||
"content": "Sample draft content",
|
content="Sample draft content",
|
||||||
"timestamp": 1596820995,
|
timestamp=1596820995,
|
||||||
}
|
)
|
||||||
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
||||||
dummy_draft["content"] = "Some more sample draft content"
|
dummy_draft.content = "Some more sample draft content"
|
||||||
action = lambda: do_edit_draft(draft_id, dummy_draft, self.user_profile)
|
action = lambda: do_edit_draft(draft_id, dummy_draft, self.user_profile)
|
||||||
self.verify_action(action)
|
self.verify_action(action)
|
||||||
|
|
||||||
def test_draft_delete_event(self) -> None:
|
def test_draft_delete_event(self) -> None:
|
||||||
self.do_enable_drafts_synchronization(self.user_profile)
|
self.do_enable_drafts_synchronization(self.user_profile)
|
||||||
dummy_draft = {
|
dummy_draft = DraftData(
|
||||||
"type": "draft",
|
type="",
|
||||||
"to": "",
|
to=[],
|
||||||
"topic": "",
|
topic="",
|
||||||
"content": "Sample draft content",
|
content="Sample draft content",
|
||||||
"timestamp": 1596820995,
|
timestamp=1596820995,
|
||||||
}
|
)
|
||||||
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
||||||
action = lambda: do_delete_draft(draft_id, self.user_profile)
|
action = lambda: do_delete_draft(draft_id, self.user_profile)
|
||||||
self.verify_action(action)
|
self.verify_action(action)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
@@ -57,17 +58,21 @@ VARMAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]:
|
def schema_type(
|
||||||
|
schema: Dict[str, Any], defs: Mapping[str, Any] = {}
|
||||||
|
) -> Union[type, Tuple[type, object]]:
|
||||||
if "oneOf" in schema:
|
if "oneOf" in schema:
|
||||||
# Hack: Just use the type of the first value
|
# Hack: Just use the type of the first value
|
||||||
# Ideally, we'd turn this into a Union type.
|
# Ideally, we'd turn this into a Union type.
|
||||||
return schema_type(schema["oneOf"][0])
|
return schema_type(schema["oneOf"][0], defs)
|
||||||
elif "anyOf" in schema:
|
elif "anyOf" in schema:
|
||||||
return schema_type(schema["anyOf"][0])
|
return schema_type(schema["anyOf"][0], defs)
|
||||||
elif schema.get("contentMediaType") == "application/json":
|
elif schema.get("contentMediaType") == "application/json":
|
||||||
return schema_type(schema["contentSchema"])
|
return schema_type(schema["contentSchema"], defs)
|
||||||
|
elif "$ref" in schema:
|
||||||
|
return schema_type(defs[schema["$ref"]], defs)
|
||||||
elif schema["type"] == "array":
|
elif schema["type"] == "array":
|
||||||
return (list, schema_type(schema["items"]))
|
return (list, schema_type(schema["items"], defs))
|
||||||
else:
|
else:
|
||||||
return VARMAP[schema["type"]]
|
return VARMAP[schema["type"]]
|
||||||
|
|
||||||
@@ -439,7 +444,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
|||||||
openapi_params.add((expected_request_var_name, schema_type(expected_param_schema)))
|
openapi_params.add((expected_request_var_name, schema_type(expected_param_schema)))
|
||||||
|
|
||||||
for actual_param in parse_view_func_signature(function).parameters:
|
for actual_param in parse_view_func_signature(function).parameters:
|
||||||
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema()
|
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema(
|
||||||
|
ref_template="{model}"
|
||||||
|
)
|
||||||
|
defs_mapping = actual_param_schema.get("$defs", {})
|
||||||
# The content type of the JSON schema generated from the
|
# The content type of the JSON schema generated from the
|
||||||
# function parameter type annotation should have content type
|
# function parameter type annotation should have content type
|
||||||
# matching that of our OpenAPI spec. If not so, hint that the
|
# matching that of our OpenAPI spec. If not so, hint that the
|
||||||
@@ -467,7 +475,9 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
|||||||
(int, bool),
|
(int, bool),
|
||||||
f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.',
|
f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.',
|
||||||
)
|
)
|
||||||
function_params.add((actual_param.request_var_name, schema_type(actual_param_schema)))
|
function_params.add(
|
||||||
|
(actual_param.request_var_name, schema_type(actual_param_schema, defs_mapping))
|
||||||
|
)
|
||||||
|
|
||||||
diff = openapi_params - function_params
|
diff = openapi_params - function_params
|
||||||
if diff: # nocoverage
|
if diff: # nocoverage
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import List
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
|
from pydantic import Json
|
||||||
|
|
||||||
from zerver.lib.drafts import (
|
from zerver.lib.drafts import (
|
||||||
|
DraftData,
|
||||||
do_create_drafts,
|
do_create_drafts,
|
||||||
do_delete_draft,
|
do_delete_draft,
|
||||||
do_edit_draft,
|
do_edit_draft,
|
||||||
draft_dict_validator,
|
|
||||||
draft_endpoint,
|
draft_endpoint,
|
||||||
)
|
)
|
||||||
from zerver.lib.request import REQ, has_request_variables
|
|
||||||
from zerver.lib.response import json_success
|
from zerver.lib.response import json_success
|
||||||
from zerver.lib.validator import check_list
|
from zerver.lib.typed_endpoint import PathOnly, typed_endpoint
|
||||||
from zerver.models import Draft, UserProfile
|
from zerver.models import Draft, UserProfile
|
||||||
|
|
||||||
|
|
||||||
@@ -23,32 +23,32 @@ def fetch_drafts(request: HttpRequest, user_profile: UserProfile) -> HttpRespons
|
|||||||
|
|
||||||
|
|
||||||
@draft_endpoint
|
@draft_endpoint
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def create_drafts(
|
def create_drafts(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
user_profile: UserProfile,
|
user_profile: UserProfile,
|
||||||
draft_dicts: List[Dict[str, Any]] = REQ(
|
*,
|
||||||
"drafts", json_validator=check_list(draft_dict_validator)
|
drafts: Json[List[DraftData]],
|
||||||
),
|
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
created_draft_objects = do_create_drafts(draft_dicts, user_profile)
|
created_draft_objects = do_create_drafts(drafts, user_profile)
|
||||||
draft_ids = [draft_object.id for draft_object in created_draft_objects]
|
draft_ids = [draft_object.id for draft_object in created_draft_objects]
|
||||||
return json_success(request, data={"ids": draft_ids})
|
return json_success(request, data={"ids": draft_ids})
|
||||||
|
|
||||||
|
|
||||||
@draft_endpoint
|
@draft_endpoint
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def edit_draft(
|
def edit_draft(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
user_profile: UserProfile,
|
user_profile: UserProfile,
|
||||||
draft_id: int,
|
*,
|
||||||
draft_dict: Dict[str, Any] = REQ("draft", json_validator=draft_dict_validator),
|
draft_id: PathOnly[int],
|
||||||
|
draft: Json[DraftData],
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
do_edit_draft(draft_id, draft_dict, user_profile)
|
do_edit_draft(draft_id, draft, user_profile)
|
||||||
return json_success(request)
|
return json_success(request)
|
||||||
|
|
||||||
|
|
||||||
@draft_endpoint
|
@draft_endpoint
|
||||||
def delete_draft(request: HttpRequest, user_profile: UserProfile, draft_id: int) -> HttpResponse:
|
def delete_draft(request: HttpRequest, user_profile: UserProfile, *, draft_id: int) -> HttpResponse:
|
||||||
do_delete_draft(draft_id, user_profile)
|
do_delete_draft(draft_id, user_profile)
|
||||||
return json_success(request)
|
return json_success(request)
|
||||||
|
|||||||
Reference in New Issue
Block a user