zerver: Migrate some files to typed_endpoint.

Migrates `invite.py`, `registration.py` and
`email_mirror.py` to use `typed_endpoint`.
This commit is contained in:
Kenneth Rodrigues
2024-07-14 23:09:20 +05:30
committed by Tim Abbott
parent 16abd82fa5
commit 6815cded83
7 changed files with 131 additions and 99 deletions

View File

@@ -1,11 +1,14 @@
import zoneinfo
from collections.abc import Collection from collections.abc import Collection
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import URLValidator from django.core.validators import URLValidator
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from pydantic import AfterValidator from pydantic import AfterValidator, BeforeValidator, NonNegativeInt
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
from zerver.lib.timezone import canonicalize_timezone
# The Pydantic.StringConstraints does not have validation for the string to be # The Pydantic.StringConstraints does not have validation for the string to be
# of the specified length. So, we need to create a custom validator for that. # of the specified length. So, we need to create a custom validator for that.
@@ -49,3 +52,34 @@ def check_url(val: str) -> str:
return val return val
except ValidationError: except ValidationError:
raise ValueError(_("Not a URL")) raise ValueError(_("Not a URL"))
def to_timezone_or_empty(s: str) -> str:
try:
s = canonicalize_timezone(s)
zoneinfo.ZoneInfo(s)
except (ValueError, zoneinfo.ZoneInfoNotFoundError):
return ""
else:
return s
def timezone_or_empty_validator() -> AfterValidator:
return AfterValidator(lambda s: to_timezone_or_empty(s))
def to_non_negative_int_or_none(s: str) -> NonNegativeInt | None:
try:
i = int(s)
if i < 0:
return None
return i
except ValueError:
return None
# We use BeforeValidator, not AfterValidator, here, because the int
# type conversion will raise a ValueError if the string is not a valid
# integer, and we want to return None in that case.
def non_negative_int_or_none_validator() -> BeforeValidator:
return BeforeValidator(lambda s: to_non_negative_int_or_none(s))

View File

@@ -30,7 +30,7 @@ for any particular type of object.
import re import re
import zoneinfo import zoneinfo
from collections.abc import Callable, Collection, Container, Iterator from collections.abc import Collection, Container, Iterator
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, NoReturn, TypeVar, cast, overload from typing import Any, NoReturn, TypeVar, cast, overload
@@ -589,28 +589,6 @@ def to_float(var_name: str, s: str) -> float:
return float(s) return float(s)
def to_timezone_or_empty(var_name: str, s: str) -> str:
try:
s = canonicalize_timezone(s)
zoneinfo.ZoneInfo(s)
except (ValueError, zoneinfo.ZoneInfoNotFoundError):
return ""
else:
return s
def to_converted_or_fallback(
sub_converter: Callable[[str, str], ResultT], default: ResultT
) -> Callable[[str, str], ResultT]:
def converter(var_name: str, s: str) -> ResultT:
try:
return sub_converter(var_name, s)
except ValueError:
return default
return converter
def check_string_or_int_list(var_name: str, val: object) -> str | list[int]: def check_string_or_int_list(var_name: str, val: object) -> str | list[int]:
if isinstance(val, str): if isinstance(val, str):
return val return val

View File

@@ -731,7 +731,9 @@ class InviteUserTest(InviteUserBase):
self.login("iago") self.login("iago")
invitee = self.nonreg_email("alice") invitee = self.nonreg_email("alice")
response = self.invite(invitee, ["Denmark"], invite_as=10) response = self.invite(invitee, ["Denmark"], invite_as=10)
self.assert_json_error(response, "Invalid invite_as") self.assert_json_error(
response, "Invalid invite_as: Value error, Not in the list of possible values"
)
def test_successful_invite_user_as_guest_from_normal_account(self) -> None: def test_successful_invite_user_as_guest_from_normal_account(self) -> None:
self.login("hamlet") self.login("hamlet")
@@ -2953,4 +2955,6 @@ class MultiuseInviteTest(ZulipTestCase):
"invite_expires_in_minutes": 2 * 24 * 60, "invite_expires_in_minutes": 2 * 24 * 60,
}, },
) )
self.assert_json_error(result, "Invalid invite_as") self.assert_json_error(
result, "Invalid invite_as: Value error, Not in the list of possible values"
)

View File

