mirror of
https://github.com/zulip/zulip.git
synced 2025-11-06 15:03:34 +00:00
db: Fix types to accept psycopg2.sql.Composable queries, avoid Any.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
committed by
Tim Abbott
parent
d0b40cd7a3
commit
cebac3f35a
@@ -1,18 +1,21 @@
|
|||||||
import time
|
import time
|
||||||
from psycopg2.extensions import cursor, connection
|
from psycopg2.extensions import cursor, connection
|
||||||
|
from psycopg2.sql import Composable
|
||||||
|
|
||||||
from typing import Callable, Optional, Iterable, Any, Dict, List, Union, TypeVar, \
|
from typing import Callable, Optional, Iterable, Any, Dict, List, Union, TypeVar, \
|
||||||
Mapping
|
Mapping, Sequence
|
||||||
|
|
||||||
CursorObj = TypeVar('CursorObj', bound=cursor)
|
CursorObj = TypeVar('CursorObj', bound=cursor)
|
||||||
ParamsT = Union[Iterable[Any], Mapping[str, Any]]
|
Query = Union[str, Composable]
|
||||||
|
Params = Union[Sequence[object], Mapping[str, object]]
|
||||||
|
ParamsT = TypeVar('ParamsT')
|
||||||
|
|
||||||
# Similar to the tracking done in Django's CursorDebugWrapper, but done at the
|
# Similar to the tracking done in Django's CursorDebugWrapper, but done at the
|
||||||
# psycopg2 cursor level so it works with SQLAlchemy.
|
# psycopg2 cursor level so it works with SQLAlchemy.
|
||||||
def wrapper_execute(self: CursorObj,
|
def wrapper_execute(self: CursorObj,
|
||||||
action: Callable[[str, Optional[ParamsT]], CursorObj],
|
action: Callable[[Query, ParamsT], CursorObj],
|
||||||
sql: str,
|
sql: Query,
|
||||||
params: Optional[ParamsT]=()) -> CursorObj:
|
params: ParamsT) -> CursorObj:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
return action(sql, params)
|
return action(sql, params)
|
||||||
@@ -26,12 +29,12 @@ def wrapper_execute(self: CursorObj,
|
|||||||
class TimeTrackingCursor(cursor):
|
class TimeTrackingCursor(cursor):
|
||||||
"""A psycopg2 cursor class that tracks the time spent executing queries."""
|
"""A psycopg2 cursor class that tracks the time spent executing queries."""
|
||||||
|
|
||||||
def execute(self, query: str,
|
def execute(self, query: Query,
|
||||||
vars: Optional[ParamsT]=None) -> 'TimeTrackingCursor':
|
vars: Optional[Params]=None) -> 'TimeTrackingCursor':
|
||||||
return wrapper_execute(self, super().execute, query, vars)
|
return wrapper_execute(self, super().execute, query, vars)
|
||||||
|
|
||||||
def executemany(self, query: str,
|
def executemany(self, query: Query,
|
||||||
vars: Iterable[Any]) -> 'TimeTrackingCursor':
|
vars: Iterable[Params]) -> 'TimeTrackingCursor':
|
||||||
return wrapper_execute(self, super().executemany, query, vars)
|
return wrapper_execute(self, super().executemany, query, vars)
|
||||||
|
|
||||||
class TimeTrackingConnection(connection):
|
class TimeTrackingConnection(connection):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from zerver.lib.actions import do_set_realm_property
|
|||||||
from zerver.lib.upload import S3UploadBackend, LocalUploadBackend
|
from zerver.lib.upload import S3UploadBackend, LocalUploadBackend
|
||||||
from zerver.lib.avatar import avatar_url
|
from zerver.lib.avatar import avatar_url
|
||||||
from zerver.lib.cache import get_cache_backend
|
from zerver.lib.cache import get_cache_backend
|
||||||
from zerver.lib.db import TimeTrackingCursor
|
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
|
||||||
from zerver.lib import cache
|
from zerver.lib import cache
|
||||||
from zerver.tornado import event_queue
|
from zerver.tornado import event_queue
|
||||||
from zerver.tornado.handlers import allocate_handler_id
|
from zerver.tornado.handlers import allocate_handler_id
|
||||||
@@ -147,9 +147,9 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||||||
queries: List[Dict[str, Union[str, bytes]]] = []
|
queries: List[Dict[str, Union[str, bytes]]] = []
|
||||||
|
|
||||||
def wrapper_execute(self: TimeTrackingCursor,
|
def wrapper_execute(self: TimeTrackingCursor,
|
||||||
action: Callable[[str, Iterable[Any]], None],
|
action: Callable[[str, ParamsT], None],
|
||||||
sql: str,
|
sql: Query,
|
||||||
params: Iterable[Any]=()) -> None:
|
params: ParamsT) -> None:
|
||||||
cache = get_cache_backend(None)
|
cache = get_cache_backend(None)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -158,7 +158,7 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||||||
finally:
|
finally:
|
||||||
stop = time.time()
|
stop = time.time()
|
||||||
duration = stop - start
|
duration = stop - start
|
||||||
if include_savepoints or ('SAVEPOINT' not in sql):
|
if include_savepoints or not isinstance(sql, str) or 'SAVEPOINT' not in sql:
|
||||||
queries.append({
|
queries.append({
|
||||||
'sql': self.mogrify(sql, params).decode('utf-8'),
|
'sql': self.mogrify(sql, params).decode('utf-8'),
|
||||||
'time': "%.3f" % (duration,),
|
'time': "%.3f" % (duration,),
|
||||||
@@ -167,13 +167,13 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||||||
old_execute = TimeTrackingCursor.execute
|
old_execute = TimeTrackingCursor.execute
|
||||||
old_executemany = TimeTrackingCursor.executemany
|
old_executemany = TimeTrackingCursor.executemany
|
||||||
|
|
||||||
def cursor_execute(self: TimeTrackingCursor, sql: str,
|
def cursor_execute(self: TimeTrackingCursor, sql: Query,
|
||||||
params: Iterable[Any]=()) -> None:
|
params: Optional[Params]=None) -> None:
|
||||||
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
|
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
|
||||||
TimeTrackingCursor.execute = cursor_execute # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
TimeTrackingCursor.execute = cursor_execute # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
||||||
|
|
||||||
def cursor_executemany(self: TimeTrackingCursor, sql: str,
|
def cursor_executemany(self: TimeTrackingCursor, sql: Query,
|
||||||
params: Iterable[Any]=()) -> None:
|
params: Iterable[Params]) -> None:
|
||||||
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # nocoverage -- doesn't actually get used in tests
|
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # nocoverage -- doesn't actually get used in tests
|
||||||
TimeTrackingCursor.executemany = cursor_executemany # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
TimeTrackingCursor.executemany = cursor_executemany # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user