validator: Disable WildValue equality comparison.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2025-06-17 14:46:00 -07:00
committed by Tim Abbott
parent 2747127e6c
commit 10705e0db3
8 changed files with 28 additions and 19 deletions

View File

@@ -569,7 +569,7 @@ def check_string_or_int(var_name: str, val: object) -> str | int:
raise ValidationError(_("{var_name} is not a string or integer").format(var_name=var_name)) raise ValidationError(_("{var_name} is not a string or integer").format(var_name=var_name))
@dataclass @dataclass(eq=False)
class WildValue: class WildValue:
var_name: str var_name: str
value: object value: object
@@ -590,7 +590,7 @@ class WildValue:
@override @override
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
return self.value == other raise TypeError("cannot compare WildValue")
def __len__(self) -> int: def __len__(self) -> int:
if not isinstance(self.value, dict | list | str): if not isinstance(self.value, dict | list | str):

View File

@@ -7,6 +7,7 @@ from zerver.lib.exceptions import InvalidJSONError
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.types import Validator from zerver.lib.types import Validator
from zerver.lib.validator import ( from zerver.lib.validator import (
check_anything,
check_bool, check_bool,
check_capped_string, check_capped_string,
check_dict, check_dict,
@@ -323,19 +324,24 @@ class ValidatorTestCase(ZulipTestCase):
def test_wild_value(self) -> None: def test_wild_value(self) -> None:
x = to_wild_value("x", '{"a": 1, "b": ["c", false, null]}') x = to_wild_value("x", '{"a": 1, "b": ["c", false, null]}')
self.assertEqual(x, x) with self.assertRaisesRegex(TypeError, r"^cannot compare WildValue$"):
self.assertEqual(x, x)
self.assertTrue(x) self.assertTrue(x)
self.assertEqual(len(x), 2) self.assertEqual(len(x), 2)
self.assertEqual(list(x.keys()), ["a", "b"]) self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(list(x.values()), [1, ["c", False, None]]) self.assertEqual([v.tame(check_anything) for v in x.values()], [1, ["c", False, None]])
self.assertEqual(list(x.items()), [("a", 1), ("b", ["c", False, None])]) self.assertEqual(
[(k, v.tame(check_anything)) for k, v in x.items()],
[("a", 1), ("b", ["c", False, None])],
)
self.assertTrue("a" in x) self.assertTrue("a" in x)
self.assertEqual(x["a"], 1) self.assertEqual(x["a"].tame(check_int), 1)
self.assertEqual(x.get("a"), 1) self.assertEqual(x.get("a").tame(check_int), 1)
self.assertEqual(x.get("z"), None) self.assertEqual(x.get("z").tame(check_none_or(check_int)), None)
self.assertEqual(x.get("z", x["a"]).tame(check_int), 1) self.assertEqual(x.get("z", x["a"]).tame(check_int), 1)
self.assertEqual(x["a"].tame(check_int), 1) self.assertEqual(x["a"].tame(check_int), 1)
self.assertEqual(x["b"], x["b"]) self.assertEqual(x["b"].tame(check_anything), x["b"].tame(check_anything))
self.assertTrue(x["b"]) self.assertTrue(x["b"])
self.assertEqual(len(x["b"]), 3) self.assertEqual(len(x["b"]), 3)
self.assert_length(list(x["b"]), 3) self.assert_length(list(x["b"]), 3)

View File

@@ -62,7 +62,7 @@ def get_body(payload: WildValue) -> str:
} }
) )
if payload["status"] == "successful": if payload["status"].tame(check_string) == "successful":
status = "was successful" status = "was successful"
else: else:
status = "failed" status = "failed"

View File

