typing: Remove ViewFuncT.

This removes ViewFuncT and all the associated type casts with ParamSpec
and Concatenate. This provides more accurate type annotation for
decorators at the cost of making the concatenated parameters
positional-only. This change does not intend to introduce any other
behavioral difference. Note that we retype args in process_view as
List[object] because the view functions can not only be called with
arguments of type str.

Note that the first argument of rest_dispatch needs to be made
positional-only because of the presence of **kwargs.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-08-20 22:03:39 -04:00
committed by Tim Abbott
parent 1f286ab283
commit 21fd62427d
6 changed files with 57 additions and 29 deletions

View File

@@ -1,10 +1,11 @@
import time import time
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Set, cast from typing import Any, Callable, Dict, List, Set
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from typing_extensions import Concatenate, ParamSpec
from zerver.lib.addressee import get_user_profiles_by_ids from zerver.lib.addressee import get_user_profiles_by_ids
from zerver.lib.exceptions import JsonableError, ResourceNotFoundError from zerver.lib.exceptions import JsonableError, ResourceNotFoundError
@@ -12,7 +13,6 @@ from zerver.lib.message import normalize_body, truncate_topic
from zerver.lib.recipient_users import recipient_for_user_profiles from zerver.lib.recipient_users import recipient_for_user_profiles
from zerver.lib.streams import access_stream_by_id from zerver.lib.streams import access_stream_by_id
from zerver.lib.timestamp import timestamp_to_datetime from zerver.lib.timestamp import timestamp_to_datetime
from zerver.lib.types import ViewFuncT
from zerver.lib.validator import ( from zerver.lib.validator import (
check_dict_only, check_dict_only,
check_float, check_float,
@@ -26,6 +26,7 @@ from zerver.lib.validator import (
from zerver.models import Draft, UserProfile from zerver.models import Draft, UserProfile
from zerver.tornado.django_api import send_event from zerver.tornado.django_api import send_event
ParamT = ParamSpec("ParamT")
VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"} VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"}
# A validator to verify if the structure (syntax) of a dictionary # A validator to verify if the structure (syntax) of a dictionary
@@ -87,16 +88,22 @@ def further_validated_draft_dict(
} }
def draft_endpoint(view_func: ViewFuncT) -> ViewFuncT: def draft_endpoint(
view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse]:
@wraps(view_func) @wraps(view_func)
def draft_view_func( def draft_view_func(
request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object request: HttpRequest,
user_profile: UserProfile,
/,
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> HttpResponse: ) -> HttpResponse:
if not user_profile.enable_drafts_synchronization: if not user_profile.enable_drafts_synchronization:
raise JsonableError(_("User has disabled synchronizing drafts.")) raise JsonableError(_("User has disabled synchronizing drafts."))
return view_func(request, user_profile, *args, **kwargs) return view_func(request, user_profile, *args, **kwargs)
return cast(ViewFuncT, draft_view_func) # https://github.com/python/mypy/issues/1927 return draft_view_func
def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]: def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]:

View File

