test_helpers: Fix logging in cursor_executemany mock.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2023-11-15 13:25:00 -08:00
committed by Tim Abbott
parent cb9a04d3e3
commit a688e753de
2 changed files with 24 additions and 20 deletions

View File

@@ -1,4 +1,5 @@
import collections
import itertools
import os
import re
import sys
@@ -45,7 +46,7 @@ from zerver.actions.user_settings import do_change_user_setting
from zerver.lib import cache
from zerver.lib.avatar import avatar_url
from zerver.lib.cache import get_cache_backend
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
from zerver.lib.db import Params, Query, TimeTrackingCursor
from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
from zerver.lib.per_request_cache import flush_per_request_caches
from zerver.lib.rate_limiter import RateLimitedIPAddr, rules
@@ -150,35 +151,38 @@ def queries_captured(
queries: List[CapturedQuery] = []
def wrapper_execute(
self: TimeTrackingCursor,
action: Callable[[Query, ParamsT], None],
sql: Query,
params: ParamsT,
) -> None:
def cursor_execute(self: TimeTrackingCursor, sql: Query, vars: Optional[Params] = None) -> None:
start = time.time()
try:
return action(sql, params)
return super(TimeTrackingCursor, self).execute(sql, vars)
finally:
stop = time.time()
duration = stop - start
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
queries.append(
CapturedQuery(
sql=self.mogrify(sql, params).decode(),
sql=self.mogrify(sql, vars).decode(),
time=f"{duration:.3f}",
)
)
def cursor_execute(
self: TimeTrackingCursor, sql: Query, params: Optional[Params] = None
) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
def cursor_executemany(self: TimeTrackingCursor, sql: Query, params: Iterable[Params]) -> None:
return wrapper_execute(
self, super(TimeTrackingCursor, self).executemany, sql, params
) # nocoverage -- doesn't actually get used in tests
def cursor_executemany(
self: TimeTrackingCursor, sql: Query, vars_list: Iterable[Params]
) -> None: # nocoverage -- doesn't actually get used in tests
vars_list, vars_list1 = itertools.tee(vars_list)
start = time.time()
try:
return super(TimeTrackingCursor, self).executemany(sql, vars_list)
finally:
stop = time.time()
duration = stop - start
queries.extend(
CapturedQuery(
sql=self.mogrify(sql, vars).decode(),
time=f"{duration:.3f}",
)
for vars in vars_list1
)
if not keep_cache_warm:
cache = get_cache_backend(None)