@@ -1,5 +1,10 @@
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url from zerver.lib.typed_endpoint_validators import (
check_int_in,
check_string_in,
check_url,
to_non_negative_int_or_none,
)
class ValidatorTestCase(ZulipTestCase): class ValidatorTestCase(ZulipTestCase):
@@ -17,3 +22,14 @@ class ValidatorTestCase(ZulipTestCase):
check_url("https://example.com") check_url("https://example.com")
with self.assertRaisesRegex(ValueError, "Not a URL"): with self.assertRaisesRegex(ValueError, "Not a URL"):
check_url("https://127.0.0..:5000") check_url("https://127.0.0..:5000")
def test_to_non_negative_int_or_none(self) -> None:
self.assertEqual(to_non_negative_int_or_none("3"), 3)
self.assertEqual(to_non_negative_int_or_none("-3"), None)
self.assertEqual(to_non_negative_int_or_none("a"), None)
self.assertEqual(to_non_negative_int_or_none("3.5"), None)
self.assertEqual(to_non_negative_int_or_none("3.0"), None)
self.assertEqual(to_non_negative_int_or_none("3.1"), None)
self.assertEqual(to_non_negative_int_or_none("3.9"), None)
self.assertEqual(to_non_negative_int_or_none("3.5"), None)
self.assertEqual(to_non_negative_int_or_none("foo"), None)

View File

@@ -3,16 +3,17 @@ from django.http import HttpRequest, HttpResponse
from zerver.decorator import internal_api_view from zerver.decorator import internal_api_view
from zerver.lib.email_mirror import mirror_email_message from zerver.lib.email_mirror import mirror_email_message
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint
@internal_api_view(False) @internal_api_view(False)
@has_request_variables @typed_endpoint
def email_mirror_message( def email_mirror_message(
request: HttpRequest, request: HttpRequest,
rcpt_to: str = REQ(), *,
msg_base64: str = REQ(), rcpt_to: str,
msg_base64: str,
) -> HttpResponse: ) -> HttpResponse:
result = mirror_email_message(rcpt_to, msg_base64) result = mirror_email_message(rcpt_to, msg_base64)
if result["status"] == "error": if result["status"] == "error":

View File

