corporate: Move attach_realm_discount to BillingSession class.

This moves the logic for `attach_realm_discount`, which is used in
the support view, to be in the BillingSession class.

Updates the function name to be `attach_discount_to_customer` so
that the context is generalized vs realm specific.

Updates RealmBillingSession implementation to account for actions
that are initiated by a support admin user.

Also moves the helper function `get_discount_for_realm` that is
only used in support views to `corporate/lib/support.py`.
This commit is contained in:
Lauryn Menard
2023-10-31 19:22:55 +01:00
committed by Tim Abbott
parent 63abf063b7
commit ee19a9c274
4 changed files with 152 additions and 124 deletions

View File

@@ -54,10 +54,8 @@ if settings.ZILENCER_ENABLED:
if settings.BILLING_ENABLED:
from corporate.lib.stripe import approve_sponsorship as do_approve_sponsorship
from corporate.lib.stripe import (
attach_discount_to_realm,
downgrade_at_the_end_of_billing_cycle,
downgrade_now_without_creating_additional_invoices,
get_discount_for_realm,
get_latest_seat_count,
make_end_of_cycle_updates_if_needed,
switch_realm_from_standard_to_plus_plan,
@@ -65,6 +63,7 @@ if settings.BILLING_ENABLED:
update_sponsorship_status,
void_all_open_invoices,
)
from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm
from corporate.models import (
Customer,
CustomerPlan,

View File

@@ -340,7 +340,9 @@ class BillingSession(ABC):
pass
@abstractmethod
def update_or_create_customer(self, stripe_customer_id: str) -> Customer:
def update_or_create_customer(
self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
) -> Customer:
pass
@catch_stripe_errors
@@ -399,11 +401,39 @@ class BillingSession(ABC):
self.replace_payment_method(customer.stripe_customer_id, payment_method, True)
return customer
def attach_discount_to_customer(self, discount: Decimal) -> None:
customer = self.get_customer()
old_discount: Optional[Decimal] = None
if customer is not None:
old_discount = customer.default_discount
customer.default_discount = discount
customer.save(update_fields=["default_discount"])
else:
customer = self.update_or_create_customer(defaults={"default_discount": discount})
plan = get_current_plan_by_customer(customer)
if plan is not None:
plan.price_per_license = get_price_per_license(
plan.tier, plan.billing_schedule, discount
)
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
self.write_to_audit_log(
event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED,
event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount},
)
class RealmBillingSession(BillingSession):
def __init__(self, user: UserProfile) -> None:
def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None:
self.user = user
if realm is not None:
assert user.is_staff
self.realm = realm
self.support_session = True
else:
self.realm = user.realm
self.support_session = False
@override
def get_customer(self) -> Optional[Customer]:
@@ -431,6 +461,8 @@ class RealmBillingSession(BillingSession):
@override
def get_data_for_stripe_customer(self) -> StripeCustomerData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
metadata: Dict[str, Any] = {}
metadata["realm_id"] = self.realm.id
metadata["realm_str"] = self.realm.string_id
@@ -442,7 +474,12 @@ class RealmBillingSession(BillingSession):
return realm_stripe_customer_data
@override
def update_or_create_customer(self, stripe_customer_id: str) -> Customer:
def update_or_create_customer(
self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
) -> Customer:
if stripe_customer_id is not None:
# Support requests do not set any stripe billing information.
assert self.support_session is False
customer, created = Customer.objects.update_or_create(
realm=self.realm, defaults={"stripe_customer_id": stripe_customer_id}
)
@@ -450,6 +487,11 @@ class RealmBillingSession(BillingSession):
do_make_user_billing_admin(self.user)
return customer
else:
customer, created = Customer.objects.update_or_create(
realm=self.realm, defaults=defaults
)
return customer
def stripe_customer_has_credit_card_as_default_payment_method(
@@ -1002,31 +1044,6 @@ def is_realm_on_free_trial(realm: Realm) -> bool:
return plan is not None and plan.is_free_trial()
def attach_discount_to_realm(
realm: Realm, discount: Decimal, *, acting_user: Optional[UserProfile]
) -> None:
customer = get_customer_by_realm(realm)
old_discount: Optional[Decimal] = None
if customer is not None:
old_discount = customer.default_discount
customer.default_discount = discount
customer.save(update_fields=["default_discount"])
else:
Customer.objects.create(realm=realm, default_discount=discount)
plan = get_current_plan_by_realm(realm)
if plan is not None:
plan.price_per_license = get_price_per_license(plan.tier, plan.billing_schedule, discount)
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
RealmAuditLog.objects.create(
realm=realm,
acting_user=acting_user,
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED,
event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount},
)
def update_sponsorship_status(
realm: Realm, sponsorship_pending: bool, *, acting_user: Optional[UserProfile]
) -> None:
@@ -1079,13 +1096,6 @@ def is_sponsored_realm(realm: Realm) -> bool:
return realm.plan_type == Realm.PLAN_TYPE_STANDARD_FREE
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = get_customer_by_realm(realm)
if customer is not None:
return customer.default_discount
return None
def do_change_plan_status(plan: CustomerPlan, status: int) -> None:
plan.status = status
plan.save(update_fields=["status"])

