diff --git a/analytics/views/support.py b/analytics/views/support.py index 09f1f43919..18c7bbf8d9 100644 --- a/analytics/views/support.py +++ b/analytics/views/support.py @@ -54,10 +54,8 @@ if settings.ZILENCER_ENABLED: if settings.BILLING_ENABLED: from corporate.lib.stripe import approve_sponsorship as do_approve_sponsorship from corporate.lib.stripe import ( - attach_discount_to_realm, downgrade_at_the_end_of_billing_cycle, downgrade_now_without_creating_additional_invoices, - get_discount_for_realm, get_latest_seat_count, make_end_of_cycle_updates_if_needed, switch_realm_from_standard_to_plus_plan, @@ -65,6 +63,7 @@ if settings.BILLING_ENABLED: update_sponsorship_status, void_all_open_invoices, ) + from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm from corporate.models import ( Customer, CustomerPlan, diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index a062c4770e..1f03d09077 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -340,7 +340,9 @@ class BillingSession(ABC): pass @abstractmethod - def update_or_create_customer(self, stripe_customer_id: str) -> Customer: + def update_or_create_customer( + self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None + ) -> Customer: pass @catch_stripe_errors @@ -399,11 +401,39 @@ class BillingSession(ABC): self.replace_payment_method(customer.stripe_customer_id, payment_method, True) return customer + def attach_discount_to_customer(self, discount: Decimal) -> None: + customer = self.get_customer() + old_discount: Optional[Decimal] = None + if customer is not None: + old_discount = customer.default_discount + customer.default_discount = discount + customer.save(update_fields=["default_discount"]) + else: + customer = self.update_or_create_customer(defaults={"default_discount": discount}) + plan = get_current_plan_by_customer(customer) + if plan is not None: + plan.price_per_license = get_price_per_license( + plan.tier, plan.billing_schedule, discount + ) + plan.discount = discount + plan.save(update_fields=["price_per_license", "discount"]) + self.write_to_audit_log( + event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED, + event_time=timezone_now(), + extra_data={"old_discount": old_discount, "new_discount": discount}, + ) + class RealmBillingSession(BillingSession): - def __init__(self, user: UserProfile) -> None: + def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None: self.user = user - self.realm = user.realm + if realm is not None: + assert user.is_staff + self.realm = realm + self.support_session = True + else: + self.realm = user.realm + self.support_session = False @override def get_customer(self) -> Optional[Customer]: @@ -431,6 +461,8 @@ class RealmBillingSession(BillingSession): @override def get_data_for_stripe_customer(self) -> StripeCustomerData: + # Support requests do not set any stripe billing information. + assert self.support_session is False metadata: Dict[str, Any] = {} metadata["realm_id"] = self.realm.id metadata["realm_str"] = self.realm.string_id @@ -442,14 +474,24 @@ class RealmBillingSession(BillingSession): return realm_stripe_customer_data @override - def update_or_create_customer(self, stripe_customer_id: str) -> Customer: - customer, created = Customer.objects.update_or_create( - realm=self.realm, defaults={"stripe_customer_id": stripe_customer_id} - ) - from zerver.actions.users import do_make_user_billing_admin + def update_or_create_customer( + self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None + ) -> Customer: + if stripe_customer_id is not None: + # Support requests do not set any stripe billing information. + assert self.support_session is False + customer, created = Customer.objects.update_or_create( + realm=self.realm, defaults={"stripe_customer_id": stripe_customer_id} + ) + from zerver.actions.users import do_make_user_billing_admin - do_make_user_billing_admin(self.user) - return customer + do_make_user_billing_admin(self.user) + return customer + else: + customer, created = Customer.objects.update_or_create( + realm=self.realm, defaults=defaults + ) + return customer def stripe_customer_has_credit_card_as_default_payment_method( @@ -1002,31 +1044,6 @@ def is_realm_on_free_trial(realm: Realm) -> bool: return plan is not None and plan.is_free_trial() -def attach_discount_to_realm( - realm: Realm, discount: Decimal, *, acting_user: Optional[UserProfile] -) -> None: - customer = get_customer_by_realm(realm) - old_discount: Optional[Decimal] = None - if customer is not None: - old_discount = customer.default_discount - customer.default_discount = discount - customer.save(update_fields=["default_discount"]) - else: - Customer.objects.create(realm=realm, default_discount=discount) - plan = get_current_plan_by_realm(realm) - if plan is not None: - plan.price_per_license = get_price_per_license(plan.tier, plan.billing_schedule, discount) - plan.discount = discount - plan.save(update_fields=["price_per_license", "discount"]) - RealmAuditLog.objects.create( - realm=realm, - acting_user=acting_user, - event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED, - event_time=timezone_now(), - extra_data={"old_discount": old_discount, "new_discount": discount}, - ) - - def update_sponsorship_status( realm: Realm, sponsorship_pending: bool, *, acting_user: Optional[UserProfile] ) -> None: @@ -1079,13 +1096,6 @@ def is_sponsored_realm(realm: Realm) -> bool: return realm.plan_type == Realm.PLAN_TYPE_STANDARD_FREE -def get_discount_for_realm(realm: Realm) -> Optional[Decimal]: - customer = get_customer_by_realm(realm) - if customer is not None: - return customer.default_discount - return None - - def do_change_plan_status(plan: CustomerPlan, status: int) -> None: plan.status = status plan.save(update_fields=["status"]) diff --git a/corporate/lib/support.py b/corporate/lib/support.py index cbffbc389a..e8855598ed 100644 --- a/corporate/lib/support.py +++ b/corporate/lib/support.py @@ -1,9 +1,13 @@ +from decimal import Decimal +from typing import Optional from urllib.parse import urlencode, urljoin, urlunsplit from django.conf import settings from django.urls import reverse -from zerver.models import Realm, get_realm +from corporate.lib.stripe import RealmBillingSession +from corporate.models import get_customer_by_realm +from zerver.models import Realm, UserProfile, get_realm def get_support_url(realm: Realm) -> str: @@ -13,3 +17,15 @@ def get_support_url(realm: Realm) -> str: urlunsplit(("", "", reverse("support"), urlencode({"q": realm.string_id}), "")), ) return support_url + + +def get_discount_for_realm(realm: Realm) -> Optional[Decimal]: + customer = get_customer_by_realm(realm) + if customer is not None: + return customer.default_discount + return None + + +def attach_discount_to_realm(realm: Realm, discount: Decimal, *, acting_user: UserProfile) -> None: + billing_session = RealmBillingSession(acting_user, realm) + billing_session.attach_discount_to_customer(discount) diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index fa32055daf..6ebfc0d6b9 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -48,14 +48,12 @@ from corporate.lib.stripe import ( StripeCardError, add_months, approve_sponsorship, - attach_discount_to_realm, catch_stripe_errors, compute_plan_parameters, customer_has_credit_card_as_default_payment_method, do_change_remote_server_plan_type, do_deactivate_remote_server, downgrade_small_realms_behind_on_payments_as_needed, - get_discount_for_realm, get_latest_seat_count, get_plan_renewal_or_end_date, get_price_per_license, @@ -79,6 +77,7 @@ from corporate.lib.stripe import ( update_sponsorship_status, void_all_open_invoices, ) +from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm from corporate.models import ( Customer, CustomerPlan, @@ -2470,75 +2469,6 @@ class StripeTest(StripeTestCase): # card on file, and should show it # TODO - @mock_stripe() - def test_attach_discount_to_realm(self, *mocks: Mock) -> None: - # Attach discount before Stripe customer exists - user = self.example_user("hamlet") - attach_discount_to_realm(user.realm, Decimal(85), acting_user=user) - realm_audit_log = RealmAuditLog.objects.filter( - event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED - ).last() - assert realm_audit_log is not None - expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))} - self.assertEqual(realm_audit_log.extra_data, expected_extra_data) - self.login_user(user) - # Check that the discount appears in page_params - self.assert_in_success_response(["85"], self.client_get("/upgrade/")) - # Check that the customer was charged the discounted amount - self.upgrade() - customer = Customer.objects.first() - assert customer is not None - [charge] = stripe.Charge.list(customer=customer.stripe_customer_id) - self.assertEqual(1200 * self.seat_count, charge.amount) - stripe_customer_id = customer.stripe_customer_id - assert stripe_customer_id is not None - [invoice] = stripe.Invoice.list(customer=stripe_customer_id) - self.assertEqual( - [1200 * self.seat_count, -1200 * self.seat_count], - [item.amount for item in invoice.lines], - ) - # Check CustomerPlan reflects the discount - plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85)) - - # Attach discount to existing Stripe customer - plan.status = CustomerPlan.ENDED - plan.save(update_fields=["status"]) - attach_discount_to_realm(user.realm, Decimal(25), acting_user=user) - with patch("corporate.lib.stripe.timezone_now", return_value=self.now): - self.upgrade(license_management="automatic", billing_modality="charge_automatically") - [charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id) - self.assertEqual(6000 * self.seat_count, charge.amount) - stripe_customer_id = customer.stripe_customer_id - assert stripe_customer_id is not None - [invoice, _] = stripe.Invoice.list(customer=stripe_customer_id) - self.assertEqual( - [6000 * self.seat_count, -6000 * self.seat_count], - [item.amount for item in invoice.lines], - ) - plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25)) - - attach_discount_to_realm(user.realm, Decimal(50), acting_user=user) - plan.refresh_from_db() - self.assertEqual(plan.price_per_license, 4000) - self.assertEqual(plan.discount, 50) - customer.refresh_from_db() - self.assertEqual(customer.default_discount, 50) - invoice_plans_as_needed(self.next_year + timedelta(days=10)) - stripe_customer_id = customer.stripe_customer_id - assert stripe_customer_id is not None - [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id) - self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines]) - realm_audit_log = RealmAuditLog.objects.filter( - event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED - ).last() - assert realm_audit_log is not None - expected_extra_data = { - "old_discount": str(Decimal("25.0000")), - "new_discount": str(Decimal("50")), - } - self.assertEqual(realm_audit_log.extra_data, expected_extra_data) - self.assertEqual(realm_audit_log.acting_user, user) - def test_approve_sponsorship(self) -> None: user = self.example_user("hamlet") approve_sponsorship(user.realm, acting_user=user) @@ -2572,13 +2502,6 @@ class StripeTest(StripeTestCase): self.assertEqual(realm_audit_log.extra_data, expected_extra_data) self.assertEqual(realm_audit_log.acting_user, iago) - def test_get_discount_for_realm(self) -> None: - user = self.example_user("hamlet") - self.assertEqual(get_discount_for_realm(user.realm), None) - - attach_discount_to_realm(user.realm, Decimal(85), acting_user=None) - self.assertEqual(get_discount_for_realm(user.realm), 85) - @mock_stripe() def test_replace_payment_method(self, *mocks: Mock) -> None: user = self.example_user("hamlet") @@ -3750,7 +3673,7 @@ class StripeTest(StripeTestCase): users_to_create=1, create_stripe_customer=False, create_plan=False ) # To create local Customer object but no Stripe customer. - attach_discount_to_realm(realm, Decimal(20), acting_user=None) + attach_discount_to_realm(realm, Decimal(20), acting_user=self.example_user("iago")) rows.append(Row(realm, Realm.PLAN_TYPE_SELF_HOSTED, None, None, False, False)) realm, _, _, _ = create_realm( @@ -5040,3 +4963,83 @@ class TestTestClasses(ZulipTestCase): realm.refresh_from_db() self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD) + + +class TestSupportBillingHelpers(StripeTestCase): + def test_get_discount_for_realm(self) -> None: + iago = self.example_user("iago") + user = self.example_user("hamlet") + self.assertEqual(get_discount_for_realm(user.realm), None) + + attach_discount_to_realm(user.realm, Decimal(85), acting_user=iago) + self.assertEqual(get_discount_for_realm(user.realm), 85) + + @mock_stripe() + def test_attach_discount_to_realm(self, *mocks: Mock) -> None: + # Attach discount before Stripe customer exists + support_admin = self.example_user("iago") + user = self.example_user("hamlet") + attach_discount_to_realm(user.realm, Decimal(85), acting_user=support_admin) + realm_audit_log = RealmAuditLog.objects.filter( + event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED + ).last() + assert realm_audit_log is not None + expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))} + self.assertEqual(realm_audit_log.extra_data, expected_extra_data) + self.login_user(user) + # Check that the discount appears in page_params + self.assert_in_success_response(["85"], self.client_get("/upgrade/")) + # Check that the customer was charged the discounted amount + self.upgrade() + customer = Customer.objects.first() + assert customer is not None + [charge] = stripe.Charge.list(customer=customer.stripe_customer_id) + self.assertEqual(1200 * self.seat_count, charge.amount) + stripe_customer_id = customer.stripe_customer_id + assert stripe_customer_id is not None + [invoice] = stripe.Invoice.list(customer=stripe_customer_id) + self.assertEqual( + [1200 * self.seat_count, -1200 * self.seat_count], + [item.amount for item in invoice.lines], + ) + # Check CustomerPlan reflects the discount + plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85)) + + # Attach discount to existing Stripe customer + plan.status = CustomerPlan.ENDED + plan.save(update_fields=["status"]) + attach_discount_to_realm(user.realm, Decimal(25), acting_user=support_admin) + with patch("corporate.lib.stripe.timezone_now", return_value=self.now): + self.upgrade(license_management="automatic", billing_modality="charge_automatically") + [charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id) + self.assertEqual(6000 * self.seat_count, charge.amount) + stripe_customer_id = customer.stripe_customer_id + assert stripe_customer_id is not None + [invoice, _] = stripe.Invoice.list(customer=stripe_customer_id) + self.assertEqual( + [6000 * self.seat_count, -6000 * self.seat_count], + [item.amount for item in invoice.lines], + ) + plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25)) + + attach_discount_to_realm(user.realm, Decimal(50), acting_user=support_admin) + plan.refresh_from_db() + self.assertEqual(plan.price_per_license, 4000) + self.assertEqual(plan.discount, 50) + customer.refresh_from_db() + self.assertEqual(customer.default_discount, 50) + invoice_plans_as_needed(self.next_year + timedelta(days=10)) + stripe_customer_id = customer.stripe_customer_id + assert stripe_customer_id is not None + [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id) + self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines]) + realm_audit_log = RealmAuditLog.objects.filter( + event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED + ).last() + assert realm_audit_log is not None + expected_extra_data = { + "old_discount": str(Decimal("25.0000")), + "new_discount": str(Decimal("50")), + } + self.assertEqual(realm_audit_log.extra_data, expected_extra_data) + self.assertEqual(realm_audit_log.acting_user, support_admin)