@@ -1,10 +1,11 @@
import re import re
from collections.abc import Sequence from typing import Annotated
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from pydantic import Json
from confirmation import settings as confirmation_settings from confirmation import settings as confirmation_settings
from zerver.actions.invites import ( from zerver.actions.invites import (
@@ -17,10 +18,10 @@ from zerver.actions.invites import (
) )
from zerver.decorator import require_member_or_admin from zerver.decorator import require_member_or_admin
from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.streams import access_stream_by_id from zerver.lib.streams import access_stream_by_id
from zerver.lib.validator import check_bool, check_int, check_int_in, check_list, check_none_or from zerver.lib.typed_endpoint import ApiParamConfig, PathOnly, typed_endpoint
from zerver.lib.typed_endpoint_validators import check_int_in_validator
from zerver.models import MultiuseInvite, PreregistrationUser, Stream, UserProfile from zerver.models import MultiuseInvite, PreregistrationUser, Stream, UserProfile
# Convert INVITATION_LINK_VALIDITY_DAYS into minutes. # Convert INVITATION_LINK_VALIDITY_DAYS into minutes.
@@ -44,25 +45,20 @@ def check_role_based_permissions(
@require_member_or_admin @require_member_or_admin
@has_request_variables @typed_endpoint
def invite_users_backend( def invite_users_backend(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
invitee_emails_raw: str = REQ("invitee_emails"), *,
invite_expires_in_minutes: int | None = REQ( invitee_emails_raw: Annotated[str, ApiParamConfig("invitee_emails")],
json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES,
), invite_as: Annotated[
invite_as: int = REQ( Json[int],
json_validator=check_int_in( check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())),
list(PreregistrationUser.INVITE_AS.values()), ] = PreregistrationUser.INVITE_AS["MEMBER"],
), notify_referrer_on_join: Json[bool] = True,
default=PreregistrationUser.INVITE_AS["MEMBER"], stream_ids: Json[list[int]],
), include_realm_default_subscriptions: Json[bool] = False,
notify_referrer_on_join: bool = REQ(
"notify_referrer_on_join", json_validator=check_bool, default=True
),
stream_ids: list[int] = REQ(json_validator=check_list(check_int)),
include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False),
) -> HttpResponse: ) -> HttpResponse:
if not user_profile.can_invite_users_by_email(): if not user_profile.can_invite_users_by_email():
# Guest users case will not be handled here as it will # Guest users case will not be handled here as it will
@@ -140,9 +136,9 @@ def get_user_invites(request: HttpRequest, user_profile: UserProfile) -> HttpRes
@require_member_or_admin @require_member_or_admin
@has_request_variables @typed_endpoint
def revoke_user_invite( def revoke_user_invite(
request: HttpRequest, user_profile: UserProfile, invite_id: int request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse: ) -> HttpResponse:
try: try:
prereg_user = PreregistrationUser.objects.get(id=invite_id) prereg_user = PreregistrationUser.objects.get(id=invite_id)
@@ -160,9 +156,9 @@ def revoke_user_invite(
@require_member_or_admin @require_member_or_admin
@has_request_variables @typed_endpoint
def revoke_multiuse_invite( def revoke_multiuse_invite(
request: HttpRequest, user_profile: UserProfile, invite_id: int request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse: ) -> HttpResponse:
try: try:
invite = MultiuseInvite.objects.get(id=invite_id) invite = MultiuseInvite.objects.get(id=invite_id)
@@ -183,9 +179,9 @@ def revoke_multiuse_invite(
@require_member_or_admin @require_member_or_admin
@has_request_variables @typed_endpoint
def resend_user_invite_email( def resend_user_invite_email(
request: HttpRequest, user_profile: UserProfile, invite_id: int request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse: ) -> HttpResponse:
try: try:
prereg_user = PreregistrationUser.objects.get(id=invite_id) prereg_user = PreregistrationUser.objects.get(id=invite_id)
@@ -205,22 +201,21 @@ def resend_user_invite_email(
@require_member_or_admin @require_member_or_admin
@has_request_variables @typed_endpoint
def generate_multiuse_invite_backend( def generate_multiuse_invite_backend(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
invite_expires_in_minutes: int | None = REQ( *,
json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES,
), invite_as: Annotated[
invite_as: int = REQ( Json[int],
json_validator=check_int_in( check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())),
list(PreregistrationUser.INVITE_AS.values()), ] = PreregistrationUser.INVITE_AS["MEMBER"],
), stream_ids: Json[list[int]] | None = None,
default=PreregistrationUser.INVITE_AS["MEMBER"], include_realm_default_subscriptions: Json[bool] = False,
),
stream_ids: Sequence[int] = REQ(json_validator=check_list(check_int), default=[]),
include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False),
) -> HttpResponse: ) -> HttpResponse:
if stream_ids is None:
stream_ids = []
if not user_profile.can_create_multiuse_invite_to_realm(): if not user_profile.can_create_multiuse_invite_to_realm():
# Guest users case will not be handled here as it will # Guest users case will not be handled here as it will
# be handled by the decorator above. # be handled by the decorator above.

View File

@@ -1,7 +1,7 @@
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import suppress from contextlib import suppress
from typing import Any from typing import Annotated, Any
from urllib.parse import urlencode, urljoin from urllib.parse import urlencode, urljoin
import orjson import orjson
@@ -19,6 +19,7 @@ from django.urls import reverse
from django.utils.translation import get_language from django.utils.translation import get_language
from django.views.defaults import server_error from django.views.defaults import server_error
from django_auth_ldap.backend import LDAPBackend, _LDAPUser from django_auth_ldap.backend import LDAPBackend, _LDAPUser
from pydantic import Json, NonNegativeInt, StringConstraints
from confirmation.models import ( from confirmation.models import (
Confirmation, Confirmation,
@@ -59,19 +60,22 @@ from zerver.lib.i18n import (
) )
from zerver.lib.pysa import mark_sanitized from zerver.lib.pysa import mark_sanitized
from zerver.lib.rate_limiter import rate_limit_request_by_ip from zerver.lib.rate_limiter import rate_limit_request_by_ip
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.send_email import EmailNotDeliveredError, FromAddress, send_email from zerver.lib.send_email import EmailNotDeliveredError, FromAddress, send_email
from zerver.lib.sessions import get_expirable_session_var from zerver.lib.sessions import get_expirable_session_var
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
from zerver.lib.typed_endpoint import (
ApiParamConfig,
PathOnly,
typed_endpoint,
typed_endpoint_without_parameters,
)
from zerver.lib.typed_endpoint_validators import (
check_int_in_validator,
non_negative_int_or_none_validator,
timezone_or_empty_validator,
)
from zerver.lib.url_encoding import append_url_query_string from zerver.lib.url_encoding import append_url_query_string
from zerver.lib.users import get_accounts_for_email from zerver.lib.users import get_accounts_for_email
from zerver.lib.validator import (
check_capped_string,
check_int_in,
to_converted_or_fallback,
to_non_negative_int,
to_timezone_or_empty,
)
from zerver.lib.zephyr import compute_mit_user_fullname from zerver.lib.zephyr import compute_mit_user_fullname
from zerver.models import ( from zerver.models import (
MultiuseInvite, MultiuseInvite,
@@ -119,9 +123,9 @@ if settings.BILLING_ENABLED:
from corporate.lib.stripe import LicenseLimitError from corporate.lib.stripe import LicenseLimitError
@has_request_variables @typed_endpoint
def get_prereg_key_and_redirect( def get_prereg_key_and_redirect(
request: HttpRequest, confirmation_key: str, full_name: str | None = REQ(default=None) request: HttpRequest, *, confirmation_key: PathOnly[str], full_name: str | None = None
) -> HttpResponse: ) -> HttpResponse:
""" """
The purpose of this little endpoint is primarily to take a GET The purpose of this little endpoint is primarily to take a GET
@@ -223,17 +227,16 @@ def accounts_register(*args: Any, **kwargs: Any) -> HttpResponse:
return registration_helper(*args, **kwargs) return registration_helper(*args, **kwargs)
@has_request_variables @typed_endpoint
def registration_helper( def registration_helper(
request: HttpRequest, request: HttpRequest,
key: str = REQ(default=""), *,
timezone: str = REQ(default="", converter=to_timezone_or_empty), key: str = "",
from_confirmation: str | None = REQ(default=None), timezone: Annotated[str, timezone_or_empty_validator()] = "",
form_full_name: str | None = REQ("full_name", default=None), from_confirmation: str | None = None,
source_realm_id: int | None = REQ( form_full_name: Annotated[str | None, ApiParamConfig("full_name")] = None,
default=None, converter=to_converted_or_fallback(to_non_negative_int, None) source_realm_id: Annotated[NonNegativeInt | None, non_negative_int_or_none_validator()] = None,
), form_is_demo_organization: Annotated[str | None, ApiParamConfig("is_demo_organization")] = None,
form_is_demo_organization: str | None = REQ("is_demo_organization", default=None),
) -> HttpResponse: ) -> HttpResponse:
try: try:
prereg_object, realm_creation = check_prereg_key(request, key) prereg_object, realm_creation = check_prereg_key(request, key)
@@ -958,8 +961,8 @@ def create_realm(request: HttpRequest, creation_key: str | None = None) -> HttpR
) )
@has_request_variables @typed_endpoint
def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> HttpResponse: def signup_send_confirm(request: HttpRequest, *, email: str) -> HttpResponse:
try: try:
# Because we interpolate the email directly into the template # Because we interpolate the email directly into the template
# from the query parameter, do a simple validation that it # from the query parameter, do a simple validation that it
@@ -980,14 +983,15 @@ def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> Http
@add_google_analytics @add_google_analytics
@has_request_variables @typed_endpoint
def new_realm_send_confirm( def new_realm_send_confirm(
request: HttpRequest, request: HttpRequest,
email: str = REQ("email"), *,
realm_name: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_NAME_LENGTH)), email: str,
realm_type: int = REQ(json_validator=check_int_in(Realm.ORG_TYPE_IDS)), realm_name: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_NAME_LENGTH)],
realm_default_language: str = REQ(str_validator=check_capped_string(MAX_LANGUAGE_ID_LENGTH)), realm_type: Annotated[Json[int], check_int_in_validator(Realm.ORG_TYPE_IDS)],
realm_subdomain: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_SUBDOMAIN_LENGTH)), realm_default_language: Annotated[str, StringConstraints(max_length=MAX_LANGUAGE_ID_LENGTH)],
realm_subdomain: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_SUBDOMAIN_LENGTH)],
) -> HttpResponse: ) -> HttpResponse:
return TemplateResponse( return TemplateResponse(
request, request,
@@ -1113,7 +1117,7 @@ def accounts_home_from_multiuse_invite(request: HttpRequest, confirmation_key: s
) )
@has_request_variables @typed_endpoint_without_parameters
def find_account(request: HttpRequest) -> HttpResponse: def find_account(request: HttpRequest) -> HttpResponse:
url = reverse("find_account") url = reverse("find_account")
form = FindMyTeamForm() form = FindMyTeamForm()
@@ -1217,8 +1221,8 @@ def find_account(request: HttpRequest) -> HttpResponse:
) )
@has_request_variables @typed_endpoint
def realm_redirect(request: HttpRequest, next: str = REQ(default="")) -> HttpResponse: def realm_redirect(request: HttpRequest, *, next: str = "") -> HttpResponse:
if request.method == "POST": if request.method == "POST":
form = RealmRedirectForm(request.POST) form = RealmRedirectForm(request.POST)
if form.is_valid(): if form.is_valid():