corporate: Create AuditLogEventType enum for BillingSession audit logs.

Creates an enum class, AuditLogEventType, and an abstract method in
BillingSession, get_audit_log_event, so that we have an abstraction
for getting the audit log event type since it might be different for
Customer objects with a realm vs a remote_server.
This commit is contained in:
Lauryn Menard
2023-11-02 17:44:02 +01:00
committed by Tim Abbott
parent ee19a9c274
commit c8021925c8
2 changed files with 64 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Union
@@ -33,7 +34,7 @@ from zerver.lib.logging_util import log_to_file
from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none
from zerver.models import AbstractRealmAuditLog, Realm, RealmAuditLog, UserProfile, get_system_bot
from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog
from zproject.config import get_secret
@@ -324,14 +325,35 @@ class StripeCustomerData:
metadata: Dict[str, Any]
class AuditLogEventType(Enum):
STRIPE_CUSTOMER_CREATED = 1
STRIPE_CARD_CHANGED = 2
CUSTOMER_PLAN_CREATED = 3
DISCOUNT_CHANGED = 4
class BillingSessionAuditLogEventError(Exception):
def __init__(self, event_type: AuditLogEventType) -> None:
self.message = f"Unknown audit log event type: {event_type}"
super().__init__(self.message)
class BillingSession(ABC):
@abstractmethod
def get_customer(self) -> Optional[Customer]:
pass
@abstractmethod
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
pass
@abstractmethod
def write_to_audit_log(
self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None
self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None:
pass
@@ -359,9 +381,9 @@ class BillingSession(ABC):
)
event_time = timestamp_to_datetime(stripe_customer.created)
with transaction.atomic():
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CUSTOMER_CREATED, event_time)
self.write_to_audit_log(AuditLogEventType.STRIPE_CUSTOMER_CREATED, event_time)
if payment_method is not None:
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, event_time)
self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, event_time)
customer = self.update_or_create_customer(stripe_customer.id)
return customer
@@ -372,7 +394,7 @@ class BillingSession(ABC):
stripe.Customer.modify(
stripe_customer_id, invoice_settings={"default_payment_method": payment_method}
)
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, timezone_now())
self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, timezone_now())
if pay_invoices:
for stripe_invoice in stripe.Invoice.list(
collection_method="charge_automatically",
@@ -418,7 +440,7 @@ class BillingSession(ABC):
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
self.write_to_audit_log(
event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED,
event_type=AuditLogEventType.DISCOUNT_CHANGED,
event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount},
)
@@ -439,15 +461,33 @@ class RealmBillingSession(BillingSession):
def get_customer(self) -> Optional[Customer]:
return get_customer_by_realm(self.realm)
@override
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED:
return RealmAuditLog.STRIPE_CUSTOMER_CREATED
elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED:
return RealmAuditLog.STRIPE_CARD_CHANGED
elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED:
return RealmAuditLog.CUSTOMER_PLAN_CREATED
elif event_type is AuditLogEventType.DISCOUNT_CHANGED:
return RealmAuditLog.REALM_DISCOUNT_CHANGED
else:
raise BillingSessionAuditLogEventError(event_type)
@override
def write_to_audit_log(
self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None
self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None:
audit_log_event = self.get_audit_log_event(event_type)
if extra_data:
RealmAuditLog.objects.create(
realm=self.realm,
acting_user=self.user,
event_type=event_type,
event_type=audit_log_event,
event_time=event_time,
extra_data=extra_data,
)
@@ -455,7 +495,7 @@ class RealmBillingSession(BillingSession):
RealmAuditLog.objects.create(
realm=self.realm,
acting_user=self.user,
event_type=event_type,
event_type=audit_log_event,
event_time=event_time,
)
@@ -825,7 +865,7 @@ def process_initial_upgrade(
plan.invoiced_through = ledger_entry
plan.save(update_fields=["invoiced_through"])
billing_session.write_to_audit_log(
event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED,
event_type=AuditLogEventType.CUSTOMER_PLAN_CREATED,
event_time=billing_cycle_anchor,
extra_data=plan_params,
)