zilencer: Only apply rate limit to remote server.

This refactors the test case alongside, since normal views accessed by
remote server do not get rate limited by remote server anymore.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-08-14 10:16:36 -04:00
committed by Tim Abbott
parent 79e86471e7
commit 29bad25f83
3 changed files with 16 additions and 17 deletions

View File

@@ -646,14 +646,9 @@ def rate_limit(request: HttpRequest) -> None:
if not should_rate_limit(request): if not should_rate_limit(request):
return return
from zerver.lib.request import RequestNotes
user = request.user user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
if settings.ZILENCER_ENABLED and remote_server is not None: if not user.is_authenticated:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip") rate_limit_request_by_ip(request, domain="api_by_ip")
else: else:
assert isinstance(user, UserProfile) assert isinstance(user, UserProfile)

View File

@@ -723,6 +723,7 @@ class RateLimitTestCase(ZulipTestCase):
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer") @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
def test_rate_limiting_happens_if_remote_server(self) -> None: def test_rate_limiting_happens_if_remote_server(self) -> None:
user = self.example_user("hamlet")
server_uuid = str(uuid.uuid4()) server_uuid = str(uuid.uuid4())
server = RemoteZulipServer( server = RemoteZulipServer(
uuid=server_uuid, uuid=server_uuid,
@@ -730,16 +731,18 @@ class RateLimitTestCase(ZulipTestCase):
hostname="demo.example.com", hostname="demo.example.com",
last_updated=timezone_now(), last_updated=timezone_now(),
) )
META = {"REMOTE_ADDR": "3.3.3.3"} server.save()
req = HostRequestMock(client_name="external", remote_server=server, meta_data=META) with self.settings(RATE_LIMITING=True), mock.patch(
"zerver.lib.rate_limiter.rate_limit_remote_server"
f = self.get_ratelimited_view() ) as rate_limit_mock:
result = self.uuid_post(
with self.settings(RATE_LIMITING=True): server_uuid,
with mock.patch("zerver.lib.rate_limiter.rate_limit_remote_server") as rate_limit_mock: "/api/v1/remotes/push/unregister/all",
with self.errors_disallowed(): {"user_id": user.id},
self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") subdomain="",
)
self.assert_json_success(result)
self.assertTrue(rate_limit_mock.called) self.assertTrue(rate_limit_mock.called)

View File

@@ -15,7 +15,7 @@ from zerver.lib.exceptions import (
RemoteServerDeactivatedError, RemoteServerDeactivatedError,
UnauthorizedError, UnauthorizedError,
) )
from zerver.lib.rate_limiter import rate_limit from zerver.lib.rate_limiter import rate_limit_remote_server, should_rate_limit
from zerver.lib.request import RequestNotes from zerver.lib.request import RequestNotes
from zerver.lib.rest import get_target_view_function_or_response from zerver.lib.rest import get_target_view_function_or_response
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
@@ -80,7 +80,8 @@ def authenticated_remote_server_view(
except JsonableError as e: except JsonableError as e:
raise UnauthorizedError(e.msg) raise UnauthorizedError(e.msg)
rate_limit(request) if should_rate_limit(request):
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
return view_func(request, remote_server, *args, **kwargs) return view_func(request, remote_server, *args, **kwargs)
return _wrapped_view_func return _wrapped_view_func