request: Support converter or json_validator with argument_type="body".

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2021-12-14 20:30:39 -05:00
committed by Tim Abbott
parent 970f22380a
commit 04d772b582
3 changed files with 49 additions and 31 deletions

View File

@@ -31,6 +31,7 @@ import zerver.tornado.handlers as handlers
from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError
from zerver.lib.notes import BaseNotes from zerver.lib.notes import BaseNotes
from zerver.lib.types import Validator, ViewFuncT from zerver.lib.types import Validator, ViewFuncT
from zerver.lib.validator import check_anything
from zerver.models import Client, Realm from zerver.models import Client, Realm
@@ -168,6 +169,10 @@ class _REQ(Generic[ResultT]):
""" """
if argument_type == "body" and converter is None and json_validator is None:
# legacy behavior
json_validator = cast(Callable[[str, object], ResultT], check_anything)
self.post_var_name = whence self.post_var_name = whence
self.func_var_name: Optional[str] = None self.func_var_name: Optional[str] = None
self.converter = converter self.converter = converter
@@ -206,6 +211,7 @@ def REQ(
*, *,
converter: Callable[[str, str], ResultT], converter: Callable[[str, str], ResultT],
default: ResultT = ..., default: ResultT = ...,
argument_type: Optional[Literal["body"]] = ...,
intentionally_undocumented: bool = ..., intentionally_undocumented: bool = ...,
documentation_pending: bool = ..., documentation_pending: bool = ...,
aliases: Sequence[str] = ..., aliases: Sequence[str] = ...,
@@ -221,6 +227,7 @@ def REQ(
*, *,
default: ResultT = ..., default: ResultT = ...,
json_validator: Validator[ResultT], json_validator: Validator[ResultT],
argument_type: Optional[Literal["body"]] = ...,
intentionally_undocumented: bool = ..., intentionally_undocumented: bool = ...,
documentation_pending: bool = ..., documentation_pending: bool = ...,
aliases: Sequence[str] = ..., aliases: Sequence[str] = ...,
@@ -368,46 +375,45 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
continue continue
assert func_var_name is not None assert func_var_name is not None
post_var_name: Optional[str]
if param.argument_type == "body": if param.argument_type == "body":
post_var_name = "request"
try: try:
val = orjson.loads(request.body) val = request.body.decode(request.encoding or "utf-8")
except orjson.JSONDecodeError: except UnicodeDecodeError:
raise InvalidJSONError(_("Malformed JSON")) raise JsonableError(_("Malformed payload"))
kwargs[func_var_name] = val
continue
else: else:
# This is a view bug, not a user error, and thus should throw a 500. # This is a view bug, not a user error, and thus should throw a 500.
assert param.argument_type is None, "Invalid argument type" assert param.argument_type is None, "Invalid argument type"
post_var_names = [param.post_var_name] post_var_names = [param.post_var_name]
post_var_names += param.aliases post_var_names += param.aliases
post_var_name = None
post_var_name: Optional[str] = None for req_var in post_var_names:
for req_var in post_var_names:
assert req_var is not None
if req_var in request.POST:
val = request.POST[req_var]
request_notes.processed_parameters.add(req_var)
elif req_var in request.GET:
val = request.GET[req_var]
request_notes.processed_parameters.add(req_var)
else:
# This is covered by test_REQ_aliases, but coverage.py
# fails to recognize this for some reason.
continue # nocoverage
if post_var_name is not None:
assert req_var is not None assert req_var is not None
raise RequestConfusingParmsError(post_var_name, req_var) if req_var in request.POST:
post_var_name = req_var val = request.POST[req_var]
request_notes.processed_parameters.add(req_var)
elif req_var in request.GET:
val = request.GET[req_var]
request_notes.processed_parameters.add(req_var)
else:
# This is covered by test_REQ_aliases, but coverage.py
# fails to recognize this for some reason.
continue # nocoverage
if post_var_name is not None:
raise RequestConfusingParmsError(post_var_name, req_var)
post_var_name = req_var
if post_var_name is None: if post_var_name is None:
post_var_name = param.post_var_name post_var_name = param.post_var_name
assert post_var_name is not None assert post_var_name is not None
if param.default is _REQ.NotSpecified: if param.default is _REQ.NotSpecified:
raise RequestVariableMissingError(post_var_name) raise RequestVariableMissingError(post_var_name)
kwargs[func_var_name] = param.default kwargs[func_var_name] = param.default
continue continue
if param.converter is not None: if param.converter is not None:
try: try:
@@ -422,6 +428,8 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
try: try:
val = orjson.loads(val) val = orjson.loads(val)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
if param.argument_type == "body":
raise InvalidJSONError(_("Malformed JSON"))
raise JsonableError(_('Argument "{}" is not valid JSON.').format(post_var_name)) raise JsonableError(_('Argument "{}" is not valid JSON.').format(post_var_name))
try: try:

View File

@@ -58,6 +58,10 @@ from zerver.lib.types import ProfileFieldData, Validator
ResultT = TypeVar("ResultT") ResultT = TypeVar("ResultT")
def check_anything(var_name: str, val: object) -> object:
return val
def check_string(var_name: str, val: object) -> str: def check_string(var_name: str, val: object) -> str:
if not isinstance(val, str): if not isinstance(val, str):
raise ValidationError(_("{var_name} is not a string").format(var_name=var_name)) raise ValidationError(_("{var_name} is not a string").format(var_name=var_name))

View File

@@ -248,6 +248,12 @@ class DecoratorTestCase(ZulipTestCase):
) -> HttpResponse: ) -> HttpResponse:
return json_response(data={"payload": payload}) return json_response(data={"payload": payload})
request = HostRequestMock()
request.body = b"\xde\xad\xbe\xef"
with self.assertRaises(JsonableError) as cm:
get_payload(request)
self.assertEqual(str(cm.exception), "Malformed payload")
request = HostRequestMock() request = HostRequestMock()
request.body = b"notjson" request.body = b"notjson"
with self.assertRaises(JsonableError) as cm: with self.assertRaises(JsonableError) as cm: