mirror of
https://github.com/zulip/zulip.git
synced 2025-11-16 03:41:58 +00:00
billing: Use require_billing_access decorator in JSON endpoints.
This commit is contained in:
@@ -11,6 +11,7 @@ import json
|
|||||||
|
|
||||||
from django.core import signing
|
from django.core import signing
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
|
from django.core.urlresolvers import get_resolver
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.utils.timezone import utc as timezone_utc
|
from django.utils.timezone import utc as timezone_utc
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ from corporate.lib.stripe import catch_stripe_errors, \
|
|||||||
get_next_billing_log_entry, run_billing_processor_one_step, \
|
get_next_billing_log_entry, run_billing_processor_one_step, \
|
||||||
BillingError, StripeCardError, StripeConnectionError, stripe_get_customer
|
BillingError, StripeCardError, StripeConnectionError, stripe_get_customer
|
||||||
from corporate.models import Customer, Plan, Coupon, BillingProcessor
|
from corporate.models import Customer, Plan, Coupon, BillingProcessor
|
||||||
|
import corporate.urls
|
||||||
|
|
||||||
CallableT = TypeVar('CallableT', bound=Callable[..., Any])
|
CallableT = TypeVar('CallableT', bound=Callable[..., Any])
|
||||||
|
|
||||||
@@ -561,26 +563,6 @@ class StripeTest(ZulipTestCase):
|
|||||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'downgrade without subscription')
|
self.assertEqual(ujson.loads(response.content)['error_description'], 'downgrade without subscription')
|
||||||
mock_save_customer.assert_not_called()
|
mock_save_customer.assert_not_called()
|
||||||
|
|
||||||
def test_downgrade_permissions(self) -> None:
|
|
||||||
self.login(self.example_email('hamlet'))
|
|
||||||
response = self.client_post("/json/billing/downgrade", {})
|
|
||||||
self.assert_json_error_contains(response, "Access denied")
|
|
||||||
# billing admin but not realm admin
|
|
||||||
user = self.example_user('hamlet')
|
|
||||||
user.is_billing_admin = True
|
|
||||||
user.save(update_fields=['is_billing_admin'])
|
|
||||||
with patch('corporate.views.process_downgrade') as mocked1:
|
|
||||||
self.client_post("/json/billing/downgrade", {})
|
|
||||||
mocked1.assert_called()
|
|
||||||
# realm admin but not billing admin
|
|
||||||
user = self.example_user('hamlet')
|
|
||||||
user.is_billing_admin = False
|
|
||||||
user.is_realm_admin = True
|
|
||||||
user.save(update_fields=['is_billing_admin', 'is_realm_admin'])
|
|
||||||
with patch('corporate.views.process_downgrade') as mocked2:
|
|
||||||
self.client_post("/json/billing/downgrade", {})
|
|
||||||
mocked2.assert_called()
|
|
||||||
|
|
||||||
@patch("stripe.Subscription.delete")
|
@patch("stripe.Subscription.delete")
|
||||||
@patch("stripe.Customer.retrieve", side_effect=mock_customer_with_account_balance(1234))
|
@patch("stripe.Customer.retrieve", side_effect=mock_customer_with_account_balance(1234))
|
||||||
def test_downgrade_credits(self, mock_retrieve_customer: Mock,
|
def test_downgrade_credits(self, mock_retrieve_customer: Mock,
|
||||||
@@ -632,31 +614,6 @@ class StripeTest(ZulipTestCase):
|
|||||||
self.assertFalse(RealmAuditLog.objects.filter(
|
self.assertFalse(RealmAuditLog.objects.filter(
|
||||||
event_type=RealmAuditLog.STRIPE_CARD_CHANGED).exists())
|
event_type=RealmAuditLog.STRIPE_CARD_CHANGED).exists())
|
||||||
|
|
||||||
def test_update_payment_source_permissions(self) -> None:
|
|
||||||
# This can be removed / merged with e.g. test_downgrade_permissions
|
|
||||||
# once we have a decorator that handles billing page permissions
|
|
||||||
self.login(self.example_email('hamlet'))
|
|
||||||
response = self.client_post("/json/billing/sources/change",
|
|
||||||
{'stripe_token': ujson.dumps('token')})
|
|
||||||
self.assert_json_error_contains(response, "Access denied")
|
|
||||||
# billing admin but not realm admin
|
|
||||||
user = self.example_user('hamlet')
|
|
||||||
user.is_billing_admin = True
|
|
||||||
user.save(update_fields=['is_billing_admin'])
|
|
||||||
with patch('corporate.views.do_replace_payment_source') as mocked1:
|
|
||||||
self.client_post("/json/billing/sources/change",
|
|
||||||
{'stripe_token': ujson.dumps('token')})
|
|
||||||
mocked1.assert_called()
|
|
||||||
# realm admin but not billing admin
|
|
||||||
user = self.example_user('hamlet')
|
|
||||||
user.is_billing_admin = False
|
|
||||||
user.is_realm_admin = True
|
|
||||||
user.save(update_fields=['is_billing_admin', 'is_realm_admin'])
|
|
||||||
with patch('corporate.views.do_replace_payment_source') as mocked2:
|
|
||||||
self.client_post("/json/billing/sources/change",
|
|
||||||
{'stripe_token': ujson.dumps('token')})
|
|
||||||
mocked2.assert_called()
|
|
||||||
|
|
||||||
@patch("stripe.Customer.create", side_effect=mock_create_customer)
|
@patch("stripe.Customer.create", side_effect=mock_create_customer)
|
||||||
@patch("stripe.Subscription.create", side_effect=mock_create_subscription)
|
@patch("stripe.Subscription.create", side_effect=mock_create_subscription)
|
||||||
@patch("stripe.Customer.retrieve", side_effect=mock_customer_with_subscription)
|
@patch("stripe.Customer.retrieve", side_effect=mock_customer_with_subscription)
|
||||||
@@ -739,6 +696,53 @@ class RequiresBillingUpdateTest(ZulipTestCase):
|
|||||||
do_activate_user(user2)
|
do_activate_user(user2)
|
||||||
self.assertEqual(4, RealmAuditLog.objects.filter(requires_billing_update=True).count())
|
self.assertEqual(4, RealmAuditLog.objects.filter(requires_billing_update=True).count())
|
||||||
|
|
||||||
|
class RequiresBillingAccessTest(ZulipTestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
hamlet = self.example_user("hamlet")
|
||||||
|
hamlet.is_billing_admin = True
|
||||||
|
hamlet.save(update_fields=["is_billing_admin"])
|
||||||
|
|
||||||
|
# mocked_function_name will typically be something imported from
|
||||||
|
# stripe.py. In theory we could have endpoints that need to mock
|
||||||
|
# multiple functions, but we'll cross that bridge when we get there.
|
||||||
|
def _test_endpoint(self, url: str, mocked_function_name: str,
|
||||||
|
request_data: Optional[Dict[str, Any]]={}) -> None:
|
||||||
|
# Normal users do not have access
|
||||||
|
self.login(self.example_email('cordelia'))
|
||||||
|
response = self.client_post(url, request_data)
|
||||||
|
self.assert_json_error_contains(response, "Access denied")
|
||||||
|
|
||||||
|
# Billing admins have access
|
||||||
|
self.login(self.example_email('hamlet'))
|
||||||
|
with patch("corporate.views.{}".format(mocked_function_name)) as mocked1:
|
||||||
|
response = self.client_post(url, request_data)
|
||||||
|
self.assert_json_success(response)
|
||||||
|
mocked1.assert_called()
|
||||||
|
|
||||||
|
# Realm admins have access, even if they are not billing admins
|
||||||
|
self.login(self.example_email('iago'))
|
||||||
|
with patch("corporate.views.{}".format(mocked_function_name)) as mocked2:
|
||||||
|
response = self.client_post(url, request_data)
|
||||||
|
self.assert_json_success(response)
|
||||||
|
mocked2.assert_called()
|
||||||
|
|
||||||
|
def test_json_endpoints(self) -> None:
|
||||||
|
params = [
|
||||||
|
("/json/billing/sources/change", "do_replace_payment_source",
|
||||||
|
{'stripe_token': ujson.dumps('token')}),
|
||||||
|
("/json/billing/downgrade", "process_downgrade", {})
|
||||||
|
] # type: List[Tuple[str, str, Dict[str, Any]]]
|
||||||
|
|
||||||
|
for (url, mocked_function_name, data) in params:
|
||||||
|
self._test_endpoint(url, mocked_function_name, data)
|
||||||
|
|
||||||
|
# Make sure that we are testing all the JSON endpoints
|
||||||
|
# Quite a hack, but probably fine for now
|
||||||
|
string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict)
|
||||||
|
json_endpoints = set([word.strip("\"'()[],$") for word in string_with_all_endpoints.split()
|
||||||
|
if 'json' in word])
|
||||||
|
self.assertEqual(len(json_endpoints), len(params))
|
||||||
|
|
||||||
class BillingProcessorTest(ZulipTestCase):
|
class BillingProcessorTest(ZulipTestCase):
|
||||||
def add_log_entry(self, realm: Realm=get_realm('zulip'),
|
def add_log_entry(self, realm: Realm=get_realm('zulip'),
|
||||||
event_type: str=RealmAuditLog.USER_CREATED,
|
event_type: str=RealmAuditLog.USER_CREATED,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ i18n_urlpatterns = [
|
|||||||
v1_api_and_json_patterns = [
|
v1_api_and_json_patterns = [
|
||||||
url(r'^billing/downgrade$', rest_dispatch,
|
url(r'^billing/downgrade$', rest_dispatch,
|
||||||
{'POST': 'corporate.views.downgrade'}),
|
{'POST': 'corporate.views.downgrade'}),
|
||||||
url(r'billing/sources/change', rest_dispatch,
|
url(r'^billing/sources/change', rest_dispatch,
|
||||||
{'POST': 'corporate.views.replace_payment_source'}),
|
{'POST': 'corporate.views.replace_payment_source'}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from django.shortcuts import redirect, render
|
|||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
from zerver.decorator import zulip_login_required
|
from zerver.decorator import zulip_login_required, require_billing_access
|
||||||
from zerver.lib.request import REQ, has_request_variables
|
from zerver.lib.request import REQ, has_request_variables
|
||||||
from zerver.lib.response import json_error, json_success
|
from zerver.lib.response import json_error, json_success
|
||||||
from zerver.lib.validator import check_string
|
from zerver.lib.validator import check_string
|
||||||
@@ -144,20 +144,18 @@ def billing_home(request: HttpRequest) -> HttpResponse:
|
|||||||
|
|
||||||
return render(request, 'corporate/billing.html', context=context)
|
return render(request, 'corporate/billing.html', context=context)
|
||||||
|
|
||||||
|
@require_billing_access
|
||||||
def downgrade(request: HttpRequest, user: UserProfile) -> HttpResponse:
|
def downgrade(request: HttpRequest, user: UserProfile) -> HttpResponse:
|
||||||
if not user.is_realm_admin and not user.is_billing_admin:
|
|
||||||
return json_error(_('Access denied'))
|
|
||||||
try:
|
try:
|
||||||
process_downgrade(user)
|
process_downgrade(user)
|
||||||
except BillingError as e:
|
except BillingError as e:
|
||||||
return json_error(e.message, data={'error_description': e.description})
|
return json_error(e.message, data={'error_description': e.description})
|
||||||
return json_success()
|
return json_success()
|
||||||
|
|
||||||
|
@require_billing_access
|
||||||
@has_request_variables
|
@has_request_variables
|
||||||
def replace_payment_source(request: HttpRequest, user: UserProfile,
|
def replace_payment_source(request: HttpRequest, user: UserProfile,
|
||||||
stripe_token: str=REQ("stripe_token", validator=check_string)) -> HttpResponse:
|
stripe_token: str=REQ("stripe_token", validator=check_string)) -> HttpResponse:
|
||||||
if not user.is_realm_admin and not user.is_billing_admin:
|
|
||||||
return json_error(_("Access denied"))
|
|
||||||
try:
|
try:
|
||||||
do_replace_payment_source(user, stripe_token)
|
do_replace_payment_source(user, stripe_token)
|
||||||
except BillingError as e:
|
except BillingError as e:
|
||||||
|
|||||||
@@ -135,6 +135,14 @@ def require_realm_admin(func: ViewFuncT) -> ViewFuncT:
|
|||||||
return func(request, user_profile, *args, **kwargs)
|
return func(request, user_profile, *args, **kwargs)
|
||||||
return wrapper # type: ignore # https://github.com/python/mypy/issues/1927
|
return wrapper # type: ignore # https://github.com/python/mypy/issues/1927
|
||||||
|
|
||||||
|
def require_billing_access(func: ViewFuncT) -> ViewFuncT:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(request: HttpRequest, user_profile: UserProfile, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
|
if not user_profile.is_realm_admin and not user_profile.is_billing_admin:
|
||||||
|
raise JsonableError(_("Access denied"))
|
||||||
|
return func(request, user_profile, *args, **kwargs)
|
||||||
|
return wrapper # type: ignore # https://github.com/python/mypy/issues/1927
|
||||||
|
|
||||||
from zerver.lib.user_agent import parse_user_agent
|
from zerver.lib.user_agent import parse_user_agent
|
||||||
|
|
||||||
def get_client_name(request: HttpRequest, is_browser_view: bool) -> str:
|
def get_client_name(request: HttpRequest, is_browser_view: bool) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user