diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 1f03d09077..7719fc01ea 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta from decimal import Decimal +from enum import Enum from functools import wraps from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Union @@ -33,7 +34,7 @@ from zerver.lib.logging_util import log_to_file from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.utils import assert_is_not_none -from zerver.models import AbstractRealmAuditLog, Realm, RealmAuditLog, UserProfile, get_system_bot +from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog from zproject.config import get_secret @@ -324,14 +325,35 @@ class StripeCustomerData: metadata: Dict[str, Any] +class AuditLogEventType(Enum): + STRIPE_CUSTOMER_CREATED = 1 + STRIPE_CARD_CHANGED = 2 + CUSTOMER_PLAN_CREATED = 3 + DISCOUNT_CHANGED = 4 + + +class BillingSessionAuditLogEventError(Exception): + def __init__(self, event_type: AuditLogEventType) -> None: + self.message = f"Unknown audit log event type: {event_type}" + super().__init__(self.message) + + class BillingSession(ABC): @abstractmethod def get_customer(self) -> Optional[Customer]: pass + @abstractmethod + def get_audit_log_event(self, event_type: AuditLogEventType) -> int: + pass + @abstractmethod def write_to_audit_log( - self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None + self, + event_type: AuditLogEventType, + event_time: datetime, + *, + extra_data: Optional[Dict[str, Any]] = None, ) -> None: pass @@ -359,9 +381,9 @@ class BillingSession(ABC): ) event_time = timestamp_to_datetime(stripe_customer.created) with transaction.atomic(): - self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CUSTOMER_CREATED, event_time) + self.write_to_audit_log(AuditLogEventType.STRIPE_CUSTOMER_CREATED, event_time) if payment_method is not None: - self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, event_time) + self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, event_time) customer = self.update_or_create_customer(stripe_customer.id) return customer @@ -372,7 +394,7 @@ class BillingSession(ABC): stripe.Customer.modify( stripe_customer_id, invoice_settings={"default_payment_method": payment_method} ) - self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, timezone_now()) + self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, timezone_now()) if pay_invoices: for stripe_invoice in stripe.Invoice.list( collection_method="charge_automatically", @@ -418,7 +440,7 @@ class BillingSession(ABC): plan.discount = discount plan.save(update_fields=["price_per_license", "discount"]) self.write_to_audit_log( - event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED, + event_type=AuditLogEventType.DISCOUNT_CHANGED, event_time=timezone_now(), extra_data={"old_discount": old_discount, "new_discount": discount}, ) @@ -439,15 +461,33 @@ class RealmBillingSession(BillingSession): def get_customer(self) -> Optional[Customer]: return get_customer_by_realm(self.realm) + @override + def get_audit_log_event(self, event_type: AuditLogEventType) -> int: + if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED: + return RealmAuditLog.STRIPE_CUSTOMER_CREATED + elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED: + return RealmAuditLog.STRIPE_CARD_CHANGED + elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED: + return RealmAuditLog.CUSTOMER_PLAN_CREATED + elif event_type is AuditLogEventType.DISCOUNT_CHANGED: + return RealmAuditLog.REALM_DISCOUNT_CHANGED + else: + raise BillingSessionAuditLogEventError(event_type) + @override def write_to_audit_log( - self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None + self, + event_type: AuditLogEventType, + event_time: datetime, + *, + extra_data: Optional[Dict[str, Any]] = None, ) -> None: + audit_log_event = self.get_audit_log_event(event_type) if extra_data: RealmAuditLog.objects.create( realm=self.realm, acting_user=self.user, - event_type=event_type, + event_type=audit_log_event, event_time=event_time, extra_data=extra_data, ) @@ -455,7 +495,7 @@ class RealmBillingSession(BillingSession): RealmAuditLog.objects.create( realm=self.realm, acting_user=self.user, - event_type=event_type, + event_type=audit_log_event, event_time=event_time, ) @@ -825,7 +865,7 @@ def process_initial_upgrade( plan.invoiced_through = ledger_entry plan.save(update_fields=["invoiced_through"]) billing_session.write_to_audit_log( - event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED, + event_type=AuditLogEventType.CUSTOMER_PLAN_CREATED, event_time=billing_cycle_anchor, extra_data=plan_params, ) diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 6ebfc0d6b9..1b222632d1 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -4,6 +4,7 @@ import os import random import re import sys +import typing import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone @@ -41,7 +42,9 @@ from corporate.lib.stripe import ( MAX_INVOICED_LICENSES, MIN_INVOICED_LICENSES, STRIPE_API_VERSION, + AuditLogEventType, BillingError, + BillingSessionAuditLogEventError, InvalidBillingScheduleError, InvalidTierError, RealmBillingSession, @@ -4965,6 +4968,17 @@ class TestTestClasses(ZulipTestCase): self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD) +class TestRealmBillingSession(StripeTestCase): + def test_get_audit_log_error(self) -> None: + user = self.example_user("hamlet") + billing_session = RealmBillingSession(user) + fake_audit_log = typing.cast(AuditLogEventType, 0) + with self.assertRaisesRegex( + BillingSessionAuditLogEventError, "Unknown audit log event type: 0" + ): + billing_session.get_audit_log_event(event_type=fake_audit_log) + + class TestSupportBillingHelpers(StripeTestCase): def test_get_discount_for_realm(self) -> None: iago = self.example_user("iago")