endpoints: Remove the has_request_variables decorator.

All endpoints have been migrated to the typed_endpoint decorator,
therefore the has_request_variables decorator and the REQ function are
no longer needed and have been removed.
This commit is contained in:
Kenneth Rodrigues
2024-09-03 21:37:09 +05:30
committed by Tim Abbott
parent 88e1810733
commit dc32396180
10 changed files with 15 additions and 719 deletions

View File

@@ -1,4 +0,0 @@
# One of the ways user-controlled data enters the application is through the
# request variables framework. This model teaches Pysa that every instance of
# 'REQ()' in a view function is a source of UserControlled taint.
class zerver.lib.request._REQ(TaintSource[UserControlled]): ...

View File

@@ -354,10 +354,6 @@ python_rules = RuleList(
"exclude": {"zerver/tests", "zerver/views/development/"},
"description": "Argument to JsonableError should be a literal string enclosed by _()",
},
{
"pattern": r"""([a-zA-Z0-9_]+)=REQ\(['"]\1['"]""",
"description": "REQ's first argument already defaults to parameter name",
},
{
"pattern": r"self\.client\.(get|post|patch|put|delete)",
"description": """Do not call self.client directly for put/patch/post/get.

View File

@@ -179,24 +179,6 @@ rules:
severity: ERROR
message: "Prefer {named} fields over positional {} in translated strings"
- id: mutable-default-type
languages: [python]
pattern-either:
- pattern: |
def $F(..., $A: typing.List[...] = zerver.lib.request.REQ(..., default=[...], ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.List[...]] = zerver.lib.request.REQ(..., default=[...], ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Dict[...] = zerver.lib.request.REQ(..., default={}, ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.Dict[...]] = zerver.lib.request.REQ(..., default={}, ...), ...) -> ...:
...
severity: ERROR
message: "Guard mutable default with read-only type (Sequence, Mapping, AbstractSet)"
- id: percent-formatting
languages: [python]
pattern-either:

View File

@@ -1,23 +1,16 @@
from collections import defaultdict
from collections.abc import Callable, MutableMapping, Sequence
from collections.abc import MutableMapping
from dataclasses import dataclass, field
from functools import wraps
from types import FunctionType
from typing import Any, Concatenate, Generic, Literal, Optional, TypeVar, cast, overload
from typing import Any, Optional
import orjson
from django.conf import settings
from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
from typing_extensions import ParamSpec, override
from typing_extensions import override
from zerver.lib import rate_limiter
from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError
from zerver.lib.exceptions import ErrorCode, JsonableError
from zerver.lib.notes import BaseNotes
from zerver.lib.response import MutableJsonResponse
from zerver.lib.types import Validator
from zerver.lib.validator import check_anything
from zerver.models import Client, Realm
if settings.ZILENCER_ENABLED:
@@ -105,376 +98,4 @@ class RequestVariableConversionError(JsonableError):
return _("Bad value for '{var_name}': {bad_value}")
# Used in conjunction with @has_request_variables, below
ResultT = TypeVar("ResultT")
class _REQ(Generic[ResultT]):
# NotSpecified is a sentinel value for determining whether a
# default value was specified for a request variable. We can't
# use None because that could be a valid, user-specified default
class _NotSpecified:
pass
NotSpecified = _NotSpecified()
def __init__(
self,
whence: str | None = None,
*,
converter: Callable[[str, str], ResultT] | None = None,
default: _NotSpecified | ResultT | None = NotSpecified,
json_validator: Validator[ResultT] | None = None,
str_validator: Validator[ResultT] | None = None,
argument_type: str | None = None,
intentionally_undocumented: bool = False,
documentation_pending: bool = False,
aliases: Sequence[str] = [],
path_only: bool = False,
) -> None:
"""whence: the name of the request variable that should be used
for this parameter. Defaults to a request variable of the
same name as the parameter.
converter: a function that takes a string and returns a new
value. If specified, this will be called on the request
variable value before passing to the function
default: a value to be used for the argument if the parameter
is missing in the request
json_validator: similar to converter, but takes an already
parsed JSON data structure. If specified, we will parse the
JSON request variable value before passing to the function
str_validator: Like json_validator, but doesn't parse JSON
first.
argument_type: pass 'body' to extract the parsed JSON
corresponding to the request body
aliases: alternate names for the POST var
path_only: Used for parameters included in the URL that we still want
to validate via REQ's hooks.
"""
if argument_type == "body" and converter is None and json_validator is None:
# legacy behavior
json_validator = cast(Callable[[str, object], ResultT], check_anything)
self.post_var_name = whence
self.func_var_name: str | None = None
self.converter = converter
self.json_validator = json_validator
self.str_validator = str_validator
self.default = default
self.argument_type = argument_type
self.aliases = aliases
self.intentionally_undocumented = intentionally_undocumented
self.documentation_pending = documentation_pending
self.path_only = path_only
assert converter is None or (
json_validator is None and str_validator is None
), "converter and json_validator are mutually exclusive"
assert (
json_validator is None or str_validator is None
), "json_validator and str_validator are mutually exclusive"
# This factory function ensures that mypy can correctly analyze REQ.
#
# Note that REQ claims to return a type matching that of the parameter
# of which it is the default value, allowing type checking of view
# functions using has_request_variables. In reality, REQ returns an
# instance of class _REQ to enable the decorator to scan the parameter
# list for _REQ objects and patch the parameters as the true types.
#
# See also this documentation to learn how @overload helps here.
# https://zulip.readthedocs.io/en/latest/testing/mypy.html#using-overload-to-accurately-describe-variations
#
# Overload 1: converter
@overload
def REQ(
whence: str | None = ...,
*,
converter: Callable[[str, str], ResultT],
default: ResultT = ...,
argument_type: Literal["body"] | None = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
path_only: bool = ...,
) -> ResultT: ...
# Overload 2: json_validator
@overload
def REQ(
whence: str | None = ...,
*,
default: ResultT = ...,
json_validator: Validator[ResultT],
argument_type: Literal["body"] | None = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
path_only: bool = ...,
) -> ResultT: ...
# Overload 3: no converter/json_validator, default: str or unspecified, argument_type=None
@overload
def REQ(
whence: str | None = ...,
*,
default: str = ...,
str_validator: Validator[str] | None = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
path_only: bool = ...,
) -> str: ...
# Overload 4: no converter/validator, default=None, argument_type=None
@overload
def REQ(
whence: str | None = ...,
*,
default: None,
str_validator: Validator[str] | None = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
path_only: bool = ...,
) -> str | None: ...
# Overload 5: argument_type="body"
@overload
def REQ(
whence: str | None = ...,
*,
default: ResultT = ...,
argument_type: Literal["body"],
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Sequence[str] = ...,
path_only: bool = ...,
) -> ResultT: ...
# Implementation
def REQ(
whence: str | None = None,
*,
converter: Callable[[str, str], ResultT] | None = None,
default: _REQ._NotSpecified | ResultT = _REQ.NotSpecified,
json_validator: Validator[ResultT] | None = None,
str_validator: Validator[ResultT] | None = None,
argument_type: str | None = None,
intentionally_undocumented: bool = False,
documentation_pending: bool = False,
aliases: Sequence[str] = [],
path_only: bool = False,
) -> ResultT:
return cast(
ResultT,
_REQ(
whence,
converter=converter,
default=default,
json_validator=json_validator,
str_validator=str_validator,
argument_type=argument_type,
intentionally_undocumented=intentionally_undocumented,
documentation_pending=documentation_pending,
aliases=aliases,
path_only=path_only,
),
)
arguments_map: dict[str, list[str]] = defaultdict(list)
ParamT = ParamSpec("ParamT")
ReturnT = TypeVar("ReturnT")
# Extracts variables from the request object and passes them as
# named function arguments. The request object must be the first
# argument to the function.
#
# To use, assign a function parameter a default value that is an
# instance of the _REQ class. That parameter will then be automatically
# populated from the HTTP request. The request object must be the
# first argument to the decorated function.
#
# This should generally be the innermost (syntactically bottommost)
# decorator applied to a view, since other decorators won't preserve
# the default parameter values used by has_request_variables.
#
# Note that this can't be used in helper functions which are not
# expected to call json_success or raise JsonableError, as it uses JsonableError
# internally when it encounters an error
def has_request_variables(
req_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT],
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
num_params = req_func.__code__.co_argcount
default_param_values = cast(FunctionType, req_func).__defaults__
if default_param_values is None: # nocoverage # No users of this path.
default_param_values = ()
num_default_params = len(default_param_values)
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
post_params = []
view_func_full_name = f"{req_func.__module__}.{req_func.__name__}"
for name, value in zip(default_param_names, default_param_values, strict=False):
if isinstance(value, _REQ):
value.func_var_name = name
if value.post_var_name is None:
value.post_var_name = name
post_params.append(value)
# Record arguments that should be documented so that our
# automated OpenAPI docs tests can compare these against the code.
if (
not value.intentionally_undocumented
and not value.documentation_pending
and not value.path_only
):
arguments_map[view_func_full_name].append(value.post_var_name)
@wraps(req_func)
def _wrapped_req_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> ReturnT:
request_notes = RequestNotes.get_notes(request)
for param in post_params:
func_var_name = param.func_var_name
if param.path_only:
# For path_only parameters, they should already have
# been passed via the URL, so there's no need for REQ
# to do anything.
#
# TODO: Either run validators for path_only parameters
# or don't declare them using REQ.
# no coverage because has_request_variables will be removed once
# all the endpoints have been migrated to use typed_endpoint.
assert func_var_name in kwargs # nocoverage
if func_var_name in kwargs:
continue # nocoverage
assert func_var_name is not None
post_var_name: str | None
if param.argument_type == "body":
post_var_name = "request"
try:
val = request.body.decode(request.encoding or "utf-8")
except UnicodeDecodeError:
raise JsonableError(_("Malformed payload"))
else:
# This is a view bug, not a user error, and thus should throw a 500.
assert param.argument_type is None, "Invalid argument type"
post_var_names = [param.post_var_name]
post_var_names += param.aliases
post_var_name = None
for req_var in post_var_names:
assert req_var is not None
if req_var in request.POST:
val = request.POST[req_var]
request_notes.processed_parameters.add(req_var)
elif req_var in request.GET: # nocoverage # No users of this path
val = request.GET[req_var]
request_notes.processed_parameters.add(req_var)
else:
# This is covered by test_REQ_aliases, but coverage.py
# fails to recognize this for some reason.
continue # nocoverage
if post_var_name is not None:
raise RequestConfusingParamsError(post_var_name, req_var)
post_var_name = req_var
if post_var_name is None:
post_var_name = param.post_var_name
assert post_var_name is not None
if param.default is _REQ.NotSpecified:
raise RequestVariableMissingError(post_var_name)
kwargs[func_var_name] = param.default
continue
if param.converter is not None:
try:
val = param.converter(post_var_name, val)
except JsonableError:
raise
except Exception:
raise RequestVariableConversionError(post_var_name, val)
# json_validator is like converter, but doesn't handle JSON parsing; we do.
if param.json_validator is not None:
try:
val = orjson.loads(val)
except orjson.JSONDecodeError:
if param.argument_type == "body":
raise InvalidJSONError(_("Malformed JSON"))
raise JsonableError(
_('Argument "{name}" is not valid JSON.').format(name=post_var_name)
)
try:
val = param.json_validator(post_var_name, val)
except ValidationError as error:
raise JsonableError(error.message)
# str_validators is like json_validator, but for direct strings (no JSON parsing).
if param.str_validator is not None:
try:
val = param.str_validator(post_var_name, val)
except ValidationError as error:
raise JsonableError(error.message)
kwargs[func_var_name] = val
return_value = req_func(request, *args, **kwargs)
if (
isinstance(return_value, MutableJsonResponse)
and not request_notes.is_webhook_view
# Implemented only for 200 responses.
# TODO: Implement returning unsupported ignored parameters for 400
# JSON error responses. This is complex because has_request_variables
# can be called multiple times, so when an error response is raised,
# there may be supported parameters that have not yet been processed,
# which could lead to inaccurate output.
and 200 <= return_value.status_code < 300
):
ignored_parameters = {*request.POST, *request.GET}.difference(
request_notes.processed_parameters
)
# This will be called each time a function decorated with
# has_request_variables returns a MutableJsonResponse with a
# success status_code. Because a shared processed_parameters
# value is checked each time, the value for the
# ignored_parameters_unsupported key is either added/updated
# to the response data or it is removed in the case that all
# of the request parameters have been processed.
if ignored_parameters:
return_value.get_data()["ignored_parameters_unsupported"] = sorted(
ignored_parameters
)
else:
return_value.get_data().pop("ignored_parameters_unsupported", None)
return return_value
return _wrapped_req_func

View File

@@ -24,7 +24,6 @@ from typing_extensions import ParamSpec
from zerver.lib.exceptions import ApiParamValidationError, JsonableError
from zerver.lib.request import (
_REQ,
RequestConfusingParamsError,
RequestNotes,
RequestVariableMissingError,
@@ -461,9 +460,6 @@ def typed_endpoint(
view_func_name=endpoint_info.view_func_full_name
)
for func_param in endpoint_info.parameters:
assert not isinstance(
func_param.default, _REQ
), f"Unexpected REQ for parameter {func_param.param_name}; REQ is incompatible with typed_endpoint"
if func_param.path_only:
assert (
func_param.default is NotSpecified
@@ -579,6 +575,4 @@ def typed_endpoint(
return return_value
# TODO: Remove this once we replace has_request_variables with typed_endpoint.
_wrapped_view_func.use_endpoint = True # type: ignore[attr-defined] # Distinguish functions decorated with @typed_endpoint from those decorated with has_request_variables
return _wrapped_view_func

View File

@@ -28,7 +28,6 @@ for any particular type of object.
"""
import re
from collections.abc import Collection, Container, Iterator
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -174,17 +173,6 @@ def check_bool(var_name: str, val: object) -> bool:
return val
def check_color(var_name: str, val: object) -> str:
s = check_string(var_name, val)
valid_color_pattern = re.compile(r"^#([a-fA-F0-9]{3,6})$")
matched_results = valid_color_pattern.match(s)
if not matched_results:
raise ValidationError(
_("{var_name} is not a valid hex color code").format(var_name=var_name)
)
return s
def check_none_or(sub_validator: Validator[ResultT]) -> Validator[ResultT | None]:
def f(var_name: str, val: object) -> ResultT | None:
if val is None:
@@ -562,16 +550,6 @@ def validate_todo_data(todo_data: object, is_widget_author: bool) -> None:
raise ValidationError(f"Unknown type for todo data: {todo_data['type']}")
# Converter functions for use with has_request_variables
def to_non_negative_int(var_name: str, s: str, max_int_size: int = 2**32 - 1) -> int:
x = int(s)
if x < 0:
raise ValueError("argument is negative")
if x > max_int_size:
raise ValueError(f"{x} is too large (max {max_int_size})")
return x
def check_string_or_int_list(var_name: str, val: object) -> str | list[int]:
if isinstance(val, str):
return val

View File

@@ -1,247 +0,0 @@
from collections.abc import Sequence
from typing import Any
import orjson
from django.http import HttpRequest, HttpResponse
from zerver.lib.exceptions import JsonableError
from zerver.lib.request import (
REQ,
RequestConfusingParamsError,
RequestVariableConversionError,
RequestVariableMissingError,
has_request_variables,
)
from zerver.lib.response import MutableJsonResponse, json_response, json_success
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import HostRequestMock
from zerver.lib.validator import check_bool, check_int, check_list, check_string_fixed_length
class REQTestCase(ZulipTestCase):
def test_REQ_aliases(self) -> None:
@has_request_variables
def double(
request: HttpRequest,
x: int = REQ(whence="number", aliases=["x", "n"], json_validator=check_int),
) -> HttpResponse:
return json_response(data={"number": x + x})
request = HostRequestMock(post_data={"bogus": "5555"})
with self.assertRaises(RequestVariableMissingError):
double(request)
request = HostRequestMock(post_data={"number": "3"})
self.assertEqual(orjson.loads(double(request).content).get("number"), 6)
request = HostRequestMock(post_data={"x": "4"})
self.assertEqual(orjson.loads(double(request).content).get("number"), 8)
request = HostRequestMock(post_data={"n": "5"})
self.assertEqual(orjson.loads(double(request).content).get("number"), 10)
request = HostRequestMock(post_data={"number": "6", "x": "7"})
with self.assertRaises(RequestConfusingParamsError) as cm:
double(request)
self.assertEqual(str(cm.exception), "Can't decide between 'number' and 'x' arguments")
def test_REQ_converter(self) -> None:
def my_converter(var_name: str, data: str) -> list[int]:
lst = orjson.loads(data)
if not isinstance(lst, list):
raise ValueError("not a list")
if 13 in lst:
raise JsonableError("13 is an unlucky number!")
return [int(elem) for elem in lst]
@has_request_variables
def get_total(
request: HttpRequest, numbers: Sequence[int] = REQ(converter=my_converter)
) -> HttpResponse:
return json_response(data={"number": sum(numbers)})
request = HostRequestMock()
with self.assertRaises(RequestVariableMissingError):
get_total(request)
request.POST["numbers"] = "bad_value"
with self.assertRaises(RequestVariableConversionError) as cm:
get_total(request)
self.assertEqual(str(cm.exception), "Bad value for 'numbers': bad_value")
request.POST["numbers"] = orjson.dumps("{fun: unfun}").decode()
with self.assertRaises(JsonableError) as jsonable_error_cm:
get_total(request)
self.assertEqual(
str(jsonable_error_cm.exception), "Bad value for 'numbers': \"{fun: unfun}\""
)
request.POST["numbers"] = orjson.dumps([2, 3, 5, 8, 13, 21]).decode()
with self.assertRaises(JsonableError) as jsonable_error_cm:
get_total(request)
self.assertEqual(str(jsonable_error_cm.exception), "13 is an unlucky number!")
request.POST["numbers"] = orjson.dumps([1, 2, 3, 4, 5, 6]).decode()
result = get_total(request)
self.assertEqual(orjson.loads(result.content).get("number"), 21)
def test_REQ_validator(self) -> None:
@has_request_variables
def get_total(
request: HttpRequest, numbers: Sequence[int] = REQ(json_validator=check_list(check_int))
) -> HttpResponse:
return json_response(data={"number": sum(numbers)})
request = HostRequestMock()
with self.assertRaises(RequestVariableMissingError):
get_total(request)
request.POST["numbers"] = "bad_value"
with self.assertRaises(JsonableError) as cm:
get_total(request)
self.assertEqual(str(cm.exception), 'Argument "numbers" is not valid JSON.')
request.POST["numbers"] = orjson.dumps([1, 2, "what?", 4, 5, 6]).decode()
with self.assertRaises(JsonableError) as cm:
get_total(request)
self.assertEqual(str(cm.exception), "numbers[2] is not an integer")
request.POST["numbers"] = orjson.dumps([1, 2, 3, 4, 5, 6]).decode()
result = get_total(request)
self.assertEqual(orjson.loads(result.content).get("number"), 21)
def test_REQ_str_validator(self) -> None:
@has_request_variables
def get_middle_characters(
request: HttpRequest, value: str = REQ(str_validator=check_string_fixed_length(5))
) -> HttpResponse:
return json_response(data={"value": value[1:-1]})
request = HostRequestMock()
with self.assertRaises(RequestVariableMissingError):
get_middle_characters(request)
request.POST["value"] = "long_value"
with self.assertRaises(JsonableError) as cm:
get_middle_characters(request)
self.assertEqual(str(cm.exception), "value has incorrect length 10; should be 5")
request.POST["value"] = "valid"
result = get_middle_characters(request)
self.assertEqual(orjson.loads(result.content).get("value"), "ali")
def test_REQ_argument_type(self) -> None:
@has_request_variables
def get_payload(
request: HttpRequest, payload: dict[str, Any] = REQ(argument_type="body")
) -> HttpResponse:
return json_response(data={"payload": payload})
request = HostRequestMock()
request._body = b"\xde\xad\xbe\xef"
with self.assertRaises(JsonableError) as cm:
get_payload(request)
self.assertEqual(str(cm.exception), "Malformed payload")
request = HostRequestMock()
request._body = b"notjson"
with self.assertRaises(JsonableError) as cm:
get_payload(request)
self.assertEqual(str(cm.exception), "Malformed JSON")
request._body = b'{"a": "b"}'
self.assertEqual(orjson.loads(get_payload(request).content).get("payload"), {"a": "b"})
class TestIgnoredParametersUnsupported(ZulipTestCase):
def test_ignored_parameters_json_success(self) -> None:
@has_request_variables
def test_view(
request: HttpRequest,
name: str | None = REQ(default=None),
age: int | None = 0,
) -> HttpResponse:
return json_success(request)
# ignored parameter (not processed through REQ)
request = HostRequestMock()
request.POST["age"] = "30"
result = test_view(request)
self.assert_json_success(result, ignored_parameters=["age"])
# valid parameter, returns no ignored parameters
request = HostRequestMock()
request.POST["name"] = "Hamlet"
result = test_view(request)
self.assert_json_success(result)
# both valid and ignored parameters
request = HostRequestMock()
request.POST["name"] = "Hamlet"
request.POST["age"] = "30"
request.POST["location"] = "Denmark"
request.POST["dies"] = "True"
result = test_view(request)
ignored_parameters = ["age", "dies", "location"]
json_result = self.assert_json_success(result, ignored_parameters=ignored_parameters)
# check that results are sorted
self.assertEqual(json_result["ignored_parameters_unsupported"], ignored_parameters)
# Because `has_request_variables` can be called multiple times on a request,
# here we test that parameters processed in separate, nested function calls
# are not returned in the `ignored parameters_unsupported` array.
def test_nested_has_request_variables(self) -> None:
@has_request_variables
def not_view_function_A(
request: HttpRequest, dies: bool = REQ(json_validator=check_bool)
) -> None:
return
@has_request_variables
def not_view_function_B(
request: HttpRequest, married: bool = REQ(json_validator=check_bool)
) -> None:
return
@has_request_variables
def view_B(request: HttpRequest, name: str = REQ()) -> MutableJsonResponse:
return json_success(request)
@has_request_variables
def view_A(
request: HttpRequest, age: int = REQ(json_validator=check_int)
) -> MutableJsonResponse:
not_view_function_A(request)
response = view_B(request)
not_view_function_B(request)
return response
# valid parameters, returns no ignored parameters
post_data = {"name": "Hamlet", "age": "30", "dies": "true", "married": "false"}
request = HostRequestMock(post_data)
result = view_A(request)
result_iter = list(iter(result))
self.assertEqual(result_iter, [b'{"result":"success","msg":""}\n'])
self.assert_json_success(result)
# ignored parameter
post_data = {
"name": "Hamlet",
"age": "30",
"dies": "true",
"married": "false",
"author": "William Shakespeare",
}
request = HostRequestMock(post_data)
result = view_A(request)
result_iter = list(iter(result))
self.assertEqual(
result_iter,
[b'{"result":"success","msg":"","ignored_parameters_unsupported":["author"]}\n'],
)
self.assert_json_success(result, ignored_parameters=["author"])

View File

@@ -438,16 +438,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
# Iterate through the decorators to find the original
# function, wrapped by typed_endpoint, so we can parse its
# arguments.
use_endpoint_decorator = False
while (wrapped := getattr(function, "__wrapped__", None)) is not None:
# TODO: Remove this check once we replace has_request_variables with
# typed_endpoint.
if getattr(function, "use_endpoint", False):
use_endpoint_decorator = True
function = wrapped
if len(openapi_parameters) > 0:
assert use_endpoint_decorator
return self.validate_json_schema(function, openapi_parameters)
def check_openapi_arguments_for_view(

View File

@@ -635,3 +635,14 @@ class ValidationErrorHandlingTest(ZulipTestCase):
self.assertEqual(m.exception.msg, subtest.error_message)
self.assertEqual(m.exception.error_type, subtest.error_type)
def test_response(self) -> None:
@typed_endpoint
def view(request: HttpRequest, *, foo: Json[int]) -> MutableJsonResponse:
return json_success(request, {"value": foo})
response = call_endpoint(view, HostRequestMock({"foo": orjson.dumps(42).decode()}))
for content in response:
decoded_content = content.decode()
self.assertIn("value", decoded_content)
self.assertIn("success", decoded_content)

View File

@@ -1,4 +1,3 @@
import re
from typing import Any
from django.conf import settings
@@ -10,7 +9,6 @@ from zerver.lib.types import Validator
from zerver.lib.validator import (
check_bool,
check_capped_string,
check_color,
check_dict,
check_dict_only,
check_float,
@@ -27,7 +25,6 @@ from zerver.lib.validator import (
check_union,
check_url,
equals,
to_non_negative_int,
to_wild_value,
)
@@ -118,17 +115,6 @@ class ValidatorTestCase(ZulipTestCase):
with self.assertRaisesRegex(ValidationError, r"x is not an integer"):
check_int("x", x)
def test_to_non_negative_int(self) -> None:
self.assertEqual(to_non_negative_int("x", "5"), 5)
with self.assertRaisesRegex(ValueError, "argument is negative"):
to_non_negative_int("x", "-1")
with self.assertRaisesRegex(ValueError, re.escape("5 is too large (max 4)")):
to_non_negative_int("x", "5", max_int_size=4)
with self.assertRaisesRegex(
ValueError, re.escape(f"{2**32} is too large (max {2**32 - 1})")
):
to_non_negative_int("x", str(2**32))
def test_check_float(self) -> None:
x: Any = 5.5
check_float("x", x)
@@ -141,21 +127,6 @@ class ValidatorTestCase(ZulipTestCase):
with self.assertRaisesRegex(ValidationError, r"x is not a float"):
check_float("x", x)
def test_check_color(self) -> None:
x = ["#000099", "#80ffaa", "#80FFAA", "#abcd12", "#ffff00", "#ff0", "#f00"] # valid
y = ["000099", "#80f_aa", "#80fraa", "#abcd1234", "blue"] # invalid
z = 5 # invalid
for hex_color in x:
check_color("color", hex_color)
for hex_color in y:
with self.assertRaisesRegex(ValidationError, r"color is not a valid hex color code"):
check_color("color", hex_color)
with self.assertRaisesRegex(ValidationError, r"color is not a string"):
check_color("color", z)
def test_check_list(self) -> None:
x: Any = 999
with self.assertRaisesRegex(ValidationError, r"x is not a list"):