diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 69300460a7..f7aee427d8 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -21,6 +21,7 @@ from typing import ( Tuple, TypeVar, ) +from unittest import mock from unittest.mock import Mock, patch import orjson @@ -681,6 +682,14 @@ class StripeTestCase(ZulipTestCase): licenses, automanage_licenses, billing_schedule, charge_automatically, free_trial ) + def setup_mocked_stripe(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Mock: + with patch.multiple("stripe", Invoice=mock.DEFAULT, InvoiceItem=mock.DEFAULT) as mocked: + mocked["Invoice"].create.return_value = None + mocked["Invoice"].finalize_invoice.return_value = None + mocked["InvoiceItem"].create.return_value = None + callback(*args, **kwargs) + return mocked + class StripeTest(StripeTestCase): def test_catch_stripe_errors(self) -> None: @@ -2242,14 +2251,10 @@ class StripeTest(StripeTestCase): ) # Verify that we invoice them for the additional users - from stripe import Invoice - - Invoice.create = lambda **args: None # type: ignore[assignment] # cleaner than mocking - Invoice.finalize_invoice = lambda *args: None # type: ignore[assignment] # cleaner than mocking - with patch("stripe.InvoiceItem.create") as mocked: - invoice_plans_as_needed(self.next_month) - mocked.assert_called_once() - mocked.reset_mock() + mocked = self.setup_mocked_stripe(invoice_plans_as_needed, self.next_month) + mocked["InvoiceItem"].create.assert_called_once() + mocked["Invoice"].finalize_invoice.assert_called_once() + mocked["Invoice"].create.assert_called_once() # Check that we downgrade properly if the cycle is over with patch("corporate.lib.stripe.get_latest_seat_count", return_value=30): @@ -2280,10 +2285,14 @@ class StripeTest(StripeTestCase): plan = CustomerPlan.objects.first() assert plan is not None self.assertIsNotNone(plan.next_invoice_date) - with patch("stripe.InvoiceItem.create") as mocked: - invoice_plans_as_needed(self.next_year + timedelta(days=32)) - mocked.assert_not_called() - mocked.reset_mock() + + mocked = self.setup_mocked_stripe( + invoice_plans_as_needed, self.next_year + timedelta(days=32) + ) + mocked["InvoiceItem"].create.assert_not_called() + mocked["Invoice"].finalize_invoice.assert_not_called() + mocked["Invoice"].create.assert_not_called() + # Check that we updated next_invoice_date in invoice_plan plan = CustomerPlan.objects.first() assert plan is not None @@ -2292,9 +2301,13 @@ class StripeTest(StripeTestCase): # Check that we don't call invoice_plan after that final call with patch("corporate.lib.stripe.get_latest_seat_count", return_value=50): update_license_ledger_if_needed(user.realm, self.next_year + timedelta(days=80)) - with patch("corporate.lib.stripe.invoice_plan") as mocked: - invoice_plans_as_needed(self.next_year + timedelta(days=400)) - mocked.assert_not_called() + + mocked = self.setup_mocked_stripe( + invoice_plans_as_needed, self.next_year + timedelta(days=400) + ) + mocked["InvoiceItem"].create.assert_not_called() + mocked["Invoice"].finalize_invoice.assert_not_called() + mocked["Invoice"].create.assert_not_called() @mock_stripe() def test_switch_from_monthly_plan_to_annual_plan_for_automatic_license_management( @@ -3059,13 +3072,10 @@ class StripeTest(StripeTestCase): ) # Verify that we don't invoice them for the additional users during free trial. - from stripe import Invoice - - Invoice.create = lambda **args: None # type: ignore[assignment] # cleaner than mocking - Invoice.finalize_invoice = lambda *args: None # type: ignore[assignment] # cleaner than mocking - with patch("stripe.InvoiceItem.create") as mocked: - invoice_plans_as_needed(self.next_month) - mocked.assert_not_called() + mocked = self.setup_mocked_stripe(invoice_plans_as_needed, self.next_month) + mocked["InvoiceItem"].create.assert_not_called() + mocked["Invoice"].finalize_invoice.assert_not_called() + mocked["Invoice"].create.assert_not_called() # Check that we downgrade properly if the cycle is over with patch("corporate.lib.stripe.get_latest_seat_count", return_value=30):