@@ -78,7 +78,10 @@ def api_circleci_webhook(
# We currently don't support projects using VCS providers other than GitHub, # We currently don't support projects using VCS providers other than GitHub,
# BitBucket and GitLab. # BitBucket and GitLab.
pipeline = payload["pipeline"] pipeline = payload["pipeline"]
if "trigger_parameters" in pipeline and pipeline["trigger"]["type"] != "gitlab": if (
"trigger_parameters" in pipeline
and pipeline["trigger"]["type"].tame(check_string) != "gitlab"
):
raise JsonableError( raise JsonableError(
_("Projects using this version control system provider aren't supported") _("Projects using this version control system provider aren't supported")
) # nocoverage ) # nocoverage
@@ -102,7 +105,7 @@ def get_commit_details(payload: WildValue) -> str:
revision = payload["pipeline"]["vcs"]["revision"].tame(check_string) revision = payload["pipeline"]["vcs"]["revision"].tame(check_string)
commit_id = get_short_sha(revision) commit_id = get_short_sha(revision)
if payload["pipeline"]["vcs"]["provider_name"] == "github": if payload["pipeline"]["vcs"]["provider_name"].tame(check_string) == "github":
commit_link = GITHUB_COMMIT_LINK.format( commit_link = GITHUB_COMMIT_LINK.format(
target_repository_url=payload["pipeline"]["vcs"]["target_repository_url"].tame( target_repository_url=payload["pipeline"]["vcs"]["target_repository_url"].tame(
check_url check_url

View File

@@ -77,7 +77,7 @@ STORY_UPDATE_BATCH_ADD_REMOVE_TEMPLATE = "{operation} with {entity}"
def get_action_with_primary_id(payload: WildValue) -> WildValue: def get_action_with_primary_id(payload: WildValue) -> WildValue:
for action in payload["actions"]: for action in payload["actions"]:
if payload["primary_id"] == action["id"]: if payload["primary_id"].tame(check_int) == action["id"].tame(check_int):
action_with_primary_id = action action_with_primary_id = action
return action_with_primary_id return action_with_primary_id
@@ -176,9 +176,9 @@ def get_comment_added_body(entity: str, payload: WildValue, ignored_action: Wild
actions = payload["actions"] actions = payload["actions"]
kwargs = {"entity": entity} kwargs = {"entity": entity}
for action in actions: for action in actions:
if action["id"] == payload["primary_id"]: if action["id"].tame(check_int) == payload["primary_id"].tame(check_int):
kwargs["text"] = action["text"].tame(check_string) kwargs["text"] = action["text"].tame(check_string)
elif action["entity_type"] == entity: elif action["entity_type"].tame(check_string) == entity:
name_template = get_name_template(entity).format( name_template = get_name_template(entity).format(
name=action["name"].tame(check_string), name=action["name"].tame(check_string),
app_url=action.get("app_url").tame(check_none_or(check_string)), app_url=action.get("app_url").tame(check_none_or(check_string)),
@@ -663,7 +663,7 @@ def get_story_update_batch_body(payload: WildValue, action: WildValue) -> str |
def get_entity_name(entity: str, payload: WildValue, action: WildValue) -> str | None: def get_entity_name(entity: str, payload: WildValue, action: WildValue) -> str | None:
name = action["name"].tame(check_string) if "name" in action else None name = action["name"].tame(check_string) if "name" in action else None
if name is None or action["entity_type"] == "branch": if name is None or action["entity_type"].tame(check_string) == "branch":
for other_action in payload["actions"]: for other_action in payload["actions"]:
if other_action["entity_type"].tame(check_string) == entity: if other_action["entity_type"].tame(check_string) == entity:
name = other_action["name"].tame(check_string) name = other_action["name"].tame(check_string)

View File

@@ -25,7 +25,7 @@ def api_crashlytics_webhook(
*, *,
payload: JsonBodyPayload[WildValue], payload: JsonBodyPayload[WildValue],
) -> HttpResponse: ) -> HttpResponse:
event = payload["event"] event = payload["event"].tame(check_string)
if event == VERIFICATION_EVENT: if event == VERIFICATION_EVENT:
topic_name = CRASHLYTICS_SETUP_TOPIC_TEMPLATE topic_name = CRASHLYTICS_SETUP_TOPIC_TEMPLATE
body = CRASHLYTICS_SETUP_MESSAGE_TEMPLATE body = CRASHLYTICS_SETUP_MESSAGE_TEMPLATE

View File

@@ -225,7 +225,7 @@ def handle_updated_issue_event(payload: WildValue, user_profile: UserProfile) ->
else: else:
verb = "deleted a comment from" verb = "deleted a comment from"
if payload.get("webhookEvent") == "comment_created": if payload.get("webhookEvent").tame(check_none_or(check_string)) == "comment_created":
author = payload["comment"]["author"]["displayName"].tame(check_string) author = payload["comment"]["author"]["displayName"].tame(check_string)
else: else:
author = get_issue_author(payload) author = get_issue_author(payload)

View File

@@ -271,7 +271,7 @@ def api_slack_webhook(
# for how to add support for this type of payload. # for how to add support for this type of payload.
raise UnsupportedWebhookEventTypeError( raise UnsupportedWebhookEventTypeError(
"integration bot message" "integration bot message"
if event_dict["subtype"] == "bot_message" if event_dict["subtype"].tame(check_string) == "bot_message"
else "unknown Slack event" else "unknown Slack event"
) )
sender = get_slack_sender_name(user_id, slack_app_token) sender = get_slack_sender_name(user_id, slack_app_token)