decorator: Strengthen types of signature-preserving decorators.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2020-06-23 16:52:37 -07:00
parent a7bac82f2e
commit e582bbea4a
4 changed files with 37 additions and 33 deletions

View File

@@ -4,7 +4,7 @@ import logging
import urllib import urllib
from functools import wraps from functools import wraps
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, cast
import django_otp import django_otp
import ujson import ujson
@@ -61,8 +61,6 @@ else: # nocoverage # Hack here basically to make impossible code paths compile
get_remote_server_by_uuid = Mock() get_remote_server_by_uuid = Mock()
RemoteZulipServer = Mock() # type: ignore[misc] # https://github.com/JukkaL/mypy/issues/1188 RemoteZulipServer = Mock() # type: ignore[misc] # https://github.com/JukkaL/mypy/issues/1188
ReturnT = TypeVar('ReturnT')
webhook_logger = logging.getLogger("zulip.zerver.webhooks") webhook_logger = logging.getLogger("zulip.zerver.webhooks")
log_to_file(webhook_logger, settings.API_KEY_ONLY_WEBHOOK_LOG_PATH) log_to_file(webhook_logger, settings.API_KEY_ONLY_WEBHOOK_LOG_PATH)
@@ -70,17 +68,19 @@ webhook_unexpected_events_logger = logging.getLogger("zulip.zerver.lib.webhooks.
log_to_file(webhook_unexpected_events_logger, log_to_file(webhook_unexpected_events_logger,
settings.WEBHOOK_UNEXPECTED_EVENTS_LOG_PATH) settings.WEBHOOK_UNEXPECTED_EVENTS_LOG_PATH)
def cachify(method: Callable[..., ReturnT]) -> Callable[..., ReturnT]: FuncT = TypeVar('FuncT', bound=Callable[..., object])
dct: Dict[Tuple[Any, ...], ReturnT] = {}
def cache_wrapper(*args: Any) -> ReturnT: def cachify(method: FuncT) -> FuncT:
dct: Dict[Tuple[object, ...], object] = {}
def cache_wrapper(*args: object) -> object:
tup = tuple(args) tup = tuple(args)
if tup in dct: if tup in dct:
return dct[tup] return dct[tup]
result = method(*args) result = method(*args)
dct[tup] = result dct[tup] = result
return result return result
return cache_wrapper return cast(FuncT, cache_wrapper) # https://github.com/python/mypy/issues/1927
def update_user_activity(request: HttpRequest, user_profile: UserProfile, def update_user_activity(request: HttpRequest, user_profile: UserProfile,
query: Optional[str]) -> None: query: Optional[str]) -> None:
@@ -732,19 +732,18 @@ def internal_notify_view(is_tornado_view: bool) -> Callable[[ViewFuncT], ViewFun
def to_utc_datetime(timestamp: str) -> datetime.datetime: def to_utc_datetime(timestamp: str) -> datetime.datetime:
return timestamp_to_datetime(float(timestamp)) return timestamp_to_datetime(float(timestamp))
def statsd_increment(counter: str, val: int=1, def statsd_increment(counter: str, val: int=1) -> Callable[[FuncT], FuncT]:
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
"""Increments a statsd counter on completion of the """Increments a statsd counter on completion of the
decorated function. decorated function.
Pass the name of the counter to this decorator-returning function.""" Pass the name of the counter to this decorator-returning function."""
def wrapper(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]: def wrapper(func: FuncT) -> FuncT:
@wraps(func) @wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT: def wrapped_func(*args: object, **kwargs: object) -> object:
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
statsd.incr(counter, val) statsd.incr(counter, val)
return ret return ret
return wrapped_func return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
return wrapper return wrapper
def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None: def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None:

View File

@@ -20,6 +20,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
cast,
) )
from django.conf import settings from django.conf import settings
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
MEMCACHED_MAX_KEY_LENGTH = 250 MEMCACHED_MAX_KEY_LENGTH = 250
ReturnT = TypeVar('ReturnT') # Useful for matching return types via Callable[..., ReturnT] FuncT = TypeVar('FuncT', bound=Callable[..., object])
logger = logging.getLogger() logger = logging.getLogger()
@@ -127,15 +128,15 @@ def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
def get_cache_with_key( def get_cache_with_key(
keyfunc: Callable[..., str], keyfunc: Callable[..., str],
cache_name: Optional[str]=None, cache_name: Optional[str]=None,
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]: ) -> Callable[[FuncT], FuncT]:
""" """
The main goal of this function getting value from the cache like in the "cache_with_key". The main goal of this function getting value from the cache like in the "cache_with_key".
A cache value can contain any data including the "None", so A cache value can contain any data including the "None", so
here used exception for case if value isn't found in the cache. here used exception for case if value isn't found in the cache.
""" """
def decorator(func: Callable[..., ReturnT]) -> (Callable[..., ReturnT]): def decorator(func: FuncT) -> FuncT:
@wraps(func) @wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., ReturnT]: def func_with_caching(*args: object, **kwargs: object) -> object:
key = keyfunc(*args, **kwargs) key = keyfunc(*args, **kwargs)
try: try:
val = cache_get(key, cache_name=cache_name) val = cache_get(key, cache_name=cache_name)
@@ -148,14 +149,14 @@ def get_cache_with_key(
return val[0] return val[0]
raise NotFoundInCache() raise NotFoundInCache()
return func_with_caching return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return decorator return decorator
def cache_with_key( def cache_with_key(
keyfunc: Callable[..., str], cache_name: Optional[str]=None, keyfunc: Callable[..., str], cache_name: Optional[str]=None,
timeout: Optional[int]=None, with_statsd_key: Optional[str]=None, timeout: Optional[int]=None, with_statsd_key: Optional[str]=None,
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]: ) -> Callable[[FuncT], FuncT]:
"""Decorator which applies Django caching to a function. """Decorator which applies Django caching to a function.
Decorator argument is a function which computes a cache key Decorator argument is a function which computes a cache key
@@ -163,9 +164,9 @@ def cache_with_key(
for avoiding collisions with other uses of this decorator or for avoiding collisions with other uses of this decorator or
other uses of caching.""" other uses of caching."""
def decorator(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]: def decorator(func: FuncT) -> FuncT:
@wraps(func) @wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT: def func_with_caching(*args: object, **kwargs: object) -> object:
key = keyfunc(*args, **kwargs) key = keyfunc(*args, **kwargs)
try: try:
@@ -198,7 +199,7 @@ def cache_with_key(
return val return val
return func_with_caching return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return decorator return decorator

View File

@@ -1,10 +1,10 @@
import cProfile import cProfile
from functools import wraps from functools import wraps
from typing import Any, Callable, TypeVar from typing import Callable, TypeVar, cast
ReturnT = TypeVar('ReturnT') FuncT = TypeVar('FuncT', bound=Callable[..., object])
def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]: def profiled(func: FuncT) -> FuncT:
""" """
This decorator should obviously be used only in a dev environment. This decorator should obviously be used only in a dev environment.
It works best when surrounding a function that you expect to be It works best when surrounding a function that you expect to be
@@ -21,11 +21,13 @@ def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
./tools/show-profile-results test_ratelimit_decrease.profile ./tools/show-profile-results test_ratelimit_decrease.profile
""" """
func_: Callable[..., object] = func # work around https://github.com/python/mypy/issues/9075
@wraps(func) @wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT: def wrapped_func(*args: object, **kwargs: object) -> object:
fn = func.__name__ + ".profile" fn = func.__name__ + ".profile"
prof = cProfile.Profile() prof = cProfile.Profile()
retval: ReturnT = prof.runcall(func, *args, **kwargs) retval = prof.runcall(func_, *args, **kwargs)
prof.dump_stats(fn) prof.dump_stats(fn)
return retval return retval
return wrapped_func return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927

View File

@@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Optional, Set from typing import Any, Callable, Dict, Iterable, List, Optional, Set, TypeVar, cast
from zulip import Client from zulip import Client
@@ -12,27 +12,29 @@ from zerver.openapi.openapi import validate_against_openapi_schema
ZULIP_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ZULIP_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
TEST_FUNCTIONS: Dict[str, Callable[..., None]] = dict() TEST_FUNCTIONS: Dict[str, Callable[..., object]] = dict()
REGISTERED_TEST_FUNCTIONS: Set[str] = set() REGISTERED_TEST_FUNCTIONS: Set[str] = set()
CALLED_TEST_FUNCTIONS: Set[str] = set() CALLED_TEST_FUNCTIONS: Set[str] = set()
def openapi_test_function(endpoint: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: FuncT = TypeVar("FuncT", bound=Callable[..., object])
def openapi_test_function(endpoint: str) -> Callable[[FuncT], FuncT]:
"""This decorator is used to register an openapi test function with """This decorator is used to register an openapi test function with
its endpoint. Example usage: its endpoint. Example usage:
@openapi_test_function("/messages/render:post") @openapi_test_function("/messages/render:post")
def ... def ...
""" """
def wrapper(test_func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(test_func: FuncT) -> FuncT:
@wraps(test_func) @wraps(test_func)
def _record_calls_wrapper(*args: Any, **kwargs: Any) -> Any: def _record_calls_wrapper(*args: object, **kwargs: object) -> object:
CALLED_TEST_FUNCTIONS.add(test_func.__name__) CALLED_TEST_FUNCTIONS.add(test_func.__name__)
return test_func(*args, **kwargs) return test_func(*args, **kwargs)
REGISTERED_TEST_FUNCTIONS.add(test_func.__name__) REGISTERED_TEST_FUNCTIONS.add(test_func.__name__)
TEST_FUNCTIONS[endpoint] = _record_calls_wrapper TEST_FUNCTIONS[endpoint] = _record_calls_wrapper
return _record_calls_wrapper return cast(FuncT, _record_calls_wrapper) # https://github.com/python/mypy/issues/1927
return wrapper return wrapper
def ensure_users(ids_list: List[int], user_names: List[str]) -> None: def ensure_users(ids_list: List[int], user_names: List[str]) -> None: