corporate: Move attach_realm_discount to BillingSession class.

This moves the logic for `attach_realm_discount`, which is used in
the support view, to be in the BillingSession class.

Updates the function name to be `attach_discount_to_customer` so
that the context is generalized vs realm specific.

Updates RealmBillingSession implementation to account for actions
that are initiated by a support admin user.

Also moves the helper function `get_discount_for_realm` that is
only used in support views to `corporate/lib/support.py`.
This commit is contained in:
Lauryn Menard
2023-10-31 19:22:55 +01:00
committed by Tim Abbott
parent 63abf063b7
commit ee19a9c274
4 changed files with 152 additions and 124 deletions

View File

@@ -54,10 +54,8 @@ if settings.ZILENCER_ENABLED:
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
from corporate.lib.stripe import approve_sponsorship as do_approve_sponsorship from corporate.lib.stripe import approve_sponsorship as do_approve_sponsorship
from corporate.lib.stripe import ( from corporate.lib.stripe import (
attach_discount_to_realm,
downgrade_at_the_end_of_billing_cycle, downgrade_at_the_end_of_billing_cycle,
downgrade_now_without_creating_additional_invoices, downgrade_now_without_creating_additional_invoices,
get_discount_for_realm,
get_latest_seat_count, get_latest_seat_count,
make_end_of_cycle_updates_if_needed, make_end_of_cycle_updates_if_needed,
switch_realm_from_standard_to_plus_plan, switch_realm_from_standard_to_plus_plan,
@@ -65,6 +63,7 @@ if settings.BILLING_ENABLED:
update_sponsorship_status, update_sponsorship_status,
void_all_open_invoices, void_all_open_invoices,
) )
from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm
from corporate.models import ( from corporate.models import (
Customer, Customer,
CustomerPlan, CustomerPlan,

View File

@@ -340,7 +340,9 @@ class BillingSession(ABC):
pass pass
@abstractmethod @abstractmethod
def update_or_create_customer(self, stripe_customer_id: str) -> Customer: def update_or_create_customer(
self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
) -> Customer:
pass pass
@catch_stripe_errors @catch_stripe_errors
@@ -399,11 +401,39 @@ class BillingSession(ABC):
self.replace_payment_method(customer.stripe_customer_id, payment_method, True) self.replace_payment_method(customer.stripe_customer_id, payment_method, True)
return customer return customer
def attach_discount_to_customer(self, discount: Decimal) -> None:
customer = self.get_customer()
old_discount: Optional[Decimal] = None
if customer is not None:
old_discount = customer.default_discount
customer.default_discount = discount
customer.save(update_fields=["default_discount"])
else:
customer = self.update_or_create_customer(defaults={"default_discount": discount})
plan = get_current_plan_by_customer(customer)
if plan is not None:
plan.price_per_license = get_price_per_license(
plan.tier, plan.billing_schedule, discount
)
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
self.write_to_audit_log(
event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED,
event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount},
)
class RealmBillingSession(BillingSession): class RealmBillingSession(BillingSession):
def __init__(self, user: UserProfile) -> None: def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None:
self.user = user self.user = user
self.realm = user.realm if realm is not None:
assert user.is_staff
self.realm = realm
self.support_session = True
else:
self.realm = user.realm
self.support_session = False
@override @override
def get_customer(self) -> Optional[Customer]: def get_customer(self) -> Optional[Customer]:
@@ -431,6 +461,8 @@ class RealmBillingSession(BillingSession):
@override @override
def get_data_for_stripe_customer(self) -> StripeCustomerData: def get_data_for_stripe_customer(self) -> StripeCustomerData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
metadata: Dict[str, Any] = {} metadata: Dict[str, Any] = {}
metadata["realm_id"] = self.realm.id metadata["realm_id"] = self.realm.id
metadata["realm_str"] = self.realm.string_id metadata["realm_str"] = self.realm.string_id
@@ -442,14 +474,24 @@ class RealmBillingSession(BillingSession):
return realm_stripe_customer_data return realm_stripe_customer_data
@override @override
def update_or_create_customer(self, stripe_customer_id: str) -> Customer: def update_or_create_customer(
customer, created = Customer.objects.update_or_create( self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
realm=self.realm, defaults={"stripe_customer_id": stripe_customer_id} ) -> Customer:
) if stripe_customer_id is not None:
from zerver.actions.users import do_make_user_billing_admin # Support requests do not set any stripe billing information.
assert self.support_session is False
customer, created = Customer.objects.update_or_create(
realm=self.realm, defaults={"stripe_customer_id": stripe_customer_id}
)
from zerver.actions.users import do_make_user_billing_admin
do_make_user_billing_admin(self.user) do_make_user_billing_admin(self.user)
return customer return customer
else:
customer, created = Customer.objects.update_or_create(
realm=self.realm, defaults=defaults
)
return customer
def stripe_customer_has_credit_card_as_default_payment_method( def stripe_customer_has_credit_card_as_default_payment_method(
@@ -1002,31 +1044,6 @@ def is_realm_on_free_trial(realm: Realm) -> bool:
return plan is not None and plan.is_free_trial() return plan is not None and plan.is_free_trial()
def attach_discount_to_realm(
realm: Realm, discount: Decimal, *, acting_user: Optional[UserProfile]
) -> None:
customer = get_customer_by_realm(realm)
old_discount: Optional[Decimal] = None
if customer is not None:
old_discount = customer.default_discount
customer.default_discount = discount
customer.save(update_fields=["default_discount"])
else:
Customer.objects.create(realm=realm, default_discount=discount)
plan = get_current_plan_by_realm(realm)
if plan is not None:
plan.price_per_license = get_price_per_license(plan.tier, plan.billing_schedule, discount)
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
RealmAuditLog.objects.create(
realm=realm,
acting_user=acting_user,
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED,
event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount},
)
def update_sponsorship_status( def update_sponsorship_status(
realm: Realm, sponsorship_pending: bool, *, acting_user: Optional[UserProfile] realm: Realm, sponsorship_pending: bool, *, acting_user: Optional[UserProfile]
) -> None: ) -> None:
@@ -1079,13 +1096,6 @@ def is_sponsored_realm(realm: Realm) -> bool:
return realm.plan_type == Realm.PLAN_TYPE_STANDARD_FREE return realm.plan_type == Realm.PLAN_TYPE_STANDARD_FREE
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = get_customer_by_realm(realm)
if customer is not None:
return customer.default_discount
return None
def do_change_plan_status(plan: CustomerPlan, status: int) -> None: def do_change_plan_status(plan: CustomerPlan, status: int) -> None:
plan.status = status plan.status = status
plan.save(update_fields=["status"]) plan.save(update_fields=["status"])

View File

@@ -1,9 +1,13 @@
from decimal import Decimal
from typing import Optional
from urllib.parse import urlencode, urljoin, urlunsplit from urllib.parse import urlencode, urljoin, urlunsplit
from django.conf import settings from django.conf import settings
from django.urls import reverse from django.urls import reverse
from zerver.models import Realm, get_realm from corporate.lib.stripe import RealmBillingSession
from corporate.models import get_customer_by_realm
from zerver.models import Realm, UserProfile, get_realm
def get_support_url(realm: Realm) -> str: def get_support_url(realm: Realm) -> str:
@@ -13,3 +17,15 @@ def get_support_url(realm: Realm) -> str:
urlunsplit(("", "", reverse("support"), urlencode({"q": realm.string_id}), "")), urlunsplit(("", "", reverse("support"), urlencode({"q": realm.string_id}), "")),
) )
return support_url return support_url
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = get_customer_by_realm(realm)
if customer is not None:
return customer.default_discount
return None
def attach_discount_to_realm(realm: Realm, discount: Decimal, *, acting_user: UserProfile) -> None:
billing_session = RealmBillingSession(acting_user, realm)
billing_session.attach_discount_to_customer(discount)

