From d0c276d8632b979be0eda213d67b6e086073a2dc Mon Sep 17 00:00:00 2001 From: Tim Abbott Date: Wed, 21 Feb 2024 11:44:46 -0800 Subject: [PATCH] corporate: Fix billing_session variable reuse confusion. The previous logic incorrectly used the server-level number of users even when a (presumably smaller) realm-level count was available. Fixes a bug introduced in 2e1ed4431adf8dfe1e8bcfeae63dfc327b425369. --- analytics/tests/test_counts.py | 2 +- corporate/lib/stripe.py | 43 +++++++++++++++++-------- zerver/tests/test_push_notifications.py | 14 ++++---- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/analytics/tests/test_counts.py b/analytics/tests/test_counts.py index b30e9df0b1..0b161ce2d4 100644 --- a/analytics/tests/test_counts.py +++ b/analytics/tests/test_counts.py @@ -1580,7 +1580,7 @@ class TestLoggingCountStats(AnalyticsTestCase): with time_machine.travel(now, tick=False), mock.patch( "zilencer.views.send_android_push_notification", return_value=1 ), mock.patch("zilencer.views.send_apple_push_notification", return_value=1), mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=10, ), self.assertLogs( "zilencer.views", level="INFO" diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 9bc3f2ab95..b4057093a5 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -4988,10 +4988,19 @@ def get_push_status_for_remote_request( # installation. customer = None current_plan = None + realm_billing_session: Optional[BillingSession] = None + server_billing_session: Optional[RemoteServerBillingSession] = None if remote_realm is not None: - billing_session: BillingSession = RemoteRealmBillingSession(remote_realm) - customer = billing_session.get_customer() + realm_billing_session = RemoteRealmBillingSession(remote_realm) + if realm_billing_session.is_sponsored(): + return PushNotificationsEnabledStatus( + can_push=True, + expected_end_timestamp=None, + message="Community plan", + ) + + customer = realm_billing_session.get_customer() if customer is not None: current_plan = get_current_plan_by_customer(customer) @@ -4999,22 +5008,28 @@ def get_push_status_for_remote_request( # takes precedence, but look for a current plan on the server if # there is a customer with only inactive/expired plans on the Realm. if customer is None or current_plan is None: - billing_session = RemoteServerBillingSession(remote_server) - customer = billing_session.get_customer() + server_billing_session = RemoteServerBillingSession(remote_server) + if server_billing_session.is_sponsored(): + return PushNotificationsEnabledStatus( + can_push=True, + expected_end_timestamp=None, + message="Community plan", + ) + + customer = server_billing_session.get_customer() if customer is not None: current_plan = get_current_plan_by_customer(customer) - if billing_session.is_sponsored(): - return PushNotificationsEnabledStatus( - can_push=True, - expected_end_timestamp=None, - message="Community plan", - ) + if realm_billing_session is not None: + user_count_billing_session: BillingSession = realm_billing_session + else: + assert server_billing_session is not None + user_count_billing_session = server_billing_session user_count: Optional[int] = None if current_plan is None: try: - user_count = billing_session.current_count_for_billed_licenses() + user_count = user_count_billing_session.current_count_for_billed_licenses() except MissingDataError: return PushNotificationsEnabledStatus( can_push=False, @@ -5047,7 +5062,7 @@ def get_push_status_for_remote_request( ) try: - user_count = billing_session.current_count_for_billed_licenses() + user_count = user_count_billing_session.current_count_for_billed_licenses() except MissingDataError: user_count = None @@ -5062,8 +5077,10 @@ def get_push_status_for_remote_request( message="Expiring plan few users", ) + # TODO: Move get_next_billing_cycle to be plan.get_next_billing_cycle + # to avoid this somewhat evil use of a possibly non-matching billing session. expected_end_timestamp = datetime_to_timestamp( - billing_session.get_next_billing_cycle(current_plan) + user_count_billing_session.get_next_billing_cycle(current_plan) ) return PushNotificationsEnabledStatus( can_push=True, diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 5c1d6ea6fb..d9c5d8e442 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -591,7 +591,7 @@ class PushBouncerNotificationTest(BouncerTestCase): with mock.patch( "zilencer.views.send_android_push_notification", return_value=1 ), mock.patch("zilencer.views.send_apple_push_notification", return_value=1), mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=10, ), self.assertLogs( "zilencer.views", level="INFO" @@ -665,7 +665,7 @@ class PushBouncerNotificationTest(BouncerTestCase): ) as android_push, mock.patch( "zilencer.views.send_apple_push_notification", return_value=1 ) as apple_push, mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=10, ), time_machine.travel( time_sent, tick=False @@ -2233,7 +2233,7 @@ class AnalyticsBouncerTest(BouncerTestCase): "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None ) as m: with mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=11, ): send_server_data_to_push_bouncer(consider_usage_statistics=False) @@ -2266,7 +2266,7 @@ class AnalyticsBouncerTest(BouncerTestCase): "corporate.lib.stripe.get_current_plan_by_customer", return_value=None ) as m: with mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=11, ): send_server_data_to_push_bouncer(consider_usage_statistics=False) @@ -2707,7 +2707,7 @@ class HandlePushNotificationTest(PushNotificationTest): with time_machine.travel(time_received, tick=False), mock.patch( "zerver.lib.push_notifications.gcm_client" ) as mock_gcm, self.mock_apns() as (apns_context, send_notification), mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=10, ), self.assertLogs( "zerver.lib.push_notifications", level="INFO" @@ -2800,7 +2800,7 @@ class HandlePushNotificationTest(PushNotificationTest): "trigger": NotificationTriggers.DIRECT_MESSAGE, } with mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=100, ) as mock_current_count, self.assertLogs( "zerver.lib.push_notifications", level="INFO" @@ -2871,7 +2871,7 @@ class HandlePushNotificationTest(PushNotificationTest): with time_machine.travel(time_received, tick=False), mock.patch( "zerver.lib.push_notifications.gcm_client" ) as mock_gcm, self.mock_apns() as (apns_context, send_notification), mock.patch( - "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=10, ), self.assertLogs( "zerver.lib.push_notifications", level="INFO"