mirror of
				https://github.com/zulip/zulip.git
				synced 2025-10-30 19:43:47 +00:00 
			
		
		
		
	rate_limiter: Create a general rate_limit_request_by_entity function.
This commit is contained in:
		
				
					committed by
					
						 Tim Abbott
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							22d0cd9696
						
					
				
				
					commit
					f73600c82c
				
			| @@ -22,13 +22,12 @@ from zerver.lib.queue import queue_json_publish | |||||||
| from zerver.lib.subdomains import get_subdomain, user_matches_subdomain | from zerver.lib.subdomains import get_subdomain, user_matches_subdomain | ||||||
| from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime | from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime | ||||||
| from zerver.lib.utils import statsd, is_remote_server | from zerver.lib.utils import statsd, is_remote_server | ||||||
| from zerver.lib.exceptions import RateLimited, JsonableError, ErrorCode, \ | from zerver.lib.exceptions import JsonableError, ErrorCode, \ | ||||||
|     InvalidJSONError, InvalidAPIKeyError |     InvalidJSONError, InvalidAPIKeyError | ||||||
| from zerver.lib.types import ViewFuncT | from zerver.lib.types import ViewFuncT | ||||||
| from zerver.lib.validator import to_non_negative_int | from zerver.lib.validator import to_non_negative_int | ||||||
|  |  | ||||||
| from zerver.lib.rate_limiter import rate_limit_entity, \ | from zerver.lib.rate_limiter import rate_limit_request_by_entity, RateLimitedUser | ||||||
|     api_calls_left, RateLimitedUser |  | ||||||
| from zerver.lib.request import REQ, has_request_variables | from zerver.lib.request import REQ, has_request_variables | ||||||
|  |  | ||||||
| from functools import wraps | from functools import wraps | ||||||
| @@ -746,18 +745,7 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non | |||||||
|     the rate limit information""" |     the rate limit information""" | ||||||
|  |  | ||||||
|     entity = RateLimitedUser(user, domain=domain) |     entity = RateLimitedUser(user, domain=domain) | ||||||
|     ratelimited, time = rate_limit_entity(entity) |     rate_limit_request_by_entity(request, entity) | ||||||
|     request._ratelimit_applied_limits = True |  | ||||||
|     request._ratelimit_secs_to_freedom = time |  | ||||||
|     request._ratelimit_over_limit = ratelimited |  | ||||||
|     # Abort this request if the user is over their rate limits |  | ||||||
|     if ratelimited: |  | ||||||
|         raise RateLimited() |  | ||||||
|  |  | ||||||
|     calls_remaining, time_reset = api_calls_left(entity) |  | ||||||
|  |  | ||||||
|     request._ratelimit_remaining = calls_remaining |  | ||||||
|     request._ratelimit_secs_to_freedom = time_reset |  | ||||||
|  |  | ||||||
| def rate_limit(domain: str='all') -> Callable[[ViewFuncT], ViewFuncT]: | def rate_limit(domain: str='all') -> Callable[[ViewFuncT], ViewFuncT]: | ||||||
|     """Rate-limits a view. Takes an optional 'domain' param if you wish to |     """Rate-limits a view. Takes an optional 'domain' param if you wish to | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ import os | |||||||
| from typing import List, Optional, Tuple | from typing import List, Optional, Tuple | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  | from django.http import HttpRequest | ||||||
|  | from zerver.lib.exceptions import RateLimited | ||||||
| from zerver.lib.redis_utils import get_redis_client | from zerver.lib.redis_utils import get_redis_client | ||||||
| from zerver.lib.utils import statsd | from zerver.lib.utils import statsd | ||||||
|  |  | ||||||
| @@ -261,3 +263,23 @@ def rate_limit_entity(entity: RateLimitedObject) -> Tuple[bool, float]: | |||||||
|             ratelimited = True |             ratelimited = True | ||||||
|  |  | ||||||
|     return ratelimited, time |     return ratelimited, time | ||||||
|  |  | ||||||
|  | def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject) -> None: | ||||||
|  |     ratelimited, time = rate_limit_entity(entity) | ||||||
|  |  | ||||||
|  |     entity_type = type(entity).__name__ | ||||||
|  |     if not hasattr(request, '_ratelimit'): | ||||||
|  |         request._ratelimit = {} | ||||||
|  |     request._ratelimit[entity_type] = {} | ||||||
|  |     request._ratelimit[entity_type]['applied_limits'] = True | ||||||
|  |     request._ratelimit[entity_type]['secs_to_freedom'] = time | ||||||
|  |     request._ratelimit[entity_type]['over_limit'] = ratelimited | ||||||
|  |     # Abort this request if the user is over their rate limits | ||||||
|  |     if ratelimited: | ||||||
|  |         # Pass information about what kind of entity got limited in the exception: | ||||||
|  |         raise RateLimited(entity_type) | ||||||
|  |  | ||||||
|  |     calls_remaining, time_reset = api_calls_left(entity) | ||||||
|  |  | ||||||
|  |     request._ratelimit[entity_type]['remaining'] = calls_remaining | ||||||
|  |     request._ratelimit[entity_type]['secs_to_freedom'] = time_reset | ||||||
|   | |||||||
| @@ -326,23 +326,25 @@ class RateLimitMiddleware(MiddlewareMixin): | |||||||
|  |  | ||||||
|         from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser |         from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser | ||||||
|         # Add X-RateLimit-*** headers |         # Add X-RateLimit-*** headers | ||||||
|         if hasattr(request, '_ratelimit_applied_limits'): |         if hasattr(request, '_ratelimit'): | ||||||
|  |             # Right now, the only kind of limiting requests is user-based. | ||||||
|  |             ratelimit_user_results = request._ratelimit['RateLimitedUser'] | ||||||
|             entity = RateLimitedUser(request.user) |             entity = RateLimitedUser(request.user) | ||||||
|             response['X-RateLimit-Limit'] = str(max_api_calls(entity)) |             response['X-RateLimit-Limit'] = str(max_api_calls(entity)) | ||||||
|             if hasattr(request, '_ratelimit_secs_to_freedom'): |             response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom'])) | ||||||
|                 response['X-RateLimit-Reset'] = str(int(time.time() + request._ratelimit_secs_to_freedom)) |             if 'remaining' in ratelimit_user_results: | ||||||
|             if hasattr(request, '_ratelimit_remaining'): |                 response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining']) | ||||||
|                 response['X-RateLimit-Remaining'] = str(request._ratelimit_remaining) |  | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]: |     def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]: | ||||||
|         if isinstance(exception, RateLimited): |         if isinstance(exception, RateLimited): | ||||||
|  |             entity_type = str(exception)  # entity type is passed to RateLimited when raising | ||||||
|             resp = json_error( |             resp = json_error( | ||||||
|                 _("API usage exceeded rate limit"), |                 _("API usage exceeded rate limit"), | ||||||
|                 data={'retry-after': request._ratelimit_secs_to_freedom}, |                 data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']}, | ||||||
|                 status=429 |                 status=429 | ||||||
|             ) |             ) | ||||||
|             resp['Retry-After'] = request._ratelimit_secs_to_freedom |             resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom'] | ||||||
|             return resp |             return resp | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user