corporate: Create AuditLogEventType enum for BillingSession audit logs.

Creates an enum class, AuditLogEventType, and an abstract method in
BillingSession, get_audit_log_event, so that we have an abstraction
for getting the audit log event type since it might be different for
Customer objects with a realm vs a remote_server.
This commit is contained in:
Lauryn Menard
2023-11-02 17:44:02 +01:00
committed by Tim Abbott
parent ee19a9c274
commit c8021925c8
2 changed files with 64 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
from enum import Enum
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Union
@@ -33,7 +34,7 @@ from zerver.lib.logging_util import log_to_file
from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none from zerver.lib.utils import assert_is_not_none
from zerver.models import AbstractRealmAuditLog, Realm, RealmAuditLog, UserProfile, get_system_bot from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog
from zproject.config import get_secret from zproject.config import get_secret
@@ -324,14 +325,35 @@ class StripeCustomerData:
metadata: Dict[str, Any] metadata: Dict[str, Any]
class AuditLogEventType(Enum):
STRIPE_CUSTOMER_CREATED = 1
STRIPE_CARD_CHANGED = 2
CUSTOMER_PLAN_CREATED = 3
DISCOUNT_CHANGED = 4
class BillingSessionAuditLogEventError(Exception):
def __init__(self, event_type: AuditLogEventType) -> None:
self.message = f"Unknown audit log event type: {event_type}"
super().__init__(self.message)
class BillingSession(ABC): class BillingSession(ABC):
@abstractmethod @abstractmethod
def get_customer(self) -> Optional[Customer]: def get_customer(self) -> Optional[Customer]:
pass pass
@abstractmethod
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
pass
@abstractmethod @abstractmethod
def write_to_audit_log( def write_to_audit_log(
self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
pass pass
@@ -359,9 +381,9 @@ class BillingSession(ABC):
) )
event_time = timestamp_to_datetime(stripe_customer.created) event_time = timestamp_to_datetime(stripe_customer.created)
with transaction.atomic(): with transaction.atomic():
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CUSTOMER_CREATED, event_time) self.write_to_audit_log(AuditLogEventType.STRIPE_CUSTOMER_CREATED, event_time)
if payment_method is not None: if payment_method is not None:
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, event_time) self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, event_time)
customer = self.update_or_create_customer(stripe_customer.id) customer = self.update_or_create_customer(stripe_customer.id)
return customer return customer
@@ -372,7 +394,7 @@ class BillingSession(ABC):
stripe.Customer.modify( stripe.Customer.modify(
stripe_customer_id, invoice_settings={"default_payment_method": payment_method} stripe_customer_id, invoice_settings={"default_payment_method": payment_method}
) )
self.write_to_audit_log(AbstractRealmAuditLog.STRIPE_CARD_CHANGED, timezone_now()) self.write_to_audit_log(AuditLogEventType.STRIPE_CARD_CHANGED, timezone_now())
if pay_invoices: if pay_invoices:
for stripe_invoice in stripe.Invoice.list( for stripe_invoice in stripe.Invoice.list(
collection_method="charge_automatically", collection_method="charge_automatically",
@@ -418,7 +440,7 @@ class BillingSession(ABC):
plan.discount = discount plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"]) plan.save(update_fields=["price_per_license", "discount"])
self.write_to_audit_log( self.write_to_audit_log(
event_type=AbstractRealmAuditLog.REALM_DISCOUNT_CHANGED, event_type=AuditLogEventType.DISCOUNT_CHANGED,
event_time=timezone_now(), event_time=timezone_now(),
extra_data={"old_discount": old_discount, "new_discount": discount}, extra_data={"old_discount": old_discount, "new_discount": discount},
) )
@@ -439,15 +461,33 @@ class RealmBillingSession(BillingSession):
def get_customer(self) -> Optional[Customer]: def get_customer(self) -> Optional[Customer]:
return get_customer_by_realm(self.realm) return get_customer_by_realm(self.realm)
@override
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED:
return RealmAuditLog.STRIPE_CUSTOMER_CREATED
elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED:
return RealmAuditLog.STRIPE_CARD_CHANGED
elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED:
return RealmAuditLog.CUSTOMER_PLAN_CREATED
elif event_type is AuditLogEventType.DISCOUNT_CHANGED:
return RealmAuditLog.REALM_DISCOUNT_CHANGED
else:
raise BillingSessionAuditLogEventError(event_type)
@override @override
def write_to_audit_log( def write_to_audit_log(
self, event_type: int, event_time: datetime, *, extra_data: Optional[Dict[str, Any]] = None self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
audit_log_event = self.get_audit_log_event(event_type)
if extra_data: if extra_data:
RealmAuditLog.objects.create( RealmAuditLog.objects.create(
realm=self.realm, realm=self.realm,
acting_user=self.user, acting_user=self.user,
event_type=event_type, event_type=audit_log_event,
event_time=event_time, event_time=event_time,
extra_data=extra_data, extra_data=extra_data,
) )
@@ -455,7 +495,7 @@ class RealmBillingSession(BillingSession):
RealmAuditLog.objects.create( RealmAuditLog.objects.create(
realm=self.realm, realm=self.realm,
acting_user=self.user, acting_user=self.user,
event_type=event_type, event_type=audit_log_event,
event_time=event_time, event_time=event_time,
) )
@@ -825,7 +865,7 @@ def process_initial_upgrade(
plan.invoiced_through = ledger_entry plan.invoiced_through = ledger_entry
plan.save(update_fields=["invoiced_through"]) plan.save(update_fields=["invoiced_through"])
billing_session.write_to_audit_log( billing_session.write_to_audit_log(
event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED, event_type=AuditLogEventType.CUSTOMER_PLAN_CREATED,
event_time=billing_cycle_anchor, event_time=billing_cycle_anchor,
extra_data=plan_params, extra_data=plan_params,
) )

View File

@@ -4,6 +4,7 @@ import os
import random import random
import re import re
import sys import sys
import typing
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -41,7 +42,9 @@ from corporate.lib.stripe import (
MAX_INVOICED_LICENSES, MAX_INVOICED_LICENSES,
MIN_INVOICED_LICENSES, MIN_INVOICED_LICENSES,
STRIPE_API_VERSION, STRIPE_API_VERSION,
AuditLogEventType,
BillingError, BillingError,
BillingSessionAuditLogEventError,
InvalidBillingScheduleError, InvalidBillingScheduleError,
InvalidTierError, InvalidTierError,
RealmBillingSession, RealmBillingSession,
@@ -4965,6 +4968,17 @@ class TestTestClasses(ZulipTestCase):
self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(realm.plan_type, Realm.PLAN_TYPE_STANDARD)
class TestRealmBillingSession(StripeTestCase):
def test_get_audit_log_error(self) -> None:
user = self.example_user("hamlet")
billing_session = RealmBillingSession(user)
fake_audit_log = typing.cast(AuditLogEventType, 0)
with self.assertRaisesRegex(
BillingSessionAuditLogEventError, "Unknown audit log event type: 0"
):
billing_session.get_audit_log_event(event_type=fake_audit_log)
class TestSupportBillingHelpers(StripeTestCase): class TestSupportBillingHelpers(StripeTestCase):
def test_get_discount_for_realm(self) -> None: def test_get_discount_for_realm(self) -> None:
iago = self.example_user("iago") iago = self.example_user("iago")