@@ -1,11 +1,12 @@
from functools import wraps from functools import wraps
from typing import Callable, Dict, Set, Tuple, Union, cast from typing import Callable, Dict, Set, Tuple, Union
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.urls import path from django.urls import path
from django.urls.resolvers import URLPattern from django.urls.resolvers import URLPattern
from django.utils.cache import add_never_cache_headers from django.utils.cache import add_never_cache_headers
from django.views.decorators.csrf import csrf_exempt, csrf_protect from django.views.decorators.csrf import csrf_exempt, csrf_protect
from typing_extensions import Concatenate, ParamSpec
from zerver.decorator import ( from zerver.decorator import (
authenticated_json_view, authenticated_json_view,
@@ -17,12 +18,14 @@ from zerver.decorator import (
from zerver.lib.exceptions import MissingAuthenticationError from zerver.lib.exceptions import MissingAuthenticationError
from zerver.lib.request import RequestNotes from zerver.lib.request import RequestNotes
from zerver.lib.response import json_method_not_allowed from zerver.lib.response import json_method_not_allowed
from zerver.lib.types import ViewFuncT
ParamT = ParamSpec("ParamT")
METHODS = ("GET", "HEAD", "POST", "PUT", "DELETE", "PATCH") METHODS = ("GET", "HEAD", "POST", "PUT", "DELETE", "PATCH")
def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT: def default_never_cache_responses(
view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
"""Patched version of the standard Django never_cache_responses """Patched version of the standard Django never_cache_responses
decorator that adds headers to a response so that it will never be decorator that adds headers to a response so that it will never be
cached, unless the view code has already set a Cache-Control cached, unless the view code has already set a Cache-Control
@@ -30,7 +33,9 @@ def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT:
""" """
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: def _wrapped_view_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
response = view_func(request, *args, **kwargs) response = view_func(request, *args, **kwargs)
if response.has_header("Cache-Control"): if response.has_header("Cache-Control"):
return response return response
@@ -38,7 +43,7 @@ def default_never_cache_responses(view_func: ViewFuncT) -> ViewFuncT:
add_never_cache_headers(response) add_never_cache_headers(response)
return response return response
return cast(ViewFuncT, _wrapped_view_func) # https://github.com/python/mypy/issues/1927 return _wrapped_view_func
def get_target_view_function_or_response( def get_target_view_function_or_response(
@@ -102,7 +107,7 @@ def get_target_view_function_or_response(
@default_never_cache_responses @default_never_cache_responses
@csrf_exempt @csrf_exempt
def rest_dispatch(request: HttpRequest, **kwargs: object) -> HttpResponse: def rest_dispatch(request: HttpRequest, /, **kwargs: object) -> HttpResponse:
"""Dispatch to a REST API endpoint. """Dispatch to a REST API endpoint.
Authentication is verified in the following ways: Authentication is verified in the following ways:

View File

@@ -2,12 +2,9 @@ import datetime
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
from django.http import HttpResponse
from django.utils.functional import Promise from django.utils.functional import Promise
from typing_extensions import NotRequired from typing_extensions import NotRequired
ViewFuncT = TypeVar("ViewFuncT", bound=Callable[..., HttpResponse])
# See zerver/lib/validator.py for more details of Validators, # See zerver/lib/validator.py for more details of Validators,
# including many examples # including many examples
ResultT = TypeVar("ResultT") ResultT = TypeVar("ResultT")

View File

@@ -22,6 +22,7 @@ from django_scim.middleware import SCIMAuthCheckMiddleware
from django_scim.settings import scim_settings from django_scim.settings import scim_settings
from sentry_sdk import capture_exception from sentry_sdk import capture_exception
from sentry_sdk.integrations.logging import ignore_logger from sentry_sdk.integrations.logging import ignore_logger
from typing_extensions import Concatenate, ParamSpec
from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
from zerver.lib.db import reset_queries from zerver.lib.db import reset_queries
@@ -38,11 +39,11 @@ from zerver.lib.response import (
json_unauthorized, json_unauthorized,
) )
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
from zerver.lib.types import ViewFuncT
from zerver.lib.user_agent import parse_user_agent from zerver.lib.user_agent import parse_user_agent
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
from zerver.models import Realm, SCIMClient, flush_per_request_caches, get_realm from zerver.models import Realm, SCIMClient, flush_per_request_caches, get_realm
ParamT = ParamSpec("ParamT")
logger = logging.getLogger("zulip.requests") logger = logging.getLogger("zulip.requests")
slow_query_logger = logging.getLogger("zulip.slow_queries") slow_query_logger = logging.getLogger("zulip.slow_queries")
@@ -368,8 +369,8 @@ class LogRequests(MiddlewareMixin):
def process_view( def process_view(
self, self,
request: HttpRequest, request: HttpRequest,
view_func: ViewFuncT, view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponseBase],
args: List[str], args: List[object],
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
) -> None: ) -> None:
request_notes = RequestNotes.get_notes(request) request_notes = RequestNotes.get_notes(request)
@@ -475,7 +476,11 @@ class JsonErrorHandler(MiddlewareMixin):
class TagRequests(MiddlewareMixin): class TagRequests(MiddlewareMixin):
def process_view( def process_view(
self, request: HttpRequest, view_func: ViewFuncT, args: List[str], kwargs: Dict[str, Any] self,
request: HttpRequest,
view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponseBase],
args: List[object],
kwargs: Dict[str, Any],
) -> None: ) -> None:
self.process_request(request) self.process_request(request)

View File

@@ -2,29 +2,38 @@ import logging
import sys import sys
from functools import wraps from functools import wraps
from types import TracebackType from types import TracebackType
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.http import HttpRequest from django.http import HttpRequest, HttpResponse
from django.utils.log import AdminEmailHandler from django.utils.log import AdminEmailHandler
from typing_extensions import Concatenate, ParamSpec
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import mock_queue_publish from zerver.lib.test_helpers import mock_queue_publish
from zerver.lib.types import ViewFuncT
from zerver.logging_handlers import AdminNotifyHandler, HasRequest from zerver.logging_handlers import AdminNotifyHandler, HasRequest
from zerver.models import UserProfile
ParamT = ParamSpec("ParamT")
captured_request: Optional[HttpRequest] = None captured_request: Optional[HttpRequest] = None
captured_exc_info: Optional[ captured_exc_info: Optional[
Union[Tuple[Type[BaseException], BaseException, TracebackType], Tuple[None, None, None]] Union[Tuple[Type[BaseException], BaseException, TracebackType], Tuple[None, None, None]]
] = None ] = None
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT: def capture_and_throw(
view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], NoReturn]:
@wraps(view_func) @wraps(view_func)
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: def wrapped_view(
request: HttpRequest,
/,
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> NoReturn:
global captured_request global captured_request
captured_request = request captured_request = request
try: try:
@@ -34,7 +43,7 @@ def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
captured_exc_info = sys.exc_info() captured_exc_info = sys.exc_info()
raise e raise e
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 return wrapped_view
class AdminNotifyHandlerTest(ZulipTestCase): class AdminNotifyHandlerTest(ZulipTestCase):

View File

@@ -3,7 +3,7 @@ import secrets
import urllib import urllib
from email.headerregistry import Address from email.headerregistry import Address
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, cast
from urllib.parse import urlencode from urllib.parse import urlencode
import jwt import jwt
@@ -29,6 +29,7 @@ from markupsafe import Markup as mark_safe
from social_django.utils import load_backend, load_strategy from social_django.utils import load_backend, load_strategy
from two_factor.forms import BackupTokenForm from two_factor.forms import BackupTokenForm
from two_factor.views import LoginView as BaseTwoFactorLoginView from two_factor.views import LoginView as BaseTwoFactorLoginView
from typing_extensions import Concatenate, ParamSpec
from confirmation.models import ( from confirmation.models import (
Confirmation, Confirmation,
@@ -65,7 +66,6 @@ from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.sessions import set_expirable_session_var from zerver.lib.sessions import set_expirable_session_var
from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias
from zerver.lib.types import ViewFuncT
from zerver.lib.url_encoding import append_url_query_string from zerver.lib.url_encoding import append_url_query_string
from zerver.lib.user_agent import parse_user_agent from zerver.lib.user_agent import parse_user_agent
from zerver.lib.users import get_api_key, is_2fa_verified from zerver.lib.users import get_api_key, is_2fa_verified
@@ -102,6 +102,7 @@ from zproject.backends import (
if TYPE_CHECKING: if TYPE_CHECKING:
from django.http.request import _ImmutableQueryDict from django.http.request import _ImmutableQueryDict
ParamT = ParamSpec("ParamT")
ExtraContext = Optional[Dict[str, Any]] ExtraContext = Optional[Dict[str, Any]]
@@ -554,16 +555,20 @@ def oauth_redirect_to_root(
return redirect(append_url_query_string(main_site_uri, urllib.parse.urlencode(params))) return redirect(append_url_query_string(main_site_uri, urllib.parse.urlencode(params)))
def handle_desktop_flow(func: ViewFuncT) -> ViewFuncT: def handle_desktop_flow(
func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
@wraps(func) @wraps(func)
def wrapper(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: def wrapper(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
user_agent = parse_user_agent(request.headers.get("User-Agent", "Missing User-Agent")) user_agent = parse_user_agent(request.headers.get("User-Agent", "Missing User-Agent"))
if user_agent["name"] == "ZulipElectron": if user_agent["name"] == "ZulipElectron":
return render(request, "zerver/desktop_login.html") return render(request, "zerver/desktop_login.html")
return func(request, *args, **kwargs) return func(request, *args, **kwargs)
return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927 return wrapper
@handle_desktop_flow @handle_desktop_flow