zerver: Use Python 3 syntax for typing.

Tweaked by tabbott to fix some minor whitespace errors.
This commit is contained in:
rht
2017-11-27 06:33:05 +00:00
committed by Tim Abbott
parent 0ec2a9d259
commit a1cc720860
10 changed files with 267 additions and 496 deletions

View File

@@ -7,8 +7,7 @@ from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.db.models.signals import post_migrate from django.db.models.signals import post_migrate
def flush_cache(sender, **kwargs): def flush_cache(sender: AppConfig, **kwargs: Any) -> None:
# type: (AppConfig, **Any) -> None
logging.info("Clearing memcached cache after migrations") logging.info("Clearing memcached cache after migrations")
cache.clear() cache.clear()
@@ -16,8 +15,7 @@ def flush_cache(sender, **kwargs):
class ZerverConfig(AppConfig): class ZerverConfig(AppConfig):
name = "zerver" # type: str name = "zerver" # type: str
def ready(self): def ready(self) -> None:
# type: () -> None
import zerver.signals import zerver.signals
if settings.POST_MIGRATION_CACHE_FLUSHING: if settings.POST_MIGRATION_CACHE_FLUSHING:

View File

@@ -22,8 +22,7 @@ from zerver.lib.realm_icon import get_realm_icon_url
from version import ZULIP_VERSION from version import ZULIP_VERSION
def common_context(user): def common_context(user: UserProfile) -> Dict[str, Any]:
# type: (UserProfile) -> Dict[str, Any]
"""Common context used for things like outgoing emails that don't """Common context used for things like outgoing emails that don't
have a request. have a request.
""" """
@@ -34,15 +33,13 @@ def common_context(user):
'external_host': settings.EXTERNAL_HOST, 'external_host': settings.EXTERNAL_HOST,
} }
def get_realm_from_request(request): def get_realm_from_request(request: HttpRequest) -> Optional[Realm]:
# type: (HttpRequest) -> Optional[Realm]
if hasattr(request, "user") and hasattr(request.user, "realm"): if hasattr(request, "user") and hasattr(request.user, "realm"):
return request.user.realm return request.user.realm
subdomain = get_subdomain(request) subdomain = get_subdomain(request)
return get_realm(subdomain) return get_realm(subdomain)
def zulip_default_context(request): def zulip_default_context(request: HttpRequest) -> Dict[str, Any]:
# type: (HttpRequest) -> Dict[str, Any]
"""Context available to all Zulip Jinja2 templates that have a request """Context available to all Zulip Jinja2 templates that have a request
passed in. Designed to provide the long list of variables at the passed in. Designed to provide the long list of variables at the
bottom of this function in a wide range of situations: logged-in bottom of this function in a wide range of situations: logged-in
@@ -141,8 +138,7 @@ def zulip_default_context(request):
} }
def add_metrics(request): def add_metrics(request: HttpRequest) -> Dict[str, str]:
# type: (HttpRequest) -> Dict[str, str]
return { return {
'dropboxAppKey': settings.DROPBOX_APP_KEY 'dropboxAppKey': settings.DROPBOX_APP_KEY
} }

View File