View File

@@ -48,14 +48,12 @@ from corporate.lib.stripe import (
StripeCardError, StripeCardError,
add_months, add_months,
approve_sponsorship, approve_sponsorship,
attach_discount_to_realm,
catch_stripe_errors, catch_stripe_errors,
compute_plan_parameters, compute_plan_parameters,
customer_has_credit_card_as_default_payment_method, customer_has_credit_card_as_default_payment_method,
do_change_remote_server_plan_type, do_change_remote_server_plan_type,
do_deactivate_remote_server, do_deactivate_remote_server,
downgrade_small_realms_behind_on_payments_as_needed, downgrade_small_realms_behind_on_payments_as_needed,
get_discount_for_realm,
get_latest_seat_count, get_latest_seat_count,
get_plan_renewal_or_end_date, get_plan_renewal_or_end_date,
get_price_per_license, get_price_per_license,
@@ -79,6 +77,7 @@ from corporate.lib.stripe import (
update_sponsorship_status, update_sponsorship_status,
void_all_open_invoices, void_all_open_invoices,
) )
from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm
from corporate.models import ( from corporate.models import (
Customer, Customer,
CustomerPlan, CustomerPlan,
@@ -2470,75 +2469,6 @@ class StripeTest(StripeTestCase):
# card on file, and should show it # card on file, and should show it
# TODO # TODO
@mock_stripe()
def test_attach_discount_to_realm(self, *mocks: Mock) -> None:
# Attach discount before Stripe customer exists
user = self.example_user("hamlet")
attach_discount_to_realm(user.realm, Decimal(85), acting_user=user)
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.login_user(user)
# Check that the discount appears in page_params
self.assert_in_success_response(["85"], self.client_get("/upgrade/"))
# Check that the customer was charged the discounted amount
self.upgrade()
customer = Customer.objects.first()
assert customer is not None
[charge] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(1200 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[1200 * self.seat_count, -1200 * self.seat_count],
[item.amount for item in invoice.lines],
)
# Check CustomerPlan reflects the discount
plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85))
# Attach discount to existing Stripe customer
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
attach_discount_to_realm(user.realm, Decimal(25), acting_user=user)
with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.upgrade(license_management="automatic", billing_modality="charge_automatically")
[charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[6000 * self.seat_count, -6000 * self.seat_count],
[item.amount for item in invoice.lines],
)
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
attach_discount_to_realm(user.realm, Decimal(50), acting_user=user)
plan.refresh_from_db()
self.assertEqual(plan.price_per_license, 4000)
self.assertEqual(plan.discount, 50)
customer.refresh_from_db()
self.assertEqual(customer.default_discount, 50)
invoice_plans_as_needed(self.next_year + timedelta(days=10))
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines])
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {
"old_discount": str(Decimal("25.0000")),
"new_discount": str(Decimal("50")),
}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, user)
def test_approve_sponsorship(self) -> None: def test_approve_sponsorship(self) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
approve_sponsorship(user.realm, acting_user=user) approve_sponsorship(user.realm, acting_user=user)
@@ -2572,13 +2502,6 @@ class StripeTest(StripeTestCase):
self.assertEqual(realm_audit_log.extra_data, expected_extra_data) self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, iago) self.assertEqual(realm_audit_log.acting_user, iago)
def test_get_discount_for_realm(self) -> None:
user = self.example_user("hamlet")
self.assertEqual(get_discount_for_realm(user.realm), None)
attach_discount_to_realm(user.realm, Decimal(85), acting_user=None)
self.assertEqual(get_discount_for_realm(user.realm), 85)
@mock_stripe() @mock_stripe()
def test_replace_payment_method(self, *mocks: Mock) -> None: def test_replace_payment_method(self, *mocks: Mock) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
@@ -3750,7 +3673,7 @@ class StripeTest(StripeTestCase):
users_to_create=1, create_stripe_customer=False, create_plan=False users_to_create=1, create_stripe_customer=False, create_plan=False
) )
# To create local Customer object but no Stripe customer. # To create local Customer object but no Stripe customer.
attach_discount_to_realm(realm, Decimal(20), acting_user=None) attach_discount_to_realm(realm, Decimal(20), acting_user=self.example_user("iago"))
rows.append(Row(realm, Realm.PLAN_TYPE_SELF_HOSTED, None, None, False, False)) rows.append(Row(realm, Realm.PLAN_TYPE_SELF_HOSTED, None, None, False, False))
realm, _, _, _ = create_realm( realm, _, _, _ = create_realm(
@@ -5040,3 +4963,83 @@ class TestTestClasses(ZulipTestCase):
realm.refresh_from_db() realm.refresh_from_db()
self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD)
class TestSupportBillingHelpers(StripeTestCase):
def test_get_discount_for_realm(self) -> None:
iago = self.example_user("iago")
user = self.example_user("hamlet")
self.assertEqual(get_discount_for_realm(user.realm), None)
attach_discount_to_realm(user.realm, Decimal(85), acting_user=iago)
self.assertEqual(get_discount_for_realm(user.realm), 85)
@mock_stripe()
def test_attach_discount_to_realm(self, *mocks: Mock) -> None:
# Attach discount before Stripe customer exists
support_admin = self.example_user("iago")
user = self.example_user("hamlet")
attach_discount_to_realm(user.realm, Decimal(85), acting_user=support_admin)
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.login_user(user)
# Check that the discount appears in page_params
self.assert_in_success_response(["85"], self.client_get("/upgrade/"))
# Check that the customer was charged the discounted amount
self.upgrade()
customer = Customer.objects.first()
assert customer is not None
[charge] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(1200 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[1200 * self.seat_count, -1200 * self.seat_count],
[item.amount for item in invoice.lines],
)
# Check CustomerPlan reflects the discount
plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85))
# Attach discount to existing Stripe customer
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
attach_discount_to_realm(user.realm, Decimal(25), acting_user=support_admin)
with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.upgrade(license_management="automatic", billing_modality="charge_automatically")
[charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[6000 * self.seat_count, -6000 * self.seat_count],
[item.amount for item in invoice.lines],
)
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
attach_discount_to_realm(user.realm, Decimal(50), acting_user=support_admin)
plan.refresh_from_db()
self.assertEqual(plan.price_per_license, 4000)
self.assertEqual(plan.discount, 50)
customer.refresh_from_db()
self.assertEqual(customer.default_discount, 50)
invoice_plans_as_needed(self.next_year + timedelta(days=10))
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines])
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {
"old_discount": str(Decimal("25.0000")),
"new_discount": str(Decimal("50")),
}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, support_admin)