mirror of
https://github.com/zulip/zulip.git
synced 2025-11-23 07:52:35 +00:00
test_helpers: Convert TypedDict from queries_captured to dataclass.
An implicit coercion from an untyped dict to the TypedDict was hiding
a type error: CapturedQuery.sql was really str, not bytes. We should
always prefer dataclass over TypedDict to prevent such errors.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 92db6eba78)
This commit is contained in:
committed by
Alex Vandiver
parent
9628cc9278
commit
201cab601a
@@ -1189,7 +1189,7 @@ Output:
|
|||||||
if actual_count != count: # nocoverage
|
if actual_count != count: # nocoverage
|
||||||
print("\nITEMS:\n")
|
print("\nITEMS:\n")
|
||||||
for index, query in enumerate(queries):
|
for index, query in enumerate(queries):
|
||||||
print(f"#{index + 1}\nsql: {str(query['sql'])}\ntime: {query['time']}\n")
|
print(f"#{index + 1}\nsql: {str(query.sql)}\ntime: {query.time}\n")
|
||||||
print(f"expected count: {count}\nactual count: {actual_count}")
|
print(f"expected count: {count}\nactual count: {actual_count}")
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"""
|
f"""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
IO,
|
IO,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -16,7 +17,6 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypedDict,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@@ -133,21 +133,22 @@ def simulated_empty_cache() -> Iterator[List[Tuple[str, Union[str, List[str]], O
|
|||||||
yield cache_queries
|
yield cache_queries
|
||||||
|
|
||||||
|
|
||||||
class CapturedQueryDict(TypedDict):
|
@dataclass
|
||||||
sql: bytes
|
class CapturedQuery:
|
||||||
|
sql: str
|
||||||
time: str
|
time: str
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def queries_captured(
|
def queries_captured(
|
||||||
include_savepoints: bool = False, keep_cache_warm: bool = False
|
include_savepoints: bool = False, keep_cache_warm: bool = False
|
||||||
) -> Iterator[List[CapturedQueryDict]]:
|
) -> Iterator[List[CapturedQuery]]:
|
||||||
"""
|
"""
|
||||||
Allow a user to capture just the queries executed during
|
Allow a user to capture just the queries executed during
|
||||||
the with statement.
|
the with statement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
queries: List[CapturedQueryDict] = []
|
queries: List[CapturedQuery] = []
|
||||||
|
|
||||||
def wrapper_execute(
|
def wrapper_execute(
|
||||||
self: TimeTrackingCursor,
|
self: TimeTrackingCursor,
|
||||||
@@ -163,10 +164,10 @@ def queries_captured(
|
|||||||
duration = stop - start
|
duration = stop - start
|
||||||
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
|
if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
|
||||||
queries.append(
|
queries.append(
|
||||||
{
|
CapturedQuery(
|
||||||
"sql": self.mogrify(sql, params).decode(),
|
sql=self.mogrify(sql, params).decode(),
|
||||||
"time": f"{duration:.3f}",
|
time=f"{duration:.3f}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def cursor_execute(
|
def cursor_execute(
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class EditMessageTestCase(ZulipTestCase):
|
|||||||
|
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
for query in queries:
|
for query in queries:
|
||||||
self.assertNotIn("message", query["sql"])
|
self.assertNotIn("message", query.sql)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
fetch_message_dict[TOPIC_NAME],
|
fetch_message_dict[TOPIC_NAME],
|
||||||
|
|||||||
@@ -3289,7 +3289,7 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
get_messages_backend(request, user_profile)
|
get_messages_backend(request, user_profile)
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
sql = str(query["sql"])
|
sql = str(query.sql)
|
||||||
if "/* get_messages */" in sql:
|
if "/* get_messages */" in sql:
|
||||||
sql = sql.replace(" /* get_messages */", "")
|
sql = sql.replace(" /* get_messages */", "")
|
||||||
self.assertEqual(sql, expected)
|
self.assertEqual(sql, expected)
|
||||||
@@ -3470,9 +3470,9 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
get_messages_backend(request, user_profile)
|
get_messages_backend(request, user_profile)
|
||||||
|
|
||||||
# Verify the query for old messages looks correct.
|
# Verify the query for old messages looks correct.
|
||||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
sql = queries[0]["sql"]
|
sql = queries[0].sql
|
||||||
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
||||||
self.assertIn("ORDER BY message_id ASC", sql)
|
self.assertIn("ORDER BY message_id ASC", sql)
|
||||||
|
|
||||||
@@ -3516,9 +3516,9 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
with queries_captured() as all_queries:
|
with queries_captured() as all_queries:
|
||||||
get_messages_backend(request, user_profile)
|
get_messages_backend(request, user_profile)
|
||||||
|
|
||||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
sql = queries[0]["sql"]
|
sql = queries[0].sql
|
||||||
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
self.assertNotIn(f"AND message_id = {LARGER_THAN_MAX_MESSAGE_ID}", sql)
|
||||||
self.assertIn("ORDER BY message_id ASC", sql)
|
self.assertIn("ORDER BY message_id ASC", sql)
|
||||||
cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}"
|
cond = f"WHERE user_profile_id = {user_profile.id} AND message_id <= {first_unread_message_id - 1}"
|
||||||
@@ -3540,10 +3540,10 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
with queries_captured() as all_queries:
|
with queries_captured() as all_queries:
|
||||||
get_messages_backend(request, user_profile)
|
get_messages_backend(request, user_profile)
|
||||||
|
|
||||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
|
|
||||||
sql = queries[0]["sql"]
|
sql = queries[0].sql
|
||||||
|
|
||||||
self.assertNotIn("AND message_id <=", sql)
|
self.assertNotIn("AND message_id <=", sql)
|
||||||
self.assertNotIn("AND message_id >=", sql)
|
self.assertNotIn("AND message_id >=", sql)
|
||||||
@@ -3553,8 +3553,8 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
with first_visible_id_as(first_visible_message_id):
|
with first_visible_id_as(first_visible_message_id):
|
||||||
with queries_captured() as all_queries:
|
with queries_captured() as all_queries:
|
||||||
get_messages_backend(request, user_profile)
|
get_messages_backend(request, user_profile)
|
||||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||||
sql = queries[0]["sql"]
|
sql = queries[0].sql
|
||||||
self.assertNotIn("AND message_id <=", sql)
|
self.assertNotIn("AND message_id <=", sql)
|
||||||
self.assertNotIn("AND message_id >=", sql)
|
self.assertNotIn("AND message_id >=", sql)
|
||||||
|
|
||||||
@@ -3595,20 +3595,20 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||||||
|
|
||||||
# Do some tests on the main query, to verify the muting logic
|
# Do some tests on the main query, to verify the muting logic
|
||||||
# runs on this code path.
|
# runs on this code path.
|
||||||
queries = [q for q in all_queries if str(q["sql"]).startswith("SELECT message_id, flags")]
|
queries = [q for q in all_queries if q.sql.startswith("SELECT message_id, flags")]
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
|
|
||||||
stream = get_stream("Scotland", realm)
|
stream = get_stream("Scotland", realm)
|
||||||
assert stream.recipient is not None
|
assert stream.recipient is not None
|
||||||
recipient_id = stream.recipient.id
|
recipient_id = stream.recipient.id
|
||||||
cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))"
|
cond = f"AND NOT (recipient_id = {recipient_id} AND upper(subject) = upper('golf'))"
|
||||||
self.assertIn(cond, queries[0]["sql"])
|
self.assertIn(cond, queries[0].sql)
|
||||||
|
|
||||||
# Next, verify the use_first_unread_anchor setting invokes
|
# Next, verify the use_first_unread_anchor setting invokes
|
||||||
# the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack.
|
# the `message_id = LARGER_THAN_MAX_MESSAGE_ID` hack.
|
||||||
queries = [q for q in all_queries if "/* get_messages */" in str(q["sql"])]
|
queries = [q for q in all_queries if "/* get_messages */" in q.sql]
|
||||||
self.assert_length(queries, 1)
|
self.assert_length(queries, 1)
|
||||||
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0]["sql"])
|
self.assertIn(f"AND zerver_message.id = {LARGER_THAN_MAX_MESSAGE_ID}", queries[0].sql)
|
||||||
|
|
||||||
def test_exclude_muting_conditions(self) -> None:
|
def test_exclude_muting_conditions(self) -> None:
|
||||||
realm = get_realm("zulip")
|
realm = get_realm("zulip")
|
||||||
|
|||||||
Reference in New Issue
Block a user