mirror of
https://github.com/zulip/zulip.git
synced 2025-11-19 22:19:48 +00:00
test_helpers: Convert TypedDict from queries_captured to dataclass.
An implicit coercion from an untyped dict to the TypedDict was hiding
a type error: CapturedQuery.sql was really str, not bytes. We should
always prefer dataclass over TypedDict to prevent such errors.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 92db6eba78)
This commit is contained in:
committed by
Alex Vandiver
parent
9628cc9278
commit
201cab601a
@@ -1189,7 +1189,7 @@ Output:
|
||||
if actual_count != count: # nocoverage
|
||||
print("\nITEMS:\n")
|
||||
for index, query in enumerate(queries):
|
||||
print(f"#{index + 1}\nsql: {str(query['sql'])}\ntime: {query['time']}\n")
|
||||
print(f"#{index + 1}\nsql: {str(query.sql)}\ntime: {query.time}\n")
|
||||
print(f"expected count: {count}\nactual count: {actual_count}")
|
||||
raise AssertionError(
|
||||
f"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
@@ -16,7 +17,6 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -133,21 +133,22 @@ def simulated_empty_cache() -> Iterator[List[Tuple[str, Union[str, List[str]], O
|
||||
yield cache_queries
|
||||
|
||||
|
||||
class CapturedQueryDict(TypedDict):
|
||||
sql: bytes
|
||||
@dataclass
|
||||
class CapturedQuery:
|
||||
sql: str
|
||||
time: str
|
||||
|
||||
|
||||
@contextmanager
|
||||
def queries_captured(
|
||||
include_savepoints: bool = False, keep_cache_warm: bool = False
|
||||
) -> Iterator[List[CapturedQueryDict]]:
|
||||
) -> Iterator[List[CapturedQuery]]:
|
||||
"""
|
||||
Allow a user to capture just the queries executed during
|
||||
the with statement.
|
||||
"""
|
||||
|
||||
queries: List[CapturedQueryDict] = []
|
||||
queries: List[CapturedQuery] = []
|
||||
|
||||
def wrapper_execute(
|
||||
self: TimeTrackingCursor,
|
||||
@@ -163,10 +164,10 @@ def queries_captured(
|
||||
duration = stop - start
|
||||
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
|
||||
queries.append(
|
||||
{
|
||||
"sql": self.mogrify(sql, params).decode(),
|
||||
"time": f"{duration:.3f}",
|
||||
}
|
||||
CapturedQuery(
|
||||
sql=self.mogrify(sql, params).decode(),
|
||||
time=f"{duration:.3f}",
|
||||
)
|
||||
)
|
||||
|
||||
def cursor_execute(
|
||||
|
||||
@@ -76,7 +76,7 @@ class EditMessageTestCase(ZulipTestCase):
|
||||
|
||||
self.assert_length(queries, 1)
|
||||
for query in queries:
|
||||
self.assertNotIn("message", query["sql"])
|
||||
self.assertNotIn("message", query.sql)
|
||||
|
||||
self.assertEqual(
|
||||
fetch_message_dict[TOPIC_NAME],
|
||||
|
||||
@@ -3289,7 +3289,7 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
get_messages_backend(request, user_profile)
|
||||
|
||||
for query in queries:
|
||||
sql = str(query["sql"])
|
||||
sql = str(query.sql)
|
||||
if "/* get_messages */" in sql:
|
||||
sql = sql.replace(" /* get_messages */", "")
|
||||
self.assertEqual(sql, expected)
|
||||
@@ -3470,9 +3470,9 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
get_messages_backend(request, user_profile)
|
||||
|
||||
# Verify the query for old messages looks correct.
|
||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
||||
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||
self.assert_length(queries, 1)
|
||||
sql = queries[0]["sql"]
|
||||
sql = queries[0].sql
|
||||
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
||||
self.assertIn("ORDER BY message_id ASC", sql)
|
||||
|
||||
@@ -3516,9 +3516,9 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
with queries_captured() as all_queries:
|
||||
get_messages_backend(request, user_profile)
|
||||
|
||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
||||
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||
self.assert_length(queries, 1)
|
||||
sql = queries[0]["sql"]
|
||||
sql = queries[0].sql
|
||||
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
||||
self.assertIn("ORDER BY message_id ASC", sql)
|
||||
cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}"
|
||||
@@ -3540,10 +3540,10 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
with queries_captured() as all_queries:
|
||||
get_messages_backend(request, user_profile)
|
||||
|
||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
||||
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||
self.assert_length(queries, 1)
|
||||
|
||||
sql = queries[0]["sql"]
|
||||
sql = queries[0].sql
|
||||
|
||||
self.assertNotIn("AND message_id <=", sql)
|
||||
self.assertNotIn("AND message_id >=", sql)
|
||||
@@ -3553,8 +3553,8 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
with first_visible_id_as(first_visible_message_id):
|
||||
with queries_captured() as all_queries:
|
||||
get_messages_backend(request, user_profile)
|
||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
||||
sql = queries[0]["sql"]
|
||||
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||
sql = queries[0].sql
|
||||
self.assertNotIn("AND message_id <=", sql)
|
||||
self.assertNotIn("AND message_id >=", sql)
|
||||
|
||||
@@ -3595,20 +3595,20 @@ class GetOldMessagesTest(ZulipTestCase):
|
||||
|
||||
# Do some tests on the main query, to verify the muting logic
|
||||
# runs on this code path.
|
||||
queries = [q for q in all_queries if str(q["sql"]).startswith("SELECT message_id, flags")]
|
||||
queries = [q for q in all_queries if q.sql.startswith("SELECT message_id, flags")]
|
||||
self.assert_length(queries, 1)
|
||||
|
||||
stream = get_stream("Scotland", realm)
|
||||
assert stream.recipient is not None
|
||||
recipient_id = stream.recipient.id
|
||||
cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))"
|
||||
self.assertIn(cond, queries[0]["sql"])
|
||||
self.assertIn(cond, queries[0].sql)
|
||||
|
||||
# Next, verify the use_first_unread_anchor setting invokes
|
||||
# the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack.
|
||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
||||
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||
self.assert_length(queries, 1)
|
||||
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0]["sql"])
|
||||
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0].sql)
|
||||
|
||||
def test_exclude_muting_conditions(self) -> None:
|
||||
realm = get_realm("zulip")
|
||||
|
||||
Reference in New Issue
Block a user