diff --git a/zerver/lib/request.py b/zerver/lib/request.py index 84f74414ab..989f5818a4 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -31,6 +31,7 @@ import zerver.tornado.handlers as handlers from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError from zerver.lib.notes import BaseNotes from zerver.lib.types import Validator, ViewFuncT +from zerver.lib.validator import check_anything from zerver.models import Client, Realm @@ -168,6 +169,10 @@ class _REQ(Generic[ResultT]): """ + if argument_type == "body" and converter is None and json_validator is None: + # legacy behavior + json_validator = cast(Callable[[str, object], ResultT], check_anything) + self.post_var_name = whence self.func_var_name: Optional[str] = None self.converter = converter @@ -206,6 +211,7 @@ def REQ( *, converter: Callable[[str, str], ResultT], default: ResultT = ..., + argument_type: Optional[Literal["body"]] = ..., intentionally_undocumented: bool = ..., documentation_pending: bool = ..., aliases: Sequence[str] = ..., @@ -221,6 +227,7 @@ def REQ( *, default: ResultT = ..., json_validator: Validator[ResultT], + argument_type: Optional[Literal["body"]] = ..., intentionally_undocumented: bool = ..., documentation_pending: bool = ..., aliases: Sequence[str] = ..., @@ -368,46 +375,45 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: continue assert func_var_name is not None + post_var_name: Optional[str] + if param.argument_type == "body": + post_var_name = "request" try: - val = orjson.loads(request.body) - except orjson.JSONDecodeError: - raise InvalidJSONError(_("Malformed JSON")) - kwargs[func_var_name] = val - continue + val = request.body.decode(request.encoding or "utf-8") + except UnicodeDecodeError: + raise JsonableError(_("Malformed payload")) else: # This is a view bug, not a user error, and thus should throw a 500. assert param.argument_type is None, "Invalid argument type" - post_var_names = [param.post_var_name] - post_var_names += param.aliases + post_var_names = [param.post_var_name] + post_var_names += param.aliases + post_var_name = None - post_var_name: Optional[str] = None - - for req_var in post_var_names: - assert req_var is not None - if req_var in request.POST: - val = request.POST[req_var] - request_notes.processed_parameters.add(req_var) - elif req_var in request.GET: - val = request.GET[req_var] - request_notes.processed_parameters.add(req_var) - else: - # This is covered by test_REQ_aliases, but coverage.py - # fails to recognize this for some reason. - continue # nocoverage - if post_var_name is not None: + for req_var in post_var_names: assert req_var is not None - raise RequestConfusingParmsError(post_var_name, req_var) - post_var_name = req_var + if req_var in request.POST: + val = request.POST[req_var] + request_notes.processed_parameters.add(req_var) + elif req_var in request.GET: + val = request.GET[req_var] + request_notes.processed_parameters.add(req_var) + else: + # This is covered by test_REQ_aliases, but coverage.py + # fails to recognize this for some reason. + continue # nocoverage + if post_var_name is not None: + raise RequestConfusingParmsError(post_var_name, req_var) + post_var_name = req_var - if post_var_name is None: - post_var_name = param.post_var_name - assert post_var_name is not None - if param.default is _REQ.NotSpecified: - raise RequestVariableMissingError(post_var_name) - kwargs[func_var_name] = param.default - continue + if post_var_name is None: + post_var_name = param.post_var_name + assert post_var_name is not None + if param.default is _REQ.NotSpecified: + raise RequestVariableMissingError(post_var_name) + kwargs[func_var_name] = param.default + continue if param.converter is not None: try: @@ -422,6 +428,8 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: try: val = orjson.loads(val) except orjson.JSONDecodeError: + if param.argument_type == "body": + raise InvalidJSONError(_("Malformed JSON")) raise JsonableError(_('Argument "{}" is not valid JSON.').format(post_var_name)) try: diff --git a/zerver/lib/validator.py b/zerver/lib/validator.py index c9138f71e4..7c080da2a3 100644 --- a/zerver/lib/validator.py +++ b/zerver/lib/validator.py @@ -58,6 +58,10 @@ from zerver.lib.types import ProfileFieldData, Validator ResultT = TypeVar("ResultT") +def check_anything(var_name: str, val: object) -> object: + return val + + def check_string(var_name: str, val: object) -> str: if not isinstance(val, str): raise ValidationError(_("{var_name} is not a string").format(var_name=var_name)) diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 61b4500cab..1c0556e4f5 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -248,6 +248,12 @@ class DecoratorTestCase(ZulipTestCase): ) -> HttpResponse: return json_response(data={"payload": payload}) + request = HostRequestMock() + request.body = b"\xde\xad\xbe\xef" + with self.assertRaises(JsonableError) as cm: + get_payload(request) + self.assertEqual(str(cm.exception), "Malformed payload") + request = HostRequestMock() request.body = b"notjson" with self.assertRaises(JsonableError) as cm: