mirror of
				https://github.com/zulip/zulip.git
				synced 2025-10-31 03:53:50 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
 | |
| import os
 | |
| 
 | |
| from typing import Any, Iterator, List, Optional, Tuple, Text
 | |
| 
 | |
| from django.conf import settings
 | |
| from zerver.lib.redis_utils import get_redis_client
 | |
| 
 | |
| from zerver.models import UserProfile
 | |
| 
 | |
| import redis
 | |
| import time
 | |
| import logging
 | |
| 
 | |
| # Implement a rate-limiting scheme inspired by the one described here, but heavily modified
 | |
| # http://blog.domaintools.com/2013/04/rate-limiting-with-redis/
 | |
| 
 | |
| client = get_redis_client()
 | |
| rules = settings.RATE_LIMITING_RULES  # type: List[Tuple[int, int]]
 | |
| 
 | |
| KEY_PREFIX = ''
 | |
| 
 | |
| class RateLimitedObject:
 | |
|     def get_keys(self):
 | |
|         # type: () -> List[Text]
 | |
|         key_fragment = self.key_fragment()
 | |
|         return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
 | |
|                 for keytype in ['list', 'zset', 'block']]
 | |
| 
 | |
|     def key_fragment(self):
 | |
|         # type: () -> Text
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     def rules(self):
 | |
|         # type: () -> List[Tuple[int, int]]
 | |
|         raise NotImplementedError()
 | |
| 
 | |
| class RateLimitedUser(RateLimitedObject):
 | |
|     def __init__(self, user, domain='all'):
 | |
|         # type: (UserProfile, Text) -> None
 | |
|         self.user = user
 | |
|         self.domain = domain
 | |
| 
 | |
|     def key_fragment(self):
 | |
|         # type: () -> Text
 | |
|         return "{}:{}:{}".format(type(self.user), self.user.id, self.domain)
 | |
| 
 | |
|     def rules(self):
 | |
|         # type: () -> List[Tuple[int, int]]
 | |
|         if self.user.rate_limits != "":
 | |
|             result = []  # type: List[Tuple[int, int]]
 | |
|             for limit in self.user.rate_limits.split(','):
 | |
|                 (seconds, requests) = limit.split(':', 2)
 | |
|                 result.append((int(seconds), int(requests)))
 | |
|             return result
 | |
|         return rules
 | |
| 
 | |
| def bounce_redis_key_prefix_for_testing(test_name):
 | |
|     # type: (Text) -> None
 | |
|     global KEY_PREFIX
 | |
|     KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':'
 | |
| 
 | |
| def max_api_calls(entity):
 | |
|     # type: (RateLimitedObject) -> int
 | |
|     "Returns the API rate limit for the highest limit"
 | |
|     return entity.rules()[-1][1]
 | |
| 
 | |
| def max_api_window(entity):
 | |
|     # type: (RateLimitedObject) -> int
 | |
|     "Returns the API time window for the highest limit"
 | |
|     return entity.rules()[-1][0]
 | |
| 
 | |
| def add_ratelimit_rule(range_seconds, num_requests):
 | |
|     # type: (int , int) -> None
 | |
|     "Add a rate-limiting rule to the ratelimiter"
 | |
|     global rules
 | |
| 
 | |
|     rules.append((range_seconds, num_requests))
 | |
|     rules.sort(key=lambda x: x[0])
 | |
| 
 | |
| def remove_ratelimit_rule(range_seconds, num_requests):
 | |
|     # type: (int , int) -> None
 | |
|     global rules
 | |
|     rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests]
 | |
| 
 | |
| def block_access(entity, seconds):
 | |
|     # type: (RateLimitedObject, int) -> None
 | |
|     "Manually blocks an entity for the desired number of seconds"
 | |
|     _, _, blocking_key = entity.get_keys()
 | |
|     with client.pipeline() as pipe:
 | |
|         pipe.set(blocking_key, 1)
 | |
|         pipe.expire(blocking_key, seconds)
 | |
|         pipe.execute()
 | |
| 
 | |
| def unblock_access(entity):
 | |
|     # type: (RateLimitedObject) -> None
 | |
|     _, _, blocking_key = entity.get_keys()
 | |
|     client.delete(blocking_key)
 | |
| 
 | |
| def clear_history(entity):
 | |
|     # type: (RateLimitedObject) -> None
 | |
|     '''
 | |
|     This is only used by test code now, where it's very helpful in
 | |
|     allowing us to run tests quickly, by giving a user a clean slate.
 | |
|     '''
 | |
|     for key in entity.get_keys():
 | |
|         client.delete(key)
 | |
| 
 | |
| def _get_api_calls_left(entity, range_seconds, max_calls):
 | |
|     # type: (RateLimitedObject, int, int) -> Tuple[int, float]
 | |
|     list_key, set_key, _ = entity.get_keys()
 | |
|     # Count the number of values in our sorted set
 | |
|     # that are between now and the cutoff
 | |
|     now = time.time()
 | |
|     boundary = now - range_seconds
 | |
| 
 | |
|     with client.pipeline() as pipe:
 | |
|         # Count how many API calls in our range have already been made
 | |
|         pipe.zcount(set_key, boundary, now)
 | |
|         # Get the newest call so we can calculate when the ratelimit
 | |
|         # will reset to 0
 | |
|         pipe.lindex(list_key, 0)
 | |
| 
 | |
|         results = pipe.execute()
 | |
| 
 | |
|     count = results[0]  # type: int
 | |
|     newest_call = results[1]  # type: Optional[bytes]
 | |
| 
 | |
|     calls_left = max_calls - count
 | |
|     if newest_call is not None:
 | |
|         time_reset = now + (range_seconds - (now - float(newest_call)))
 | |
|     else:
 | |
|         time_reset = now
 | |
| 
 | |
|     return calls_left, time_reset
 | |
| 
 | |
| def api_calls_left(entity):
 | |
|     # type: (RateLimitedObject) -> Tuple[int, float]
 | |
|     """Returns how many API calls in this range this client has, as well as when
 | |
|        the rate-limit will be reset to 0"""
 | |
|     max_window = max_api_window(entity)
 | |
|     max_calls = max_api_calls(entity)
 | |
|     return _get_api_calls_left(entity, max_window, max_calls)
 | |
| 
 | |
| def is_ratelimited(entity):
 | |
|     # type: (RateLimitedObject) -> Tuple[bool, float]
 | |
|     "Returns a tuple of (rate_limited, time_till_free)"
 | |
|     list_key, set_key, blocking_key = entity.get_keys()
 | |
| 
 | |
|     rules = entity.rules()
 | |
| 
 | |
|     if len(rules) == 0:
 | |
|         return False, 0.0
 | |
| 
 | |
|     # Go through the rules from shortest to longest,
 | |
|     # seeing if this user has violated any of them. First
 | |
|     # get the timestamps for each nth items
 | |
|     with client.pipeline() as pipe:
 | |
|         for _, request_count in rules:
 | |
|             pipe.lindex(list_key, request_count - 1)  # 0-indexed list
 | |
| 
 | |
|         # Get blocking info
 | |
|         pipe.get(blocking_key)
 | |
|         pipe.ttl(blocking_key)
 | |
| 
 | |
|         rule_timestamps = pipe.execute()  # type: List[Optional[bytes]]
 | |
| 
 | |
|     # Check if there is a manual block on this API key
 | |
|     blocking_ttl_b = rule_timestamps.pop()
 | |
|     key_blocked = rule_timestamps.pop()
 | |
| 
 | |
|     if key_blocked is not None:
 | |
|         # We are manually blocked. Report for how much longer we will be
 | |
|         if blocking_ttl_b is None:
 | |
|             blocking_ttl = 0.5
 | |
|         else:
 | |
|             blocking_ttl = int(blocking_ttl_b)
 | |
|         return True, blocking_ttl
 | |
| 
 | |
|     now = time.time()
 | |
|     for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
 | |
|         # Check if the nth timestamp is newer than the associated rule. If so,
 | |
|         # it means we've hit our limit for this rule
 | |
|         if timestamp is None:
 | |
|             continue
 | |
| 
 | |
|         boundary = float(timestamp) + range_seconds
 | |
|         if boundary > now:
 | |
|             free = boundary - now
 | |
|             return True, free
 | |
| 
 | |
|     # No api calls recorded yet
 | |
|     return False, 0.0
 | |
| 
 | |
| def incr_ratelimit(entity):
 | |
|     # type: (RateLimitedObject) -> None
 | |
|     """Increases the rate-limit for the specified entity"""
 | |
|     list_key, set_key, _ = entity.get_keys()
 | |
|     now = time.time()
 | |
| 
 | |
|     # If we have no rules, we don't store anything
 | |
|     if len(rules) == 0:
 | |
|         return
 | |
| 
 | |
|     # Start redis transaction
 | |
|     with client.pipeline() as pipe:
 | |
|         count = 0
 | |
|         while True:
 | |
|             try:
 | |
|                 # To avoid a race condition between getting the element we might trim from our list
 | |
|                 # and removing it from our associated set, we abort this whole transaction if
 | |
|                 # another agent manages to change our list out from under us
 | |
|                 # When watching a value, the pipeline is set to Immediate mode
 | |
|                 pipe.watch(list_key)
 | |
| 
 | |
|                 # Get the last elem that we'll trim (so we can remove it from our sorted set)
 | |
|                 last_val = pipe.lindex(list_key, max_api_calls(entity) - 1)
 | |
| 
 | |
|                 # Restart buffered execution
 | |
|                 pipe.multi()
 | |
| 
 | |
|                 # Add this timestamp to our list
 | |
|                 pipe.lpush(list_key, now)
 | |
| 
 | |
|                 # Trim our list to the oldest rule we have
 | |
|                 pipe.ltrim(list_key, 0, max_api_calls(entity) - 1)
 | |
| 
 | |
|                 # Add our new value to the sorted set that we keep
 | |
|                 # We need to put the score and val both as timestamp,
 | |
|                 # as we sort by score but remove by value
 | |
|                 pipe.zadd(set_key, now, now)
 | |
| 
 | |
|                 # Remove the trimmed value from our sorted set, if there was one
 | |
|                 if last_val is not None:
 | |
|                     pipe.zrem(set_key, last_val)
 | |
| 
 | |
|                 # Set the TTL for our keys as well
 | |
|                 api_window = max_api_window(entity)
 | |
|                 pipe.expire(list_key, api_window)
 | |
|                 pipe.expire(set_key, api_window)
 | |
| 
 | |
|                 pipe.execute()
 | |
| 
 | |
|                 # If no exception was raised in the execution, there were no transaction conflicts
 | |
|                 break
 | |
|             except redis.WatchError:
 | |
|                 if count > 10:
 | |
|                     logging.error("Failed to complete incr_ratelimit transaction without "
 | |
|                                   "interference 10 times in a row! Aborting rate-limit increment")
 | |
|                     break
 | |
|                 count += 1
 | |
| 
 | |
|                 continue
 |