rate_limiter: Store data in request._ratelimits_applied list.

The information used to be stored in a request._ratelimit dict, but
there's no need for that, and a list is a simpler structure, so this
allows us to simplify the plumbing somewhat.
This commit is contained in:
Mateusz Mandera
2020-04-01 13:31:20 +02:00
committed by Tim Abbott
parent 9911c6a0f0
commit e86cfbdbd7
3 changed files with 13 additions and 11 deletions

View File

@@ -41,14 +41,13 @@ class RateLimitedObject(ABC):
def rate_limit_request(self, request: HttpRequest) -> None: def rate_limit_request(self, request: HttpRequest) -> None:
ratelimited, time = self.rate_limit() ratelimited, time = self.rate_limit()
entity_type = type(self).__name__ if not hasattr(request, '_ratelimits_applied'):
if not hasattr(request, '_ratelimit'): request._ratelimits_applied = []
request._ratelimit = {} request._ratelimits_applied.append(RateLimitResult(
request._ratelimit[entity_type] = RateLimitResult(
entity=self, entity=self,
secs_to_freedom=time, secs_to_freedom=time,
over_limit=ratelimited over_limit=ratelimited
) ))
# Abort this request if the user is over their rate limits # Abort this request if the user is over their rate limits
if ratelimited: if ratelimited:
# Pass information about what kind of entity got limited in the exception: # Pass information about what kind of entity got limited in the exception:
@@ -56,8 +55,8 @@ class RateLimitedObject(ABC):
calls_remaining, time_reset = self.api_calls_left() calls_remaining, time_reset = self.api_calls_left()
request._ratelimit[entity_type].remaining = calls_remaining request._ratelimits_applied[-1].remaining = calls_remaining
request._ratelimit[entity_type].secs_to_freedom = time_reset request._ratelimits_applied[-1].secs_to_freedom = time_reset
def block_access(self, seconds: int) -> None: def block_access(self, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds" "Manually blocks an entity for the desired number of seconds"

View File

@@ -370,9 +370,8 @@ class RateLimitMiddleware(MiddlewareMixin):
return response return response
# Add X-RateLimit-*** headers # Add X-RateLimit-*** headers
if hasattr(request, '_ratelimit'): if hasattr(request, '_ratelimits_applied'):
rate_limit_results = list(request._ratelimit.values()) self.set_response_headers(response, request._ratelimits_applied)
self.set_response_headers(response, rate_limit_results)
return response return response

View File

@@ -192,7 +192,11 @@ def rate_limit_authentication_by_username(request: HttpRequest, username: str) -
RateLimitedAuthenticationByUsername(username).rate_limit_request(request) RateLimitedAuthenticationByUsername(username).rate_limit_request(request)
def auth_rate_limiting_already_applied(request: HttpRequest) -> bool: def auth_rate_limiting_already_applied(request: HttpRequest) -> bool:
return hasattr(request, '_ratelimit') and 'RateLimitedAuthenticationByUsername' in request._ratelimit if not hasattr(request, '_ratelimits_applied'):
return False
return any(isinstance(r.entity, RateLimitedAuthenticationByUsername)
for r in request._ratelimits_applied)
# Django's authentication mechanism uses introspection on the various authenticate() functions # Django's authentication mechanism uses introspection on the various authenticate() functions
# defined by backends, so we need a decorator that doesn't break function signatures. # defined by backends, so we need a decorator that doesn't break function signatures.