diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index cb6c44dc85..ebfb460586 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -3655,13 +3655,16 @@ class GetOldMessagesTest(ZulipTestCase): ) self.assertEqual(final_dict["content"], "

test content

") - def common_check_get_messages_query( - self, query_params: Dict[str, object], expected: str - ) -> None: + def common_check_get_messages_query(self, query_params: Dict[str, Any], expected: str) -> None: user_profile = self.example_user("hamlet") request = HostRequestMock(query_params, user_profile) with queries_captured() as queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=query_params["num_before"], + num_after=query_params["num_after"], + ) for query in queries: sql = str(query.sql) @@ -3721,7 +3724,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], first_message_id) self.assertEqual(result["found_newest"], True) @@ -3758,7 +3766,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], first_message_id) @@ -3771,7 +3784,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], 0) @@ -3785,7 +3803,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID) @@ -3799,7 +3822,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], 0) @@ -3813,7 +3841,12 @@ class GetOldMessagesTest(ZulipTestCase): ) request = HostRequestMock(query_params, user_profile) - payload = get_messages_backend(request, user_profile) + payload = get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) result = orjson.loads(payload.content) self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID) @@ -3842,7 +3875,12 @@ class GetOldMessagesTest(ZulipTestCase): request = HostRequestMock(query_params, user_profile) with queries_captured() as all_queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) # Verify the query for old messages looks correct. queries = [q for q in all_queries if "/* get_messages */" in q.sql] @@ -3889,7 +3927,12 @@ class GetOldMessagesTest(ZulipTestCase): first_visible_message_id = first_unread_message_id + 2 with first_visible_id_as(first_visible_message_id): with queries_captured() as all_queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) queries = [q for q in all_queries if "/* get_messages */" in q.sql] self.assert_length(queries, 1) @@ -3913,7 +3956,12 @@ class GetOldMessagesTest(ZulipTestCase): request = HostRequestMock(query_params, user_profile) with queries_captured() as all_queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) queries = [q for q in all_queries if "/* get_messages */" in q.sql] self.assert_length(queries, 1) @@ -3927,7 +3975,12 @@ class GetOldMessagesTest(ZulipTestCase): first_visible_message_id = 5 with first_visible_id_as(first_visible_message_id): with queries_captured() as all_queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) queries = [q for q in all_queries if "/* get_messages */" in q.sql] sql = queries[0].sql self.assertNotIn("AND message_id <=", sql) @@ -3966,7 +4019,12 @@ class GetOldMessagesTest(ZulipTestCase): request = HostRequestMock(query_params, user_profile) with queries_captured() as all_queries: - get_messages_backend(request, user_profile) + get_messages_backend( + request, + user_profile, + num_before=0, + num_after=0, + ) # Do some tests on the main query, to verify the muting logic # runs on this code path. diff --git a/zilencer/management/commands/profile_request.py b/zilencer/management/commands/profile_request.py index 40f6071cd1..92e9e58714 100644 --- a/zilencer/management/commands/profile_request.py +++ b/zilencer/management/commands/profile_request.py @@ -21,9 +21,16 @@ class MockSession(SessionBase): self.modified = False -def profile_request(request: HttpRequest) -> HttpResponseBase: +def profile_request(request: HttpRequest, num_before: int, num_after: int) -> HttpResponseBase: def get_response(request: HttpRequest) -> HttpResponseBase: - return prof.runcall(get_messages_backend, request, request.user, apply_markdown=True) + return prof.runcall( + get_messages_backend, + request, + request.user, + num_before=num_before, + num_after=num_after, + apply_markdown=True, + ) prof = cProfile.Profile() with tempfile.NamedTemporaryFile(prefix="profile.data.", delete=False) as stats_file: @@ -58,4 +65,4 @@ class Command(ZulipBaseCommand): mock_request.session = MockSession() RequestNotes.get_notes(mock_request).log_data = None - profile_request(mock_request) + profile_request(mock_request, num_before=1200, num_after=200)