rate_limiter: Create a general rate_limit_request_by_entity function.

This commit is contained in:
Mateusz Mandera
2019-04-01 20:11:56 +02:00
committed by Tim Abbott
parent 22d0cd9696
commit f73600c82c
3 changed files with 34 additions and 22 deletions

View File

@@ -22,13 +22,12 @@ from zerver.lib.queue import queue_json_publish
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import statsd, is_remote_server from zerver.lib.utils import statsd, is_remote_server
from zerver.lib.exceptions import RateLimited, JsonableError, ErrorCode, \ from zerver.lib.exceptions import JsonableError, ErrorCode, \
InvalidJSONError, InvalidAPIKeyError InvalidJSONError, InvalidAPIKeyError
from zerver.lib.types import ViewFuncT from zerver.lib.types import ViewFuncT
from zerver.lib.validator import to_non_negative_int from zerver.lib.validator import to_non_negative_int
from zerver.lib.rate_limiter import rate_limit_entity, \ from zerver.lib.rate_limiter import rate_limit_request_by_entity, RateLimitedUser
api_calls_left, RateLimitedUser
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from functools import wraps from functools import wraps
@@ -746,18 +745,7 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non
the rate limit information""" the rate limit information"""
entity = RateLimitedUser(user, domain=domain) entity = RateLimitedUser(user, domain=domain)
ratelimited, time = rate_limit_entity(entity) rate_limit_request_by_entity(request, entity)
request._ratelimit_applied_limits = True
request._ratelimit_secs_to_freedom = time
request._ratelimit_over_limit = ratelimited
# Abort this request if the user is over their rate limits
if ratelimited:
raise RateLimited()
calls_remaining, time_reset = api_calls_left(entity)
request._ratelimit_remaining = calls_remaining
request._ratelimit_secs_to_freedom = time_reset
def rate_limit(domain: str='all') -> Callable[[ViewFuncT], ViewFuncT]: def rate_limit(domain: str='all') -> Callable[[ViewFuncT], ViewFuncT]:
"""Rate-limits a view. Takes an optional 'domain' param if you wish to """Rate-limits a view. Takes an optional 'domain' param if you wish to

View File

@@ -4,6 +4,8 @@ import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from django.conf import settings from django.conf import settings
from django.http import HttpRequest
from zerver.lib.exceptions import RateLimited
from zerver.lib.redis_utils import get_redis_client from zerver.lib.redis_utils import get_redis_client
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
@@ -261,3 +263,23 @@ def rate_limit_entity(entity: RateLimitedObject) -> Tuple[bool, float]:
ratelimited = True ratelimited = True
return ratelimited, time return ratelimited, time
def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject) -> None:
ratelimited, time = rate_limit_entity(entity)
entity_type = type(entity).__name__
if not hasattr(request, '_ratelimit'):
request._ratelimit = {}
request._ratelimit[entity_type] = {}
request._ratelimit[entity_type]['applied_limits'] = True
request._ratelimit[entity_type]['secs_to_freedom'] = time
request._ratelimit[entity_type]['over_limit'] = ratelimited
# Abort this request if the user is over their rate limits
if ratelimited:
# Pass information about what kind of entity got limited in the exception:
raise RateLimited(entity_type)
calls_remaining, time_reset = api_calls_left(entity)
request._ratelimit[entity_type]['remaining'] = calls_remaining
request._ratelimit[entity_type]['secs_to_freedom'] = time_reset

View File

@@ -326,23 +326,25 @@ class RateLimitMiddleware(MiddlewareMixin):
from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser
# Add X-RateLimit-*** headers # Add X-RateLimit-*** headers
if hasattr(request, '_ratelimit_applied_limits'): if hasattr(request, '_ratelimit'):
# Right now, the only kind of limiting requests is user-based.
ratelimit_user_results = request._ratelimit['RateLimitedUser']
entity = RateLimitedUser(request.user) entity = RateLimitedUser(request.user)
response['X-RateLimit-Limit'] = str(max_api_calls(entity)) response['X-RateLimit-Limit'] = str(max_api_calls(entity))
if hasattr(request, '_ratelimit_secs_to_freedom'): response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom']))
response['X-RateLimit-Reset'] = str(int(time.time() + request._ratelimit_secs_to_freedom)) if 'remaining' in ratelimit_user_results:
if hasattr(request, '_ratelimit_remaining'): response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining'])
response['X-RateLimit-Remaining'] = str(request._ratelimit_remaining)
return response return response
def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]: def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]:
if isinstance(exception, RateLimited): if isinstance(exception, RateLimited):
entity_type = str(exception) # entity type is passed to RateLimited when raising
resp = json_error( resp = json_error(
_("API usage exceeded rate limit"), _("API usage exceeded rate limit"),
data={'retry-after': request._ratelimit_secs_to_freedom}, data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']},
status=429 status=429
) )
resp['Retry-After'] = request._ratelimit_secs_to_freedom resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom']
return resp return resp
return None return None