diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index a677c99461..a1a8c86c39 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -2299,6 +2299,17 @@ class StripeTest(StripeTestCase): self.assertEqual(old_plan.next_invoice_date, None) self.assertEqual(old_plan.status, CustomerPlan.ENDED) + def test_update_plan_with_invalid_status(self, *mocks: Mock) -> None: + with patch("corporate.lib.stripe.timezone_now", return_value=self.now): + self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, "token") + self.login_user(self.example_user("hamlet")) + + response = self.client_post( + "/json/billing/plan/change", + {"status": CustomerPlan.NEVER_STARTED}, + ) + self.assert_json_error_contains(response, "Invalid status") + def test_deactivate_realm(self) -> None: user = self.example_user("hamlet") with patch("corporate.lib.stripe.timezone_now", return_value=self.now): diff --git a/corporate/views.py b/corporate/views.py index 272bc2bbdf..8b77a3a58f 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -46,7 +46,7 @@ from zerver.decorator import ( from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_error, json_success from zerver.lib.send_email import FromAddress, send_email -from zerver.lib.validator import check_int, check_string_in +from zerver.lib.validator import check_int, check_int_in, check_string_in from zerver.models import UserProfile, get_realm billing_logger = logging.getLogger("corporate.stripe") @@ -357,14 +357,20 @@ def billing_home(request: HttpRequest) -> HttpResponse: @require_billing_access @has_request_variables def change_plan_status( - request: HttpRequest, user: UserProfile, status: int = REQ("status", json_validator=check_int) + request: HttpRequest, + user: UserProfile, + status: int = REQ( + "status", + json_validator=check_int_in( + [ + CustomerPlan.ACTIVE, + CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, + CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE, + CustomerPlan.ENDED, + ] + ), + ), ) -> HttpResponse: - assert status in [ - CustomerPlan.ACTIVE, - CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, - CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE, - CustomerPlan.ENDED, - ] plan = get_current_plan_by_realm(user.realm) assert plan is not None # for mypy