typing: Fix misuse of HttpResponse.

Amend usage of HttpResponse when appropriate.
This commit is contained in:
PIG208
2021-07-26 22:29:19 +08:00
parent ad8d9f2133
commit 8121d2d58d
7 changed files with 38 additions and 41 deletions

View File

@@ -30,7 +30,7 @@ class TestSupportEndpoint(ZulipTestCase):
reset_emails_in_zulip_realm() reset_emails_in_zulip_realm()
def assert_user_details_in_html_response( def assert_user_details_in_html_response(
html_response: str, full_name: str, email: str, role: str html_response: HttpResponse, full_name: str, email: str, role: str
) -> None: ) -> None:
self.assert_in_success_response( self.assert_in_success_response(
[ [

View File

@@ -53,7 +53,7 @@ def render_stats(
for_installation: bool = False, for_installation: bool = False,
remote: bool = False, remote: bool = False,
analytics_ready: bool = True, analytics_ready: bool = True,
) -> HttpRequest: ) -> HttpResponse:
page_params = dict( page_params = dict(
data_url_suffix=data_url_suffix, data_url_suffix=data_url_suffix,
for_installation=for_installation, for_installation=for_installation,

View File

@@ -379,7 +379,7 @@ def webhook_view(
# the subdomain validation happen elsewhere and switch to using the # the subdomain validation happen elsewhere and switch to using the
# stock Django version. # stock Django version.
def user_passes_test( def user_passes_test(
test_func: Callable[[HttpResponse], bool], test_func: Callable[[HttpRequest], bool],
login_url: Optional[str] = None, login_url: Optional[str] = None,
redirect_field_name: str = REDIRECT_FIELD_NAME, redirect_field_name: str = REDIRECT_FIELD_NAME,
) -> Callable[[ViewFuncT], ViewFuncT]: ) -> Callable[[ViewFuncT], ViewFuncT]:

View File

@@ -2,24 +2,13 @@ import cProfile
import logging import logging
import time import time
import traceback import traceback
from typing import ( from typing import Any, AnyStr, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple
Any,
AnyStr,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
Tuple,
Union,
)
from django.conf import settings from django.conf import settings
from django.conf.urls.i18n import is_language_prefix_patterns_used from django.conf.urls.i18n import is_language_prefix_patterns_used
from django.core.handlers.wsgi import WSGIRequest
from django.db import connection from django.db import connection
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, StreamingHttpResponse from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, StreamingHttpResponse
from django.http.response import HttpResponseBase
from django.middleware.common import CommonMiddleware from django.middleware.common import CommonMiddleware
from django.middleware.locale import LocaleMiddleware as DjangoLocaleMiddleware from django.middleware.locale import LocaleMiddleware as DjangoLocaleMiddleware
from django.shortcuts import render from django.shortcuts import render
@@ -428,9 +417,7 @@ class LogRequests(MiddlewareMixin):
class JsonErrorHandler(MiddlewareMixin): class JsonErrorHandler(MiddlewareMixin):
def __init__( def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self, get_response: Callable[[Any, WSGIRequest], Union[HttpResponse, BaseException]]
) -> None:
super().__init__(get_response) super().__init__(get_response)
ignore_logger("zerver.middleware.json_error_handler") ignore_logger("zerver.middleware.json_error_handler")
@@ -498,7 +485,9 @@ def csrf_failure(request: HttpRequest, reason: str = "") -> HttpResponse:
class LocaleMiddleware(DjangoLocaleMiddleware): class LocaleMiddleware(DjangoLocaleMiddleware):
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: def process_response(
self, request: HttpRequest, response: HttpResponseBase
) -> HttpResponseBase:
# This is the same as the default LocaleMiddleware, minus the # This is the same as the default LocaleMiddleware, minus the
# logic that redirects 404's that lack a prefixed language in # logic that redirects 404's that lack a prefixed language in

View File

@@ -262,17 +262,19 @@ class DecoratorTestCase(ZulipTestCase):
def test_webhook_view(self) -> None: def test_webhook_view(self) -> None:
@webhook_view("ClientName") @webhook_view("ClientName")
def my_webhook(request: HttpRequest, user_profile: UserProfile) -> str: def my_webhook(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return user_profile.email return json_response(msg=user_profile.email)
@webhook_view("ClientName") @webhook_view("ClientName")
def my_webhook_raises_exception(request: HttpRequest, user_profile: UserProfile) -> None: def my_webhook_raises_exception(
request: HttpRequest, user_profile: UserProfile
) -> HttpResponse:
raise Exception("raised by webhook function") raise Exception("raised by webhook function")
@webhook_view("ClientName") @webhook_view("ClientName")
def my_webhook_raises_exception_unsupported_event( def my_webhook_raises_exception_unsupported_event(
request: HttpRequest, user_profile: UserProfile request: HttpRequest, user_profile: UserProfile
) -> None: ) -> HttpResponse:
raise UnsupportedWebhookEventType("test_event") raise UnsupportedWebhookEventType("test_event")
webhook_bot_email = "webhook-bot@zulip.com" webhook_bot_email = "webhook-bot@zulip.com"
@@ -367,7 +369,7 @@ class DecoratorTestCase(ZulipTestCase):
with self.settings(RATE_LIMITING=True): with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
api_result = my_webhook(request) api_result = orjson.loads(my_webhook(request).content).get("msg")
# Verify rate limiting was attempted. # Verify rate limiting was attempted.
self.assertTrue(rate_limit_mock.called) self.assertTrue(rate_limit_mock.called)
@@ -393,11 +395,11 @@ class DecoratorTestCase(ZulipTestCase):
class SkipRateLimitingTest(ZulipTestCase): class SkipRateLimitingTest(ZulipTestCase):
def test_authenticated_rest_api_view(self) -> None: def test_authenticated_rest_api_view(self) -> None:
@authenticated_rest_api_view(skip_rate_limiting=False) @authenticated_rest_api_view(skip_rate_limiting=False)
def my_rate_limited_view(request: HttpRequest, user_profile: UserProfile) -> str: def my_rate_limited_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success() # nocoverage # mock prevents this from being called return json_success() # nocoverage # mock prevents this from being called
@authenticated_rest_api_view(skip_rate_limiting=True) @authenticated_rest_api_view(skip_rate_limiting=True)
def my_unlimited_view(request: HttpRequest, user_profile: UserProfile) -> str: def my_unlimited_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success() return json_success()
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
@@ -416,11 +418,11 @@ class SkipRateLimitingTest(ZulipTestCase):
def test_authenticated_uploads_api_view(self) -> None: def test_authenticated_uploads_api_view(self) -> None:
@authenticated_uploads_api_view(skip_rate_limiting=False) @authenticated_uploads_api_view(skip_rate_limiting=False)
def my_rate_limited_view(request: HttpRequest, user_profile: UserProfile) -> str: def my_rate_limited_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success() # nocoverage # mock prevents this from being called return json_success() # nocoverage # mock prevents this from being called
@authenticated_uploads_api_view(skip_rate_limiting=True) @authenticated_uploads_api_view(skip_rate_limiting=True)
def my_unlimited_view(request: HttpRequest, user_profile: UserProfile) -> str: def my_unlimited_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success() return json_success()
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
@@ -438,7 +440,7 @@ class SkipRateLimitingTest(ZulipTestCase):
self.assertTrue(rate_limit_mock.called) self.assertTrue(rate_limit_mock.called)
def test_authenticated_json_view(self) -> None: def test_authenticated_json_view(self) -> None:
def my_view(request: HttpRequest, user_profile: UserProfile) -> str: def my_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success() return json_success()
my_rate_limited_view = authenticated_json_view(my_view, skip_rate_limiting=False) my_rate_limited_view = authenticated_json_view(my_view, skip_rate_limiting=False)
@@ -462,7 +464,9 @@ class SkipRateLimitingTest(ZulipTestCase):
class DecoratorLoggingTestCase(ZulipTestCase): class DecoratorLoggingTestCase(ZulipTestCase):
def test_authenticated_rest_api_view_logging(self) -> None: def test_authenticated_rest_api_view_logging(self) -> None:
@authenticated_rest_api_view(webhook_client_name="ClientName") @authenticated_rest_api_view(webhook_client_name="ClientName")
def my_webhook_raises_exception(request: HttpRequest, user_profile: UserProfile) -> None: def my_webhook_raises_exception(
request: HttpRequest, user_profile: UserProfile
) -> HttpResponse:
raise Exception("raised by webhook function") raise Exception("raised by webhook function")
webhook_bot_email = "webhook-bot@zulip.com" webhook_bot_email = "webhook-bot@zulip.com"
@@ -483,7 +487,9 @@ class DecoratorLoggingTestCase(ZulipTestCase):
def test_authenticated_rest_api_view_logging_unsupported_event(self) -> None: def test_authenticated_rest_api_view_logging_unsupported_event(self) -> None:
@authenticated_rest_api_view(webhook_client_name="ClientName") @authenticated_rest_api_view(webhook_client_name="ClientName")
def my_webhook_raises_exception(request: HttpRequest, user_profile: UserProfile) -> None: def my_webhook_raises_exception(
request: HttpRequest, user_profile: UserProfile
) -> HttpResponse:
raise UnsupportedWebhookEventType("test_event") raise UnsupportedWebhookEventType("test_event")
webhook_bot_email = "webhook-bot@zulip.com" webhook_bot_email = "webhook-bot@zulip.com"
@@ -511,7 +517,7 @@ class DecoratorLoggingTestCase(ZulipTestCase):
@authenticated_rest_api_view() @authenticated_rest_api_view()
def non_webhook_view_raises_exception( def non_webhook_view_raises_exception(
request: HttpRequest, user_profile: UserProfile request: HttpRequest, user_profile: UserProfile
) -> None: ) -> HttpResponse:
raise Exception("raised by a non-webhook view") raise Exception("raised by a non-webhook view")
request = HostRequestMock() request = HostRequestMock()

View File

@@ -646,7 +646,7 @@ def redirect_and_log_into_subdomain(result: ExternalAuthResult) -> HttpResponse:
return redirect(subdomain_login_uri) return redirect(subdomain_login_uri)
def redirect_to_misconfigured_ldap_notice(request: HttpResponse, error_type: int) -> HttpResponse: def redirect_to_misconfigured_ldap_notice(request: HttpRequest, error_type: int) -> HttpResponse:
if error_type == ZulipLDAPAuthBackend.REALM_IS_NONE_ERROR: if error_type == ZulipLDAPAuthBackend.REALM_IS_NONE_ERROR:
return config_error(request, "ldap") return config_error(request, "ldap")
else: else:
@@ -789,6 +789,7 @@ def login_page(
# https://github.com/django/django/blob/master/django/template/response.py#L19. # https://github.com/django/django/blob/master/django/template/response.py#L19.
update_login_page_context(request, template_response.context_data) update_login_page_context(request, template_response.context_data)
assert isinstance(template_response, HttpResponse)
return template_response return template_response
@@ -957,12 +958,13 @@ def logout_then_login(request: HttpRequest, **kwargs: Any) -> HttpResponse:
def password_reset(request: HttpRequest) -> HttpResponse: def password_reset(request: HttpRequest) -> HttpResponse:
view_func = DjangoPasswordResetView.as_view( response = DjangoPasswordResetView.as_view(
template_name="zerver/reset.html", template_name="zerver/reset.html",
form_class=ZulipPasswordResetForm, form_class=ZulipPasswordResetForm,
success_url="/accounts/password/reset/done/", success_url="/accounts/password/reset/done/",
) )(request)
return view_func(request) assert isinstance(response, HttpResponse)
return response
@csrf_exempt @csrf_exempt

View File

@@ -167,7 +167,7 @@ def create_default_stream_group(
group_name: str = REQ(), group_name: str = REQ(),
description: str = REQ(), description: str = REQ(),
stream_names: List[str] = REQ(json_validator=check_list(check_string)), stream_names: List[str] = REQ(json_validator=check_list(check_string)),
) -> None: ) -> HttpResponse:
streams = [] streams = []
for stream_name in stream_names: for stream_name in stream_names:
(stream, sub) = access_stream_by_name(user_profile, stream_name) (stream, sub) = access_stream_by_name(user_profile, stream_name)
@@ -184,7 +184,7 @@ def update_default_stream_group_info(
group_id: int, group_id: int,
new_group_name: Optional[str] = REQ(default=None), new_group_name: Optional[str] = REQ(default=None),
new_description: Optional[str] = REQ(default=None), new_description: Optional[str] = REQ(default=None),
) -> None: ) -> HttpResponse:
if not new_group_name and not new_description: if not new_group_name and not new_description:
raise JsonableError(_('You must pass "new_description" or "new_group_name".')) raise JsonableError(_('You must pass "new_description" or "new_group_name".'))
@@ -204,7 +204,7 @@ def update_default_stream_group_streams(
group_id: int, group_id: int,
op: str = REQ(), op: str = REQ(),
stream_names: List[str] = REQ(json_validator=check_list(check_string)), stream_names: List[str] = REQ(json_validator=check_list(check_string)),
) -> None: ) -> HttpResponse:
group = access_default_stream_group_by_id(user_profile.realm, group_id) group = access_default_stream_group_by_id(user_profile.realm, group_id)
streams = [] streams = []
for stream_name in stream_names: for stream_name in stream_names:
@@ -224,7 +224,7 @@ def update_default_stream_group_streams(
@has_request_variables @has_request_variables
def remove_default_stream_group( def remove_default_stream_group(
request: HttpRequest, user_profile: UserProfile, group_id: int request: HttpRequest, user_profile: UserProfile, group_id: int
) -> None: ) -> HttpResponse:
group = access_default_stream_group_by_id(user_profile.realm, group_id) group = access_default_stream_group_by_id(user_profile.realm, group_id)
do_remove_default_stream_group(user_profile.realm, group) do_remove_default_stream_group(user_profile.realm, group)
return json_success() return json_success()