diff --git a/corporate/lib/decorator.py b/corporate/lib/decorator.py index 3f4dc53b89..941572d436 100644 --- a/corporate/lib/decorator.py +++ b/corporate/lib/decorator.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Callable from functools import wraps from typing import Concatenate @@ -65,7 +66,7 @@ def authenticated_remote_realm_management_endpoint( if not is_self_hosting_management_subdomain(request): # nocoverage return render(request, "404.html", status=404) - realm_uuid = kwargs.get("realm_uuid") + realm_uuid = kwargs.pop("realm_uuid") if realm_uuid is not None and not isinstance(realm_uuid, str): # nocoverage raise TypeError("realm_uuid must be a string or None") @@ -124,7 +125,16 @@ def authenticated_remote_realm_management_endpoint( billing_session = RemoteRealmBillingSession( remote_realm, remote_billing_user=remote_billing_user ) - return view_func(request, billing_session) + return view_func(request, billing_session, *args, **kwargs) + + signature = inspect.signature(view_func) + request_parameter, billing_session_parameter, *other_parameters = signature.parameters.values() + _wrapped_view_func.__signature__ = signature.replace( # type: ignore[attr-defined] # too magic + parameters=[request_parameter, *other_parameters] + ) + _wrapped_view_func.__annotations__ = { + k: v for k, v in view_func.__annotations__.items() if k != billing_session_parameter.name + } return _wrapped_view_func @@ -165,7 +175,7 @@ def authenticated_remote_server_management_endpoint( if not is_self_hosting_management_subdomain(request): # nocoverage return render(request, "404.html", status=404) - server_uuid = kwargs.get("server_uuid") + server_uuid = kwargs.pop("server_uuid") if not isinstance(server_uuid, str): raise TypeError("server_uuid must be a string") # nocoverage @@ -199,6 +209,15 @@ def authenticated_remote_server_management_endpoint( billing_session = RemoteServerBillingSession( remote_server, remote_billing_user=remote_billing_user ) - return view_func(request, billing_session) + return view_func(request, billing_session, *args, **kwargs) + + signature = inspect.signature(view_func) + request_parameter, billing_session_parameter, *other_parameters = signature.parameters.values() + _wrapped_view_func.__signature__ = signature.replace( # type: ignore[attr-defined] # too magic + parameters=[request_parameter, *other_parameters] + ) + _wrapped_view_func.__annotations__ = { + k: v for k, v in view_func.__annotations__.items() if k != billing_session_parameter.name + } return _wrapped_view_func