corporate: Use enum value for type of plan tier change.

Updates do_change_plan_to_new_tier in BillingSession to use an
enum for the value returned when checking for a valid change
between two plan tier types. This makes it more explicit that
the implementation for a valid upgrade in plan tier will be
different from a valid downgrade in plan tier.
This commit is contained in:
Lauryn Menard
2023-11-30 17:11:41 +01:00
committed by Tim Abbott
parent 4eea4d4717
commit 2c34dcf7dc

View File

@@ -530,6 +530,12 @@ class AuditLogEventType(Enum):
CUSTOMER_SWITCHED_FROM_ANNUAL_TO_MONTHLY_PLAN = 9
class PlanTierChangeType(Enum):
INVALID = 1
UPGRADE = 2
DOWNGRADE = 3
class BillingSessionAuditLogEventError(Exception):
def __init__(self, event_type: AuditLogEventType) -> None:
self.message = f"Unknown audit log event type: {event_type}"
@@ -684,7 +690,9 @@ class BillingSession(ABC):
pass
@abstractmethod
def is_valid_plan_tier_switch(self, current_plan_tier: int, new_plan_tier: int) -> bool:
def get_type_of_plan_tier_change(
self, current_plan_tier: int, new_plan_tier: int
) -> PlanTierChangeType:
pass
@abstractmethod
@@ -1800,31 +1808,40 @@ class BillingSession(ABC):
if not current_plan.customer.stripe_customer_id:
raise BillingError("Organization missing Stripe customer.")
if not self.is_valid_plan_tier_switch(current_plan.tier, new_plan_tier):
type_of_tier_change = self.get_type_of_plan_tier_change(current_plan.tier, new_plan_tier)
if type_of_tier_change == PlanTierChangeType.INVALID:
raise BillingError("Invalid change of customer plan tier.")
plan_switch_time = timezone_now()
if type_of_tier_change == PlanTierChangeType.UPGRADE:
plan_switch_time = timezone_now()
current_plan.status = CustomerPlan.SWITCH_PLAN_TIER_NOW
current_plan.next_invoice_date = plan_switch_time
current_plan.save(update_fields=["status", "next_invoice_date"])
current_plan.status = CustomerPlan.SWITCH_PLAN_TIER_NOW
current_plan.next_invoice_date = plan_switch_time
current_plan.save(update_fields=["status", "next_invoice_date"])
self.do_change_plan_type(tier=new_plan_tier)
self.do_change_plan_type(tier=new_plan_tier)
amount_to_credit_for_early_termination = get_amount_to_credit_for_plan_tier_change(
current_plan, plan_switch_time
)
stripe.Customer.create_balance_transaction(
current_plan.customer.stripe_customer_id,
amount=-1 * amount_to_credit_for_early_termination,
currency="usd",
description="Credit from early termination of active plan",
)
self.switch_plan_tier(current_plan, new_plan_tier)
self.invoice_plan(current_plan, plan_switch_time)
new_plan = get_current_plan_by_customer(customer)
assert new_plan is not None # for mypy
self.invoice_plan(new_plan, plan_switch_time)
return
amount_to_credit_for_early_termination = get_amount_to_credit_for_plan_tier_change(
current_plan, plan_switch_time
)
stripe.Customer.create_balance_transaction(
current_plan.customer.stripe_customer_id,
amount=-1 * amount_to_credit_for_early_termination,
currency="usd",
description="Credit from early termination of active plan",
)
self.switch_plan_tier(current_plan, new_plan_tier)
self.invoice_plan(current_plan, plan_switch_time)
new_plan = get_current_plan_by_customer(customer)
assert new_plan is not None # for mypy
self.invoice_plan(new_plan, plan_switch_time)
# TODO: Implement downgrade that is a change from and to a paid plan
# tier. This should keep the same billing cycle schedule and change
# the plan when it's next invoiced vs immediately. Note this will need
# new CustomerPlan.status value, e.g. SWITCH_PLAN_TIER_AT_END_OF_CYCLE.
assert type_of_tier_change == PlanTierChangeType.DOWNGRADE # nocoverage
def get_event_status(self, event_status_request: EventStatusRequest) -> Dict[str, Any]:
customer = self.get_customer()
@@ -2199,12 +2216,25 @@ class RealmBillingSession(BillingSession):
)
@override
def is_valid_plan_tier_switch(self, current_plan_tier: int, new_plan_tier: int) -> bool:
if current_plan_tier == CustomerPlan.TIER_CLOUD_STANDARD:
return new_plan_tier == CustomerPlan.TIER_CLOUD_PLUS
def get_type_of_plan_tier_change(
self, current_plan_tier: int, new_plan_tier: int
) -> PlanTierChangeType:
valid_plan_tiers = [CustomerPlan.TIER_CLOUD_STANDARD, CustomerPlan.TIER_CLOUD_PLUS]
if (
current_plan_tier not in valid_plan_tiers
or new_plan_tier not in valid_plan_tiers
or current_plan_tier == new_plan_tier
):
return PlanTierChangeType.INVALID
if (
current_plan_tier == CustomerPlan.TIER_CLOUD_STANDARD
and new_plan_tier == CustomerPlan.TIER_CLOUD_PLUS
):
return PlanTierChangeType.UPGRADE
else: # nocoverage, not currently implemented
assert current_plan_tier == CustomerPlan.TIER_CLOUD_PLUS
return new_plan_tier == CustomerPlan.TIER_CLOUD_STANDARD
assert new_plan_tier == CustomerPlan.TIER_CLOUD_STANDARD
return PlanTierChangeType.DOWNGRADE
@override
def has_billing_access(self) -> bool:
@@ -2450,9 +2480,11 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage
plan.save(update_fields=["status"])
@override
def is_valid_plan_tier_switch(self, current_plan_tier: int, new_plan_tier: int) -> bool:
def get_type_of_plan_tier_change(
self, current_plan_tier: int, new_plan_tier: int
) -> PlanTierChangeType:
# TBD
return False
return PlanTierChangeType.INVALID
@override
def has_billing_access(self) -> bool:
@@ -2696,9 +2728,11 @@ class RemoteServerBillingSession(BillingSession): # nocoverage
)
@override
def is_valid_plan_tier_switch(self, current_plan_tier: int, new_plan_tier: int) -> bool:
def get_type_of_plan_tier_change(
self, current_plan_tier: int, new_plan_tier: int
) -> PlanTierChangeType:
# TBD
return False
return PlanTierChangeType.INVALID
@override
def has_billing_access(self) -> bool: