mirror of
https://github.com/zulip/zulip.git
synced 2025-11-04 14:03:30 +00:00
type().__name__ is sufficient, and much readable than type(), so it's better to use the former for keys. We also make the classes consistent in forming the keys in the format type(self).__name__:identifier and adjust logger.warning and statsd to take advantage of that and simply log the key().
370 lines
13 KiB
Python
370 lines
13 KiB
Python
import os
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from django.conf import settings
|
|
from django.http import HttpRequest
|
|
from zerver.lib.exceptions import RateLimited
|
|
from zerver.lib.redis_utils import get_redis_client
|
|
from zerver.lib.utils import statsd
|
|
|
|
from zerver.models import UserProfile
|
|
|
|
import logging
|
|
import redis
|
|
import time
|
|
|
|
# Implement a rate-limiting scheme inspired by the one described here, but heavily modified
|
|
# http://blog.domaintools.com/2013/04/rate-limiting-with-redis/
|
|
|
|
client = get_redis_client()
|
|
rules = settings.RATE_LIMITING_RULES # type: Dict[str, List[Tuple[int, int]]]
|
|
|
|
KEY_PREFIX = ''
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RateLimiterLockingException(Exception):
|
|
pass
|
|
|
|
class RateLimitedObject(ABC):
|
|
def __init__(self) -> None:
|
|
self.backend = RedisRateLimiterBackend
|
|
|
|
def get_keys(self) -> List[str]:
|
|
key = self.key()
|
|
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key, keytype)
|
|
for keytype in ['list', 'zset', 'block']]
|
|
|
|
def rate_limit(self) -> Tuple[bool, float]:
|
|
# Returns (ratelimited, secs_to_freedom)
|
|
return self.backend.rate_limit_entity(self)
|
|
|
|
def rate_limit_request(self, request: HttpRequest) -> None:
|
|
ratelimited, time = self.rate_limit()
|
|
|
|
entity_type = type(self).__name__
|
|
if not hasattr(request, '_ratelimit'):
|
|
request._ratelimit = {}
|
|
request._ratelimit[entity_type] = RateLimitResult(
|
|
entity=self,
|
|
secs_to_freedom=time,
|
|
over_limit=ratelimited
|
|
)
|
|
# Abort this request if the user is over their rate limits
|
|
if ratelimited:
|
|
# Pass information about what kind of entity got limited in the exception:
|
|
raise RateLimited(entity_type)
|
|
|
|
calls_remaining, time_reset = self.api_calls_left()
|
|
|
|
request._ratelimit[entity_type].remaining = calls_remaining
|
|
request._ratelimit[entity_type].secs_to_freedom = time_reset
|
|
|
|
def block_access(self, seconds: int) -> None:
|
|
"Manually blocks an entity for the desired number of seconds"
|
|
self.backend.block_access(self, seconds)
|
|
|
|
def unblock_access(self) -> None:
|
|
self.backend.unblock_access(self)
|
|
|
|
def clear_history(self) -> None:
|
|
'''
|
|
This is only used by test code now, where it's very helpful in
|
|
allowing us to run tests quickly, by giving a user a clean slate.
|
|
'''
|
|
self.backend.clear_history(self)
|
|
|
|
def max_api_calls(self) -> int:
|
|
"Returns the API rate limit for the highest limit"
|
|
return self.rules()[-1][1]
|
|
|
|
def max_api_window(self) -> int:
|
|
"Returns the API time window for the highest limit"
|
|
return self.rules()[-1][0]
|
|
|
|
def api_calls_left(self) -> Tuple[int, float]:
|
|
"""Returns how many API calls in this range this client has, as well as when
|
|
the rate-limit will be reset to 0"""
|
|
max_window = self.max_api_window()
|
|
max_calls = self.max_api_calls()
|
|
return self.backend.get_api_calls_left(self, max_window, max_calls)
|
|
|
|
@abstractmethod
|
|
def key(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def rules(self) -> List[Tuple[int, int]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def __str__(self) -> str:
|
|
pass
|
|
|
|
class RateLimitedUser(RateLimitedObject):
|
|
def __init__(self, user: UserProfile, domain: str='api_by_user') -> None:
|
|
self.user = user
|
|
self.domain = domain
|
|
super().__init__()
|
|
|
|
def __str__(self) -> str:
|
|
return "Id: {}".format(self.user.id)
|
|
|
|
def key(self) -> str:
|
|
return "{}:{}:{}".format(type(self).__name__, self.user.id, self.domain)
|
|
|
|
def rules(self) -> List[Tuple[int, int]]:
|
|
# user.rate_limits are general limits, applicable to the domain 'api_by_user'
|
|
if self.user.rate_limits != "" and self.domain == 'api_by_user':
|
|
result = [] # type: List[Tuple[int, int]]
|
|
for limit in self.user.rate_limits.split(','):
|
|
(seconds, requests) = limit.split(':', 2)
|
|
result.append((int(seconds), int(requests)))
|
|
return result
|
|
return rules[self.domain]
|
|
|
|
def bounce_redis_key_prefix_for_testing(test_name: str) -> None:
|
|
global KEY_PREFIX
|
|
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
|
|
|
|
def add_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='api_by_user') -> None:
|
|
"Add a rate-limiting rule to the ratelimiter"
|
|
global rules
|
|
|
|
if domain not in rules:
|
|
# If we don't have any rules for domain yet, the domain key needs to be
|
|
# added to the rules dictionary.
|
|
rules[domain] = []
|
|
|
|
rules[domain].append((range_seconds, num_requests))
|
|
rules[domain].sort(key=lambda x: x[0])
|
|
|
|
def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='api_by_user') -> None:
|
|
global rules
|
|
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
|
|
|
|
class RateLimiterBackend(ABC):
|
|
@classmethod
|
|
@abstractmethod
|
|
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
|
"Manually blocks an entity for the desired number of seconds"
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def clear_history(cls, entity: RateLimitedObject) -> None:
|
|
'''
|
|
This is only used by test code now, where it's very helpful in
|
|
allowing us to run tests quickly, by giving a user a clean slate.
|
|
'''
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
|
max_calls: int) -> Tuple[int, float]:
|
|
pass
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
|
# Returns (ratelimited, secs_to_freedom)
|
|
pass
|
|
|
|
class RedisRateLimiterBackend(RateLimiterBackend):
|
|
@classmethod
|
|
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
|
"Manually blocks an entity for the desired number of seconds"
|
|
_, _, blocking_key = entity.get_keys()
|
|
with client.pipeline() as pipe:
|
|
pipe.set(blocking_key, 1)
|
|
pipe.expire(blocking_key, seconds)
|
|
pipe.execute()
|
|
|
|
@classmethod
|
|
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
|
_, _, blocking_key = entity.get_keys()
|
|
client.delete(blocking_key)
|
|
|
|
@classmethod
|
|
def clear_history(cls, entity: RateLimitedObject) -> None:
|
|
'''
|
|
This is only used by test code now, where it's very helpful in
|
|
allowing us to run tests quickly, by giving a user a clean slate.
|
|
'''
|
|
for key in entity.get_keys():
|
|
client.delete(key)
|
|
|
|
@classmethod
|
|
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
|
max_calls: int) -> Tuple[int, float]:
|
|
list_key, set_key, _ = entity.get_keys()
|
|
# Count the number of values in our sorted set
|
|
# that are between now and the cutoff
|
|
now = time.time()
|
|
boundary = now - range_seconds
|
|
|
|
with client.pipeline() as pipe:
|
|
# Count how many API calls in our range have already been made
|
|
pipe.zcount(set_key, boundary, now)
|
|
# Get the newest call so we can calculate when the ratelimit
|
|
# will reset to 0
|
|
pipe.lindex(list_key, 0)
|
|
|
|
results = pipe.execute()
|
|
|
|
count = results[0] # type: int
|
|
newest_call = results[1] # type: Optional[bytes]
|
|
|
|
calls_left = max_calls - count
|
|
if newest_call is not None:
|
|
time_reset = now + (range_seconds - (now - float(newest_call)))
|
|
else:
|
|
time_reset = now
|
|
|
|
return calls_left, time_reset
|
|
|
|
@classmethod
|
|
def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
|
"Returns a tuple of (rate_limited, time_till_free)"
|
|
list_key, set_key, blocking_key = entity.get_keys()
|
|
|
|
rules = entity.rules()
|
|
|
|
if len(rules) == 0:
|
|
return False, 0.0
|
|
|
|
# Go through the rules from shortest to longest,
|
|
# seeing if this user has violated any of them. First
|
|
# get the timestamps for each nth items
|
|
with client.pipeline() as pipe:
|
|
for _, request_count in rules:
|
|
pipe.lindex(list_key, request_count - 1) # 0-indexed list
|
|
|
|
# Get blocking info
|
|
pipe.get(blocking_key)
|
|
pipe.ttl(blocking_key)
|
|
|
|
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
|
|
|
|
# Check if there is a manual block on this API key
|
|
blocking_ttl_b = rule_timestamps.pop()
|
|
key_blocked = rule_timestamps.pop()
|
|
|
|
if key_blocked is not None:
|
|
# We are manually blocked. Report for how much longer we will be
|
|
if blocking_ttl_b is None:
|
|
blocking_ttl = 0.5
|
|
else:
|
|
blocking_ttl = int(blocking_ttl_b)
|
|
return True, blocking_ttl
|
|
|
|
now = time.time()
|
|
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
|
# Check if the nth timestamp is newer than the associated rule. If so,
|
|
# it means we've hit our limit for this rule
|
|
if timestamp is None:
|
|
continue
|
|
|
|
boundary = float(timestamp) + range_seconds
|
|
if boundary > now:
|
|
free = boundary - now
|
|
return True, free
|
|
|
|
# No api calls recorded yet
|
|
return False, 0.0
|
|
|
|
@classmethod
|
|
def incr_ratelimit(cls, entity: RateLimitedObject) -> None:
|
|
"""Increases the rate-limit for the specified entity"""
|
|
list_key, set_key, _ = entity.get_keys()
|
|
now = time.time()
|
|
|
|
# If we have no rules, we don't store anything
|
|
if len(rules) == 0:
|
|
return
|
|
|
|
# Start redis transaction
|
|
with client.pipeline() as pipe:
|
|
count = 0
|
|
while True:
|
|
try:
|
|
# To avoid a race condition between getting the element we might trim from our list
|
|
# and removing it from our associated set, we abort this whole transaction if
|
|
# another agent manages to change our list out from under us
|
|
# When watching a value, the pipeline is set to Immediate mode
|
|
pipe.watch(list_key)
|
|
|
|
# Get the last elem that we'll trim (so we can remove it from our sorted set)
|
|
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
|
|
|
|
# Restart buffered execution
|
|
pipe.multi()
|
|
|
|
# Add this timestamp to our list
|
|
pipe.lpush(list_key, now)
|
|
|
|
# Trim our list to the oldest rule we have
|
|
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
|
|
|
|
# Add our new value to the sorted set that we keep
|
|
# We need to put the score and val both as timestamp,
|
|
# as we sort by score but remove by value
|
|
pipe.zadd(set_key, {str(now): now})
|
|
|
|
# Remove the trimmed value from our sorted set, if there was one
|
|
if last_val is not None:
|
|
pipe.zrem(set_key, last_val)
|
|
|
|
# Set the TTL for our keys as well
|
|
api_window = entity.max_api_window()
|
|
pipe.expire(list_key, api_window)
|
|
pipe.expire(set_key, api_window)
|
|
|
|
pipe.execute()
|
|
|
|
# If no exception was raised in the execution, there were no transaction conflicts
|
|
break
|
|
except redis.WatchError:
|
|
if count > 10:
|
|
raise RateLimiterLockingException()
|
|
count += 1
|
|
|
|
continue
|
|
|
|
@classmethod
|
|
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
|
ratelimited, time = cls.is_ratelimited(entity)
|
|
|
|
if ratelimited:
|
|
statsd.incr("ratelimiter.limited.%s" % (entity.key(),))
|
|
|
|
else:
|
|
try:
|
|
cls.incr_ratelimit(entity)
|
|
except RateLimiterLockingException:
|
|
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity.key(),))
|
|
# rate-limit users who are hitting the API so hard we can't update our stats.
|
|
ratelimited = True
|
|
|
|
return ratelimited, time
|
|
|
|
class RateLimitResult:
|
|
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
|
|
remaining: Optional[int]=None) -> None:
|
|
if over_limit:
|
|
assert not remaining
|
|
|
|
self.entity = entity
|
|
self.secs_to_freedom = secs_to_freedom
|
|
self.over_limit = over_limit
|
|
self.remaining = remaining
|
|
|
|
def get_rate_limit_result_from_request(request: HttpRequest, entity_type: str) -> RateLimitResult:
|
|
return request._ratelimit[entity_type]
|