@@ -58,27 +58,25 @@ class _RespondAsynchronously:
# mode. # mode.
RespondAsynchronously = _RespondAsynchronously() RespondAsynchronously = _RespondAsynchronously()
def asynchronous(method): AsyncWrapperT = Callable[..., Union[HttpResponse, _RespondAsynchronously]]
# type: (Callable[..., Union[HttpResponse, _RespondAsynchronously]]) -> Callable[..., Union[HttpResponse, _RespondAsynchronously]] def asynchronous(method: Callable[..., Union[HttpResponse, _RespondAsynchronously]]) -> AsyncWrapperT:
# TODO: this should be the correct annotation when mypy gets fixed: type: # TODO: this should be the correct annotation when mypy gets fixed: type:
# (Callable[[HttpRequest, base.BaseHandler, Sequence[Any], Dict[str, Any]], # (Callable[[HttpRequest, base.BaseHandler, Sequence[Any], Dict[str, Any]],
# Union[HttpResponse, _RespondAsynchronously]]) -> # Union[HttpResponse, _RespondAsynchronously]]) ->
# Callable[[HttpRequest, Sequence[Any], Dict[str, Any]], Union[HttpResponse, _RespondAsynchronously]] # Callable[[HttpRequest, Sequence[Any], Dict[str, Any]], Union[HttpResponse, _RespondAsynchronously]]
# TODO: see https://github.com/python/mypy/issues/1655 # TODO: see https://github.com/python/mypy/issues/1655
@wraps(method) @wraps(method)
def wrapper(request, *args, **kwargs): def wrapper(request: HttpRequest, *args: Any,
# type: (HttpRequest, *Any, **Any) -> Union[HttpResponse, _RespondAsynchronously] **kwargs: Any) -> Union[HttpResponse, _RespondAsynchronously]:
return method(request, handler=request._tornado_handler, *args, **kwargs) return method(request, handler=request._tornado_handler, *args, **kwargs)
if getattr(method, 'csrf_exempt', False): if getattr(method, 'csrf_exempt', False):
wrapper.csrf_exempt = True # type: ignore # https://github.com/JukkaL/mypy/issues/1170 wrapper.csrf_exempt = True # type: ignore # https://github.com/JukkaL/mypy/issues/1170
return wrapper return wrapper
def cachify(method): def cachify(method: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
# type: (Callable[..., ReturnT]) -> Callable[..., ReturnT]
dct = {} # type: Dict[Tuple[Any, ...], ReturnT] dct = {} # type: Dict[Tuple[Any, ...], ReturnT]
def cache_wrapper(*args): def cache_wrapper(*args: Any) -> ReturnT:
# type: (*Any) -> ReturnT
tup = tuple(args) tup = tuple(args)
if tup in dct: if tup in dct:
return dct[tup] return dct[tup]
@@ -87,8 +85,8 @@ def cachify(method):
return result return result
return cache_wrapper return cache_wrapper
def update_user_activity(request, user_profile, query): def update_user_activity(request: HttpRequest, user_profile: UserProfile,
# type: (HttpRequest, UserProfile, Optional[str]) -> None query: Optional[str]) -> None:
# update_active_status also pushes to rabbitmq, and it seems # update_active_status also pushes to rabbitmq, and it seems
# redundant to log that here as well. # redundant to log that here as well.
if request.META["PATH_INFO"] == '/json/users/me/presence': if request.META["PATH_INFO"] == '/json/users/me/presence':
@@ -108,11 +106,9 @@ def update_user_activity(request, user_profile, query):
queue_json_publish("user_activity", event, lambda event: None) queue_json_publish("user_activity", event, lambda event: None)
# Based on django.views.decorators.http.require_http_methods # Based on django.views.decorators.http.require_http_methods
def require_post(func): def require_post(func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(func) @wraps(func)
def wrapper(request, *args, **kwargs): def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
if (request.method != "POST" and if (request.method != "POST" and
not (request.method == "SOCKET" and not (request.method == "SOCKET" and
request.META['zulip.emulated_method'] == "POST")): request.META['zulip.emulated_method'] == "POST")):
@@ -126,11 +122,9 @@ def require_post(func):
return func(request, *args, **kwargs) return func(request, *args, **kwargs)
return wrapper # type: ignore # https://github.com/python/mypy/issues/1927 return wrapper # type: ignore # https://github.com/python/mypy/issues/1927
def require_realm_admin(func): def require_realm_admin(func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(func) @wraps(func)
def wrapper(request, user_profile, *args, **kwargs): def wrapper(request: HttpRequest, user_profile: UserProfile, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, UserProfile, *Any, **Any) -> HttpResponse
if not user_profile.is_realm_admin: if not user_profile.is_realm_admin:
raise JsonableError(_("Must be a realm administrator")) raise JsonableError(_("Must be a realm administrator"))
return func(request, user_profile, *args, **kwargs) return func(request, user_profile, *args, **kwargs)
@@ -138,8 +132,7 @@ def require_realm_admin(func):
from zerver.lib.user_agent import parse_user_agent from zerver.lib.user_agent import parse_user_agent
def get_client_name(request, is_browser_view): def get_client_name(request: HttpRequest, is_browser_view: bool) -> Text:
# type: (HttpRequest, bool) -> Text
# If the API request specified a client in the request content, # If the API request specified a client in the request content,
# that has priority. Otherwise, extract the client from the # that has priority. Otherwise, extract the client from the
# User-Agent. # User-Agent.
@@ -184,19 +177,16 @@ class InvalidZulipServerError(JsonableError):
code = ErrorCode.INVALID_ZULIP_SERVER code = ErrorCode.INVALID_ZULIP_SERVER
data_fields = ['role'] data_fields = ['role']
def __init__(self, role): def __init__(self, role: Text) -> None:
# type: (Text) -> None
self.role = role # type: Text self.role = role # type: Text
@staticmethod @staticmethod
def msg_format(): def msg_format() -> Text:
# type: () -> Text
return "Zulip server auth failure: {role} is not registered" return "Zulip server auth failure: {role} is not registered"
class InvalidZulipServerKeyError(JsonableError): class InvalidZulipServerKeyError(JsonableError):
@staticmethod @staticmethod
def msg_format(): def msg_format() -> Text:
# type: () -> Text
return "Zulip server auth failure: key does not match role {role}" return "Zulip server auth failure: key does not match role {role}"
def validate_api_key(request, role, api_key, is_webhook=False, def validate_api_key(request, role, api_key, is_webhook=False,
@@ -233,8 +223,7 @@ def validate_api_key(request, role, api_key, is_webhook=False,
return user_profile return user_profile
def validate_account_and_subdomain(request, user_profile): def validate_account_and_subdomain(request: HttpRequest, user_profile: UserProfile) -> None:
# type: (HttpRequest, UserProfile) -> None
if not user_profile.is_active: if not user_profile.is_active:
raise JsonableError(_("Account not active")) raise JsonableError(_("Account not active"))
@@ -255,8 +244,7 @@ def validate_account_and_subdomain(request, user_profile):
user_profile.email, user_profile.realm.subdomain, get_subdomain(request))) user_profile.email, user_profile.realm.subdomain, get_subdomain(request)))
raise JsonableError(_("Account is not associated with this subdomain")) raise JsonableError(_("Account is not associated with this subdomain"))
def access_user_by_api_key(request, api_key, email=None): def access_user_by_api_key(request: HttpRequest, api_key: Text, email: Optional[Text]=None) -> UserProfile:
# type: (HttpRequest, Text, Optional[Text]) -> UserProfile
try: try:
user_profile = get_user_profile_by_api_key(api_key) user_profile = get_user_profile_by_api_key(api_key)
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
@@ -272,12 +260,11 @@ def access_user_by_api_key(request, api_key, email=None):
return user_profile return user_profile
# Use this for webhook views that don't get an email passed in. # Use this for webhook views that don't get an email passed in.
def api_key_only_webhook_view(client_name): WrappedViewFuncT = Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]]
# type: (Text) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]] def api_key_only_webhook_view(client_name: Text) -> WrappedViewFuncT:
# TODO The typing here could be improved by using the Extended Callable types: # TODO The typing here could be improved by using the Extended Callable types:
# https://mypy.readthedocs.io/en/latest/kinds_of_types.html#extended-callable-types # https://mypy.readthedocs.io/en/latest/kinds_of_types.html#extended-callable-types
def _wrapped_view_func(view_func): def _wrapped_view_func(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@csrf_exempt @csrf_exempt
@has_request_variables @has_request_variables
@wraps(view_func) @wraps(view_func)
@@ -341,18 +328,16 @@ def redirect_to_login(next, login_url=None,
return HttpResponseRedirect(urllib.parse.urlunparse(login_url_parts)) return HttpResponseRedirect(urllib.parse.urlunparse(login_url_parts))
# From Django 1.8 # From Django 1.8
def user_passes_test(test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME): def user_passes_test(test_func: Callable[[HttpResponse], bool], login_url: Optional[Text]=None,
# type: (Callable[[HttpResponse], bool], Optional[Text], Text) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]] redirect_field_name: Text=REDIRECT_FIELD_NAME) -> WrappedViewFuncT:
""" """
Decorator for views that checks that the user passes the given test, Decorator for views that checks that the user passes the given test,
redirecting to the log-in page if necessary. The test should be a callable redirecting to the log-in page if necessary. The test should be a callable
that takes the user object and returns True if the user passes. that takes the user object and returns True if the user passes.
""" """
def decorator(view_func): def decorator(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@wraps(view_func, assigned=available_attrs(view_func)) @wraps(view_func, assigned=available_attrs(view_func))
def _wrapped_view(request, *args, **kwargs): def _wrapped_view(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
if test_func(request): if test_func(request):
return view_func(request, *args, **kwargs) return view_func(request, *args, **kwargs)
path = request.build_absolute_uri() path = request.build_absolute_uri()
@@ -369,8 +354,7 @@ def user_passes_test(test_func, login_url=None, redirect_field_name=REDIRECT_FIE
return _wrapped_view return _wrapped_view
return decorator return decorator
def logged_in_and_active(request): def logged_in_and_active(request: HttpRequest) -> bool:
# type: (HttpRequest) -> bool
if not request.user.is_authenticated: if not request.user.is_authenticated:
return False return False
if not request.user.is_active: if not request.user.is_active:
@@ -379,8 +363,7 @@ def logged_in_and_active(request):
return False return False
return user_matches_subdomain(get_subdomain(request), request.user) return user_matches_subdomain(get_subdomain(request), request.user)
def do_login(request, user_profile): def do_login(request: HttpRequest, user_profile: UserProfile) -> None:
# type: (HttpRequest, UserProfile) -> None
"""Creates a session, logging in the user, using the Django method, """Creates a session, logging in the user, using the Django method,
and also adds helpful data needed by our server logs. and also adds helpful data needed by our server logs.
""" """
@@ -388,31 +371,25 @@ def do_login(request, user_profile):
request._email = user_profile.email request._email = user_profile.email
process_client(request, user_profile, is_browser_view=True) process_client(request, user_profile, is_browser_view=True)
def log_view_func(view_func): def log_view_func(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
request._query = view_func.__name__ request._query = view_func.__name__
return view_func(request, *args, **kwargs) return view_func(request, *args, **kwargs)
return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927 return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927
def add_logging_data(view_func): def add_logging_data(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
request._email = request.user.email request._email = request.user.email
process_client(request, request.user, is_browser_view=True, process_client(request, request.user, is_browser_view=True,
query=view_func.__name__) query=view_func.__name__)
return rate_limit()(view_func)(request, *args, **kwargs) return rate_limit()(view_func)(request, *args, **kwargs)
return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927 return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927
def human_users_only(view_func): def human_users_only(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
if request.user.is_bot: if request.user.is_bot:
return json_error(_("This endpoint does not accept bot requests.")) return json_error(_("This endpoint does not accept bot requests."))
return view_func(request, *args, **kwargs) return view_func(request, *args, **kwargs)
@@ -433,12 +410,10 @@ def zulip_login_required(function=None,
return actual_decorator(add_logging_data(function)) return actual_decorator(add_logging_data(function))
return actual_decorator return actual_decorator
def require_server_admin(view_func): def require_server_admin(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@zulip_login_required @zulip_login_required
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
if not request.user.is_staff: if not request.user.is_staff:
return HttpResponseRedirect(settings.HOME_NOT_LOGGED_IN) return HttpResponseRedirect(settings.HOME_NOT_LOGGED_IN)
@@ -449,10 +424,8 @@ def require_server_admin(view_func):
# user_profile to the view function's arguments list, since we have to # user_profile to the view function's arguments list, since we have to
# look it up anyway. It is deprecated in favor on the REST API # look it up anyway. It is deprecated in favor on the REST API
# versions. # versions.
def authenticated_api_view(is_webhook=False): def authenticated_api_view(is_webhook: bool=False) -> WrappedViewFuncT:
# type: (bool) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]] def _wrapped_view_func(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
def _wrapped_view_func(view_func):
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@csrf_exempt @csrf_exempt
@require_post @require_post
@has_request_variables @has_request_variables
@@ -474,14 +447,11 @@ def authenticated_api_view(is_webhook=False):
# A more REST-y authentication decorator, using, in particular, HTTP Basic # A more REST-y authentication decorator, using, in particular, HTTP Basic
# authentication. # authentication.
def authenticated_rest_api_view(is_webhook=False): def authenticated_rest_api_view(is_webhook: bool=False) -> WrappedViewFuncT:
# type: (bool) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]] def _wrapped_view_func(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
def _wrapped_view_func(view_func):
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@csrf_exempt @csrf_exempt
@wraps(view_func) @wraps(view_func)
def _wrapped_func_arguments(request, *args, **kwargs): def _wrapped_func_arguments(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
# First try block attempts to get the credentials we need to do authentication # First try block attempts to get the credentials we need to do authentication
try: try:
# Grab the base64-encoded authentication string, decode it, and split it into # Grab the base64-encoded authentication string, decode it, and split it into
@@ -507,11 +477,9 @@ def authenticated_rest_api_view(is_webhook=False):
return _wrapped_func_arguments return _wrapped_func_arguments
return _wrapped_view_func return _wrapped_view_func
def process_as_post(view_func): def process_as_post(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
# Adapted from django/http/__init__.py. # Adapted from django/http/__init__.py.
# So by default Django doesn't populate request.POST for anything besides # So by default Django doesn't populate request.POST for anything besides
# POST requests. We want this dict populated for PATCH/PUT, so we have to # POST requests. We want this dict populated for PATCH/PUT, so we have to
@@ -539,8 +507,9 @@ def process_as_post(view_func):
return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927 return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927
def authenticate_log_and_execute_json(request, view_func, *args, **kwargs): def authenticate_log_and_execute_json(request: HttpRequest,
# type: (HttpRequest, Callable[..., HttpResponse], *Any, **Any) -> HttpResponse view_func: Callable[..., HttpResponse],
*args: Any, **kwargs: Any) -> HttpResponse:
if not request.user.is_authenticated: if not request.user.is_authenticated:
return json_error(_("Not logged in"), status=401) return json_error(_("Not logged in"), status=401)
user_profile = request.user user_profile = request.user
@@ -557,8 +526,7 @@ def authenticate_log_and_execute_json(request, view_func, *args, **kwargs):
# Checks if the request is a POST request and that the user is logged # Checks if the request is a POST request and that the user is logged
# in. If not, return an error (the @login_required behavior of # in. If not, return an error (the @login_required behavior of
# redirecting to a login page doesn't make sense for json views) # redirecting to a login page doesn't make sense for json views)
def authenticated_json_post_view(view_func): def authenticated_json_post_view(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@require_post @require_post
@has_request_variables @has_request_variables
@wraps(view_func) @wraps(view_func)
@@ -568,8 +536,7 @@ def authenticated_json_post_view(view_func):
return authenticate_log_and_execute_json(request, view_func, *args, **kwargs) return authenticate_log_and_execute_json(request, view_func, *args, **kwargs)
return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927 return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927
def authenticated_json_view(view_func): def authenticated_json_view(view_func: ViewFuncT) -> ViewFuncT:
# type: (ViewFuncT) -> ViewFuncT
@wraps(view_func) @wraps(view_func)
def _wrapped_view_func(request, def _wrapped_view_func(request,
*args, **kwargs): *args, **kwargs):
@@ -577,20 +544,17 @@ def authenticated_json_view(view_func):
return authenticate_log_and_execute_json(request, view_func, *args, **kwargs) return authenticate_log_and_execute_json(request, view_func, *args, **kwargs)
return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927 return _wrapped_view_func # type: ignore # https://github.com/python/mypy/issues/1927
def is_local_addr(addr): def is_local_addr(addr: Text) -> bool:
# type: (Text) -> bool
return addr in ('127.0.0.1', '::1') return addr in ('127.0.0.1', '::1')
# These views are used by the main Django server to notify the Tornado server # These views are used by the main Django server to notify the Tornado server
# of events. We protect them from the outside world by checking a shared # of events. We protect them from the outside world by checking a shared
# secret, and also the originating IP (for now). # secret, and also the originating IP (for now).
def authenticate_notify(request): def authenticate_notify(request: HttpRequest) -> bool:
# type: (HttpRequest) -> bool
return (is_local_addr(request.META['REMOTE_ADDR']) and return (is_local_addr(request.META['REMOTE_ADDR']) and
request.POST.get('secret') == settings.SHARED_SECRET) request.POST.get('secret') == settings.SHARED_SECRET)
def client_is_exempt_from_rate_limiting(request): def client_is_exempt_from_rate_limiting(request: HttpRequest) -> bool:
# type: (HttpRequest) -> bool
# Don't rate limit requests from Django that come from our own servers, # Don't rate limit requests from Django that come from our own servers,
# and don't rate-limit dev instances # and don't rate-limit dev instances
@@ -598,20 +562,17 @@ def client_is_exempt_from_rate_limiting(request):
(is_local_addr(request.META['REMOTE_ADDR']) or (is_local_addr(request.META['REMOTE_ADDR']) or
settings.DEBUG_RATE_LIMITING)) settings.DEBUG_RATE_LIMITING))
def internal_notify_view(is_tornado_view): def internal_notify_view(is_tornado_view: bool) -> WrappedViewFuncT:
# type: (bool) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]]
# The typing here could be improved by using the Extended Callable types: # The typing here could be improved by using the Extended Callable types:
# https://mypy.readthedocs.io/en/latest/kinds_of_types.html#extended-callable-types # https://mypy.readthedocs.io/en/latest/kinds_of_types.html#extended-callable-types
"""Used for situations where something running on the Zulip server """Used for situations where something running on the Zulip server
needs to make a request to the (other) Django/Tornado processes running on needs to make a request to the (other) Django/Tornado processes running on
the server.""" the server."""
def _wrapped_view_func(view_func): def _wrapped_view_func(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@csrf_exempt @csrf_exempt
@require_post @require_post
@wraps(view_func) @wraps(view_func)
def _wrapped_func_arguments(request, *args, **kwargs): def _wrapped_func_arguments(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
if not authenticate_notify(request): if not authenticate_notify(request):
return json_error(_('Access denied'), status=403) return json_error(_('Access denied'), status=403)
is_tornado_request = hasattr(request, '_tornado_handler') is_tornado_request = hasattr(request, '_tornado_handler')
@@ -627,52 +588,46 @@ def internal_notify_view(is_tornado_view):
return _wrapped_view_func return _wrapped_view_func
# Converter functions for use with has_request_variables # Converter functions for use with has_request_variables
def to_non_negative_int(s): def to_non_negative_int(s: Text) -> int:
# type: (Text) -> int
x = int(s) x = int(s)
if x < 0: if x < 0:
raise ValueError("argument is negative") raise ValueError("argument is negative")
return x return x
def to_not_negative_int_or_none(s): def to_not_negative_int_or_none(s: Text) -> Optional[int]:
# type: (Text) -> Optional[int]
if s: if s:
return to_non_negative_int(s) return to_non_negative_int(s)
return None return None
def flexible_boolean(boolean): def flexible_boolean(boolean: Text) -> bool:
# type: (Text) -> bool
"""Returns True for any of "1", "true", or "True". Returns False otherwise.""" """Returns True for any of "1", "true", or "True". Returns False otherwise."""
if boolean in ("1", "true", "True"): if boolean in ("1", "true", "True"):
return True return True
else: else:
return False return False
def to_utc_datetime(timestamp): def to_utc_datetime(timestamp: Text) -> datetime.datetime:
# type: (Text) -> datetime.datetime
return timestamp_to_datetime(float(timestamp)) return timestamp_to_datetime(float(timestamp))
WrapperT = Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]
def statsd_increment(counter, val=1): def statsd_increment(counter, val=1):
# type: (Text, int) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]] # type: (Text, int) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]
"""Increments a statsd counter on completion of the """Increments a statsd counter on completion of the
decorated function. decorated function.
Pass the name of the counter to this decorator-returning function.""" Pass the name of the counter to this decorator-returning function."""
def wrapper(func): def wrapper(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
# type: (Callable[..., ReturnT]) -> Callable[..., ReturnT]
@wraps(func) @wraps(func)
def wrapped_func(*args, **kwargs): def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT:
# type: (*Any, **Any) -> ReturnT
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
statsd.incr(counter, val) statsd.incr(counter, val)
return ret return ret
return wrapped_func return wrapped_func
return wrapper return wrapper
def rate_limit_user(request, user, domain): def rate_limit_user(request: HttpRequest, user: UserProfile, domain: Text) -> None:
# type: (HttpRequest, UserProfile, Text) -> None
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception """Returns whether or not a user was rate limited. Will raise a RateLimited exception
if the user has been rate limited, otherwise returns and modifies request to contain if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information""" the rate limit information"""
@@ -693,17 +648,14 @@ def rate_limit_user(request, user, domain):
request._ratelimit_remaining = calls_remaining request._ratelimit_remaining = calls_remaining
request._ratelimit_secs_to_freedom = time_reset request._ratelimit_secs_to_freedom = time_reset
def rate_limit(domain='all'): def rate_limit(domain: Text='all') -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]]:
# type: (Text) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]]
"""Rate-limits a view. Takes an optional 'domain' param if you wish to """Rate-limits a view. Takes an optional 'domain' param if you wish to
rate limit different types of API calls independently. rate limit different types of API calls independently.
Returns a decorator""" Returns a decorator"""
def wrapper(func): def wrapper(func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]:
# type: (Callable[..., HttpResponse]) -> Callable[..., HttpResponse]
@wraps(func) @wraps(func)
def wrapped_func(request, *args, **kwargs): def wrapped_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, *Any, **Any) -> HttpResponse
# It is really tempting to not even wrap our original function # It is really tempting to not even wrap our original function
# when settings.RATE_LIMITING is False, but it would make # when settings.RATE_LIMITING is False, but it would make

View File

@@ -5,8 +5,7 @@ from django.http import HttpRequest
from django.views.debug import SafeExceptionReporterFilter from django.views.debug import SafeExceptionReporterFilter
class ZulipExceptionReporterFilter(SafeExceptionReporterFilter): class ZulipExceptionReporterFilter(SafeExceptionReporterFilter):
def get_post_parameters(self, request): def get_post_parameters(self, request: HttpRequest) -> Dict[str, Any]:
# type: (HttpRequest) -> Dict[str, Any]
filtered_post = SafeExceptionReporterFilter.get_post_parameters(self, request).copy() filtered_post = SafeExceptionReporterFilter.get_post_parameters(self, request).copy()
filtered_vars = ['content', 'secret', 'password', 'key', 'api-key', 'subject', 'stream', filtered_vars = ['content', 'secret', 'password', 'key', 'api-key', 'subject', 'stream',
'subscriptions', 'to', 'csrfmiddlewaretoken', 'api_key'] 'subscriptions', 'to', 'csrfmiddlewaretoken', 'api_key']

View File

@@ -42,8 +42,7 @@ WRONG_SUBDOMAIN_ERROR = "Your Zulip account is not a member of the " + \
"organization associated with this subdomain. " + \ "organization associated with this subdomain. " + \
"Please contact %s with any questions!" % (FromAddress.SUPPORT,) "Please contact %s with any questions!" % (FromAddress.SUPPORT,)
def email_is_not_mit_mailing_list(email): def email_is_not_mit_mailing_list(email: Text) -> None:
# type: (Text) -> None
"""Prevent MIT mailing lists from signing up for Zulip""" """Prevent MIT mailing lists from signing up for Zulip"""
if "@mit.edu" in email: if "@mit.edu" in email:
username = email.rsplit("@", 1)[0] username = email.rsplit("@", 1)[0]
@@ -56,8 +55,7 @@ def email_is_not_mit_mailing_list(email):
else: else:
raise AssertionError("Unexpected DNS error") raise AssertionError("Unexpected DNS error")
def check_subdomain_available(subdomain): def check_subdomain_available(subdomain: str) -> None:
# type: (str) -> None
error_strings = { error_strings = {
'too short': _("Subdomain needs to have length 3 or greater."), 'too short': _("Subdomain needs to have length 3 or greater."),
'extremal dash': _("Subdomain cannot start or end with a '-'."), 'extremal dash': _("Subdomain cannot start or end with a '-'."),
@@ -86,9 +84,7 @@ class RegistrationForm(forms.Form):
password = forms.CharField(widget=forms.PasswordInput, max_length=MAX_PASSWORD_LENGTH) password = forms.CharField(widget=forms.PasswordInput, max_length=MAX_PASSWORD_LENGTH)
realm_subdomain = forms.CharField(max_length=Realm.MAX_REALM_SUBDOMAIN_LENGTH, required=False) realm_subdomain = forms.CharField(max_length=Realm.MAX_REALM_SUBDOMAIN_LENGTH, required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
# type: (*Any, **Any) -> None
# Since the superclass doesn't except random extra kwargs, we # Since the superclass doesn't except random extra kwargs, we
# remove it from the kwargs dict before initializing. # remove it from the kwargs dict before initializing.
self.realm_creation = kwargs['realm_creation'] self.realm_creation = kwargs['realm_creation']
@@ -101,15 +97,13 @@ class RegistrationForm(forms.Form):
max_length=Realm.MAX_REALM_NAME_LENGTH, max_length=Realm.MAX_REALM_NAME_LENGTH,
required=self.realm_creation) required=self.realm_creation)
def clean_full_name(self): def clean_full_name(self) -> Text:
# type: () -> Text
try: try:
return check_full_name(self.cleaned_data['full_name']) return check_full_name(self.cleaned_data['full_name'])
except JsonableError as e: except JsonableError as e:
raise ValidationError(e.msg) raise ValidationError(e.msg)
def clean_realm_subdomain(self): def clean_realm_subdomain(self) -> str:
# type: () -> str
if not self.realm_creation: if not self.realm_creation:
# This field is only used if realm_creation # This field is only used if realm_creation
return "" return ""
@@ -127,14 +121,12 @@ class ToSForm(forms.Form):
class HomepageForm(forms.Form): class HomepageForm(forms.Form):
email = forms.EmailField() email = forms.EmailField()
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
# type: (*Any, **Any) -> None
self.realm = kwargs.pop('realm', None) self.realm = kwargs.pop('realm', None)
self.from_multiuse_invite = kwargs.pop('from_multiuse_invite', False) self.from_multiuse_invite = kwargs.pop('from_multiuse_invite', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def clean_email(self): def clean_email(self) -> str:
# type: () -> str
"""Returns the email if and only if the user's email address is """Returns the email if and only if the user's email address is
allowed to join the realm they are trying to join.""" allowed to join the realm they are trying to join."""
email = self.cleaned_data['email'] email = self.cleaned_data['email']
@@ -166,8 +158,7 @@ class HomepageForm(forms.Form):
return email return email
def email_is_not_disposable(email): def email_is_not_disposable(email: Text) -> None:
# type: (Text) -> None
if is_disposable_domain(email_to_domain(email)): if is_disposable_domain(email_to_domain(email)):
raise ValidationError(_("Please use your real email address.")) raise ValidationError(_("Please use your real email address."))
@@ -177,8 +168,7 @@ class RealmCreationForm(forms.Form):
email_is_not_disposable]) email_is_not_disposable])
class LoggingSetPasswordForm(SetPasswordForm): class LoggingSetPasswordForm(SetPasswordForm):
def save(self, commit=True): def save(self, commit: bool=True) -> UserProfile:
# type: (bool) -> UserProfile
do_change_password(self.user, self.cleaned_data['new_password1'], do_change_password(self.user, self.cleaned_data['new_password1'],
commit=commit) commit=commit)
return self.user return self.user
@@ -251,8 +241,7 @@ class CreateUserForm(forms.Form):
email = forms.EmailField() email = forms.EmailField()
class OurAuthenticationForm(AuthenticationForm): class OurAuthenticationForm(AuthenticationForm):
def clean(self): def clean(self) -> Dict[str, Any]:
# type: () -> Dict[str, Any]
username = self.cleaned_data.get('username') username = self.cleaned_data.get('username')
password = self.cleaned_data.get('password') password = self.cleaned_data.get('password')
@@ -300,16 +289,14 @@ class OurAuthenticationForm(AuthenticationForm):
return field_name return field_name
class MultiEmailField(forms.Field): class MultiEmailField(forms.Field):
def to_python(self, emails): def to_python(self, emails: Text) -> List[Text]:
# type: (Text) -> List[Text]
"""Normalize data to a list of strings.""" """Normalize data to a list of strings."""
if not emails: if not emails:
return [] return []
return [email.strip() for email in emails.split(',')] return [email.strip() for email in emails.split(',')]
def validate(self, emails): def validate(self, emails: List[Text]) -> None:
# type: (List[Text]) -> None
"""Check if value consists only of valid emails.""" """Check if value consists only of valid emails."""
super().validate(emails) super().validate(emails)
for email in emails: for email in emails:
@@ -319,8 +306,7 @@ class FindMyTeamForm(forms.Form):
emails = MultiEmailField( emails = MultiEmailField(
help_text=_("Add up to 10 comma-separated email addresses.")) help_text=_("Add up to 10 comma-separated email addresses."))
def clean_emails(self): def clean_emails(self) -> List[Text]:
# type: () -> List[Text]
emails = self.cleaned_data['emails'] emails = self.cleaned_data['emails']
if len(emails) > 10: if len(emails) > 10:
raise forms.ValidationError(_("Please enter at most 10 emails.")) raise forms.ValidationError(_("Please enter at most 10 emails."))

View File

@@ -13,8 +13,7 @@ from django.views.debug import ExceptionReporter, get_exception_reporter_filter
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_json_publish
def add_request_metadata(report, request): def add_request_metadata(report: Dict[str, Any], request: HttpRequest) -> None:
# type: (Dict[str, Any], HttpRequest) -> None
report['path'] = request.path report['path'] = request.path
report['method'] = request.method report['method'] = request.method
report['remote_addr'] = request.META.get('REMOTE_ADDR', None), report['remote_addr'] = request.META.get('REMOTE_ADDR', None),
@@ -60,12 +59,10 @@ class AdminZulipHandler(logging.Handler):
# adapted in part from django/utils/log.py # adapted in part from django/utils/log.py
def __init__(self): def __init__(self) -> None:
# type: () -> None
logging.Handler.__init__(self) logging.Handler.__init__(self)
def emit(self, record): def emit(self, record: logging.LogRecord) -> None:
# type: (logging.LogRecord) -> None
try: try:
if record.exc_info: if record.exc_info:
stack_trace = ''.join(traceback.format_exception(*record.exc_info)) # type: Optional[str] stack_trace = ''.join(traceback.format_exception(*record.exc_info)) # type: Optional[str]

View File

@@ -30,8 +30,7 @@ from zerver.models import Realm, flush_per_request_caches, get_realm
logger = logging.getLogger('zulip.requests') logger = logging.getLogger('zulip.requests')
def record_request_stop_data(log_data): def record_request_stop_data(log_data: MutableMapping[str, Any]) -> None:
# type: (MutableMapping[str, Any]) -> None
log_data['time_stopped'] = time.time() log_data['time_stopped'] = time.time()
log_data['remote_cache_time_stopped'] = get_remote_cache_time() log_data['remote_cache_time_stopped'] = get_remote_cache_time()
log_data['remote_cache_requests_stopped'] = get_remote_cache_requests() log_data['remote_cache_requests_stopped'] = get_remote_cache_requests()
@@ -40,12 +39,10 @@ def record_request_stop_data(log_data):
if settings.PROFILE_ALL_REQUESTS: if settings.PROFILE_ALL_REQUESTS:
log_data["prof"].disable() log_data["prof"].disable()
def async_request_stop(request): def async_request_stop(request: HttpRequest) -> None:
# type: (HttpRequest) -> None
record_request_stop_data(request._log_data) record_request_stop_data(request._log_data)
def record_request_restart_data(log_data): def record_request_restart_data(log_data: MutableMapping[str, Any]) -> None:
# type: (MutableMapping[str, Any]) -> None
if settings.PROFILE_ALL_REQUESTS: if settings.PROFILE_ALL_REQUESTS:
log_data["prof"].enable() log_data["prof"].enable()
log_data['time_restarted'] = time.time() log_data['time_restarted'] = time.time()
@@ -54,16 +51,14 @@ def record_request_restart_data(log_data):
log_data['bugdown_time_restarted'] = get_bugdown_time() log_data['bugdown_time_restarted'] = get_bugdown_time()
log_data['bugdown_requests_restarted'] = get_bugdown_requests() log_data['bugdown_requests_restarted'] = get_bugdown_requests()
def async_request_restart(request): def async_request_restart(request: HttpRequest) -> None:
# type: (HttpRequest) -> None
if "time_restarted" in request._log_data: if "time_restarted" in request._log_data:
# Don't destroy data when being called from # Don't destroy data when being called from
# finish_current_handler # finish_current_handler
return return
record_request_restart_data(request._log_data) record_request_restart_data(request._log_data)
def record_request_start_data(log_data): def record_request_start_data(log_data: MutableMapping[str, Any]) -> None:
# type: (MutableMapping[str, Any]) -> None
if settings.PROFILE_ALL_REQUESTS: if settings.PROFILE_ALL_REQUESTS:
log_data["prof"] = cProfile.Profile() log_data["prof"] = cProfile.Profile()
log_data["prof"].enable() log_data["prof"].enable()
@@ -74,18 +69,15 @@ def record_request_start_data(log_data):
log_data['bugdown_time_start'] = get_bugdown_time() log_data['bugdown_time_start'] = get_bugdown_time()
log_data['bugdown_requests_start'] = get_bugdown_requests() log_data['bugdown_requests_start'] = get_bugdown_requests()
def timedelta_ms(timedelta): def timedelta_ms(timedelta: float) -> float:
# type: (float) -> float
return timedelta * 1000 return timedelta * 1000
def format_timedelta(timedelta): def format_timedelta(timedelta: float) -> str:
# type: (float) -> str
if (timedelta >= 1): if (timedelta >= 1):
return "%.1fs" % (timedelta) return "%.1fs" % (timedelta)
return "%.0fms" % (timedelta_ms(timedelta),) return "%.0fms" % (timedelta_ms(timedelta),)
def is_slow_query(time_delta, path): def is_slow_query(time_delta: float, path: Text) -> bool:
# type: (float, Text) -> bool
if time_delta < 1.2: if time_delta < 1.2:
return False return False
is_exempt = \ is_exempt = \
@@ -232,16 +224,15 @@ class LogRequests(MiddlewareMixin):
# We primarily are doing logging using the process_view hook, but # We primarily are doing logging using the process_view hook, but
# for some views, process_view isn't run, so we call the start # for some views, process_view isn't run, so we call the start
# method here too # method here too
def process_request(self, request): def process_request(self, request: HttpRequest) -> None:
# type: (HttpRequest) -> None
maybe_tracemalloc_listen() maybe_tracemalloc_listen()
request._log_data = dict() request._log_data = dict()
record_request_start_data(request._log_data) record_request_start_data(request._log_data)
if connection.connection is not None: if connection.connection is not None:
connection.connection.queries = [] connection.connection.queries = []
def process_view(self, request, view_func, args, kwargs): def process_view(self, request: HttpRequest, view_func: Callable[..., HttpResponse],
# type: (HttpRequest, Callable[..., HttpResponse], List[str], Dict[str, Any]) -> None args: List[str], kwargs: Dict[str, Any]) -> None:
# process_request was already run; we save the initialization # process_request was already run; we save the initialization
# time (i.e. the time between receiving the request and # time (i.e. the time between receiving the request and
# figuring out which view function to call, which is primarily # figuring out which view function to call, which is primarily
@@ -253,8 +244,8 @@ class LogRequests(MiddlewareMixin):
if connection.connection is not None: if connection.connection is not None:
connection.connection.queries = [] connection.connection.queries = []
def process_response(self, request, response): def process_response(self, request: HttpRequest,
# type: (HttpRequest, StreamingHttpResponse) -> StreamingHttpResponse response: StreamingHttpResponse) -> StreamingHttpResponse:
# The reverse proxy might have sent us the real external IP # The reverse proxy might have sent us the real external IP
remote_ip = request.META.get('HTTP_X_REAL_IP') remote_ip = request.META.get('HTTP_X_REAL_IP')
if remote_ip is None: if remote_ip is None:
@@ -283,8 +274,7 @@ class LogRequests(MiddlewareMixin):
return response return response
class JsonErrorHandler(MiddlewareMixin): class JsonErrorHandler(MiddlewareMixin):
def process_exception(self, request, exception): def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]:
# type: (HttpRequest, Exception) -> Optional[HttpResponse]
if isinstance(exception, JsonableError): if isinstance(exception, JsonableError):
return json_response_from_error(exception) return json_response_from_error(exception)
if request.error_format == "JSON": if request.error_format == "JSON":
@@ -293,12 +283,11 @@ class JsonErrorHandler(MiddlewareMixin):
return None return None
class TagRequests(MiddlewareMixin): class TagRequests(MiddlewareMixin):
def process_view(self, request, view_func, args, kwargs): def process_view(self, request: HttpRequest, view_func: Callable[..., HttpResponse],
# type: (HttpRequest, Callable[..., HttpResponse], List[str], Dict[str, Any]) -> None args: List[str], kwargs: Dict[str, Any]) -> None:
self.process_request(request) self.process_request(request)
def process_request(self, request): def process_request(self, request: HttpRequest) -> None:
# type: (HttpRequest) -> None
if request.path.startswith("/api/") or request.path.startswith("/json/"): if request.path.startswith("/api/") or request.path.startswith("/json/"):
request.error_format = "JSON" request.error_format = "JSON"
else: else:
@@ -309,25 +298,21 @@ class CsrfFailureError(JsonableError):
code = ErrorCode.CSRF_FAILED code = ErrorCode.CSRF_FAILED
data_fields = ['reason'] data_fields = ['reason']
def __init__(self, reason): def __init__(self, reason: Text) -> None:
# type: (Text) -> None
self.reason = reason # type: Text self.reason = reason # type: Text
@staticmethod @staticmethod
def msg_format(): def msg_format() -> Text:
# type: () -> Text
return _("CSRF Error: {reason}") return _("CSRF Error: {reason}")
def csrf_failure(request, reason=""): def csrf_failure(request: HttpRequest, reason: Text="") -> HttpResponse:
# type: (HttpRequest, Text) -> HttpResponse
if request.error_format == "JSON": if request.error_format == "JSON":
return json_response_from_error(CsrfFailureError(reason)) return json_response_from_error(CsrfFailureError(reason))
else: else:
return html_csrf_failure(request, reason) return html_csrf_failure(request, reason)
class RateLimitMiddleware(MiddlewareMixin): class RateLimitMiddleware(MiddlewareMixin):
def process_response(self, request, response): def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
# type: (HttpRequest, HttpResponse) -> HttpResponse
if not settings.RATE_LIMITING: if not settings.RATE_LIMITING:
return response return response
@@ -342,8 +327,7 @@ class RateLimitMiddleware(MiddlewareMixin):
response['X-RateLimit-Remaining'] = str(request._ratelimit_remaining) response['X-RateLimit-Remaining'] = str(request._ratelimit_remaining)
return response return response
def process_exception(self, request, exception): def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]:
# type: (HttpRequest, Exception) -> Optional[HttpResponse]
if isinstance(exception, RateLimited): if isinstance(exception, RateLimited):
resp = json_error( resp = json_error(
_("API usage exceeded rate limit"), _("API usage exceeded rate limit"),
@@ -355,16 +339,14 @@ class RateLimitMiddleware(MiddlewareMixin):
return None return None
class FlushDisplayRecipientCache(MiddlewareMixin): class FlushDisplayRecipientCache(MiddlewareMixin):
def process_response(self, request, response): def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
# type: (HttpRequest, HttpResponse) -> HttpResponse
# We flush the per-request caches after every request, so they # We flush the per-request caches after every request, so they
# are not shared at all between requests. # are not shared at all between requests.
flush_per_request_caches() flush_per_request_caches()
return response return response
class SessionHostDomainMiddleware(SessionMiddleware): class SessionHostDomainMiddleware(SessionMiddleware):
def process_response(self, request, response): def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
# type: (HttpRequest, HttpResponse) -> HttpResponse
try: try:
request.get_host() request.get_host()
except DisallowedHost: except DisallowedHost:
@@ -431,8 +413,7 @@ class SetRemoteAddrFromForwardedFor(MiddlewareMixin):
is set in the request, then it has properly been set by NGINX. is set in the request, then it has properly been set by NGINX.
Therefore HTTP_X_FORWARDED_FOR's value is trusted. Therefore HTTP_X_FORWARDED_FOR's value is trusted.
""" """
def process_request(self, request): def process_request(self, request: HttpRequest) -> None:
# type: (HttpRequest) -> None
try: try:
real_ip = request.META['HTTP_X_FORWARDED_FOR'] real_ip = request.META['HTTP_X_FORWARDED_FOR']
except KeyError: except KeyError:

File diff suppressed because it is too large Load Diff

View File

@@ -12,8 +12,7 @@ from django.utils.timezone import now as timezone_now
from zerver.lib.send_email import FromAddress, send_email from zerver.lib.send_email import FromAddress, send_email
from zerver.models import UserProfile from zerver.models import UserProfile
def get_device_browser(user_agent): def get_device_browser(user_agent: str) -> Optional[str]:
# type: (str) -> Optional[str]
user_agent = user_agent.lower() user_agent = user_agent.lower()
if "zulip" in user_agent: if "zulip" in user_agent:
return "Zulip" return "Zulip"
@@ -35,8 +34,7 @@ def get_device_browser(user_agent):
return None return None
def get_device_os(user_agent): def get_device_os(user_agent: str) -> Optional[str]:
# type: (str) -> Optional[str]
user_agent = user_agent.lower() user_agent = user_agent.lower()
if "windows" in user_agent: if "windows" in user_agent:
return "Windows" return "Windows"
@@ -55,9 +53,7 @@ def get_device_os(user_agent):
@receiver(user_logged_in, dispatch_uid="only_on_login") @receiver(user_logged_in, dispatch_uid="only_on_login")
def email_on_new_login(sender, user, request, **kwargs): def email_on_new_login(sender: Any, user: UserProfile, request: Any, **kwargs: Any) -> None:
# type: (Any, UserProfile, Any, **Any) -> None
# We import here to minimize the dependencies of this module, # We import here to minimize the dependencies of this module,
# since it runs as part of `manage.py` initialization # since it runs as part of `manage.py` initialization
from zerver.context_processors import common_context from zerver.context_processors import common_context

View File

@@ -11,8 +11,8 @@ from pipeline.storage import PipelineMixin
from zerver.lib.str_utils import force_str from zerver.lib.str_utils import force_str
class AddHeaderMixin: class AddHeaderMixin:
def post_process(self, paths, dry_run=False, **kwargs): def post_process(self, paths: Dict[str, Tuple['ZulipStorage', str]], dry_run: bool=False,
# type: (Dict[str, Tuple[ZulipStorage, str]], bool, **Any) -> List[Tuple[str, str, bool]] **kwargs: Any) -> List[Tuple[str, str, bool]]:
if dry_run: if dry_run:
return [] return []
@@ -58,8 +58,8 @@ class AddHeaderMixin:
class RemoveUnminifiedFilesMixin: class RemoveUnminifiedFilesMixin:
def post_process(self, paths, dry_run=False, **kwargs): def post_process(self, paths: Dict[str, Tuple['ZulipStorage', str]], dry_run: bool=False,
# type: (Dict[str, Tuple[ZulipStorage, str]], bool, **Any) -> List[Tuple[str, str, bool]] **kwargs: Any) -> List[Tuple[str, str, bool]]:
if dry_run: if dry_run:
return [] return []
@@ -88,8 +88,7 @@ if settings.PRODUCTION:
"staticfiles.json") "staticfiles.json")
orig_path = ManifestStaticFilesStorage.path orig_path = ManifestStaticFilesStorage.path
def path(self, name): def path(self: Any, name: str) -> str:
# type: (Any, str) -> str
if name == ManifestStaticFilesStorage.manifest_name: if name == ManifestStaticFilesStorage.manifest_name:
return name return name
return orig_path(self, name) return orig_path(self, name)