mirror of
				https://github.com/zulip/zulip.git
				synced 2025-10-30 11:33:51 +00:00 
			
		
		
		
	request: Support converter or json_validator with argument_type="body".
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
		
				
					committed by
					
						 Tim Abbott
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							970f22380a
						
					
				
				
					commit
					04d772b582
				
			| @@ -31,6 +31,7 @@ import zerver.tornado.handlers as handlers | ||||
| from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError | ||||
| from zerver.lib.notes import BaseNotes | ||||
| from zerver.lib.types import Validator, ViewFuncT | ||||
| from zerver.lib.validator import check_anything | ||||
| from zerver.models import Client, Realm | ||||
|  | ||||
|  | ||||
| @@ -168,6 +169,10 @@ class _REQ(Generic[ResultT]): | ||||
|  | ||||
|         """ | ||||
|  | ||||
|         if argument_type == "body" and converter is None and json_validator is None: | ||||
|             # legacy behavior | ||||
|             json_validator = cast(Callable[[str, object], ResultT], check_anything) | ||||
|  | ||||
|         self.post_var_name = whence | ||||
|         self.func_var_name: Optional[str] = None | ||||
|         self.converter = converter | ||||
| @@ -206,6 +211,7 @@ def REQ( | ||||
|     *, | ||||
|     converter: Callable[[str, str], ResultT], | ||||
|     default: ResultT = ..., | ||||
|     argument_type: Optional[Literal["body"]] = ..., | ||||
|     intentionally_undocumented: bool = ..., | ||||
|     documentation_pending: bool = ..., | ||||
|     aliases: Sequence[str] = ..., | ||||
| @@ -221,6 +227,7 @@ def REQ( | ||||
|     *, | ||||
|     default: ResultT = ..., | ||||
|     json_validator: Validator[ResultT], | ||||
|     argument_type: Optional[Literal["body"]] = ..., | ||||
|     intentionally_undocumented: bool = ..., | ||||
|     documentation_pending: bool = ..., | ||||
|     aliases: Sequence[str] = ..., | ||||
| @@ -368,46 +375,45 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: | ||||
|                 continue | ||||
|             assert func_var_name is not None | ||||
|  | ||||
|             post_var_name: Optional[str] | ||||
|  | ||||
|             if param.argument_type == "body": | ||||
|                 post_var_name = "request" | ||||
|                 try: | ||||
|                     val = orjson.loads(request.body) | ||||
|                 except orjson.JSONDecodeError: | ||||
|                     raise InvalidJSONError(_("Malformed JSON")) | ||||
|                 kwargs[func_var_name] = val | ||||
|                 continue | ||||
|                     val = request.body.decode(request.encoding or "utf-8") | ||||
|                 except UnicodeDecodeError: | ||||
|                     raise JsonableError(_("Malformed payload")) | ||||
|             else: | ||||
|                 # This is a view bug, not a user error, and thus should throw a 500. | ||||
|                 assert param.argument_type is None, "Invalid argument type" | ||||
|  | ||||
|             post_var_names = [param.post_var_name] | ||||
|             post_var_names += param.aliases | ||||
|                 post_var_names = [param.post_var_name] | ||||
|                 post_var_names += param.aliases | ||||
|                 post_var_name = None | ||||
|  | ||||
|             post_var_name: Optional[str] = None | ||||
|  | ||||
|             for req_var in post_var_names: | ||||
|                 assert req_var is not None | ||||
|                 if req_var in request.POST: | ||||
|                     val = request.POST[req_var] | ||||
|                     request_notes.processed_parameters.add(req_var) | ||||
|                 elif req_var in request.GET: | ||||
|                     val = request.GET[req_var] | ||||
|                     request_notes.processed_parameters.add(req_var) | ||||
|                 else: | ||||
|                     # This is covered by test_REQ_aliases, but coverage.py | ||||
|                     # fails to recognize this for some reason. | ||||
|                     continue  # nocoverage | ||||
|                 if post_var_name is not None: | ||||
|                 for req_var in post_var_names: | ||||
|                     assert req_var is not None | ||||
|                     raise RequestConfusingParmsError(post_var_name, req_var) | ||||
|                 post_var_name = req_var | ||||
|                     if req_var in request.POST: | ||||
|                         val = request.POST[req_var] | ||||
|                         request_notes.processed_parameters.add(req_var) | ||||
|                     elif req_var in request.GET: | ||||
|                         val = request.GET[req_var] | ||||
|                         request_notes.processed_parameters.add(req_var) | ||||
|                     else: | ||||
|                         # This is covered by test_REQ_aliases, but coverage.py | ||||
|                         # fails to recognize this for some reason. | ||||
|                         continue  # nocoverage | ||||
|                     if post_var_name is not None: | ||||
|                         raise RequestConfusingParmsError(post_var_name, req_var) | ||||
|                     post_var_name = req_var | ||||
|  | ||||
|             if post_var_name is None: | ||||
|                 post_var_name = param.post_var_name | ||||
|                 assert post_var_name is not None | ||||
|                 if param.default is _REQ.NotSpecified: | ||||
|                     raise RequestVariableMissingError(post_var_name) | ||||
|                 kwargs[func_var_name] = param.default | ||||
|                 continue | ||||
|                 if post_var_name is None: | ||||
|                     post_var_name = param.post_var_name | ||||
|                     assert post_var_name is not None | ||||
|                     if param.default is _REQ.NotSpecified: | ||||
|                         raise RequestVariableMissingError(post_var_name) | ||||
|                     kwargs[func_var_name] = param.default | ||||
|                     continue | ||||
|  | ||||
|             if param.converter is not None: | ||||
|                 try: | ||||
| @@ -422,6 +428,8 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: | ||||
|                 try: | ||||
|                     val = orjson.loads(val) | ||||
|                 except orjson.JSONDecodeError: | ||||
|                     if param.argument_type == "body": | ||||
|                         raise InvalidJSONError(_("Malformed JSON")) | ||||
|                     raise JsonableError(_('Argument "{}" is not valid JSON.').format(post_var_name)) | ||||
|  | ||||
|                 try: | ||||
|   | ||||
| @@ -58,6 +58,10 @@ from zerver.lib.types import ProfileFieldData, Validator | ||||
| ResultT = TypeVar("ResultT") | ||||
|  | ||||
|  | ||||
| def check_anything(var_name: str, val: object) -> object: | ||||
|     return val | ||||
|  | ||||
|  | ||||
| def check_string(var_name: str, val: object) -> str: | ||||
|     if not isinstance(val, str): | ||||
|         raise ValidationError(_("{var_name} is not a string").format(var_name=var_name)) | ||||
|   | ||||
| @@ -248,6 +248,12 @@ class DecoratorTestCase(ZulipTestCase): | ||||
|         ) -> HttpResponse: | ||||
|             return json_response(data={"payload": payload}) | ||||
|  | ||||
|         request = HostRequestMock() | ||||
|         request.body = b"\xde\xad\xbe\xef" | ||||
|         with self.assertRaises(JsonableError) as cm: | ||||
|             get_payload(request) | ||||
|         self.assertEqual(str(cm.exception), "Malformed payload") | ||||
|  | ||||
|         request = HostRequestMock() | ||||
|         request.body = b"notjson" | ||||
|         with self.assertRaises(JsonableError) as cm: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user