diff --git a/zerver/lib/drafts.py b/zerver/lib/drafts.py index 11667abfc1..ca9ee11663 100644 --- a/zerver/lib/drafts.py +++ b/zerver/lib/drafts.py @@ -1,10 +1,11 @@ import time from functools import wraps -from typing import Any, Dict, List, Set, cast +from typing import Any, Callable, Dict, List, Set from django.core.exceptions import ValidationError from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ +from typing_extensions import Concatenate, ParamSpec from zerver.lib.addressee import get_user_profiles_by_ids from zerver.lib.exceptions import JsonableError, ResourceNotFoundError @@ -12,7 +13,6 @@ from zerver.lib.message import normalize_body, truncate_topic from zerver.lib.recipient_users import recipient_for_user_profiles from zerver.lib.streams import access_stream_by_id from zerver.lib.timestamp import timestamp_to_datetime -from zerver.lib.types import ViewFuncT from zerver.lib.validator import ( check_dict_only, check_float, @@ -26,6 +26,7 @@ from zerver.lib.validator import ( from zerver.models import Draft, UserProfile from zerver.tornado.django_api import send_event +ParamT = ParamSpec("ParamT") VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"} # A validator to verify if the structure (syntax) of a dictionary @@ -87,16 +88,22 @@ def further_validated_draft_dict( } -def draft_endpoint(view_func: ViewFuncT) -> ViewFuncT: +def draft_endpoint( + view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse] +) -> Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse]: @wraps(view_func) def draft_view_func( - request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object + request: HttpRequest, + user_profile: UserProfile, + /, + *args: ParamT.args, + **kwargs: ParamT.kwargs, ) -> HttpResponse: if not user_profile.enable_drafts_synchronization: raise JsonableError(_("User has disabled synchronizing drafts.")) return view_func(request, user_profile, *args, **kwargs) - return cast(ViewFuncT, draft_view_func) # https://github.com/python/mypy/issues/1927 + return draft_view_func def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]: diff --git a/zerver/lib/rest.py b/zerver/lib/rest.py index e99272584f..3545ce21ec 100644 --- a/zerver/lib/rest.py +++ b/zerver/lib/rest.py @@ -1,11 +1,12 @@ from functools import wraps -from typing import Callable, Dict, Set, Tuple, Union, cast +from typing import Callable, Dict, Set, Tuple, Union from django.http import HttpRequest, HttpResponse from django.urls import path from django.urls.resolvers import URLPattern from django.utils.cache import add_never_cache_headers from django.views.decorators.csrf import csrf_exempt, csrf_protect +from typing_extensions import Concatenate, ParamSpec from zerver.decorator import ( authenticated_json_view, @@ -17,12 +18,14 @@ from zerver.decorator import ( from zerver.lib.exceptions import MissingAuthenticationError from zerver.lib.request import RequestNotes from zerver.lib.response import json_method_not_allowed -from zerver.lib.types import ViewFuncT +ParamT = ParamSpec("ParamT") METHODS = ("GET", "HEAD", "POST", "PUT", "DELETE", "PATCH") -def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT: +def default_never_cache_responses( + view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse] +) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: """Patched version of the standard Django never_cache_responses decorator that adds headers to a response so that it will never be cached, unless the view code has already set a Cache-Control @@ -30,7 +33,9 @@ def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT: """ @wraps(view_func) - def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: + def _wrapped_view_func( + request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs + ) -> HttpResponse: response = view_func(request, *args, **kwargs) if response.has_header("Cache-Control"): return response @@ -38,7 +43,7 @@ def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT: add_never_cache_headers(response) return response - return cast(ViewFuncT, _wrapped_view_func) # https://github.com/python/mypy/issues/1927 + return _wrapped_view_func def get_target_view_function_or_response( @@ -102,7 +107,7 @@ def get_target_view_function_or_response( @default_never_cache_responses @csrf_exempt -def rest_dispatch(request: HttpRequest, **kwargs: object) -> HttpResponse: +def rest_dispatch(request: HttpRequest, /, **kwargs: object) -> HttpResponse: """Dispatch to a REST API endpoint. Authentication is verified in the following ways: diff --git a/zerver/lib/types.py b/zerver/lib/types.py index 890bdc3d95..6df7b614d1 100644 --- a/zerver/lib/types.py +++ b/zerver/lib/types.py @@ -2,12 +2,9 @@ import datetime from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union -from django.http import HttpResponse from django.utils.functional import Promise from typing_extensions import NotRequired -ViewFuncT = TypeVar("ViewFuncT", bound=Callable[..., HttpResponse]) - # See zerver/lib/validator.py for more details of Validators, # including many examples ResultT = TypeVar("ResultT") diff --git a/zerver/middleware.py b/zerver/middleware.py index 14d9ad5d78..9bf43b5e86 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -22,6 +22,7 @@ from django_scim.middleware import SCIMAuthCheckMiddleware from django_scim.settings import scim_settings from sentry_sdk import capture_exception from sentry_sdk.integrations.logging import ignore_logger +from typing_extensions import Concatenate, ParamSpec from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time from zerver.lib.db import reset_queries @@ -38,11 +39,11 @@ from zerver.lib.response import ( json_unauthorized, ) from zerver.lib.subdomains import get_subdomain -from zerver.lib.types import ViewFuncT from zerver.lib.user_agent import parse_user_agent from zerver.lib.utils import statsd from zerver.models import Realm, SCIMClient, flush_per_request_caches, get_realm +ParamT = ParamSpec("ParamT") logger = logging.getLogger("zulip.requests") slow_query_logger = logging.getLogger("zulip.slow_queries") @@ -368,8 +369,8 @@ class LogRequests(MiddlewareMixin): def process_view( self, request: HttpRequest, - view_func: ViewFuncT, - args: List[str], + view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponseBase], + args: List[object], kwargs: Dict[str, Any], ) -> None: request_notes = RequestNotes.get_notes(request) @@ -475,7 +476,11 @@ class JsonErrorHandler(MiddlewareMixin): class TagRequests(MiddlewareMixin): def process_view( - self, request: HttpRequest, view_func: ViewFuncT, args: List[str], kwargs: Dict[str, Any] + self, + request: HttpRequest, + view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponseBase], + args: List[object], + kwargs: Dict[str, Any], ) -> None: self.process_request(request) diff --git a/zerver/tests/test_logging_handlers.py b/zerver/tests/test_logging_handlers.py index ab1c20084f..4a9eb6fc75 100644 --- a/zerver/tests/test_logging_handlers.py +++ b/zerver/tests/test_logging_handlers.py @@ -2,29 +2,38 @@ import logging import sys from functools import wraps from types import TracebackType -from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast +from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union from unittest import mock from unittest.mock import MagicMock, patch from django.conf import settings from django.contrib.auth.models import AnonymousUser -from django.http import HttpRequest +from django.http import HttpRequest, HttpResponse from django.utils.log import AdminEmailHandler +from typing_extensions import Concatenate, ParamSpec from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import mock_queue_publish -from zerver.lib.types import ViewFuncT from zerver.logging_handlers import AdminNotifyHandler, HasRequest +from zerver.models import UserProfile +ParamT = ParamSpec("ParamT") captured_request: Optional[HttpRequest] = None captured_exc_info: Optional[ Union[Tuple[Type[BaseException], BaseException, TracebackType], Tuple[None, None, None]] ] = None -def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT: +def capture_and_throw( + view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse] +) -> Callable[Concatenate[HttpRequest, ParamT], NoReturn]: @wraps(view_func) - def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: + def wrapped_view( + request: HttpRequest, + /, + *args: ParamT.args, + **kwargs: ParamT.kwargs, + ) -> NoReturn: global captured_request captured_request = request try: @@ -34,7 +43,7 @@ def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT: captured_exc_info = sys.exc_info() raise e - return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 + return wrapped_view class AdminNotifyHandlerTest(ZulipTestCase): diff --git a/zerver/views/auth.py b/zerver/views/auth.py index 5ef6728ae2..80db7aecc0 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -3,7 +3,7 @@ import secrets import urllib from email.headerregistry import Address from functools import wraps -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, cast from urllib.parse import urlencode import jwt @@ -29,6 +29,7 @@ from markupsafe import Markup as mark_safe from social_django.utils import load_backend, load_strategy from two_factor.forms import BackupTokenForm from two_factor.views import LoginView as BaseTwoFactorLoginView +from typing_extensions import Concatenate, ParamSpec from confirmation.models import ( Confirmation, @@ -65,7 +66,6 @@ from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_success from zerver.lib.sessions import set_expirable_session_var from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias -from zerver.lib.types import ViewFuncT from zerver.lib.url_encoding import append_url_query_string from zerver.lib.user_agent import parse_user_agent from zerver.lib.users import get_api_key, is_2fa_verified @@ -102,6 +102,7 @@ from zproject.backends import ( if TYPE_CHECKING: from django.http.request import _ImmutableQueryDict +ParamT = ParamSpec("ParamT") ExtraContext = Optional[Dict[str, Any]] @@ -554,16 +555,20 @@ def oauth_redirect_to_root( return redirect(append_url_query_string(main_site_uri, urllib.parse.urlencode(params))) -def handle_desktop_flow(func: ViewFuncT) -> ViewFuncT: +def handle_desktop_flow( + func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse] +) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: @wraps(func) - def wrapper(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: + def wrapper( + request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs + ) -> HttpResponse: user_agent = parse_user_agent(request.headers.get("User-Agent", "Missing User-Agent")) if user_agent["name"] == "ZulipElectron": return render(request, "zerver/desktop_login.html") return func(request, *args, **kwargs) - return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927 + return wrapper @handle_desktop_flow