diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index a0f0ce2584..8f411c0dad 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -1,6 +1,7 @@ import logging import time from abc import ABC, abstractmethod +from ipaddress import IPv6Network, ip_network from typing import Optional, cast import orjson @@ -138,8 +139,11 @@ class RateLimitedUser(RateLimitedObject): class RateLimitedIPAddr(RateLimitedObject): - def __init__(self, ip_addr: str, domain: str = "api_by_ip") -> None: + def __init__( + self, ip_addr: str, domain: str = "api_by_ip", ipv6_network_prefix: int = 64 + ) -> None: self.ip_addr = ip_addr + self.ipv6_network_prefix = ipv6_network_prefix self.domain = domain if settings.RUNNING_INSIDE_TORNADO and domain in settings.RATE_LIMITING_DOMAINS_FOR_TORNADO: backend: type[RateLimiterBackend] | None = TornadoInMemoryRateLimiterBackend @@ -149,8 +153,20 @@ class RateLimitedIPAddr(RateLimitedObject): @override def key(self) -> str: - # The angle brackets are important since IPv6 addresses contain :. - return f"{type(self).__name__}:<{self.ip_addr}>:{self.domain}" + if self.ip_addr != "tor-exit-node" and isinstance( + network := ip_network(self.ip_addr), IPv6Network + ): + # For IPv6 we use the network portion of that IPv6. + # This essentially tells us which bucket should this IPv6 belong to. + # For example: + # The network portion of 2001:0db8:ce1:12::8a2e:0370 + # is 2001:db8:ce1:12::/64 + ip_addr_key = str(network.supernet(new_prefix=self.ipv6_network_prefix)) + else: + ip_addr_key = self.ip_addr + + # The angle brackets are important since an IPv6 address contains : + return f"{type(self).__name__}:<{ip_addr_key}>:{self.domain}" @override def rules(self) -> list[tuple[int, int]]: diff --git a/zerver/tests/test_rate_limiter.py b/zerver/tests/test_rate_limiter.py index fddc78937c..4d611690d3 100644 --- a/zerver/tests/test_rate_limiter.py +++ b/zerver/tests/test_rate_limiter.py @@ -200,18 +200,26 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase): def test_used_in_tornado(self) -> None: user_profile = self.example_user("hamlet") - ip_addr = "192.168.0.123" + ipv4_addr = "192.168.0.123" + ipv6_addr = "2002:DB8::21f:5bff:febf:ce22:1111" + with self.settings(RUNNING_INSIDE_TORNADO=True): user_obj = RateLimitedUser(user_profile, domain="api_by_user") - ip_obj = RateLimitedIPAddr(ip_addr, domain="api_by_ip") + ipv4_obj = RateLimitedIPAddr(ipv4_addr, domain="api_by_ip") + ipv6_obj = RateLimitedIPAddr(ipv6_addr, domain="api_by_ip") + self.assertEqual(user_obj.backend, TornadoInMemoryRateLimiterBackend) - self.assertEqual(ip_obj.backend, TornadoInMemoryRateLimiterBackend) + self.assertEqual(ipv4_obj.backend, TornadoInMemoryRateLimiterBackend) + self.assertEqual(ipv6_obj.backend, TornadoInMemoryRateLimiterBackend) with self.settings(RUNNING_INSIDE_TORNADO=True): user_obj = RateLimitedUser(user_profile, domain="some_domain") - ip_obj = RateLimitedIPAddr(ip_addr, domain="some_domain") + ipv4_obj = RateLimitedIPAddr(ipv4_addr, domain="some_domain") + ipv6_obj = RateLimitedIPAddr(ipv6_addr, domain="some_domain") + self.assertEqual(user_obj.backend, RedisRateLimiterBackend) - self.assertEqual(ip_obj.backend, RedisRateLimiterBackend) + self.assertEqual(ipv4_obj.backend, RedisRateLimiterBackend) + self.assertEqual(ipv6_obj.backend, RedisRateLimiterBackend) def test_block_access(self) -> None: obj = self.create_object("test", [(2, 5)]) @@ -249,6 +257,89 @@ class RateLimitedObjectsTest(ZulipTestCase): obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend) self.assertEqual(obj.get_rules(), [(1, 9999)]) + def test_ip_bucket_key(self) -> None: + ipv6 = "2001:0db8::ce1:12:8a2e:0370" + ipv4 = "192.168.0.123" + domain = "api_by_ip" + + self.assertEqual( + RateLimitedIPAddr(ipv4, domain=domain).key(), f"RateLimitedIPAddr:<{ipv4}>:{domain}" + ) + + # Here we check that each bucket key, given a different network prefix, + # is as expected. + # Although /64 is the only used prefix, we still test other prefixes + # to ensure correctness and also in case we decide to use smaller prefixes + # in future. + + # Note that the leading zero in :0db8: is omitted + self.assertEqual( + RateLimitedIPAddr(ipv6, domain=domain, ipv6_network_prefix=64).key(), + "RateLimitedIPAddr:<2001:db8::/64>:api_by_ip", + ) + self.assertEqual( + RateLimitedIPAddr(ipv6, domain=domain, ipv6_network_prefix=56).key(), + "RateLimitedIPAddr:<2001:db8::/56>:api_by_ip", + ) + self.assertEqual( + RateLimitedIPAddr(ipv6, domain=domain, ipv6_network_prefix=48).key(), + "RateLimitedIPAddr:<2001:db8::/48>:api_by_ip", + ) + + # Two IPv6 with the SAME network portion (identified by the prefix) + # should belong to the SAME bucket. + self.assertEqual( + RateLimitedIPAddr( + "2001:0db8:ce1:12::8a2e:0370", domain=domain, ipv6_network_prefix=64 + ).key(), + RateLimitedIPAddr( + "2001:0db8:ce1:12::8a2e:045f", domain=domain, ipv6_network_prefix=64 + ).key(), + ) + self.assertEqual( + RateLimitedIPAddr( + "2001:0db8:7a2e:ccd1::0370", domain=domain, ipv6_network_prefix=56 + ).key(), + RateLimitedIPAddr( + "2001:0db8:7a2e:ccf2::045f", domain=domain, ipv6_network_prefix=56 + ).key(), + ) + self.assertEqual( + RateLimitedIPAddr( + "2001:0db8:ce1:12::8a2e:0370", domain=domain, ipv6_network_prefix=48 + ).key(), + RateLimitedIPAddr( + "2001:0db8:ce1:13::8a2e:045f", domain=domain, ipv6_network_prefix=48 + ).key(), + ) + + # Two IPv6 with DIFFERENT network portions (identified by the prefix) + # should belong to DIFFERENT buckets. + self.assertNotEqual( + RateLimitedIPAddr( + "2001:0db8:ce1:12::8a2e:0370", domain=domain, ipv6_network_prefix=64 + ).key(), + RateLimitedIPAddr( + "2001:0db8:ce1:13::8a2e:045f", domain=domain, ipv6_network_prefix=64 + ).key(), + ) + self.assertNotEqual( + RateLimitedIPAddr( + "2001:0db8:7a2e:ccd1::0370", domain=domain, ipv6_network_prefix=56 + ).key(), + RateLimitedIPAddr( + "2001:0db8:7a2e:c1f2::045f", domain=domain, ipv6_network_prefix=56 + ).key(), + ) + self.assertNotEqual( + RateLimitedIPAddr( + "2001:0db8:12::8a2e:0370", domain=domain, ipv6_network_prefix=48 + ).key(), + RateLimitedIPAddr( + "2001:0db8:13::8a2e:045f", domain=domain, ipv6_network_prefix=48 + ).key(), + ) + # Don't load the base class as a test: https://bugs.python.org/issue17519. del RateLimiterBackendBase diff --git a/zproject/default_settings.py b/zproject/default_settings.py index 5c45392d1d..085d94be21 100644 --- a/zproject/default_settings.py +++ b/zproject/default_settings.py @@ -269,8 +269,10 @@ DEFAULT_RATE_LIMITING_RULES = { ], # Limits total number of unauthenticated API requests (primarily # used by the public access option). Since these are - # unauthenticated requests, each IP address is a separate bucket. + # unauthenticated requests, each IPv4 address is a separate bucket. + # For IPv6, one bucket is used for each /64 subnet. "api_by_ip": [ + # 100 requests per minute. (60, 100), ], # Limits total requests to the Mobile Push Notifications Service