View File

@@ -1,9 +1,13 @@
from decimal import Decimal
from typing import Optional
from urllib.parse import urlencode, urljoin, urlunsplit
from django.conf import settings
from django.urls import reverse
from zerver.models import Realm, get_realm
from corporate.lib.stripe import RealmBillingSession
from corporate.models import get_customer_by_realm
from zerver.models import Realm, UserProfile, get_realm
def get_support_url(realm: Realm) -> str:
@@ -13,3 +17,15 @@ def get_support_url(realm: Realm) -> str:
urlunsplit(("", "", reverse("support"), urlencode({"q": realm.string_id}), "")),
)
return support_url
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = get_customer_by_realm(realm)
if customer is not None:
return customer.default_discount
return None
def attach_discount_to_realm(realm: Realm, discount: Decimal, *, acting_user: UserProfile) -> None:
billing_session = RealmBillingSession(acting_user, realm)
billing_session.attach_discount_to_customer(discount)

View File

@@ -48,14 +48,12 @@ from corporate.lib.stripe import (
StripeCardError,
add_months,
approve_sponsorship,
attach_discount_to_realm,
catch_stripe_errors,
compute_plan_parameters,
customer_has_credit_card_as_default_payment_method,
do_change_remote_server_plan_type,
do_deactivate_remote_server,
downgrade_small_realms_behind_on_payments_as_needed,
get_discount_for_realm,
get_latest_seat_count,
get_plan_renewal_or_end_date,
get_price_per_license,
@@ -79,6 +77,7 @@ from corporate.lib.stripe import (
update_sponsorship_status,
void_all_open_invoices,
)
from corporate.lib.support import attach_discount_to_realm, get_discount_for_realm
from corporate.models import (
Customer,
CustomerPlan,
@@ -2470,75 +2469,6 @@ class StripeTest(StripeTestCase):
# card on file, and should show it
# TODO
@mock_stripe()
def test_attach_discount_to_realm(self, *mocks: Mock) -> None:
# Attach discount before Stripe customer exists
user = self.example_user("hamlet")
attach_discount_to_realm(user.realm, Decimal(85), acting_user=user)
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.login_user(user)
# Check that the discount appears in page_params
self.assert_in_success_response(["85"], self.client_get("/upgrade/"))
# Check that the customer was charged the discounted amount
self.upgrade()
customer = Customer.objects.first()
assert customer is not None
[charge] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(1200 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[1200 * self.seat_count, -1200 * self.seat_count],
[item.amount for item in invoice.lines],
)
# Check CustomerPlan reflects the discount
plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85))
# Attach discount to existing Stripe customer
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
attach_discount_to_realm(user.realm, Decimal(25), acting_user=user)
with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.upgrade(license_management="automatic", billing_modality="charge_automatically")
[charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[6000 * self.seat_count, -6000 * self.seat_count],
[item.amount for item in invoice.lines],
)
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
attach_discount_to_realm(user.realm, Decimal(50), acting_user=user)
plan.refresh_from_db()
self.assertEqual(plan.price_per_license, 4000)
self.assertEqual(plan.discount, 50)
customer.refresh_from_db()
self.assertEqual(customer.default_discount, 50)
invoice_plans_as_needed(self.next_year + timedelta(days=10))
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines])
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {
"old_discount": str(Decimal("25.0000")),
"new_discount": str(Decimal("50")),
}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, user)
def test_approve_sponsorship(self) -> None:
user = self.example_user("hamlet")
approve_sponsorship(user.realm, acting_user=user)
@@ -2572,13 +2502,6 @@ class StripeTest(StripeTestCase):
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, iago)
def test_get_discount_for_realm(self) -> None:
user = self.example_user("hamlet")
self.assertEqual(get_discount_for_realm(user.realm), None)
attach_discount_to_realm(user.realm, Decimal(85), acting_user=None)
self.assertEqual(get_discount_for_realm(user.realm), 85)
@mock_stripe()
def test_replace_payment_method(self, *mocks: Mock) -> None:
user = self.example_user("hamlet")
@@ -3750,7 +3673,7 @@ class StripeTest(StripeTestCase):
users_to_create=1, create_stripe_customer=False, create_plan=False
)
# To create local Customer object but no Stripe customer.
attach_discount_to_realm(realm, Decimal(20), acting_user=None)
attach_discount_to_realm(realm, Decimal(20), acting_user=self.example_user("iago"))
rows.append(Row(realm, Realm.PLAN_TYPE_SELF_HOSTED, None, None, False, False))
realm, _, _, _ = create_realm(
@@ -5040,3 +4963,83 @@ class TestTestClasses(ZulipTestCase):
realm.refresh_from_db()
self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD)
class TestSupportBillingHelpers(StripeTestCase):
def test_get_discount_for_realm(self) -> None:
iago = self.example_user("iago")
user = self.example_user("hamlet")
self.assertEqual(get_discount_for_realm(user.realm), None)
attach_discount_to_realm(user.realm, Decimal(85), acting_user=iago)
self.assertEqual(get_discount_for_realm(user.realm), 85)
@mock_stripe()
def test_attach_discount_to_realm(self, *mocks: Mock) -> None:
# Attach discount before Stripe customer exists
support_admin = self.example_user("iago")
user = self.example_user("hamlet")
attach_discount_to_realm(user.realm, Decimal(85), acting_user=support_admin)
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {"old_discount": None, "new_discount": str(Decimal("85"))}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.login_user(user)
# Check that the discount appears in page_params
self.assert_in_success_response(["85"], self.client_get("/upgrade/"))
# Check that the customer was charged the discounted amount
self.upgrade()
customer = Customer.objects.first()
assert customer is not None
[charge] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(1200 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[1200 * self.seat_count, -1200 * self.seat_count],
[item.amount for item in invoice.lines],
)
# Check CustomerPlan reflects the discount
plan = CustomerPlan.objects.get(price_per_license=1200, discount=Decimal(85))
# Attach discount to existing Stripe customer
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
attach_discount_to_realm(user.realm, Decimal(25), acting_user=support_admin)
with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.upgrade(license_management="automatic", billing_modality="charge_automatically")
[charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge.amount)
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual(
[6000 * self.seat_count, -6000 * self.seat_count],
[item.amount for item in invoice.lines],
)
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
attach_discount_to_realm(user.realm, Decimal(50), acting_user=support_admin)
plan.refresh_from_db()
self.assertEqual(plan.price_per_license, 4000)
self.assertEqual(plan.discount, 50)
customer.refresh_from_db()
self.assertEqual(customer.default_discount, 50)
invoice_plans_as_needed(self.next_year + timedelta(days=10))
stripe_customer_id = customer.stripe_customer_id
assert stripe_customer_id is not None
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer_id)
self.assertEqual([4000 * self.seat_count], [item.amount for item in invoice.lines])
realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.REALM_DISCOUNT_CHANGED
).last()
assert realm_audit_log is not None
expected_extra_data = {
"old_discount": str(Decimal("25.0000")),
"new_discount": str(Decimal("50")),
}
self.assertEqual(realm_audit_log.extra_data, expected_extra_data)
self.assertEqual(realm_audit_log.acting_user, support_admin)