mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	rate_limiter: Handle edge case where rules list may be empty.
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							b577366a05
						
					
				
				
					commit
					5f9da3053d
				
			@@ -39,7 +39,7 @@ class RateLimitedObject(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def rate_limit(self) -> Tuple[bool, float]:
 | 
					    def rate_limit(self) -> Tuple[bool, float]:
 | 
				
			||||||
        # Returns (ratelimited, secs_to_freedom)
 | 
					        # Returns (ratelimited, secs_to_freedom)
 | 
				
			||||||
        return self.backend.rate_limit_entity(self.key(), self.rules(),
 | 
					        return self.backend.rate_limit_entity(self.key(), self.get_rules(),
 | 
				
			||||||
                                              self.max_api_calls(),
 | 
					                                              self.max_api_calls(),
 | 
				
			||||||
                                              self.max_api_window())
 | 
					                                              self.max_api_window())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -76,11 +76,11 @@ class RateLimitedObject(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def max_api_calls(self) -> int:
 | 
					    def max_api_calls(self) -> int:
 | 
				
			||||||
        "Returns the API rate limit for the highest limit"
 | 
					        "Returns the API rate limit for the highest limit"
 | 
				
			||||||
        return self.rules()[-1][1]
 | 
					        return self.get_rules()[-1][1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def max_api_window(self) -> int:
 | 
					    def max_api_window(self) -> int:
 | 
				
			||||||
        "Returns the API time window for the highest limit"
 | 
					        "Returns the API time window for the highest limit"
 | 
				
			||||||
        return self.rules()[-1][0]
 | 
					        return self.get_rules()[-1][0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def api_calls_left(self) -> Tuple[int, float]:
 | 
					    def api_calls_left(self) -> Tuple[int, float]:
 | 
				
			||||||
        """Returns how many API calls in this range this client has, as well as when
 | 
					        """Returns how many API calls in this range this client has, as well as when
 | 
				
			||||||
@@ -89,6 +89,16 @@ class RateLimitedObject(ABC):
 | 
				
			|||||||
        max_calls = self.max_api_calls()
 | 
					        max_calls = self.max_api_calls()
 | 
				
			||||||
        return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
 | 
					        return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_rules(self) -> List[Tuple[int, int]]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        This is a simple wrapper meant to protect against having to deal with
 | 
				
			||||||
 | 
					        an empty list of rules, as it would require fiddling with that special case
 | 
				
			||||||
 | 
					        all around this system. "9999 max request per seconds" should be a good proxy
 | 
				
			||||||
 | 
					        for "no rules".
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        rules_list = self.rules()
 | 
				
			||||||
 | 
					        return rules_list or [(1, 9999), ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def key(self) -> str:
 | 
					    def key(self) -> str:
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
@@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend):
 | 
				
			|||||||
            else:
 | 
					            else:
 | 
				
			||||||
                del cls.timestamps_blocked_until[entity_key]
 | 
					                del cls.timestamps_blocked_until[entity_key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(rules) == 0:
 | 
					        assert rules
 | 
				
			||||||
            return False, 0
 | 
					 | 
				
			||||||
        for time_window, max_count in rules:
 | 
					        for time_window, max_count in rules:
 | 
				
			||||||
            ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
 | 
					            ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
 | 
				
			|||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
 | 
					    def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
 | 
				
			||||||
        "Returns a tuple of (rate_limited, time_till_free)"
 | 
					        "Returns a tuple of (rate_limited, time_till_free)"
 | 
				
			||||||
 | 
					        assert rules
 | 
				
			||||||
        list_key, set_key, blocking_key = cls.get_keys(entity_key)
 | 
					        list_key, set_key, blocking_key = cls.get_keys(entity_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Go through the rules from shortest to longest,
 | 
					        # Go through the rules from shortest to longest,
 | 
				
			||||||
@@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend):
 | 
				
			|||||||
                blocking_ttl = int(blocking_ttl_b)
 | 
					                blocking_ttl = int(blocking_ttl_b)
 | 
				
			||||||
            return True, blocking_ttl
 | 
					            return True, blocking_ttl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(rules) == 0:
 | 
					 | 
				
			||||||
            return False, 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        now = time.time()
 | 
					        now = time.time()
 | 
				
			||||||
        for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
 | 
					        for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
 | 
				
			||||||
            # Check if the nth timestamp is newer than the associated rule. If so,
 | 
					            # Check if the nth timestamp is newer than the associated rule. If so,
 | 
				
			||||||
@@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend):
 | 
				
			|||||||
        return False, 0.0
 | 
					        return False, 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
 | 
					    def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None:
 | 
				
			||||||
                       max_api_calls: int, max_api_window: int) -> None:
 | 
					 | 
				
			||||||
        """Increases the rate-limit for the specified entity"""
 | 
					        """Increases the rate-limit for the specified entity"""
 | 
				
			||||||
        list_key, set_key, _ = cls.get_keys(entity_key)
 | 
					        list_key, set_key, _ = cls.get_keys(entity_key)
 | 
				
			||||||
        now = time.time()
 | 
					        now = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # If we have no rules, we don't store anything
 | 
					 | 
				
			||||||
        if len(rules) == 0:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Start redis transaction
 | 
					        # Start redis transaction
 | 
				
			||||||
        with client.pipeline() as pipe:
 | 
					        with client.pipeline() as pipe:
 | 
				
			||||||
            count = 0
 | 
					            count = 0
 | 
				
			||||||
@@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
 | 
					                cls.incr_ratelimit(entity_key, max_api_calls, max_api_window)
 | 
				
			||||||
            except RateLimiterLockingException:
 | 
					            except RateLimiterLockingException:
 | 
				
			||||||
                logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
 | 
					                logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
 | 
				
			||||||
                # rate-limit users who are hitting the API so hard we can't update our stats.
 | 
					                # rate-limit users who are hitting the API so hard we can't update our stats.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase):
 | 
				
			|||||||
        self.assertEqual(expected_time_till_reset, time_till_reset)
 | 
					        self.assertEqual(expected_time_till_reset, time_till_reset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
 | 
					    def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
 | 
				
			||||||
        longest_rule = obj.rules()[-1]
 | 
					        longest_rule = obj.get_rules()[-1]
 | 
				
			||||||
        max_window, max_calls = longest_rule
 | 
					        max_window, max_calls = longest_rule
 | 
				
			||||||
        history = self.requests_record.get(obj.key())
 | 
					        history = self.requests_record.get(obj.key())
 | 
				
			||||||
        if history is None:
 | 
					        if history is None:
 | 
				
			||||||
@@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase):
 | 
				
			|||||||
        with mock.patch('time.time', return_value=(start_time + 1.01)):
 | 
					        with mock.patch('time.time', return_value=(start_time + 1.01)):
 | 
				
			||||||
            self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
 | 
					            self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RateLimitedUserTest(ZulipTestCase):
 | 
					class RateLimitedObjectsTest(ZulipTestCase):
 | 
				
			||||||
    def test_user_rate_limits(self) -> None:
 | 
					    def test_user_rate_limits(self) -> None:
 | 
				
			||||||
        user_profile = self.example_user("hamlet")
 | 
					        user_profile = self.example_user("hamlet")
 | 
				
			||||||
        user_profile.rate_limits = "1:3,2:4"
 | 
					        user_profile.rate_limits = "1:3,2:4"
 | 
				
			||||||
        obj = RateLimitedUser(user_profile)
 | 
					        obj = RateLimitedUser(user_profile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
 | 
					        self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_add_remove_rule(self) -> None:
 | 
					    def test_add_remove_rule(self) -> None:
 | 
				
			||||||
        user_profile = self.example_user("hamlet")
 | 
					        user_profile = self.example_user("hamlet")
 | 
				
			||||||
@@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase):
 | 
				
			|||||||
        add_ratelimit_rule(10, 100, domain='some_new_domain')
 | 
					        add_ratelimit_rule(10, 100, domain='some_new_domain')
 | 
				
			||||||
        obj = RateLimitedUser(user_profile)
 | 
					        obj = RateLimitedUser(user_profile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(obj.rules(), [(1, 2), ])
 | 
					        self.assertEqual(obj.get_rules(), [(1, 2), ])
 | 
				
			||||||
        obj.domain = 'some_new_domain'
 | 
					        obj.domain = 'some_new_domain'
 | 
				
			||||||
        self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
 | 
					        self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        remove_ratelimit_rule(10, 100, domain='some_new_domain')
 | 
					        remove_ratelimit_rule(10, 100, domain='some_new_domain')
 | 
				
			||||||
        self.assertEqual(obj.rules(), [(4, 5), ])
 | 
					        self.assertEqual(obj.get_rules(), [(4, 5), ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_empty_rules_edge_case(self) -> None:
 | 
				
			||||||
 | 
					        obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
 | 
				
			||||||
 | 
					        self.assertEqual(obj.get_rules(), [(1, 9999), ])
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user