test_helpers: Convert TypedDict from queries_captured to dataclass.

An implicit coercion from an untyped dict to the TypedDict was hiding
a type error: CapturedQuery.sql was really str, not bytes.  We should
always prefer dataclass over TypedDict to prevent such errors.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 92db6eba78)
This commit is contained in:
Anders Kaseorg
2023-06-06 14:54:19 -07:00
committed by Alex Vandiver
parent 9628cc9278
commit 201cab601a
4 changed files with 25 additions and 24 deletions

View File

@@ -1189,7 +1189,7 @@ Output:
if actual_count != count: # nocoverage
print("\nITEMS:\n")
for index, query in enumerate(queries):
print(f"#{index + 1}\nsql: {str(query['sql'])}\ntime: {query['time']}\n")
print(f"#{index + 1}\nsql: {str(query.sql)}\ntime: {query.time}\n")
print(f"expected count: {count}\nactual count: {actual_count}")
raise AssertionError(
f"""

View File

@@ -4,6 +4,7 @@ import re
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
IO,
TYPE_CHECKING,
@@ -16,7 +17,6 @@ from typing import (
Mapping,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
cast,
@@ -133,21 +133,22 @@ def simulated_empty_cache() -> Iterator[List[Tuple[str, Union[str, List[str]], O
yield cache_queries
class CapturedQueryDict(TypedDict):
sql: bytes
@dataclass
class CapturedQuery:
sql: str
time: str
@contextmanager
def queries_captured(
include_savepoints: bool = False, keep_cache_warm: bool = False
) -> Iterator[List[CapturedQueryDict]]:
) -> Iterator[List[CapturedQuery]]:
"""
Allow a user to capture just the queries executed during
the with statement.
"""
queries: List[CapturedQueryDict] = []
queries: List[CapturedQuery] = []
def wrapper_execute(
self: TimeTrackingCursor,
@@ -163,10 +164,10 @@ def queries_captured(
duration = stop - start
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
queries.append(
{
"sql": self.mogrify(sql, params).decode(),
"time": f"{duration:.3f}",
}
CapturedQuery(
sql=self.mogrify(sql, params).decode(),
time=f"{duration:.3f}",
)
)
def cursor_execute(

View File

@@ -76,7 +76,7 @@ class EditMessageTestCase(ZulipTestCase):
self.assert_length(queries, 1)
for query in queries:
self.assertNotIn("message", query["sql"])
self.assertNotIn("message", query.sql)
self.assertEqual(
fetch_message_dict[TOPIC_NAME],

View File

@@ -3289,7 +3289,7 @@ class GetOldMessagesTest(ZulipTestCase):
get_messages_backend(request, user_profile)
for query in queries:
sql = str(query["sql"])
sql = str(query.sql)
if "/* get_messages */" in sql:
sql = sql.replace(" /* get_messages */", "")
self.assertEqual(sql, expected)
@@ -3470,9 +3470,9 @@ class GetOldMessagesTest(ZulipTestCase):
get_messages_backend(request, user_profile)
# Verify the query for old messages looks correct.
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1)
sql = queries[0]["sql"]
sql = queries[0].sql
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
self.assertIn("ORDER BY message_id ASC", sql)
@@ -3516,9 +3516,9 @@ class GetOldMessagesTest(ZulipTestCase):
with queries_captured() as all_queries:
get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1)
sql = queries[0]["sql"]
sql = queries[0].sql
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
self.assertIn("ORDER BY message_id ASC", sql)
cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}"
@@ -3540,10 +3540,10 @@ class GetOldMessagesTest(ZulipTestCase):
with queries_captured() as all_queries:
get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1)
sql = queries[0]["sql"]
sql = queries[0].sql
self.assertNotIn("AND message_id <=", sql)
self.assertNotIn("AND message_id >=", sql)
@@ -3553,8 +3553,8 @@ class GetOldMessagesTest(ZulipTestCase):
with first_visible_id_as(first_visible_message_id):
with queries_captured() as all_queries:
get_messages_backend(request, user_profile)
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
sql = queries[0]["sql"]
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
sql = queries[0].sql
self.assertNotIn("AND message_id <=", sql)
self.assertNotIn("AND message_id >=", sql)
@@ -3595,20 +3595,20 @@ class GetOldMessagesTest(ZulipTestCase):
# Do some tests on the main query, to verify the muting logic
# runs on this code path.
queries = [q for q in all_queries if str(q["sql"]).startswith("SELECT message_id, flags")]
queries = [q for q in all_queries if q.sql.startswith("SELECT message_id, flags")]
self.assert_length(queries, 1)
stream = get_stream("Scotland", realm)
assert stream.recipient is not None
recipient_id = stream.recipient.id
cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))"
self.assertIn(cond, queries[0]["sql"])
self.assertIn(cond, queries[0].sql)
# Next, verify the use_first_unread_anchor setting invokes
# the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack.
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1)
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0]["sql"])
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0].sql)
def test_exclude_muting_conditions(self) -> None:
realm = get_realm("zulip")