validator: Disable WildValue equality comparison.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2025-06-17 14:46:00 -07:00
committed by Tim Abbott
parent 2747127e6c
commit 10705e0db3
8 changed files with 28 additions and 19 deletions

View File

@@ -569,7 +569,7 @@ def check_string_or_int(var_name: str, val: object) -> str | int:
raise ValidationError(_("{var_name} is not a string or integer").format(var_name=var_name))
@dataclass
@dataclass(eq=False)
class WildValue:
var_name: str
value: object
@@ -590,7 +590,7 @@ class WildValue:
@override
def __eq__(self, other: object) -> bool:
return self.value == other
raise TypeError("cannot compare WildValue")
def __len__(self) -> int:
if not isinstance(self.value, dict | list | str):

View File

@@ -7,6 +7,7 @@ from zerver.lib.exceptions import InvalidJSONError
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.types import Validator
from zerver.lib.validator import (
check_anything,
check_bool,
check_capped_string,
check_dict,
@@ -323,19 +324,24 @@ class ValidatorTestCase(ZulipTestCase):
def test_wild_value(self) -> None:
x = to_wild_value("x", '{"a": 1, "b": ["c", false, null]}')
self.assertEqual(x, x)
with self.assertRaisesRegex(TypeError, r"^cannot compare WildValue$"):
self.assertEqual(x, x)
self.assertTrue(x)
self.assertEqual(len(x), 2)
self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(list(x.values()), [1, ["c", False, None]])
self.assertEqual(list(x.items()), [("a", 1), ("b", ["c", False, None])])
self.assertEqual([v.tame(check_anything) for v in x.values()], [1, ["c", False, None]])
self.assertEqual(
[(k, v.tame(check_anything)) for k, v in x.items()],
[("a", 1), ("b", ["c", False, None])],
)
self.assertTrue("a" in x)
self.assertEqual(x["a"], 1)
self.assertEqual(x.get("a"), 1)
self.assertEqual(x.get("z"), None)
self.assertEqual(x["a"].tame(check_int), 1)
self.assertEqual(x.get("a").tame(check_int), 1)
self.assertEqual(x.get("z").tame(check_none_or(check_int)), None)
self.assertEqual(x.get("z", x["a"]).tame(check_int), 1)
self.assertEqual(x["a"].tame(check_int), 1)
self.assertEqual(x["b"], x["b"])
self.assertEqual(x["b"].tame(check_anything), x["b"].tame(check_anything))
self.assertTrue(x["b"])
self.assertEqual(len(x["b"]), 3)
self.assert_length(list(x["b"]), 3)

View File

@@ -62,7 +62,7 @@ def get_body(payload: WildValue) -> str:
}
)
if payload["status"] == "successful":
if payload["status"].tame(check_string) == "successful":
status = "was successful"
else:
status = "failed"

View File

@@ -78,7 +78,10 @@ def api_circleci_webhook(
# We currently don't support projects using VCS providers other than GitHub,
# BitBucket and GitLab.
pipeline = payload["pipeline"]
if "trigger_parameters" in pipeline and pipeline["trigger"]["type"] != "gitlab":
if (
"trigger_parameters" in pipeline
and pipeline["trigger"]["type"].tame(check_string) != "gitlab"
):
raise JsonableError(
_("Projects using this version control system provider aren't supported")
) # nocoverage
@@ -102,7 +105,7 @@ def get_commit_details(payload: WildValue) -> str:
revision = payload["pipeline"]["vcs"]["revision"].tame(check_string)
commit_id = get_short_sha(revision)
if payload["pipeline"]["vcs"]["provider_name"] == "github":
if payload["pipeline"]["vcs"]["provider_name"].tame(check_string) == "github":
commit_link = GITHUB_COMMIT_LINK.format(
target_repository_url=payload["pipeline"]["vcs"]["target_repository_url"].tame(
check_url

View File

@@ -77,7 +77,7 @@ STORY_UPDATE_BATCH_ADD_REMOVE_TEMPLATE = "{operation} with {entity}"
def get_action_with_primary_id(payload: WildValue) -> WildValue:
for action in payload["actions"]:
if payload["primary_id"] == action["id"]:
if payload["primary_id"].tame(check_int) == action["id"].tame(check_int):
action_with_primary_id = action
return action_with_primary_id
@@ -176,9 +176,9 @@ def get_comment_added_body(entity: str, payload: WildValue, ignored_action: Wild
actions = payload["actions"]
kwargs = {"entity": entity}
for action in actions:
if action["id"] == payload["primary_id"]:
if action["id"].tame(check_int) == payload["primary_id"].tame(check_int):
kwargs["text"] = action["text"].tame(check_string)
elif action["entity_type"] == entity:
elif action["entity_type"].tame(check_string) == entity:
name_template = get_name_template(entity).format(
name=action["name"].tame(check_string),
app_url=action.get("app_url").tame(check_none_or(check_string)),
@@ -663,7 +663,7 @@ def get_story_update_batch_body(payload: WildValue, action: WildValue) -> str |
def get_entity_name(entity: str, payload: WildValue, action: WildValue) -> str | None:
name = action["name"].tame(check_string) if "name" in action else None
if name is None or action["entity_type"] == "branch":
if name is None or action["entity_type"].tame(check_string) == "branch":
for other_action in payload["actions"]:
if other_action["entity_type"].tame(check_string) == entity:
name = other_action["name"].tame(check_string)

View File

@@ -25,7 +25,7 @@ def api_crashlytics_webhook(
*,
payload: JsonBodyPayload[WildValue],
) -> HttpResponse:
event = payload["event"]
event = payload["event"].tame(check_string)
if event == VERIFICATION_EVENT:
topic_name = CRASHLYTICS_SETUP_TOPIC_TEMPLATE
body = CRASHLYTICS_SETUP_MESSAGE_TEMPLATE

View File

@@ -225,7 +225,7 @@ def handle_updated_issue_event(payload: WildValue, user_profile: UserProfile) ->
else:
verb = "deleted a comment from"
if payload.get("webhookEvent") == "comment_created":
if payload.get("webhookEvent").tame(check_none_or(check_string)) == "comment_created":
author = payload["comment"]["author"]["displayName"].tame(check_string)
else:
author = get_issue_author(payload)

View File

@@ -271,7 +271,7 @@ def api_slack_webhook(
# for how to add support for this type of payload.
raise UnsupportedWebhookEventTypeError(
"integration bot message"
if event_dict["subtype"] == "bot_message"
if event_dict["subtype"].tame(check_string) == "bot_message"
else "unknown Slack event"
)
sender = get_slack_sender_name(user_id, slack_app_token)