decorator: Extract public_json_view.

This refactoring is necessary to separate the expected type annotation
for view functions with different authentication methods. Currently the
signature aren't actually check against view functions because
`rest_path` does not support type checking parameter types, but it will
become useful once we do.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-08-01 14:46:23 -04:00
committed by Tim Abbott
parent 299f3442ff
commit f54ecad6cd
3 changed files with 51 additions and 12 deletions

View File

@@ -836,13 +836,11 @@ def process_as_post(
return _wrapped_view_func
# Checks if the user is logged in. If not, return an error (the
# @login_required behavior of redirecting to a login page doesn't make
# sense for json views)
def authenticated_json_view(
view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse],
def public_json_view(
view_func: Callable[
Concatenate[HttpRequest, Union[UserProfile, AnonymousUser], ParamT], HttpResponse
],
skip_rate_limiting: bool = False,
allow_unauthenticated: bool = False,
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
@wraps(view_func)
def _wrapped_view_func(
@@ -855,9 +853,6 @@ def authenticated_json_view(
rate_limit(request)
if not request.user.is_authenticated:
if not allow_unauthenticated:
raise UnauthorizedError()
process_client(
request,
is_browser_view=True,
@@ -865,6 +860,33 @@ def authenticated_json_view(
)
return view_func(request, request.user, *args, **kwargs)
# Fall back to authenticated_json_view if the user is authenticated.
# Since we have done rate limiting earlier is no need to do it again.
return authenticated_json_view(view_func, skip_rate_limiting=True)(request, *args, **kwargs)
return _wrapped_view_func
# Checks if the user is logged in. If not, return an error (the
# @login_required behavior of redirecting to a login page doesn't make
# sense for json views)
def authenticated_json_view(
view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse],
skip_rate_limiting: bool = False,
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:
@wraps(view_func)
def _wrapped_view_func(
request: HttpRequest,
/,
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> HttpResponse:
if not skip_rate_limiting:
rate_limit(request)
if not request.user.is_authenticated:
raise UnauthorizedError()
user_profile = request.user
validate_account_and_subdomain(request, user_profile)

View File

@@ -12,6 +12,7 @@ from zerver.decorator import (
authenticated_rest_api_view,
authenticated_uploads_api_view,
process_as_post,
public_json_view,
)
from zerver.lib.exceptions import MissingAuthenticationError
from zerver.lib.request import RequestNotes
@@ -150,8 +151,7 @@ def rest_dispatch(request: HttpRequest, **kwargs: Any) -> HttpResponse:
):
# For endpoints that support anonymous web access, we do that.
# TODO: Allow /api calls when this is stable enough.
auth_kwargs = dict(allow_unauthenticated=True)
target_function = csrf_protect(authenticated_json_view(target_function, **auth_kwargs))
target_function = csrf_protect(public_json_view(target_function))
else:
# Otherwise, throw an authentication error; our middleware
# will generate the appropriate HTTP response.

View File

@@ -4,11 +4,12 @@ import re
import uuid
from collections import defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from unittest import mock, skipUnless
import orjson
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now
@@ -27,6 +28,7 @@ from zerver.decorator import (
authenticated_rest_api_view,
authenticated_uploads_api_view,
internal_notify_view,
public_json_view,
return_success_on_head_request,
validate_api_key,
webhook_view,
@@ -1837,6 +1839,21 @@ class TestAuthenticatedJsonViewDecorator(ZulipTestCase):
return self.client_post(r"/accounts/webathena_kerberos_login/", data)
class TestPublicJsonViewDecorator(ZulipTestCase):
def test_access_public_json_view_when_logged_in(self) -> None:
hamlet = self.example_user("hamlet")
@public_json_view
def public_view(
request: HttpRequest, maybe_user_profile: Union[UserProfile, AnonymousUser]
) -> HttpResponse:
self.assertEqual(maybe_user_profile, hamlet)
return json_success(request)
result = public_view(HostRequestMock(host="zulip.testserver", user_profile=hamlet))
self.assert_json_success(result)
class TestZulipLoginRequiredDecorator(ZulipTestCase):
def test_zulip_login_required_if_subdomain_is_invalid(self) -> None:
self.login("hamlet")