billing: Create StripeTestCase.

This commit is contained in:
Rishi Gupta
2019-01-27 12:16:02 -08:00
parent 83a7595feb
commit fe280fc38c

View File

@@ -211,7 +211,7 @@ class Kandra(object): # nocoverage: TODO
def __eq__(self, other: Any) -> bool:
return True
class StripeTest(ZulipTestCase):
class StripeTestCase(ZulipTestCase):
def setUp(self, *mocks: Mock) -> None:
# TODO
# Unfortunately this test suite is likely not robust to users being
@@ -280,6 +280,24 @@ class StripeTest(ZulipTestCase):
params[key] = ujson.dumps(value)
return self.client_post("/json/billing/upgrade", params, **host_args)
# Upgrade without talking to Stripe
def local_upgrade(self, *args: Any) -> None:
class StripeMock(object):
def __init__(self, depth: int=1):
self.id = 'id'
self.created = '1000'
self.last4 = '4242'
if depth == 1:
self.source = StripeMock(depth=2)
def upgrade_func(*args: Any) -> Any:
return process_initial_upgrade(self.example_user('hamlet'), *args[:4])
for mocked_function_name in MOCKED_STRIPE_FUNCTION_NAMES:
upgrade_func = patch(mocked_function_name, return_value=StripeMock())(upgrade_func)
upgrade_func(*args)
class StripeTest(StripeTestCase):
@patch("corporate.lib.stripe.billing_logger.error")
def test_catch_stripe_errors(self, mock_billing_logger_error: Mock) -> None:
@catch_stripe_errors
@@ -921,31 +939,7 @@ class BillingHelpersTest(ZulipTestCase):
mocked3.assert_not_called()
self.assertTrue(isinstance(customer, Customer))
# todo: Create a StripeTestCase, similar to AnalyticsTestCase
class LicenseLedgerTest(ZulipTestCase):
def setUp(self) -> None:
self.seat_count = get_seat_count(get_realm('zulip'))
self.now = datetime(2012, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
self.next_month = datetime(2012, 2, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
self.next_year = datetime(2013, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
# Upgrade without talking to Stripe
def local_upgrade(self, *args: Any) -> None:
class StripeMock(object):
def __init__(self, depth: int=1):
self.id = 'id'
self.created = '1000'
self.last4 = '4242'
if depth == 1:
self.source = StripeMock(depth=2)
def upgrade_func(*args: Any) -> Any:
return process_initial_upgrade(self.example_user('hamlet'), *args[:4])
for mocked_function_name in MOCKED_STRIPE_FUNCTION_NAMES:
upgrade_func = patch(mocked_function_name, return_value=StripeMock())(upgrade_func)
upgrade_func(*args)
class LicenseLedgerTest(StripeTestCase):
def test_add_plan_renewal_if_needed(self) -> None:
with patch('corporate.lib.stripe.timezone_now', return_value=self.now):
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, 'token')