mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	exceptions: Make RateLimited into a subclass of JsonableError.
This simplifies the code, as it allows using the mechanism of converting JsonableErrors into a response instead of having separate, but ultimately similar, logic in RateLimitMiddleware. We don't touch tests here because "rate limited" error responses are already verified in test_external.py.
This commit is contained in:
		
				
					committed by
					
						
						Alex Vandiver
					
				
			
			
				
	
			
			
			
						parent
						
							92ce2d0e31
						
					
				
				
					commit
					43a0c60e96
				
			@@ -349,7 +349,8 @@ class OurAuthenticationForm(AuthenticationForm):
 | 
				
			|||||||
                self.user_cache = authenticate(request=self.request, username=username, password=password,
 | 
					                self.user_cache = authenticate(request=self.request, username=username, password=password,
 | 
				
			||||||
                                               realm=realm, return_data=return_data)
 | 
					                                               realm=realm, return_data=return_data)
 | 
				
			||||||
            except RateLimited as e:
 | 
					            except RateLimited as e:
 | 
				
			||||||
                secs_to_freedom = int(float(str(e)))
 | 
					                assert e.secs_to_freedom is not None
 | 
				
			||||||
 | 
					                secs_to_freedom = int(e.secs_to_freedom)
 | 
				
			||||||
                raise ValidationError(AUTHENTICATION_RATE_LIMITED_ERROR.format(secs_to_freedom))
 | 
					                raise ValidationError(AUTHENTICATION_RATE_LIMITED_ERROR.format(secs_to_freedom))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if return_data.get("inactive_realm"):
 | 
					            if return_data.get("inactive_realm"):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -47,6 +47,7 @@ class ErrorCode(AbstractEnum):
 | 
				
			|||||||
    INVALID_ZOOM_TOKEN = ()
 | 
					    INVALID_ZOOM_TOKEN = ()
 | 
				
			||||||
    UNAUTHENTICATED_USER = ()
 | 
					    UNAUTHENTICATED_USER = ()
 | 
				
			||||||
    NONEXISTENT_SUBDOMAIN = ()
 | 
					    NONEXISTENT_SUBDOMAIN = ()
 | 
				
			||||||
 | 
					    RATE_LIMIT_HIT = ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class JsonableError(Exception):
 | 
					class JsonableError(Exception):
 | 
				
			||||||
    '''A standardized error format we can turn into a nice JSON HTTP response.
 | 
					    '''A standardized error format we can turn into a nice JSON HTTP response.
 | 
				
			||||||
