rate_limiter: Handle edge case where rules list may be empty.

This commit is contained in:
Mateusz Mandera
2020-04-02 22:23:20 +02:00
committed by Tim Abbott
parent b577366a05
commit 5f9da3053d
2 changed files with 27 additions and 21 deletions

View File

@@ -39,7 +39,7 @@ class RateLimitedObject(ABC):
def rate_limit(self) -> Tuple[bool, float]: def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom) # Returns (ratelimited, secs_to_freedom)
return self.backend.rate_limit_entity(self.key(), self.rules(), return self.backend.rate_limit_entity(self.key(), self.get_rules(),
self.max_api_calls(), self.max_api_calls(),
self.max_api_window()) self.max_api_window())
@@ -76,11 +76,11 @@ class RateLimitedObject(ABC):
def max_api_calls(self) -> int: def max_api_calls(self) -> int:
"Returns the API rate limit for the highest limit" "Returns the API rate limit for the highest limit"
return self.rules()[-1][1] return self.get_rules()[-1][1]
def max_api_window(self) -> int: def max_api_window(self) -> int:
"Returns the API time window for the highest limit" "Returns the API time window for the highest limit"
return self.rules()[-1][0] return self.get_rules()[-1][0]
def api_calls_left(self) -> Tuple[int, float]: def api_calls_left(self) -> Tuple[int, float]:
"""Returns how many API calls in this range this client has, as well as when """Returns how many API calls in this range this client has, as well as when
@@ -89,6 +89,16 @@ class RateLimitedObject(ABC):
max_calls = self.max_api_calls() max_calls = self.max_api_calls()
return self.backend.get_api_calls_left(self.key(), max_window, max_calls) return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
def get_rules(self) -> List[Tuple[int, int]]:
"""
This is a simple wrapper meant to protect against having to deal with
an empty list of rules, as it would require fiddling with that special case
all around this system. "9999 max request per seconds" should be a good proxy
for "no rules".
"""
rules_list = self.rules()
return rules_list or [(1, 9999), ]
@abstractmethod @abstractmethod
def key(self) -> str: def key(self) -> str:
pass pass
@@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend):
else: else:
del cls.timestamps_blocked_until[entity_key] del cls.timestamps_blocked_until[entity_key]
if len(rules) == 0: assert rules
return False, 0
for time_window, max_count in rules: for time_window, max_count in rules:
ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count) ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
@@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
@classmethod @classmethod
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]: def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)" "Returns a tuple of (rate_limited, time_till_free)"
assert rules
list_key, set_key, blocking_key = cls.get_keys(entity_key) list_key, set_key, blocking_key = cls.get_keys(entity_key)
# Go through the rules from shortest to longest, # Go through the rules from shortest to longest,
@@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend):
blocking_ttl = int(blocking_ttl_b) blocking_ttl = int(blocking_ttl_b)
return True, blocking_ttl return True, blocking_ttl
if len(rules) == 0:
return False, 0.0
now = time.time() now = time.time()
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules): for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
# Check if the nth timestamp is newer than the associated rule. If so, # Check if the nth timestamp is newer than the associated rule. If so,
@@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend):
return False, 0.0 return False, 0.0
@classmethod @classmethod
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]], def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None:
max_api_calls: int, max_api_window: int) -> None:
"""Increases the rate-limit for the specified entity""" """Increases the rate-limit for the specified entity"""
list_key, set_key, _ = cls.get_keys(entity_key) list_key, set_key, _ = cls.get_keys(entity_key)
now = time.time() now = time.time()
# If we have no rules, we don't store anything
if len(rules) == 0:
return
# Start redis transaction # Start redis transaction
with client.pipeline() as pipe: with client.pipeline() as pipe:
count = 0 count = 0
@@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
else: else:
try: try:
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window) cls.incr_ratelimit(entity_key, max_api_calls, max_api_window)
except RateLimiterLockingException: except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,)) logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
# rate-limit users who are hitting the API so hard we can't update our stats. # rate-limit users who are hitting the API so hard we can't update our stats.

View File

@@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase):
self.assertEqual(expected_time_till_reset, time_till_reset) self.assertEqual(expected_time_till_reset, time_till_reset)
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]: def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
longest_rule = obj.rules()[-1] longest_rule = obj.get_rules()[-1]
max_window, max_calls = longest_rule max_window, max_calls = longest_rule
history = self.requests_record.get(obj.key()) history = self.requests_record.get(obj.key())
if history is None: if history is None:
@@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase):
with mock.patch('time.time', return_value=(start_time + 1.01)): with mock.patch('time.time', return_value=(start_time + 1.01)):
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False) self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
class RateLimitedUserTest(ZulipTestCase): class RateLimitedObjectsTest(ZulipTestCase):
def test_user_rate_limits(self) -> None: def test_user_rate_limits(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
user_profile.rate_limits = "1:3,2:4" user_profile.rate_limits = "1:3,2:4"
obj = RateLimitedUser(user_profile) obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 3), (2, 4)]) self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
def test_add_remove_rule(self) -> None: def test_add_remove_rule(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
@@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase):
add_ratelimit_rule(10, 100, domain='some_new_domain') add_ratelimit_rule(10, 100, domain='some_new_domain')
obj = RateLimitedUser(user_profile) obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 2), ]) self.assertEqual(obj.get_rules(), [(1, 2), ])
obj.domain = 'some_new_domain' obj.domain = 'some_new_domain'
self.assertEqual(obj.rules(), [(4, 5), (10, 100)]) self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
remove_ratelimit_rule(10, 100, domain='some_new_domain') remove_ratelimit_rule(10, 100, domain='some_new_domain')
self.assertEqual(obj.rules(), [(4, 5), ]) self.assertEqual(obj.get_rules(), [(4, 5), ])
def test_empty_rules_edge_case(self) -> None:
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
self.assertEqual(obj.get_rules(), [(1, 9999), ])