request: Support converter or json_validator with argument_type="body".

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2021-12-14 20:30:39 -05:00
committed by Tim Abbott
parent 970f22380a
commit 04d772b582
3 changed files with 49 additions and 31 deletions

View File

@@ -31,6 +31,7 @@ import zerver.tornado.handlers as handlers
from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError
from zerver.lib.notes import BaseNotes
from zerver.lib.types import Validator, ViewFuncT
from zerver.lib.validator import check_anything
from zerver.models import Client, Realm
@@ -168,6 +169,10 @@ class _REQ(Generic[ResultT]):
"""
if argument_type == "body" and converter is None and json_validator is None:
# legacy behavior
json_validator = cast(Callable[[str, object], ResultT], check_anything)
self.post_var_name = whence
self.func_var_name: Optional[str] = None
self.converter = converter
@@ -206,6 +211,7 @@ def REQ(
*,
converter: Callable[[str, str], ResultT],
default: ResultT = ...,
argument_type: Optional[Literal["body"]] = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
@@ -221,6 +227,7 @@ def REQ(
*,
default: ResultT = ...,
json_validator: Validator[ResultT],
argument_type: Optional[Literal["body"]] = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
@@ -368,46 +375,45 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
continue
assert func_var_name is not None
post_var_name: Optional[str]
if param.argument_type == "body":
post_var_name = "request"
try:
val = orjson.loads(request.body)
except orjson.JSONDecodeError:
raise InvalidJSONError(_("Malformed JSON"))
kwargs[func_var_name] = val
continue
val = request.body.decode(request.encoding or "utf-8")
except UnicodeDecodeError:
raise JsonableError(_("Malformed payload"))
else:
# This is a view bug, not a user error, and thus should throw a 500.
assert param.argument_type is None, "Invalid argument type"
post_var_names = [param.post_var_name]
post_var_names += param.aliases
post_var_names = [param.post_var_name]
post_var_names += param.aliases
post_var_name = None
post_var_name: Optional[str] = None
for req_var in post_var_names:
assert req_var is not None
if req_var in request.POST:
val = request.POST[req_var]
request_notes.processed_parameters.add(req_var)
elif req_var in request.GET:
val = request.GET[req_var]
request_notes.processed_parameters.add(req_var)
else:
# This is covered by test_REQ_aliases, but coverage.py
# fails to recognize this for some reason.
continue # nocoverage
if post_var_name is not None:
for req_var in post_var_names:
assert req_var is not None
raise RequestConfusingParmsError(post_var_name, req_var)
post_var_name = req_var
if req_var in request.POST:
val = request.POST[req_var]
request_notes.processed_parameters.add(req_var)
elif req_var in request.GET:
val = request.GET[req_var]
request_notes.processed_parameters.add(req_var)
else:
# This is covered by test_REQ_aliases, but coverage.py
# fails to recognize this for some reason.
continue # nocoverage
if post_var_name is not None:
raise RequestConfusingParmsError(post_var_name, req_var)
post_var_name = req_var
if post_var_name is None:
post_var_name = param.post_var_name
assert post_var_name is not None
if param.default is _REQ.NotSpecified:
raise RequestVariableMissingError(post_var_name)
kwargs[func_var_name] = param.default
continue
if post_var_name is None:
post_var_name = param.post_var_name
assert post_var_name is not None
if param.default is _REQ.NotSpecified:
raise RequestVariableMissingError(post_var_name)
kwargs[func_var_name] = param.default
continue
if param.converter is not None:
try:
@@ -422,6 +428,8 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
try:
val = orjson.loads(val)
except orjson.JSONDecodeError:
if param.argument_type == "body":
raise InvalidJSONError(_("Malformed JSON"))
raise JsonableError(_('Argument "{}" is not valid JSON.').format(post_var_name))
try:

View File

@@ -58,6 +58,10 @@ from zerver.lib.types import ProfileFieldData, Validator
ResultT = TypeVar("ResultT")
def check_anything(var_name: str, val: object) -> object:
return val
def check_string(var_name: str, val: object) -> str:
if not isinstance(val, str):
raise ValidationError(_("{var_name} is not a string").format(var_name=var_name))

View File

@@ -248,6 +248,12 @@ class DecoratorTestCase(ZulipTestCase):
) -> HttpResponse:
return json_response(data={"payload": payload})
request = HostRequestMock()
request.body = b"\xde\xad\xbe\xef"
with self.assertRaises(JsonableError) as cm:
get_payload(request)
self.assertEqual(str(cm.exception), "Malformed payload")
request = HostRequestMock()
request.body = b"notjson"
with self.assertRaises(JsonableError) as cm: