mirror of
https://github.com/zulip/zulip.git
synced 2025-11-01 20:44:04 +00:00
exceptions: Make RateLimited into a subclass of JsonableError.
This simplifies the code, as it allows using the mechanism of converting JsonableErrors into a response instead of having separate, but ultimately similar, logic in RateLimitMiddleware. We don't touch tests here because "rate limited" error responses are already verified in test_external.py.
This commit is contained in:
committed by
Alex Vandiver
parent
92ce2d0e31
commit
43a0c60e96
@@ -349,7 +349,8 @@ class OurAuthenticationForm(AuthenticationForm):
|
|||||||
self.user_cache = authenticate(request=self.request, username=username, password=password,
|
self.user_cache = authenticate(request=self.request, username=username, password=password,
|
||||||
realm=realm, return_data=return_data)
|
realm=realm, return_data=return_data)
|
||||||
except RateLimited as e:
|
except RateLimited as e:
|
||||||
secs_to_freedom = int(float(str(e)))
|
assert e.secs_to_freedom is not None
|
||||||
|
secs_to_freedom = int(e.secs_to_freedom)
|
||||||
raise ValidationError(AUTHENTICATION_RATE_LIMITED_ERROR.format(secs_to_freedom))
|
raise ValidationError(AUTHENTICATION_RATE_LIMITED_ERROR.format(secs_to_freedom))
|
||||||
|
|
||||||
if return_data.get("inactive_realm"):
|
if return_data.get("inactive_realm"):
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class ErrorCode(AbstractEnum):
|
|||||||
INVALID_ZOOM_TOKEN = ()
|
INVALID_ZOOM_TOKEN = ()
|
||||||
UNAUTHENTICATED_USER = ()
|
UNAUTHENTICATED_USER = ()
|
||||||
NONEXISTENT_SUBDOMAIN = ()
|
NONEXISTENT_SUBDOMAIN = ()
|
||||||
|
RATE_LIMIT_HIT = ()
|
||||||
|
|
||||||
class JsonableError(Exception):
|
class JsonableError(Exception):
|
||||||
'''A standardized error format we can turn into a nice JSON HTTP response.
|
'''A standardized error format we can turn into a nice JSON HTTP response.
|
||||||
@@ -111,6 +112,10 @@ class JsonableError(Exception):
|
|||||||
# at construction time.
|
# at construction time.
|
||||||
return '{_msg}'
|
return '{_msg}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_headers(self) -> Dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
#
|
#
|
||||||
# Infrastructure -- not intended to be overridden in subclasses.
|
# Infrastructure -- not intended to be overridden in subclasses.
|
||||||
#
|
#
|
||||||
@@ -179,9 +184,31 @@ class InvalidMarkdownIncludeStatement(JsonableError):
|
|||||||
def msg_format() -> str:
|
def msg_format() -> str:
|
||||||
return _("Invalid Markdown include statement: {include_statement}")
|
return _("Invalid Markdown include statement: {include_statement}")
|
||||||
|
|
||||||
class RateLimited(Exception):
|
class RateLimited(JsonableError):
|
||||||
def __init__(self, msg: str="") -> None:
|
code = ErrorCode.RATE_LIMIT_HIT
|
||||||
super().__init__(msg)
|
http_status_code = 429
|
||||||
|
|
||||||
|
def __init__(self, secs_to_freedom: Optional[float]=None) -> None:
|
||||||
|
self.secs_to_freedom = secs_to_freedom
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def msg_format() -> str:
|
||||||
|
return _("API usage exceeded rate limit")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_headers(self) -> Dict[str, Any]:
|
||||||
|
extra_headers_dict = super().extra_headers
|
||||||
|
if self.secs_to_freedom is not None:
|
||||||
|
extra_headers_dict["Retry-After"] = self.secs_to_freedom
|
||||||
|
|
||||||
|
return extra_headers_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Any]:
|
||||||
|
data_dict = super().data
|
||||||
|
data_dict['retry-after'] = self.secs_to_freedom
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
|
||||||
class InvalidJSONError(JsonableError):
|
class InvalidJSONError(JsonableError):
|
||||||
code = ErrorCode.INVALID_JSON
|
code = ErrorCode.INVALID_JSON
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class RateLimitedObject(ABC):
|
|||||||
# Abort this request if the user is over their rate limits
|
# Abort this request if the user is over their rate limits
|
||||||
if ratelimited:
|
if ratelimited:
|
||||||
# Pass information about what kind of entity got limited in the exception:
|
# Pass information about what kind of entity got limited in the exception:
|
||||||
raise RateLimited(str(time))
|
raise RateLimited(time)
|
||||||
|
|
||||||
calls_remaining, seconds_until_reset = self.api_calls_left()
|
calls_remaining, seconds_until_reset = self.api_calls_left()
|
||||||
|
|
||||||
|
|||||||
@@ -66,10 +66,15 @@ def json_response_from_error(exception: JsonableError) -> HttpResponse:
|
|||||||
middleware takes care of transforming it into a response by
|
middleware takes care of transforming it into a response by
|
||||||
calling this function.
|
calling this function.
|
||||||
'''
|
'''
|
||||||
return json_response('error',
|
response = json_response('error',
|
||||||
msg=exception.msg,
|
msg=exception.msg,
|
||||||
data=exception.data,
|
data=exception.data,
|
||||||
status=exception.http_status_code)
|
status=exception.http_status_code)
|
||||||
|
|
||||||
|
for header, value in exception.extra_headers.items():
|
||||||
|
response[header] = value
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def json_error(msg: str, data: Mapping[str, Any]={}, status: int=400) -> HttpResponse:
|
def json_error(msg: str, data: Mapping[str, Any]={}, status: int=400) -> HttpResponse:
|
||||||
return json_response(res_type="error", msg=msg, data=data, status=status)
|
return json_response(res_type="error", msg=msg, data=data, status=status)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from sentry_sdk.integrations.logging import ignore_logger
|
|||||||
from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
|
from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
|
||||||
from zerver.lib.db import reset_queries
|
from zerver.lib.db import reset_queries
|
||||||
from zerver.lib.debug import maybe_tracemalloc_listen
|
from zerver.lib.debug import maybe_tracemalloc_listen
|
||||||
from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError, RateLimited
|
from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError
|
||||||
from zerver.lib.html_to_text import get_content_description
|
from zerver.lib.html_to_text import get_content_description
|
||||||
from zerver.lib.markdown import get_markdown_requests, get_markdown_time
|
from zerver.lib.markdown import get_markdown_requests, get_markdown_time
|
||||||
from zerver.lib.rate_limiter import RateLimitResult
|
from zerver.lib.rate_limiter import RateLimitResult
|
||||||
@@ -408,20 +408,6 @@ class RateLimitMiddleware(MiddlewareMixin):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def process_exception(self, request: HttpRequest,
|
|
||||||
exception: Exception) -> Optional[HttpResponse]:
|
|
||||||
if isinstance(exception, RateLimited):
|
|
||||||
# secs_to_freedom is passed to RateLimited when raising
|
|
||||||
secs_to_freedom = float(str(exception))
|
|
||||||
resp = json_error(
|
|
||||||
_("API usage exceeded rate limit"),
|
|
||||||
data={'retry-after': secs_to_freedom},
|
|
||||||
status=429,
|
|
||||||
)
|
|
||||||
resp['Retry-After'] = secs_to_freedom
|
|
||||||
return resp
|
|
||||||
return None
|
|
||||||
|
|
||||||
class FlushDisplayRecipientCache(MiddlewareMixin):
|
class FlushDisplayRecipientCache(MiddlewareMixin):
|
||||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||||
# We flush the per-request caches after every request, so they
|
# We flush the per-request caches after every request, so they
|
||||||
|
|||||||
@@ -98,7 +98,8 @@ def json_change_settings(request: HttpRequest, user_profile: UserProfile,
|
|||||||
realm=user_profile.realm, return_data=return_data):
|
realm=user_profile.realm, return_data=return_data):
|
||||||
return json_error(_("Wrong password!"))
|
return json_error(_("Wrong password!"))
|
||||||
except RateLimited as e:
|
except RateLimited as e:
|
||||||
secs_to_freedom = int(float(str(e)))
|
assert e.secs_to_freedom is not None
|
||||||
|
secs_to_freedom = int(e.secs_to_freedom)
|
||||||
return json_error(
|
return json_error(
|
||||||
_("You're making too many attempts! Try again in {} seconds.").format(secs_to_freedom),
|
_("You're making too many attempts! Try again in {} seconds.").format(secs_to_freedom),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user