diff --git a/zerver/decorator.py b/zerver/decorator.py index c2f7aaa940..4dc97f9243 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -22,13 +22,12 @@ from zerver.lib.queue import queue_json_publish from zerver.lib.subdomains import get_subdomain, user_matches_subdomain from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.utils import statsd, is_remote_server -from zerver.lib.exceptions import RateLimited, JsonableError, ErrorCode, \ +from zerver.lib.exceptions import JsonableError, ErrorCode, \ InvalidJSONError, InvalidAPIKeyError from zerver.lib.types import ViewFuncT from zerver.lib.validator import to_non_negative_int -from zerver.lib.rate_limiter import rate_limit_entity, \ - api_calls_left, RateLimitedUser +from zerver.lib.rate_limiter import rate_limit_request_by_entity, RateLimitedUser from zerver.lib.request import REQ, has_request_variables from functools import wraps @@ -746,18 +745,7 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non the rate limit information""" entity = RateLimitedUser(user, domain=domain) - ratelimited, time = rate_limit_entity(entity) - request._ratelimit_applied_limits = True - request._ratelimit_secs_to_freedom = time - request._ratelimit_over_limit = ratelimited - # Abort this request if the user is over their rate limits - if ratelimited: - raise RateLimited() - - calls_remaining, time_reset = api_calls_left(entity) - - request._ratelimit_remaining = calls_remaining - request._ratelimit_secs_to_freedom = time_reset + rate_limit_request_by_entity(request, entity) def rate_limit(domain: str='all') -> Callable[[ViewFuncT], ViewFuncT]: """Rate-limits a view. Takes an optional 'domain' param if you wish to diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 134d0ef4ca..32bfe73646 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -4,6 +4,8 @@ import os from typing import List, Optional, Tuple from django.conf import settings +from django.http import HttpRequest +from zerver.lib.exceptions import RateLimited from zerver.lib.redis_utils import get_redis_client from zerver.lib.utils import statsd @@ -261,3 +263,23 @@ def rate_limit_entity(entity: RateLimitedObject) -> Tuple[bool, float]: ratelimited = True return ratelimited, time + +def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject) -> None: + ratelimited, time = rate_limit_entity(entity) + + entity_type = type(entity).__name__ + if not hasattr(request, '_ratelimit'): + request._ratelimit = {} + request._ratelimit[entity_type] = {} + request._ratelimit[entity_type]['applied_limits'] = True + request._ratelimit[entity_type]['secs_to_freedom'] = time + request._ratelimit[entity_type]['over_limit'] = ratelimited + # Abort this request if the user is over their rate limits + if ratelimited: + # Pass information about what kind of entity got limited in the exception: + raise RateLimited(entity_type) + + calls_remaining, time_reset = api_calls_left(entity) + + request._ratelimit[entity_type]['remaining'] = calls_remaining + request._ratelimit[entity_type]['secs_to_freedom'] = time_reset diff --git a/zerver/middleware.py b/zerver/middleware.py index 3eb1de29f3..8dfe0fe4eb 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -326,23 +326,25 @@ class RateLimitMiddleware(MiddlewareMixin): from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser # Add X-RateLimit-*** headers - if hasattr(request, '_ratelimit_applied_limits'): + if hasattr(request, '_ratelimit'): + # Right now, the only kind of limiting requests is user-based. + ratelimit_user_results = request._ratelimit['RateLimitedUser'] entity = RateLimitedUser(request.user) response['X-RateLimit-Limit'] = str(max_api_calls(entity)) - if hasattr(request, '_ratelimit_secs_to_freedom'): - response['X-RateLimit-Reset'] = str(int(time.time() + request._ratelimit_secs_to_freedom)) - if hasattr(request, '_ratelimit_remaining'): - response['X-RateLimit-Remaining'] = str(request._ratelimit_remaining) + response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom'])) + if 'remaining' in ratelimit_user_results: + response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining']) return response def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]: if isinstance(exception, RateLimited): + entity_type = str(exception) # entity type is passed to RateLimited when raising resp = json_error( _("API usage exceeded rate limit"), - data={'retry-after': request._ratelimit_secs_to_freedom}, + data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']}, status=429 ) - resp['Retry-After'] = request._ratelimit_secs_to_freedom + resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom'] return resp return None