mirror of
https://github.com/zulip/zulip.git
synced 2025-11-06 06:53:25 +00:00
rate_limiter: Handle edge case where rules list may be empty.
This commit is contained in:
committed by
Tim Abbott
parent
b577366a05
commit
5f9da3053d
@@ -39,7 +39,7 @@ class RateLimitedObject(ABC):
|
|||||||
|
|
||||||
def rate_limit(self) -> Tuple[bool, float]:
|
def rate_limit(self) -> Tuple[bool, float]:
|
||||||
# Returns (ratelimited, secs_to_freedom)
|
# Returns (ratelimited, secs_to_freedom)
|
||||||
return self.backend.rate_limit_entity(self.key(), self.rules(),
|
return self.backend.rate_limit_entity(self.key(), self.get_rules(),
|
||||||
self.max_api_calls(),
|
self.max_api_calls(),
|
||||||
self.max_api_window())
|
self.max_api_window())
|
||||||
|
|
||||||
@@ -76,11 +76,11 @@ class RateLimitedObject(ABC):
|
|||||||
|
|
||||||
def max_api_calls(self) -> int:
|
def max_api_calls(self) -> int:
|
||||||
"Returns the API rate limit for the highest limit"
|
"Returns the API rate limit for the highest limit"
|
||||||
return self.rules()[-1][1]
|
return self.get_rules()[-1][1]
|
||||||
|
|
||||||
def max_api_window(self) -> int:
|
def max_api_window(self) -> int:
|
||||||
"Returns the API time window for the highest limit"
|
"Returns the API time window for the highest limit"
|
||||||
return self.rules()[-1][0]
|
return self.get_rules()[-1][0]
|
||||||
|
|
||||||
def api_calls_left(self) -> Tuple[int, float]:
|
def api_calls_left(self) -> Tuple[int, float]:
|
||||||
"""Returns how many API calls in this range this client has, as well as when
|
"""Returns how many API calls in this range this client has, as well as when
|
||||||
@@ -89,6 +89,16 @@ class RateLimitedObject(ABC):
|
|||||||
max_calls = self.max_api_calls()
|
max_calls = self.max_api_calls()
|
||||||
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
|
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
|
||||||
|
|
||||||
|
def get_rules(self) -> List[Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
This is a simple wrapper meant to protect against having to deal with
|
||||||
|
an empty list of rules, as it would require fiddling with that special case
|
||||||
|
all around this system. "9999 max request per seconds" should be a good proxy
|
||||||
|
for "no rules".
|
||||||
|
"""
|
||||||
|
rules_list = self.rules()
|
||||||
|
return rules_list or [(1, 9999), ]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
pass
|
pass
|
||||||
@@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend):
|
|||||||
else:
|
else:
|
||||||
del cls.timestamps_blocked_until[entity_key]
|
del cls.timestamps_blocked_until[entity_key]
|
||||||
|
|
||||||
if len(rules) == 0:
|
assert rules
|
||||||
return False, 0
|
|
||||||
for time_window, max_count in rules:
|
for time_window, max_count in rules:
|
||||||
ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
|
ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
|
||||||
|
|
||||||
@@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
|
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
|
||||||
"Returns a tuple of (rate_limited, time_till_free)"
|
"Returns a tuple of (rate_limited, time_till_free)"
|
||||||
|
assert rules
|
||||||
list_key, set_key, blocking_key = cls.get_keys(entity_key)
|
list_key, set_key, blocking_key = cls.get_keys(entity_key)
|
||||||
|
|
||||||
# Go through the rules from shortest to longest,
|
# Go through the rules from shortest to longest,
|
||||||
@@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||||||
blocking_ttl = int(blocking_ttl_b)
|
blocking_ttl = int(blocking_ttl_b)
|
||||||
return True, blocking_ttl
|
return True, blocking_ttl
|
||||||
|
|
||||||
if len(rules) == 0:
|
|
||||||
return False, 0.0
|
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
||||||
# Check if the nth timestamp is newer than the associated rule. If so,
|
# Check if the nth timestamp is newer than the associated rule. If so,
|
||||||
@@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||||||
return False, 0.0
|
return False, 0.0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
|
def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None:
|
||||||
max_api_calls: int, max_api_window: int) -> None:
|
|
||||||
"""Increases the rate-limit for the specified entity"""
|
"""Increases the rate-limit for the specified entity"""
|
||||||
list_key, set_key, _ = cls.get_keys(entity_key)
|
list_key, set_key, _ = cls.get_keys(entity_key)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# If we have no rules, we don't store anything
|
|
||||||
if len(rules) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Start redis transaction
|
# Start redis transaction
|
||||||
with client.pipeline() as pipe:
|
with client.pipeline() as pipe:
|
||||||
count = 0
|
count = 0
|
||||||
@@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
|
cls.incr_ratelimit(entity_key, max_api_calls, max_api_window)
|
||||||
except RateLimiterLockingException:
|
except RateLimiterLockingException:
|
||||||
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
|
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
|
||||||
# rate-limit users who are hitting the API so hard we can't update our stats.
|
# rate-limit users who are hitting the API so hard we can't update our stats.
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase):
|
|||||||
self.assertEqual(expected_time_till_reset, time_till_reset)
|
self.assertEqual(expected_time_till_reset, time_till_reset)
|
||||||
|
|
||||||
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
|
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
|
||||||
longest_rule = obj.rules()[-1]
|
longest_rule = obj.get_rules()[-1]
|
||||||
max_window, max_calls = longest_rule
|
max_window, max_calls = longest_rule
|
||||||
history = self.requests_record.get(obj.key())
|
history = self.requests_record.get(obj.key())
|
||||||
if history is None:
|
if history is None:
|
||||||
@@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase):
|
|||||||
with mock.patch('time.time', return_value=(start_time + 1.01)):
|
with mock.patch('time.time', return_value=(start_time + 1.01)):
|
||||||
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
|
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
|
||||||
|
|
||||||
class RateLimitedUserTest(ZulipTestCase):
|
class RateLimitedObjectsTest(ZulipTestCase):
|
||||||
def test_user_rate_limits(self) -> None:
|
def test_user_rate_limits(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
user_profile.rate_limits = "1:3,2:4"
|
user_profile.rate_limits = "1:3,2:4"
|
||||||
obj = RateLimitedUser(user_profile)
|
obj = RateLimitedUser(user_profile)
|
||||||
|
|
||||||
self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
|
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
|
||||||
|
|
||||||
def test_add_remove_rule(self) -> None:
|
def test_add_remove_rule(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
@@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase):
|
|||||||
add_ratelimit_rule(10, 100, domain='some_new_domain')
|
add_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||||
obj = RateLimitedUser(user_profile)
|
obj = RateLimitedUser(user_profile)
|
||||||
|
|
||||||
self.assertEqual(obj.rules(), [(1, 2), ])
|
self.assertEqual(obj.get_rules(), [(1, 2), ])
|
||||||
obj.domain = 'some_new_domain'
|
obj.domain = 'some_new_domain'
|
||||||
self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
|
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
|
||||||
|
|
||||||
remove_ratelimit_rule(10, 100, domain='some_new_domain')
|
remove_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||||
self.assertEqual(obj.rules(), [(4, 5), ])
|
self.assertEqual(obj.get_rules(), [(4, 5), ])
|
||||||
|
|
||||||
|
def test_empty_rules_edge_case(self) -> None:
|
||||||
|
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
|
||||||
|
self.assertEqual(obj.get_rules(), [(1, 9999), ])
|
||||||
|
|||||||
Reference in New Issue
Block a user