@@ -111,6 +112,10 @@ class JsonableError(Exception):
 | 
				
			|||||||
        # at construction time.
 | 
					        # at construction time.
 | 
				
			||||||
        return '{_msg}'
 | 
					        return '{_msg}'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def extra_headers(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        return {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    # Infrastructure -- not intended to be overridden in subclasses.
 | 
					    # Infrastructure -- not intended to be overridden in subclasses.
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
@@ -179,9 +184,31 @@ class InvalidMarkdownIncludeStatement(JsonableError):
 | 
				
			|||||||
    def msg_format() -> str:
 | 
					    def msg_format() -> str:
 | 
				
			||||||
        return _("Invalid Markdown include statement: {include_statement}")
 | 
					        return _("Invalid Markdown include statement: {include_statement}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RateLimited(Exception):
 | 
					class RateLimited(JsonableError):
 | 
				
			||||||
    def __init__(self, msg: str="") -> None:
 | 
					    code = ErrorCode.RATE_LIMIT_HIT
 | 
				
			||||||
        super().__init__(msg)
 | 
					    http_status_code = 429
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, secs_to_freedom: Optional[float]=None) -> None:
 | 
				
			||||||
 | 
					        self.secs_to_freedom = secs_to_freedom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def msg_format() -> str:
 | 
				
			||||||
 | 
					        return _("API usage exceeded rate limit")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def extra_headers(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        extra_headers_dict = super().extra_headers
 | 
				
			||||||
 | 
					        if self.secs_to_freedom is not None:
 | 
				
			||||||
 | 
					            extra_headers_dict["Retry-After"] = self.secs_to_freedom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return extra_headers_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def data(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        data_dict = super().data
 | 
				
			||||||
 | 
					        data_dict['retry-after'] = self.secs_to_freedom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return data_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InvalidJSONError(JsonableError):
 | 
					class InvalidJSONError(JsonableError):
 | 
				
			||||||
    code = ErrorCode.INVALID_JSON
 | 
					    code = ErrorCode.INVALID_JSON
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -53,7 +53,7 @@ class RateLimitedObject(ABC):
 | 
				
			|||||||
        # Abort this request if the user is over their rate limits
 | 
					        # Abort this request if the user is over their rate limits
 | 
				
			||||||
        if ratelimited:
 | 
					        if ratelimited:
 | 
				
			||||||
            # Pass information about what kind of entity got limited in the exception:
 | 
					            # Pass information about what kind of entity got limited in the exception:
 | 
				
			||||||
            raise RateLimited(str(time))
 | 
					            raise RateLimited(time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        calls_remaining, seconds_until_reset = self.api_calls_left()
 | 
					        calls_remaining, seconds_until_reset = self.api_calls_left()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -66,10 +66,15 @@ def json_response_from_error(exception: JsonableError) -> HttpResponse:
 | 
				
			|||||||
    middleware takes care of transforming it into a response by
 | 
					    middleware takes care of transforming it into a response by
 | 
				
			||||||
    calling this function.
 | 
					    calling this function.
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    return json_response('error',
 | 
					    response = json_response('error',
 | 
				
			||||||
                         msg=exception.msg,
 | 
					                             msg=exception.msg,
 | 
				
			||||||
                         data=exception.data,
 | 
					                             data=exception.data,
 | 
				
			||||||
                         status=exception.http_status_code)
 | 
					                             status=exception.http_status_code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for header, value in exception.extra_headers.items():
 | 
				
			||||||
 | 
					        response[header] = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def json_error(msg: str, data: Mapping[str, Any]={}, status: int=400) -> HttpResponse:
 | 
					def json_error(msg: str, data: Mapping[str, Any]={}, status: int=400) -> HttpResponse:
 | 
				
			||||||
    return json_response(res_type="error", msg=msg, data=data, status=status)
 | 
					    return json_response(res_type="error", msg=msg, data=data, status=status)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,7 +23,7 @@ from sentry_sdk.integrations.logging import ignore_logger
 | 
				
			|||||||
from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
 | 
					from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
 | 
				
			||||||
from zerver.lib.db import reset_queries
 | 
					from zerver.lib.db import reset_queries
 | 
				
			||||||
from zerver.lib.debug import maybe_tracemalloc_listen
 | 
					from zerver.lib.debug import maybe_tracemalloc_listen
 | 
				
			||||||
from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError, RateLimited
 | 
					from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError
 | 
				
			||||||
from zerver.lib.html_to_text import get_content_description
 | 
					from zerver.lib.html_to_text import get_content_description
 | 
				
			||||||
from zerver.lib.markdown import get_markdown_requests, get_markdown_time
 | 
					from zerver.lib.markdown import get_markdown_requests, get_markdown_time
 | 
				
			||||||
from zerver.lib.rate_limiter import RateLimitResult
 | 
					from zerver.lib.rate_limiter import RateLimitResult
 | 
				
			||||||
@@ -408,20 +408,6 @@ class RateLimitMiddleware(MiddlewareMixin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_exception(self, request: HttpRequest,
 | 
					 | 
				
			||||||
                          exception: Exception) -> Optional[HttpResponse]:
 | 
					 | 
				
			||||||
        if isinstance(exception, RateLimited):
 | 
					 | 
				
			||||||
            # secs_to_freedom is passed to RateLimited when raising
 | 
					 | 
				
			||||||
            secs_to_freedom = float(str(exception))
 | 
					 | 
				
			||||||
            resp = json_error(
 | 
					 | 
				
			||||||
                _("API usage exceeded rate limit"),
 | 
					 | 
				
			||||||
                data={'retry-after': secs_to_freedom},
 | 
					 | 
				
			||||||
                status=429,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            resp['Retry-After'] = secs_to_freedom
 | 
					 | 
				
			||||||
            return resp
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlushDisplayRecipientCache(MiddlewareMixin):
 | 
					class FlushDisplayRecipientCache(MiddlewareMixin):
 | 
				
			||||||
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
 | 
					    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
 | 
				
			||||||
        # We flush the per-request caches after every request, so they
 | 
					        # We flush the per-request caches after every request, so they
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -98,7 +98,8 @@ def json_change_settings(request: HttpRequest, user_profile: UserProfile,
 | 
				
			|||||||
                                realm=user_profile.realm, return_data=return_data):
 | 
					                                realm=user_profile.realm, return_data=return_data):
 | 
				
			||||||
                return json_error(_("Wrong password!"))
 | 
					                return json_error(_("Wrong password!"))
 | 
				
			||||||
        except RateLimited as e:
 | 
					        except RateLimited as e:
 | 
				
			||||||
            secs_to_freedom = int(float(str(e)))
 | 
					            assert e.secs_to_freedom is not None
 | 
				
			||||||
 | 
					            secs_to_freedom = int(e.secs_to_freedom)
 | 
				
			||||||
            return json_error(
 | 
					            return json_error(
 | 
				
			||||||
                _("You're making too many attempts! Try again in {} seconds.").format(secs_to_freedom),
 | 
					                _("You're making too many attempts! Try again in {} seconds.").format(secs_to_freedom),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user