billing: Apply a flat discount for self hosted plans.

This commit is contained in:
Aman Agrawal
2023-12-20 06:24:21 +00:00
committed by Tim Abbott
parent 0337c726d3
commit e192aef23d
95 changed files with 8706 additions and 388 deletions

View File

@@ -588,6 +588,8 @@ class UpgradePageParams(TypedDict):
seat_count: int
billing_base_url: str
tier: int
flat_discount: int
flat_discounted_months: int
class UpgradePageSessionTypeSpecificContext(TypedDict):
@@ -706,9 +708,11 @@ class BillingSession(ABC):
def get_data_for_stripe_payment_intent(
self,
customer: Customer,
price_per_license: int,
licenses: int,
plan_tier: int,
billing_schedule: int,
email: str,
) -> StripePaymentIntentData:
if hasattr(self, "support_session") and self.support_session: # nocoverage
@@ -721,6 +725,12 @@ class BillingSession(ABC):
plan_name = CustomerPlan.name_from_tier(plan_tier)
description = f"Upgrade to {plan_name}, ${price_per_license/100} x {licenses}"
if customer.flat_discounted_months > 0:
num_months = 12 if billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1
flat_discounted_months = min(customer.flat_discounted_months, num_months)
amount -= customer.flat_discount * flat_discounted_months
description += f" - ${customer.flat_discount/100} x {flat_discounted_months}"
return StripePaymentIntentData(
amount=amount,
description=description,
@@ -917,7 +927,12 @@ class BillingSession(ABC):
customer = self.get_customer()
assert customer is not None and customer.stripe_customer_id is not None
payment_intent_data = self.get_data_for_stripe_payment_intent(
price_per_license, licenses, metadata["plan_tier"], self.get_email()
customer,
price_per_license,
licenses,
metadata["plan_tier"],
metadata["billing_schedule"],
self.get_email(),
)
# Ensure customers have a default payment method set.
stripe_customer = stripe_get_customer(customer.stripe_customer_id)
@@ -1021,14 +1036,18 @@ class BillingSession(ABC):
plan.save(update_fields=["discount", "price_per_license"])
def attach_discount_to_customer(self, new_discount: Decimal) -> str:
# Remove flat discount if giving customer a percentage discount.
customer = self.get_customer()
old_discount = None
if customer is not None:
old_discount = customer.default_discount
customer.default_discount = new_discount
customer.save(update_fields=["default_discount"])
customer.flat_discounted_months = 0
customer.save(update_fields=["default_discount", "flat_discounted_months"])
else:
customer = self.update_or_create_customer(defaults={"default_discount": new_discount})
customer = self.update_or_create_customer(
defaults={"default_discount": new_discount, "flat_discounted_months": 0}
)
plan = get_current_plan_by_customer(customer)
if plan is not None:
self.apply_discount_to_plan(plan, new_discount)
@@ -1305,6 +1324,21 @@ class BillingSession(ABC):
unit_amount=price_per_license,
)
if customer.flat_discounted_months > 0:
num_months = 12 if billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1
flat_discounted_months = min(customer.flat_discounted_months, num_months)
discount = customer.flat_discount * flat_discounted_months
customer.flat_discounted_months -= flat_discounted_months
customer.save(update_fields=["flat_discounted_months"])
stripe.InvoiceItem.create(
currency="usd",
customer=customer.stripe_customer_id,
description=f"${customer.flat_discount}/month new customer discount",
# Negative value to apply discount.
amount=(-1 * discount),
)
if charge_automatically:
collection_method = "charge_automatically"
days_until_due = None
@@ -1717,18 +1751,23 @@ class BillingSession(ABC):
billing_frequency = CustomerPlan.BILLING_SCHEDULES[plan.billing_schedule]
if switch_to_annual_at_end_of_cycle:
num_months_next_cycle = 12
annual_price_per_license = get_price_per_license(
plan.tier, CustomerPlan.BILLING_SCHEDULE_ANNUAL, customer.default_discount
)
renewal_cents = annual_price_per_license * licenses_at_next_renewal
price_per_license = format_money(annual_price_per_license / 12)
elif switch_to_monthly_at_end_of_cycle:
num_months_next_cycle = 1
monthly_price_per_license = get_price_per_license(
plan.tier, CustomerPlan.BILLING_SCHEDULE_MONTHLY, customer.default_discount
)
renewal_cents = monthly_price_per_license * licenses_at_next_renewal
price_per_license = format_money(monthly_price_per_license)
else:
num_months_next_cycle = (
12 if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1
)
renewal_cents = self.get_customer_plan_renewal_amount(plan, now, last_ledger_entry)
if plan.price_per_license is None:
@@ -1738,6 +1777,14 @@ class BillingSession(ABC):
else:
price_per_license = format_money(plan.price_per_license)
# TODO: Do this calculation in `invoice_plan` too.
pre_discount_renewal_cents = renewal_cents
flat_discount, flat_discounted_months = self.get_flat_discount_info(plan.customer)
if flat_discounted_months > 0:
flat_discounted_months = min(flat_discounted_months, num_months_next_cycle)
discount = flat_discount * flat_discounted_months
renewal_cents = renewal_cents - discount
charge_automatically = plan.charge_automatically
assert customer.stripe_customer_id is not None # for mypy
stripe_customer = stripe_get_customer(customer.stripe_customer_id)
@@ -1786,6 +1833,9 @@ class BillingSession(ABC):
"legacy_remote_server_next_plan_name": legacy_remote_server_next_plan_name,
"using_min_licenses_for_plan": using_min_licenses_for_plan,
"min_licenses_for_plan": min_licenses_for_plan,
"pre_discount_renewal_cents": cents_to_dollar_string(pre_discount_renewal_cents),
"flat_discount": format_money(customer.flat_discount),
"discounted_months_left": customer.flat_discounted_months,
}
return context
@@ -1821,12 +1871,30 @@ class BillingSession(ABC):
"price_per_license",
"discount_percent",
"using_min_licenses_for_plan",
"min_licenses_for_plan",
"pre_discount_renewal_cents",
]
for key in keys:
context[key] = next_plan_context[key]
return context
def get_flat_discount_info(self, customer: Optional[Customer] = None) -> Tuple[int, int]:
is_self_hosted_billing = not isinstance(self, RealmBillingSession)
flat_discount = 0
flat_discounted_months = 0
if is_self_hosted_billing and (customer is None or customer.flat_discounted_months > 0):
if customer is None:
temp_customer = Customer()
flat_discount = temp_customer.flat_discount
flat_discounted_months = 12
else:
flat_discount = customer.flat_discount
flat_discounted_months = customer.flat_discounted_months
assert isinstance(flat_discount, int)
assert isinstance(flat_discounted_months, int)
return flat_discount, flat_discounted_months
def get_initial_upgrade_context(
self, initial_upgrade_request: InitialUpgradeRequest
) -> Tuple[Optional[str], Optional[UpgradePageContext]]:
@@ -1893,6 +1961,7 @@ class BillingSession(ABC):
f"{free_trial_end:%B} {free_trial_end.day}, {free_trial_end.year}"
)
flat_discount, flat_discounted_months = self.get_flat_discount_info(customer)
context: UpgradePageContext = {
"customer_name": customer_specific_context["customer_name"],
"default_invoice_days_until_due": DEFAULT_INVOICE_DAYS_UNTIL_DUE,
@@ -1917,6 +1986,8 @@ class BillingSession(ABC):
"seat_count": seat_count,
"billing_base_url": self.billing_base_url,
"tier": tier,
"flat_discount": flat_discount,
"flat_discounted_months": flat_discounted_months,
},
"using_min_licenses_for_plan": using_min_licenses_for_plan,
"min_licenses_for_plan": min_licenses_for_plan,
@@ -3245,12 +3316,16 @@ class RemoteRealmBillingSession(BillingSession):
remote_realm=self.remote_realm,
defaults={"stripe_customer_id": stripe_customer_id},
)
return customer
else:
customer, created = Customer.objects.update_or_create(
remote_realm=self.remote_realm, defaults=defaults
)
return customer
if created and not customer.default_discount:
customer.flat_discounted_months = 12
customer.save(update_fields=["flat_discounted_months"])
return customer
@override
@transaction.atomic
@@ -3634,12 +3709,16 @@ class RemoteServerBillingSession(BillingSession):
remote_server=self.remote_server,
defaults={"stripe_customer_id": stripe_customer_id},
)
return customer
else:
customer, created = Customer.objects.update_or_create(
remote_server=self.remote_server, defaults=defaults
)
return customer
if created and not customer.default_discount:
customer.flat_discounted_months = 12
customer.save(update_fields=["flat_discounted_months"])
return customer
@override
@transaction.atomic