db: Fix types to accept psycopg2.sql.Composable queries, avoid Any.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2020-05-03 17:36:15 -07:00
committed by Tim Abbott
parent d0b40cd7a3
commit cebac3f35a
2 changed files with 21 additions and 18 deletions

View File

@@ -17,7 +17,7 @@ from zerver.lib.actions import do_set_realm_property
from zerver.lib.upload import S3UploadBackend, LocalUploadBackend
from zerver.lib.avatar import avatar_url
from zerver.lib.cache import get_cache_backend
from zerver.lib.db import TimeTrackingCursor
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
from zerver.lib import cache
from zerver.tornado import event_queue
from zerver.tornado.handlers import allocate_handler_id
@@ -147,9 +147,9 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
queries: List[Dict[str, Union[str, bytes]]] = []
def wrapper_execute(self: TimeTrackingCursor,
action: Callable[[str, Iterable[Any]], None],
sql: str,
params: Iterable[Any]=()) -> None:
action: Callable[[str, ParamsT], None],
sql: Query,
params: ParamsT) -> None:
cache = get_cache_backend(None)
cache.clear()
start = time.time()
@@ -158,7 +158,7 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
finally:
stop = time.time()
duration = stop - start
if include_savepoints or ('SAVEPOINT' not in sql):
if include_savepoints or not isinstance(sql, str) or 'SAVEPOINT' not in sql:
queries.append({
'sql': self.mogrify(sql, params).decode('utf-8'),
'time': "%.3f" % (duration,),
@@ -167,13 +167,13 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
old_execute = TimeTrackingCursor.execute
old_executemany = TimeTrackingCursor.executemany
def cursor_execute(self: TimeTrackingCursor, sql: str,
params: Iterable[Any]=()) -> None:
def cursor_execute(self: TimeTrackingCursor, sql: Query,
params: Optional[Params]=None) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
TimeTrackingCursor.execute = cursor_execute # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
def cursor_executemany(self: TimeTrackingCursor, sql: str,
params: Iterable[Any]=()) -> None:
def cursor_executemany(self: TimeTrackingCursor, sql: Query,
params: Iterable[Params]) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # nocoverage -- doesn't actually get used in tests
TimeTrackingCursor.executemany = cursor_executemany # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167