test_helpers: Convert TypedDict from queries_captured to dataclass.

An implicit coercion from an untyped dict to the TypedDict was hiding
a type error: CapturedQuery.sql was really str, not bytes.  We should
always prefer dataclass over TypedDict to prevent such errors.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 92db6eba78)
This commit is contained in:
Anders Kaseorg
2023-06-06 14:54:19 -07:00
committed by Alex Vandiver
parent 9628cc9278
commit 201cab601a
4 changed files with 25 additions and 24 deletions

View File

@@ -1189,7 +1189,7 @@ Output:
if actual_count != count: # nocoverage if actual_count != count: # nocoverage
print("\nITEMS:\n") print("\nITEMS:\n")
for index, query in enumerate(queries): for index, query in enumerate(queries):
print(f"#{index + 1}\nsql: {str(query['sql'])}\ntime: {query['time']}\n") print(f"#{index + 1}\nsql: {str(query.sql)}\ntime: {query.time}\n")
print(f"expected count: {count}\nactual count: {actual_count}") print(f"expected count: {count}\nactual count: {actual_count}")
raise AssertionError( raise AssertionError(
f""" f"""

View File

@@ -4,6 +4,7 @@ import re
import sys import sys
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from typing import ( from typing import (
IO, IO,
TYPE_CHECKING, TYPE_CHECKING,
@@ -16,7 +17,6 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
TypedDict,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@@ -133,21 +133,22 @@ def simulated_empty_cache() -> Iterator[List[Tuple[str, Union[str, List[str]], O
yield cache_queries yield cache_queries
class CapturedQueryDict(TypedDict): @dataclass
sql: bytes class CapturedQuery:
sql: str
time: str time: str
@contextmanager @contextmanager
def queries_captured( def queries_captured(
include_savepoints: bool = False, keep_cache_warm: bool = False include_savepoints: bool = False, keep_cache_warm: bool = False
) -> Iterator[List[CapturedQueryDict]]: ) -> Iterator[List[CapturedQuery]]:
""" """
Allow a user to capture just the queries executed during Allow a user to capture just the queries executed during
the with statement. the with statement.
""" """
queries: List[CapturedQueryDict] = [] queries: List[CapturedQuery] = []
def wrapper_execute( def wrapper_execute(
self: TimeTrackingCursor, self: TimeTrackingCursor,
@@ -163,10 +164,10 @@ def queries_captured(
duration = stop - start duration = stop - start
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql: if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
queries.append( queries.append(
{ CapturedQuery(
"sql": self.mogrify(sql, params).decode(), sql=self.mogrify(sql, params).decode(),
"time": f"{duration:.3f}", time=f"{duration:.3f}",
} )
) )
def cursor_execute( def cursor_execute(

View File

@@ -76,7 +76,7 @@ class EditMessageTestCase(ZulipTestCase):
self.assert_length(queries, 1) self.assert_length(queries, 1)
for query in queries: for query in queries:
self.assertNotIn("message", query["sql"]) self.assertNotIn("message", query.sql)
self.assertEqual( self.assertEqual(
fetch_message_dict[TOPIC_NAME], fetch_message_dict[TOPIC_NAME],

View File

@@ -3289,7 +3289,7 @@ class GetOldMessagesTest(ZulipTestCase):
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
for query in queries: for query in queries:
sql = str(query["sql"]) sql = str(query.sql)
if "/* get_messages */" in sql: if "/* get_messages */" in sql:
sql = sql.replace(" /* get_messages */", "") sql = sql.replace(" /* get_messages */", "")
self.assertEqual(sql, expected) self.assertEqual(sql, expected)
@@ -3470,9 +3470,9 @@ class GetOldMessagesTest(ZulipTestCase):
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
# Verify the query for old messages looks correct. # Verify the query for old messages looks correct.
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
sql = queries[0]["sql"] sql = queries[0].sql
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql) self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
self.assertIn("ORDER BY message_id ASC", sql) self.assertIn("ORDER BY message_id ASC", sql)
@@ -3516,9 +3516,9 @@ class GetOldMessagesTest(ZulipTestCase):
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
sql = queries[0]["sql"] sql = queries[0].sql
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql) self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
self.assertIn("ORDER BY message_id ASC", sql) self.assertIn("ORDER BY message_id ASC", sql)
cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}" cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}"
@@ -3540,10 +3540,10 @@ class GetOldMessagesTest(ZulipTestCase):
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
sql = queries[0]["sql"] sql = queries[0].sql
self.assertNotIn("AND message_id <=", sql) self.assertNotIn("AND message_id <=", sql)
self.assertNotIn("AND message_id >=", sql) self.assertNotIn("AND message_id >=", sql)
@@ -3553,8 +3553,8 @@ class GetOldMessagesTest(ZulipTestCase):
with first_visible_id_as(first_visible_message_id): with first_visible_id_as(first_visible_message_id):
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
sql = queries[0]["sql"] sql = queries[0].sql
self.assertNotIn("AND message_id <=", sql) self.assertNotIn("AND message_id <=", sql)
self.assertNotIn("AND message_id >=", sql) self.assertNotIn("AND message_id >=", sql)
@@ -3595,20 +3595,20 @@ class GetOldMessagesTest(ZulipTestCase):
# Do some tests on the main query, to verify the muting logic # Do some tests on the main query, to verify the muting logic
# runs on this code path. # runs on this code path.
queries = [q for q in all_queries if str(q["sql"]).startswith("SELECT message_id, flags")] queries = [q for q in all_queries if q.sql.startswith("SELECT message_id, flags")]
self.assert_length(queries, 1) self.assert_length(queries, 1)
stream = get_stream("Scotland", realm) stream = get_stream("Scotland", realm)
assert stream.recipient is not None assert stream.recipient is not None
recipient_id = stream.recipient.id recipient_id = stream.recipient.id
cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))" cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))"
self.assertIn(cond, queries[0]["sql"]) self.assertIn(cond, queries[0].sql)
# Next, verify the use_first_unread_anchor setting invokes # Next, verify the use_first_unread_anchor setting invokes
# the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack. # the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack.
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0]["sql"]) self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0].sql)
def test_exclude_muting_conditions(self) -> None: def test_exclude_muting_conditions(self) -> None:
realm = get_realm("zulip") realm = get_realm("zulip")