diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 0a18bad162..e9dd00bfaf 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -2069,6 +2069,23 @@ class BillingSession(ABC): return None, context + def min_licenses_for_flat_discount_to_self_hosted_basic_plan( + self, customer: Optional[Customer] + ) -> int: + # Since monthly and annual TIER_SELF_HOSTED_BASIC plans have same per user price we only need to do this calculation once. + # If we decided to apply this for other tiers, then we will have to do this calculation based on billing schedule selected by the user. + price_per_license = get_price_per_license( + CustomerPlan.TIER_SELF_HOSTED_BASIC, CustomerPlan.BILLING_SCHEDULE_MONTHLY + ) + if customer is None: + return ( + Customer._meta.get_field("flat_discount").get_default() // price_per_license + ) + 1 + elif customer.flat_discounted_months > 0: + return (customer.flat_discount // price_per_license) + 1 + # If flat discount is not applied. + return 1 + def min_licenses_for_plan(self, tier: int) -> int: customer = self.get_customer() if customer is not None and customer.minimum_licenses: @@ -2076,7 +2093,7 @@ class BillingSession(ABC): return customer.minimum_licenses if tier == CustomerPlan.TIER_SELF_HOSTED_BASIC: - return 10 + return min(self.min_licenses_for_flat_discount_to_self_hosted_basic_plan(customer), 10) if tier == CustomerPlan.TIER_SELF_HOSTED_BUSINESS: return 25 return 1 diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 7ad184d7de..bb4faab0a3 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -5851,7 +5851,7 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): min_licenses = self.billing_session.min_licenses_for_plan( CustomerPlan.TIER_SELF_HOSTED_BASIC ) - self.assertEqual(min_licenses, 10) + self.assertEqual(min_licenses, 6) flat_discount, flat_discounted_months = self.billing_session.get_flat_discount_info() self.assertEqual(flat_discounted_months, 12) @@ -5919,7 +5919,7 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): f"{self.billing_session.billing_base_url}/billing/", subdomain="selfhosting" ) - self.assertEqual(latest_ledger.licenses, 20) + self.assertEqual(latest_ledger.licenses, min_licenses + 10) for substring in [ "Zulip Basic", "Number of licenses", @@ -5932,6 +5932,13 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): ]: self.assert_in_response(substring, response) + # Check minimum licenses is 0 after flat discounted months is over. + customer.flat_discounted_months = 0 + customer.save(update_fields=["flat_discounted_months"]) + self.assertEqual( + self.billing_session.min_licenses_for_plan(CustomerPlan.TIER_SELF_HOSTED_BASIC), 1 + ) + @responses.activate def test_request_sponsorship(self) -> None: self.login("hamlet") @@ -6529,7 +6536,7 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): min_licenses = self.billing_session.min_licenses_for_plan( CustomerPlan.TIER_SELF_HOSTED_BASIC ) - self.assertEqual(min_licenses, 10) + self.assertEqual(min_licenses, 6) flat_discount, flat_discounted_months = self.billing_session.get_flat_discount_info() self.assertEqual(flat_discounted_months, 12) @@ -6571,7 +6578,7 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): self.assertEqual(LicenseLedger.objects.count(), 1) with time_machine.travel(self.now + timedelta(days=2), tick=False): - for count in range(realm_user_count, min_licenses + 10): + for count in range(realm_user_count, realm_user_count + 10): do_create_user( f"email {count}", f"password {count}", @@ -6586,18 +6593,18 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): self.assertEqual( RemoteRealmAuditLog.objects.count(), - min_licenses + 10 - realm_user_count + audit_log_count, + audit_log_count + 10, ) latest_ledger = LicenseLedger.objects.last() assert latest_ledger is not None - self.assertEqual(latest_ledger.licenses, min_licenses + 10) + self.assertEqual(latest_ledger.licenses, 28) with time_machine.travel(self.now + timedelta(days=1), tick=False): response = self.client_get( f"{self.billing_session.billing_base_url}/billing/", subdomain="selfhosting" ) - self.assertEqual(latest_ledger.licenses, 20) + self.assertEqual(latest_ledger.licenses, 28) for substring in [ "Zulip Basic", "Number of licenses", @@ -6610,6 +6617,13 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): ]: self.assert_in_response(substring, response) + # Check minimum licenses is 0 after flat discounted months is over. + customer.flat_discounted_months = 0 + customer.save(update_fields=["flat_discounted_months"]) + self.assertEqual( + self.billing_session.min_licenses_for_plan(CustomerPlan.TIER_SELF_HOSTED_BASIC), 1 + ) + def test_deactivate_registration_with_push_notification_service(self) -> None: self.login("hamlet") hamlet = self.example_user("hamlet")