openapi: Validate real requests and responses, not fictional mocks.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 029e765e20)
This commit is contained in:
Anders Kaseorg
2023-08-09 16:43:12 -07:00
committed by Tim Abbott
parent 9f2172c0f9
commit 3bf1934598
3 changed files with 3451 additions and 3304 deletions

View File

@@ -36,16 +36,18 @@ from django.db import connection
from django.db.migrations.executor import MigrationExecutor from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.state import StateApps from django.db.migrations.state import StateApps
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse, HttpResponseBase
from django.http.response import ResponseHeaders from django.http.response import ResponseHeaders
from django.test import Client as TestClient
from django.test import SimpleTestCase, TestCase, TransactionTestCase from django.test import SimpleTestCase, TestCase, TransactionTestCase
from django.test.client import BOUNDARY, MULTIPART_CONTENT, encode_multipart from django.test.client import BOUNDARY, MULTIPART_CONTENT, ClientHandler, encode_multipart
from django.test.testcases import SerializeMixin from django.test.testcases import SerializeMixin
from django.urls import resolve from django.urls import resolve
from django.utils import translation from django.utils import translation
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from fakeldap import MockLDAP from fakeldap import MockLDAP
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
from requests import PreparedRequest from requests import PreparedRequest
from two_factor.plugins.phonenumber.models import PhoneDevice from two_factor.plugins.phonenumber.models import PhoneDevice
from typing_extensions import override from typing_extensions import override
@@ -112,7 +114,7 @@ from zerver.models.groups import SystemGroups
from zerver.models.realms import clear_supported_auth_backends_cache, get_realm from zerver.models.realms import clear_supported_auth_backends_cache, get_realm
from zerver.models.streams import get_realm_stream, get_stream from zerver.models.streams import get_realm_stream, get_stream
from zerver.models.users import get_system_bot, get_user, get_user_by_delivery_email from zerver.models.users import get_system_bot, get_user, get_user_by_delivery_email
from zerver.openapi.openapi import validate_against_openapi_schema, validate_request from zerver.openapi.openapi import validate_test_request, validate_test_response
from zerver.tornado.event_queue import clear_client_event_queues_for_testing from zerver.tornado.event_queue import clear_client_event_queues_for_testing
if settings.ZILENCER_ENABLED: if settings.ZILENCER_ENABLED:
@@ -147,6 +149,36 @@ class UploadSerializeMixin(SerializeMixin):
super().setUpClass() super().setUpClass()
class ZulipClientHandler(ClientHandler):
@override
def get_response(self, request: HttpRequest) -> HttpResponseBase:
request.body # noqa: B018 # prevents RawPostDataException
response = super().get_response(request)
if (
request.method != "OPTIONS"
and isinstance(response, HttpResponse)
and not (
response.status_code == 302 and response.headers["Location"].startswith("/login/")
)
):
openapi_request = DjangoOpenAPIRequest(request)
openapi_response = DjangoOpenAPIResponse(response)
response_validated = validate_test_response(openapi_request, openapi_response)
if response_validated:
validate_test_request(
openapi_request,
str(response.status_code),
request.META.get("intentionally_undocumented", False),
)
return response
class ZulipTestClient(TestClient):
def __init__(self) -> None:
super().__init__()
self.handler = ZulipClientHandler(enforce_csrf_checks=False)
class ZulipTestCaseMixin(SimpleTestCase): class ZulipTestCaseMixin(SimpleTestCase):
# Ensure that the test system just shows us diffs # Ensure that the test system just shows us diffs
maxDiff: Optional[int] = None maxDiff: Optional[int] = None
@@ -154,6 +186,7 @@ class ZulipTestCaseMixin(SimpleTestCase):
# Override this to verify if the given extra console output matches the # Override this to verify if the given extra console output matches the
# expectation. # expectation.
expected_console_output: Optional[str] = None expected_console_output: Optional[str] = None
client_class = ZulipTestClient
@override @override
def setUp(self) -> None: def setUp(self) -> None:
@@ -266,63 +299,6 @@ Output:
elif "HTTP_USER_AGENT" not in extra: elif "HTTP_USER_AGENT" not in extra:
extra["HTTP_USER_AGENT"] = default_user_agent extra["HTTP_USER_AGENT"] = default_user_agent
def extract_api_suffix_url(self, url: str) -> Tuple[str, Dict[str, List[str]]]:
"""
Function that extracts the URL after `/api/v1` or `/json` and also
returns the query data in the URL, if there is any.
"""
url_split = url.split("?")
data = {}
if len(url_split) == 2:
data = parse_qs(url_split[1])
url = url_split[0]
url = url.replace("/json/", "/").replace("/api/v1/", "/")
return (url, data)
def validate_api_response_openapi(
self,
url: str,
method: str,
result: "TestHttpResponse",
data: Union[str, bytes, Mapping[str, Any]],
extra: Dict[str, str],
intentionally_undocumented: bool = False,
) -> None:
"""
Validates all API responses received by this test against Zulip's API documentation,
declared in zerver/openapi/zulip.yaml. This powerful test lets us use Zulip's
extensive test coverage of corner cases in the API to ensure that we've properly
documented those corner cases.
"""
if not url.startswith(("/json", "/api/v1")):
return
try:
content = orjson.loads(result.content)
except orjson.JSONDecodeError:
return
json_url = False
if url.startswith("/json"):
json_url = True
url, query_data = self.extract_api_suffix_url(url)
if len(query_data) != 0:
# In some cases the query parameters are defined in the URL itself. In such cases
# The `data` argument of our function is not used. Hence get `data` argument
# from url.
data = query_data
response_validated = validate_against_openapi_schema(
content, url, method, str(result.status_code)
)
if response_validated:
validate_request(
url,
method,
data,
extra,
json_url,
str(result.status_code),
intentionally_undocumented=intentionally_undocumented,
)
@instrument_url @instrument_url
def client_patch( def client_patch(
self, self,
@@ -342,18 +318,15 @@ Output:
extra["content_type"] = "application/x-www-form-urlencoded" extra["content_type"] = "application/x-www-form-urlencoded"
django_client = self.client # see WRAPPER_COMMENT django_client = self.client # see WRAPPER_COMMENT
self.set_http_headers(extra, skip_user_agent) self.set_http_headers(extra, skip_user_agent)
result = django_client.patch( return django_client.patch(
url, encoded, follow=follow, secure=secure, headers=headers, **extra
)
self.validate_api_response_openapi(
url, url,
"patch", encoded,
result, follow=follow,
info, secure=secure,
extra, headers=headers,
intentionally_undocumented=intentionally_undocumented, intentionally_undocumented=intentionally_undocumented,
**extra,
) )
return result
@instrument_url @instrument_url
def client_patch_multipart( def client_patch_multipart(
@@ -378,24 +351,16 @@ Output:
encoded = encode_multipart(BOUNDARY, dict(info)) encoded = encode_multipart(BOUNDARY, dict(info))
django_client = self.client # see WRAPPER_COMMENT django_client = self.client # see WRAPPER_COMMENT
self.set_http_headers(extra, skip_user_agent) self.set_http_headers(extra, skip_user_agent)
result = django_client.patch( return django_client.patch(
url, url,
encoded, encoded,
content_type=MULTIPART_CONTENT, content_type=MULTIPART_CONTENT,
follow=follow, follow=follow,
secure=secure, secure=secure,
headers=headers, headers=headers,
intentionally_undocumented=intentionally_undocumented,
**extra, **extra,
) )
self.validate_api_response_openapi(
url,
"patch",
result,
info,
extra,
intentionally_undocumented=intentionally_undocumented,
)
return result
def json_patch( def json_patch(
self, self,
@@ -477,7 +442,7 @@ Output:
extra["content_type"] = "application/x-www-form-urlencoded" extra["content_type"] = "application/x-www-form-urlencoded"
django_client = self.client # see WRAPPER_COMMENT django_client = self.client # see WRAPPER_COMMENT
self.set_http_headers(extra, skip_user_agent) self.set_http_headers(extra, skip_user_agent)
result = django_client.delete( return django_client.delete(
url, url,
encoded, encoded,
follow=follow, follow=follow,
@@ -486,17 +451,9 @@ Output:
"Content-Type": "application/x-www-form-urlencoded", # https://code.djangoproject.com/ticket/33230 "Content-Type": "application/x-www-form-urlencoded", # https://code.djangoproject.com/ticket/33230
**(headers or {}), **(headers or {}),
}, },
intentionally_undocumented=intentionally_undocumented,
**extra, **extra,
) )
self.validate_api_response_openapi(
url,
"delete",
result,
info,
extra,
intentionally_undocumented=intentionally_undocumented,
)
return result
@instrument_url @instrument_url
def client_options( def client_options(
@@ -554,7 +511,7 @@ Output:
encoded = urlencode(info, doseq=True) encoded = urlencode(info, doseq=True)
else: else:
content_type = MULTIPART_CONTENT content_type = MULTIPART_CONTENT
result = django_client.post( return django_client.post(
url, url,
encoded, encoded,
follow=follow, follow=follow,
@@ -564,17 +521,9 @@ Output:
**(headers or {}), **(headers or {}),
}, },
content_type=content_type, content_type=content_type,
intentionally_undocumented=intentionally_undocumented,
**extra, **extra,
) )
self.validate_api_response_openapi(
url,
"post",
result,
info,
extra,
intentionally_undocumented=intentionally_undocumented,
)
return result
@instrument_url @instrument_url
def client_post_request(self, url: str, req: Any) -> "TestHttpResponse": def client_post_request(self, url: str, req: Any) -> "TestHttpResponse":
@@ -604,13 +553,15 @@ Output:
) -> "TestHttpResponse": ) -> "TestHttpResponse":
django_client = self.client # see WRAPPER_COMMENT django_client = self.client # see WRAPPER_COMMENT
self.set_http_headers(extra, skip_user_agent) self.set_http_headers(extra, skip_user_agent)
result = django_client.get( return django_client.get(
url, info, follow=follow, secure=secure, headers=headers, **extra url,
info,
follow=follow,
secure=secure,
headers=headers,
intentionally_undocumented=intentionally_undocumented,
**extra,
) )
self.validate_api_response_openapi(
url, "get", result, info, extra, intentionally_undocumented=intentionally_undocumented
)
return result
example_user_map = dict( example_user_map = dict(
hamlet="hamlet@zulip.com", hamlet="hamlet@zulip.com",

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Literal, Mapping, Optional, Set, Tuple, Unio
import orjson import orjson
from openapi_core import OpenAPI from openapi_core import OpenAPI
from openapi_core.protocols import Request, Response
from openapi_core.testing import MockRequest, MockResponse from openapi_core.testing import MockRequest, MockResponse
from openapi_core.validation.exceptions import ValidationError as OpenAPIValidationError from openapi_core.validation.exceptions import ValidationError as OpenAPIValidationError
from pydantic import BaseModel from pydantic import BaseModel
@@ -423,10 +424,29 @@ def find_openapi_endpoint(path: str) -> Optional[str]:
def validate_against_openapi_schema( def validate_against_openapi_schema(
content: Dict[str, Any], path: str, method: str, status_code: str content: Dict[str, Any], path: str, method: str, status_code: str
) -> bool: ) -> bool:
mock_request = MockRequest("http://localhost:9991/", method, "/api/v1" + path)
mock_response = MockResponse(
orjson.dumps(content),
status_code=int(status_code),
)
return validate_test_response(mock_request, mock_response)
def validate_test_response(request: Request, response: Response) -> bool:
"""Compare a "content" dict with the defined schema for a specific method """Compare a "content" dict with the defined schema for a specific method
in an endpoint. Return true if validated and false if skipped. in an endpoint. Return true if validated and false if skipped.
""" """
if request.path.startswith("/json/"):
path = request.path[len("/json") :]
elif request.path.startswith("/api/v1/"):
path = request.path[len("/api/v1") :]
else:
return False
assert request.method is not None
method = request.method.lower()
status_code = str(response.status_code)
# This first set of checks are primarily training wheels that we # This first set of checks are primarily training wheels that we
# hope to eliminate over time as we improve our API documentation. # hope to eliminate over time as we improve our API documentation.
@@ -452,14 +472,8 @@ def validate_against_openapi_schema(
# response have been defined this should be removed. # response have been defined this should be removed.
return True return True
mock_request = MockRequest("http://localhost:9991/", method, "/api/v1" + path)
mock_response = MockResponse(
# TODO: Use original response content instead of re-serializing it.
orjson.dumps(content),
status_code=int(status_code),
)
try: try:
openapi_spec.spec().validate_response(mock_request, mock_response) openapi_spec.spec().validate_response(request, response)
except OpenAPIValidationError as error: except OpenAPIValidationError as error:
message = f"Response validation error at {method} /api/v1{path} ({status_code}):" message = f"Response validation error at {method} /api/v1{path} ({status_code}):"
message += f"\n\n{type(error).__name__}: {error}" message += f"\n\n{type(error).__name__}: {error}"
@@ -529,10 +543,33 @@ def validate_request(
status_code: str, status_code: str,
intentionally_undocumented: bool = False, intentionally_undocumented: bool = False,
) -> None: ) -> None:
assert isinstance(data, dict)
mock_request = MockRequest(
"http://localhost:9991/",
method,
"/api/v1" + url,
headers=http_headers,
args={k: str(v) for k, v in data.items()},
)
validate_test_request(mock_request, status_code, intentionally_undocumented)
def validate_test_request(
request: Request,
status_code: str,
intentionally_undocumented: bool = False,
) -> None:
assert request.method is not None
method = request.method.lower()
if request.path.startswith("/json/"):
url = request.path[len("/json") :]
# Some JSON endpoints have different parameters compared to # Some JSON endpoints have different parameters compared to
# their `/api/v1` counterparts. # their `/api/v1` counterparts.
if json_url and (url, method) in SKIP_JSON: if (url, method) in SKIP_JSON:
return return
else:
assert request.path.startswith("/api/v1/")
url = request.path[len("/api/v1") :]
# TODO: Add support for file upload endpoints that lack the /json/ # TODO: Add support for file upload endpoints that lack the /json/
# or /api/v1/ prefix. # or /api/v1/ prefix.
@@ -550,16 +587,8 @@ def validate_request(
# Now using the openapi_core APIs, validate the request schema # Now using the openapi_core APIs, validate the request schema
# against the OpenAPI documentation. # against the OpenAPI documentation.
assert isinstance(data, dict)
mock_request = MockRequest(
"http://localhost:9991/",
method,
"/api/v1" + url,
headers=http_headers,
args={k: str(v) for k, v in data.items()},
)
try: try:
openapi_spec.spec().validate_request(mock_request) openapi_spec.spec().validate_request(request)
except OpenAPIValidationError as error: except OpenAPIValidationError as error:
# Show a block error message explaining the options for fixing it. # Show a block error message explaining the options for fixing it.
msg = f""" msg = f"""

File diff suppressed because it is too large Load Diff