mirror of
https://github.com/zulip/zulip.git
synced 2025-10-29 02:53:52 +00:00
We wrap methods of the django test client for the test suite, and type keyword variadic arguments as `ClientArg` as it might called with a mix of `bool` and `str`. This is problematic when we call the original methods on the test client as we attempt to unpack the dictionary of keyword arguments, which has no type guarantee that certain keys that the test client requires to be bool will certainly be bool. For example, you can call `self.client_post(url, info, follow="invalid")` without getting a mypy error while the django test client requires `follow: bool`. The unsafely typed keyword variadic arguments leads to error within the body the wrapped test client functions as we call `django_client.post` with `**kwargs` when django-stubs gets added, making it necessary to refactor these wrappers for type safety. The approach here minimizes the need to refactor callers, as we keep `kwargs` being variadic while change its type from `ClientArg` to `str` after defining all the possible `bool` arguments that might previously appear in `kwargs`. We also copy the defaults from the django test client as they are unlikely to change. The tornado test cases are also refactored due to the change of the signature of `set_http_headers` with the `skip_user_agent` being added as a keyword argument. We want to unconditionally set this flag to `True` because the `HTTP_USER_AGENT` is not supported. It also removes a unnecessary duplication of an argument. This is a part of the django-stubs refactorings. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
160 lines
5.8 KiB
Python
160 lines
5.8 KiB
Python
import asyncio
|
|
import urllib.parse
|
|
from functools import wraps
|
|
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
|
|
from unittest import TestResult, mock
|
|
|
|
import orjson
|
|
from asgiref.sync import async_to_sync, sync_to_async
|
|
from django.conf import settings
|
|
from django.core import signals
|
|
from django.db import close_old_connections
|
|
from django.test import override_settings
|
|
from tornado.httpclient import HTTPResponse
|
|
from tornado.ioloop import IOLoop
|
|
from tornado.platform.asyncio import AsyncIOMainLoop
|
|
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase
|
|
from tornado.web import Application
|
|
from typing_extensions import ParamSpec
|
|
|
|
from zerver.lib.test_classes import ZulipTestCase
|
|
from zerver.tornado import event_queue
|
|
from zerver.tornado.application import create_tornado_application
|
|
from zerver.tornado.event_queue import process_event
|
|
|
|
P = ParamSpec("P")
|
|
T = TypeVar("T")
|
|
|
|
|
|
def async_to_sync_decorator(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
|
|
@wraps(f)
|
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
return async_to_sync(f)(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
|
|
async def in_django_thread(f: Callable[[], T]) -> T:
|
|
return await asyncio.create_task(sync_to_async(f)())
|
|
|
|
|
|
class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
|
|
@async_to_sync_decorator
|
|
async def setUp(self) -> None:
|
|
super().setUp()
|
|
signals.request_started.disconnect(close_old_connections)
|
|
signals.request_finished.disconnect(close_old_connections)
|
|
self.session_cookie: Optional[Dict[str, str]] = None
|
|
|
|
@async_to_sync_decorator
|
|
async def tearDown(self) -> None:
|
|
# Skip tornado.testing.AsyncTestCase.tearDown because it tries to kill
|
|
# the current task.
|
|
super(AsyncTestCase, self).tearDown()
|
|
|
|
def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]:
|
|
return async_to_sync(
|
|
sync_to_async(super().run, thread_sensitive=False), force_new_loop=True
|
|
)(result)
|
|
|
|
def get_new_ioloop(self) -> IOLoop:
|
|
return AsyncIOMainLoop()
|
|
|
|
@override_settings(DEBUG=False)
|
|
def get_app(self) -> Application:
|
|
return create_tornado_application()
|
|
|
|
async def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse:
|
|
self.add_session_cookie(kwargs)
|
|
self.set_http_headers(kwargs, skip_user_agent=True)
|
|
if "HTTP_HOST" in kwargs:
|
|
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
|
|
del kwargs["HTTP_HOST"]
|
|
return await self.http_client.fetch(self.get_url(path), method="GET", **kwargs)
|
|
|
|
async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse:
|
|
self.add_session_cookie(kwargs)
|
|
self.set_http_headers(kwargs, skip_user_agent=True)
|
|
if "HTTP_HOST" in kwargs:
|
|
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
|
|
del kwargs["HTTP_HOST"]
|
|
return await self.http_client.fetch(self.get_url(path), method=method, **kwargs)
|
|
|
|
async def client_get_async(self, path: str, **kwargs: Any) -> HTTPResponse:
|
|
self.set_http_headers(kwargs, skip_user_agent=True)
|
|
return await self.fetch_async("GET", path, **kwargs)
|
|
|
|
def login_user(self, *args: Any, **kwargs: Any) -> None:
|
|
super().login_user(*args, **kwargs)
|
|
session_cookie = settings.SESSION_COOKIE_NAME
|
|
session_key = self.client.session.session_key
|
|
self.session_cookie = {
|
|
"Cookie": f"{session_cookie}={session_key}",
|
|
}
|
|
|
|
def get_session_cookie(self) -> Dict[str, str]:
|
|
return {} if self.session_cookie is None else self.session_cookie
|
|
|
|
def add_session_cookie(self, kwargs: Dict[str, Any]) -> None:
|
|
# TODO: Currently only allows session cookie
|
|
headers = kwargs.get("headers", {})
|
|
headers.update(self.get_session_cookie())
|
|
kwargs["headers"] = headers
|
|
|
|
async def create_queue(self, **kwargs: Any) -> str:
|
|
response = await self.tornado_client_get(
|
|
"/json/events?dont_block=true",
|
|
subdomain="zulip",
|
|
)
|
|
self.assertEqual(response.code, 200)
|
|
body = orjson.loads(response.body)
|
|
self.assertEqual(body["events"], [])
|
|
self.assertIn("queue_id", body)
|
|
return body["queue_id"]
|
|
|
|
|
|
class EventsTestCase(TornadoWebTestCase):
|
|
@async_to_sync_decorator
|
|
async def test_create_queue(self) -> None:
|
|
await in_django_thread(lambda: self.login_user(self.example_user("hamlet")))
|
|
queue_id = await self.create_queue()
|
|
self.assertIn(queue_id, event_queue.clients)
|
|
|
|
@async_to_sync_decorator
|
|
async def test_events_async(self) -> None:
|
|
user_profile = await in_django_thread(lambda: self.example_user("hamlet"))
|
|
await in_django_thread(lambda: self.login_user(user_profile))
|
|
event_queue_id = await self.create_queue()
|
|
data = {
|
|
"queue_id": event_queue_id,
|
|
"last_event_id": -1,
|
|
}
|
|
|
|
path = f"/json/events?{urllib.parse.urlencode(data)}"
|
|
|
|
def process_events() -> None:
|
|
users = [user_profile.id]
|
|
event = dict(
|
|
type="test",
|
|
data="test data",
|
|
)
|
|
process_event(event, users)
|
|
|
|
def wrapped_fetch_events(**query: Any) -> Dict[str, Any]:
|
|
ret = event_queue.fetch_events(**query)
|
|
self.io_loop.add_callback(process_events)
|
|
return ret
|
|
|
|
with mock.patch("zerver.tornado.views.fetch_events", side_effect=wrapped_fetch_events):
|
|
response = await self.client_get_async(path)
|
|
|
|
self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie")
|
|
data = orjson.loads(response.body)
|
|
self.assertEqual(
|
|
data["events"],
|
|
[
|
|
{"type": "test", "data": "test data", "id": 0},
|
|
],
|
|
)
|
|
self.assertEqual(data["result"], "success")
|