mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	zilencer: Only apply rate limit to remote server.
This refactors the test case alongside, since normal views accessed by remote server do not get rate limited by remote server anymore. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							79e86471e7
						
					
				
				
					commit
					29bad25f83
				
			@@ -646,14 +646,9 @@ def rate_limit(request: HttpRequest) -> None:
 | 
				
			|||||||
    if not should_rate_limit(request):
 | 
					    if not should_rate_limit(request):
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    from zerver.lib.request import RequestNotes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    user = request.user
 | 
					    user = request.user
 | 
				
			||||||
    remote_server = RequestNotes.get_notes(request).remote_server
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if settings.ZILENCER_ENABLED and remote_server is not None:
 | 
					    if not user.is_authenticated:
 | 
				
			||||||
        rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
 | 
					 | 
				
			||||||
    elif not user.is_authenticated:
 | 
					 | 
				
			||||||
        rate_limit_request_by_ip(request, domain="api_by_ip")
 | 
					        rate_limit_request_by_ip(request, domain="api_by_ip")
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        assert isinstance(user, UserProfile)
 | 
					        assert isinstance(user, UserProfile)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -723,6 +723,7 @@ class RateLimitTestCase(ZulipTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
 | 
					    @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
 | 
				
			||||||
    def test_rate_limiting_happens_if_remote_server(self) -> None:
 | 
					    def test_rate_limiting_happens_if_remote_server(self) -> None:
 | 
				
			||||||
 | 
					        user = self.example_user("hamlet")
 | 
				
			||||||
        server_uuid = str(uuid.uuid4())
 | 
					        server_uuid = str(uuid.uuid4())
 | 
				
			||||||
        server = RemoteZulipServer(
 | 
					        server = RemoteZulipServer(
 | 
				
			||||||
            uuid=server_uuid,
 | 
					            uuid=server_uuid,
 | 
				
			||||||
@@ -730,16 +731,18 @@ class RateLimitTestCase(ZulipTestCase):
 | 
				
			|||||||
            hostname="demo.example.com",
 | 
					            hostname="demo.example.com",
 | 
				
			||||||
            last_updated=timezone_now(),
 | 
					            last_updated=timezone_now(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        META = {"REMOTE_ADDR": "3.3.3.3"}
 | 
					        server.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        req = HostRequestMock(client_name="external", remote_server=server, meta_data=META)
 | 
					        with self.settings(RATE_LIMITING=True), mock.patch(
 | 
				
			||||||
 | 
					            "zerver.lib.rate_limiter.rate_limit_remote_server"
 | 
				
			||||||
        f = self.get_ratelimited_view()
 | 
					        ) as rate_limit_mock:
 | 
				
			||||||
 | 
					            result = self.uuid_post(
 | 
				
			||||||
        with self.settings(RATE_LIMITING=True):
 | 
					                server_uuid,
 | 
				
			||||||
            with mock.patch("zerver.lib.rate_limiter.rate_limit_remote_server") as rate_limit_mock:
 | 
					                "/api/v1/remotes/push/unregister/all",
 | 
				
			||||||
                with self.errors_disallowed():
 | 
					                {"user_id": user.id},
 | 
				
			||||||
                    self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value")
 | 
					                subdomain="",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assert_json_success(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertTrue(rate_limit_mock.called)
 | 
					        self.assertTrue(rate_limit_mock.called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,7 +15,7 @@ from zerver.lib.exceptions import (
 | 
				
			|||||||
    RemoteServerDeactivatedError,
 | 
					    RemoteServerDeactivatedError,
 | 
				
			||||||
    UnauthorizedError,
 | 
					    UnauthorizedError,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from zerver.lib.rate_limiter import rate_limit
 | 
					from zerver.lib.rate_limiter import rate_limit_remote_server, should_rate_limit
 | 
				
			||||||
from zerver.lib.request import RequestNotes
 | 
					from zerver.lib.request import RequestNotes
 | 
				
			||||||
from zerver.lib.rest import get_target_view_function_or_response
 | 
					from zerver.lib.rest import get_target_view_function_or_response
 | 
				
			||||||
from zerver.lib.subdomains import get_subdomain
 | 
					from zerver.lib.subdomains import get_subdomain
 | 
				
			||||||
@@ -80,7 +80,8 @@ def authenticated_remote_server_view(
 | 
				
			|||||||
        except JsonableError as e:
 | 
					        except JsonableError as e:
 | 
				
			||||||
            raise UnauthorizedError(e.msg)
 | 
					            raise UnauthorizedError(e.msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        rate_limit(request)
 | 
					        if should_rate_limit(request):
 | 
				
			||||||
 | 
					            rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
 | 
				
			||||||
        return view_func(request, remote_server, *args, **kwargs)
 | 
					        return view_func(request, remote_server, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return _wrapped_view_func
 | 
					    return _wrapped_view_func
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user