mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 14:03:30 +00:00 
			
		
		
		
	typing: Use assertions for function arguments.
Utilize the assert_is_not_None helper to eliminate errors of 'Argument x to "Foo" has incompatible type "Optional[Bar]"...'
This commit is contained in:
		@@ -51,6 +51,7 @@ from zerver.lib.exceptions import InvitationError
 | 
				
			|||||||
from zerver.lib.test_classes import ZulipTestCase
 | 
					from zerver.lib.test_classes import ZulipTestCase
 | 
				
			||||||
from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day
 | 
					from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day
 | 
				
			||||||
from zerver.lib.topic import DB_TOPIC_NAME
 | 
					from zerver.lib.topic import DB_TOPIC_NAME
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import (
 | 
					from zerver.models import (
 | 
				
			||||||
    Client,
 | 
					    Client,
 | 
				
			||||||
    Huddle,
 | 
					    Huddle,
 | 
				
			||||||
@@ -1388,11 +1389,13 @@ class TestLoggingCountStats(AnalyticsTestCase):
 | 
				
			|||||||
        assertInviteCountEquals(5)
 | 
					        assertInviteCountEquals(5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Revoking invite should not give you credit
 | 
					        # Revoking invite should not give you credit
 | 
				
			||||||
        do_revoke_user_invite(PreregistrationUser.objects.filter(realm=user.realm).first())
 | 
					        do_revoke_user_invite(
 | 
				
			||||||
 | 
					            assert_is_not_none(PreregistrationUser.objects.filter(realm=user.realm).first())
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        assertInviteCountEquals(5)
 | 
					        assertInviteCountEquals(5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Resending invite should cost you
 | 
					        # Resending invite should cost you
 | 
				
			||||||
        do_resend_user_invite_email(PreregistrationUser.objects.first())
 | 
					        do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first()))
 | 
				
			||||||
        assertInviteCountEquals(6)
 | 
					        assertInviteCountEquals(6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_messages_read_hour(self) -> None:
 | 
					    def test_messages_read_hour(self) -> None:
 | 
				
			||||||
@@ -1423,7 +1426,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.send_stream_message(user1, stream.name)
 | 
					        self.send_stream_message(user1, stream.name)
 | 
				
			||||||
        self.send_stream_message(user1, stream.name)
 | 
					        self.send_stream_message(user1, stream.name)
 | 
				
			||||||
        do_mark_stream_messages_as_read(user2, stream.recipient_id)
 | 
					        do_mark_stream_messages_as_read(user2, assert_is_not_none(stream.recipient_id))
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            3,
 | 
					            3,
 | 
				
			||||||
            UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[
 | 
					            UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,6 +28,7 @@ from zerver.lib.actions import (
 | 
				
			|||||||
from zerver.lib.exceptions import JsonableError
 | 
					from zerver.lib.exceptions import JsonableError
 | 
				
			||||||
from zerver.lib.realm_icon import realm_icon_url
 | 
					from zerver.lib.realm_icon import realm_icon_url
 | 
				
			||||||
from zerver.lib.subdomains import get_subdomain_from_hostname
 | 
					from zerver.lib.subdomains import get_subdomain_from_hostname
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import (
 | 
					from zerver.models import (
 | 
				
			||||||
    MultiuseInvite,
 | 
					    MultiuseInvite,
 | 
				
			||||||
    PreregistrationUser,
 | 
					    PreregistrationUser,
 | 
				
			||||||
@@ -119,24 +120,24 @@ def support(request: HttpRequest) -> HttpResponse:
 | 
				
			|||||||
        if len(keys) != 2:
 | 
					        if len(keys) != 2:
 | 
				
			||||||
            raise JsonableError(_("Invalid parameters"))
 | 
					            raise JsonableError(_("Invalid parameters"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        realm_id = request.POST.get("realm_id")
 | 
					        realm_id: str = assert_is_not_none(request.POST.get("realm_id"))
 | 
				
			||||||
        realm = Realm.objects.get(id=realm_id)
 | 
					        realm = Realm.objects.get(id=realm_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if request.POST.get("plan_type", None) is not None:
 | 
					        if request.POST.get("plan_type", None) is not None:
 | 
				
			||||||
            new_plan_type = int(request.POST.get("plan_type"))
 | 
					            new_plan_type = int(assert_is_not_none(request.POST.get("plan_type")))
 | 
				
			||||||
            current_plan_type = realm.plan_type
 | 
					            current_plan_type = realm.plan_type
 | 
				
			||||||
            do_change_plan_type(realm, new_plan_type, acting_user=request.user)
 | 
					            do_change_plan_type(realm, new_plan_type, acting_user=request.user)
 | 
				
			||||||
            msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} "
 | 
					            msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} "
 | 
				
			||||||
            context["success_message"] = msg
 | 
					            context["success_message"] = msg
 | 
				
			||||||
        elif request.POST.get("discount", None) is not None:
 | 
					        elif request.POST.get("discount", None) is not None:
 | 
				
			||||||
            new_discount = Decimal(request.POST.get("discount"))
 | 
					            new_discount = Decimal(assert_is_not_none(request.POST.get("discount")))
 | 
				
			||||||
            current_discount = get_discount_for_realm(realm) or 0
 | 
					            current_discount = get_discount_for_realm(realm) or 0
 | 
				
			||||||
            attach_discount_to_realm(realm, new_discount, acting_user=request.user)
 | 
					            attach_discount_to_realm(realm, new_discount, acting_user=request.user)
 | 
				
			||||||
            context[
 | 
					            context[
 | 
				
			||||||
                "success_message"
 | 
					                "success_message"
 | 
				
			||||||
            ] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%."
 | 
					            ] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%."
 | 
				
			||||||
        elif request.POST.get("new_subdomain", None) is not None:
 | 
					        elif request.POST.get("new_subdomain", None) is not None:
 | 
				
			||||||
            new_subdomain = request.POST.get("new_subdomain")
 | 
					            new_subdomain: str = assert_is_not_none(request.POST.get("new_subdomain"))
 | 
				
			||||||
            old_subdomain = realm.string_id
 | 
					            old_subdomain = realm.string_id
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                check_subdomain_available(new_subdomain)
 | 
					                check_subdomain_available(new_subdomain)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,7 @@ import secrets
 | 
				
			|||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
from decimal import Decimal
 | 
					from decimal import Decimal
 | 
				
			||||||
from functools import wraps
 | 
					from functools import wraps
 | 
				
			||||||
from typing import Callable, Dict, Generator, Optional, Tuple, TypeVar, cast
 | 
					from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, cast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import orjson
 | 
					import orjson
 | 
				
			||||||
import stripe
 | 
					import stripe
 | 
				
			||||||
@@ -30,6 +30,7 @@ from zerver.lib.exceptions import JsonableError
 | 
				
			|||||||
from zerver.lib.logging_util import log_to_file
 | 
					from zerver.lib.logging_util import log_to_file
 | 
				
			||||||
from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
 | 
					from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
 | 
				
			||||||
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
 | 
					from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
 | 
					from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
 | 
				
			||||||
from zproject.config import get_secret
 | 
					from zproject.config import get_secret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -516,7 +517,9 @@ def compute_plan_parameters(
 | 
				
			|||||||
    if automanage_licenses:
 | 
					    if automanage_licenses:
 | 
				
			||||||
        next_invoice_date = add_months(billing_cycle_anchor, 1)
 | 
					        next_invoice_date = add_months(billing_cycle_anchor, 1)
 | 
				
			||||||
    if free_trial:
 | 
					    if free_trial:
 | 
				
			||||||
        period_end = billing_cycle_anchor + timedelta(days=settings.FREE_TRIAL_DAYS)
 | 
					        period_end = billing_cycle_anchor + timedelta(
 | 
				
			||||||
 | 
					            days=assert_is_not_none(settings.FREE_TRIAL_DAYS)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        next_invoice_date = period_end
 | 
					        next_invoice_date = period_end
 | 
				
			||||||
    return billing_cycle_anchor, next_invoice_date, period_end, price_per_license
 | 
					    return billing_cycle_anchor, next_invoice_date, period_end, price_per_license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -943,10 +946,12 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]:  # nocoverag
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_realms_to_default_discount_dict() -> Dict[str, Decimal]:
 | 
					def get_realms_to_default_discount_dict() -> Dict[str, Decimal]:
 | 
				
			||||||
    realms_to_default_discount = {}
 | 
					    realms_to_default_discount: Dict[str, Any] = {}
 | 
				
			||||||
    customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0)
 | 
					    customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0)
 | 
				
			||||||
    for customer in customers:
 | 
					    for customer in customers:
 | 
				
			||||||
        realms_to_default_discount[customer.realm.string_id] = customer.default_discount
 | 
					        realms_to_default_discount[customer.realm.string_id] = assert_is_not_none(
 | 
				
			||||||
 | 
					            customer.default_discount
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
    return realms_to_default_discount
 | 
					    return realms_to_default_discount
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -90,6 +90,7 @@ from zerver.lib.actions import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from zerver.lib.test_classes import ZulipTestCase
 | 
					from zerver.lib.test_classes import ZulipTestCase
 | 
				
			||||||
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
 | 
					from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import (
 | 
					from zerver.models import (
 | 
				
			||||||
    Message,
 | 
					    Message,
 | 
				
			||||||
    Realm,
 | 
					    Realm,
 | 
				
			||||||
@@ -532,7 +533,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Check that we correctly created a Customer object in Stripe
 | 
					        # Check that we correctly created a Customer object in Stripe
 | 
				
			||||||
        stripe_customer = stripe_get_customer(
 | 
					        stripe_customer = stripe_get_customer(
 | 
				
			||||||
            Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					            assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertEqual(stripe_customer.default_source.id[:5], "card_")
 | 
					        self.assertEqual(stripe_customer.default_source.id[:5], "card_")
 | 
				
			||||||
        self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer))
 | 
					        self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer))
 | 
				
			||||||
@@ -642,9 +643,11 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
        self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
					        self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            orjson.loads(
 | 
					            orjson.loads(
 | 
				
			||||||
 | 
					                assert_is_not_none(
 | 
				
			||||||
                    RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
					                    RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
				
			||||||
                    .values_list("extra_data", flat=True)
 | 
					                    .values_list("extra_data", flat=True)
 | 
				
			||||||
                    .first()
 | 
					                    .first()
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            )["automanage_licenses"],
 | 
					            )["automanage_licenses"],
 | 
				
			||||||
            True,
 | 
					            True,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@@ -694,7 +697,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.upgrade(invoice=True)
 | 
					            self.upgrade(invoice=True)
 | 
				
			||||||
        # Check that we correctly created a Customer in Stripe
 | 
					        # Check that we correctly created a Customer in Stripe
 | 
				
			||||||
        stripe_customer = stripe_get_customer(
 | 
					        stripe_customer = stripe_get_customer(
 | 
				
			||||||
            Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					            assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer))
 | 
					        self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer))
 | 
				
			||||||
        # It can take a second for Stripe to attach the source to the customer, and in
 | 
					        # It can take a second for Stripe to attach the source to the customer, and in
 | 
				
			||||||
@@ -781,9 +784,11 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
        self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
					        self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            orjson.loads(
 | 
					            orjson.loads(
 | 
				
			||||||
 | 
					                assert_is_not_none(
 | 
				
			||||||
                    RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
					                    RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
				
			||||||
                    .values_list("extra_data", flat=True)
 | 
					                    .values_list("extra_data", flat=True)
 | 
				
			||||||
                    .first()
 | 
					                    .first()
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            )["automanage_licenses"],
 | 
					            )["automanage_licenses"],
 | 
				
			||||||
            False,
 | 
					            False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@@ -834,7 +839,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
                self.upgrade()
 | 
					                self.upgrade()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            stripe_customer = stripe_get_customer(
 | 
					            stripe_customer = stripe_get_customer(
 | 
				
			||||||
                Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					                assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.assertEqual(stripe_customer.default_source.id[:5], "card_")
 | 
					            self.assertEqual(stripe_customer.default_source.id[:5], "card_")
 | 
				
			||||||
            self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)")
 | 
					            self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)")
 | 
				
			||||||
@@ -894,9 +899,11 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
					            self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
				
			||||||
            self.assertEqual(
 | 
					            self.assertEqual(
 | 
				
			||||||
                orjson.loads(
 | 
					                orjson.loads(
 | 
				
			||||||
 | 
					                    assert_is_not_none(
 | 
				
			||||||
                        RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
					                        RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
				
			||||||
                        .values_list("extra_data", flat=True)
 | 
					                        .values_list("extra_data", flat=True)
 | 
				
			||||||
                        .first()
 | 
					                        .first()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                )["automanage_licenses"],
 | 
					                )["automanage_licenses"],
 | 
				
			||||||
                True,
 | 
					                True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@@ -1040,7 +1047,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
                self.upgrade(invoice=True)
 | 
					                self.upgrade(invoice=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            stripe_customer = stripe_get_customer(
 | 
					            stripe_customer = stripe_get_customer(
 | 
				
			||||||
                Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					                assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.assertEqual(stripe_customer.discount, None)
 | 
					            self.assertEqual(stripe_customer.discount, None)
 | 
				
			||||||
            self.assertEqual(stripe_customer.email, user.delivery_email)
 | 
					            self.assertEqual(stripe_customer.email, user.delivery_email)
 | 
				
			||||||
@@ -1093,9 +1100,11 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
					            self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
 | 
				
			||||||
            self.assertEqual(
 | 
					            self.assertEqual(
 | 
				
			||||||
                orjson.loads(
 | 
					                orjson.loads(
 | 
				
			||||||
 | 
					                    assert_is_not_none(
 | 
				
			||||||
                        RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
					                        RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
 | 
				
			||||||
                        .values_list("extra_data", flat=True)
 | 
					                        .values_list("extra_data", flat=True)
 | 
				
			||||||
                        .first()
 | 
					                        .first()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                )["automanage_licenses"],
 | 
					                )["automanage_licenses"],
 | 
				
			||||||
                False,
 | 
					                False,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@@ -1218,7 +1227,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.upgrade()
 | 
					            self.upgrade()
 | 
				
			||||||
        customer = Customer.objects.first()
 | 
					        customer = Customer.objects.first()
 | 
				
			||||||
        assert customer is not None
 | 
					        assert customer is not None
 | 
				
			||||||
        stripe_customer_id = customer.stripe_customer_id
 | 
					        stripe_customer_id: str = assert_is_not_none(customer.stripe_customer_id)
 | 
				
			||||||
        # Check that the Charge used the old quantity, not new_seat_count
 | 
					        # Check that the Charge used the old quantity, not new_seat_count
 | 
				
			||||||
        [charge] = stripe.Charge.list(customer=stripe_customer_id)
 | 
					        [charge] = stripe.Charge.list(customer=stripe_customer_id)
 | 
				
			||||||
        self.assertEqual(8000 * self.seat_count, charge.amount)
 | 
					        self.assertEqual(8000 * self.seat_count, charge.amount)
 | 
				
			||||||
@@ -2037,9 +2046,10 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
        audit_log = RealmAuditLog.objects.get(
 | 
					        audit_log = RealmAuditLog.objects.get(
 | 
				
			||||||
            event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN
 | 
					            event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        extra_data: str = assert_is_not_none(audit_log.extra_data)
 | 
				
			||||||
        self.assertEqual(audit_log.realm, user.realm)
 | 
					        self.assertEqual(audit_log.realm, user.realm)
 | 
				
			||||||
        self.assertEqual(orjson.loads(audit_log.extra_data)["monthly_plan_id"], monthly_plan.id)
 | 
					        self.assertEqual(orjson.loads(extra_data)["monthly_plan_id"], monthly_plan.id)
 | 
				
			||||||
        self.assertEqual(orjson.loads(audit_log.extra_data)["annual_plan_id"], annual_plan.id)
 | 
					        self.assertEqual(orjson.loads(extra_data)["annual_plan_id"], annual_plan.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        invoice_plans_as_needed(self.next_month)
 | 
					        invoice_plans_as_needed(self.next_month)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2468,7 +2478,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.assert_json_success(result)
 | 
					            self.assert_json_success(result)
 | 
				
			||||||
        invoice_plans_as_needed(self.next_year)
 | 
					        invoice_plans_as_needed(self.next_year)
 | 
				
			||||||
        stripe_customer = stripe_get_customer(
 | 
					        stripe_customer = stripe_get_customer(
 | 
				
			||||||
            Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					            assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        [invoice, _] = stripe.Invoice.list(customer=stripe_customer.id)
 | 
					        [invoice, _] = stripe.Invoice.list(customer=stripe_customer.id)
 | 
				
			||||||
        invoice_params = {
 | 
					        invoice_params = {
 | 
				
			||||||
@@ -2518,7 +2528,7 @@ class StripeTest(StripeTestCase):
 | 
				
			|||||||
            self.assert_json_success(result)
 | 
					            self.assert_json_success(result)
 | 
				
			||||||
        invoice_plans_as_needed(self.next_year + timedelta(days=365))
 | 
					        invoice_plans_as_needed(self.next_year + timedelta(days=365))
 | 
				
			||||||
        stripe_customer = stripe_get_customer(
 | 
					        stripe_customer = stripe_get_customer(
 | 
				
			||||||
            Customer.objects.get(realm=user.realm).stripe_customer_id
 | 
					            assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id)
 | 
					        [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id)
 | 
				
			||||||
        invoice_params = {
 | 
					        invoice_params = {
 | 
				
			||||||
@@ -3460,7 +3470,7 @@ class InvoiceTest(StripeTestCase):
 | 
				
			|||||||
        plan.invoicing_status = CustomerPlan.STARTED
 | 
					        plan.invoicing_status = CustomerPlan.STARTED
 | 
				
			||||||
        plan.save(update_fields=["invoicing_status"])
 | 
					        plan.save(update_fields=["invoicing_status"])
 | 
				
			||||||
        with self.assertRaises(NotImplementedError):
 | 
					        with self.assertRaises(NotImplementedError):
 | 
				
			||||||
            invoice_plan(CustomerPlan.objects.first(), self.now)
 | 
					            invoice_plan(assert_is_not_none(CustomerPlan.objects.first()), self.now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invoice_plan_without_stripe_customer(self) -> None:
 | 
					    def test_invoice_plan_without_stripe_customer(self) -> None:
 | 
				
			||||||
        self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL)
 | 
					        self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ from django.utils.timezone import now as timezone_now
 | 
				
			|||||||
from sentry_sdk import capture_exception
 | 
					from sentry_sdk import capture_exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from zerver.lib.logging_util import log_to_file
 | 
					from zerver.lib.logging_util import log_to_file
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import (
 | 
					from zerver.models import (
 | 
				
			||||||
    Message,
 | 
					    Message,
 | 
				
			||||||
    Realm,
 | 
					    Realm,
 | 
				
			||||||
@@ -63,12 +64,14 @@ def filter_by_subscription_history(
 | 
				
			|||||||
                # check belongs in this inner loop, not the outer loop.
 | 
					                # check belongs in this inner loop, not the outer loop.
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            event_last_message_id = assert_is_not_none(log_entry.event_last_message_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED:
 | 
					            if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED:
 | 
				
			||||||
                # If the event shows the user was unsubscribed after
 | 
					                # If the event shows the user was unsubscribed after
 | 
				
			||||||
                # event_last_message_id, we know they must have been
 | 
					                # event_last_message_id, we know they must have been
 | 
				
			||||||
                # subscribed immediately before the event.
 | 
					                # subscribed immediately before the event.
 | 
				
			||||||
                for stream_message in stream_messages:
 | 
					                for stream_message in stream_messages:
 | 
				
			||||||
                    if stream_message["id"] <= log_entry.event_last_message_id:
 | 
					                    if stream_message["id"] <= event_last_message_id:
 | 
				
			||||||
                        store_user_message_to_insert(stream_message)
 | 
					                        store_user_message_to_insert(stream_message)
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        break
 | 
					                        break
 | 
				
			||||||
@@ -78,12 +81,12 @@ def filter_by_subscription_history(
 | 
				
			|||||||
            ):
 | 
					            ):
 | 
				
			||||||
                initial_msg_count = len(stream_messages)
 | 
					                initial_msg_count = len(stream_messages)
 | 
				
			||||||
                for i, stream_message in enumerate(stream_messages):
 | 
					                for i, stream_message in enumerate(stream_messages):
 | 
				
			||||||
                    if stream_message["id"] > log_entry.event_last_message_id:
 | 
					                    if stream_message["id"] > event_last_message_id:
 | 
				
			||||||
                        stream_messages = stream_messages[i:]
 | 
					                        stream_messages = stream_messages[i:]
 | 
				
			||||||
                        break
 | 
					                        break
 | 
				
			||||||
                final_msg_count = len(stream_messages)
 | 
					                final_msg_count = len(stream_messages)
 | 
				
			||||||
                if initial_msg_count == final_msg_count:
 | 
					                if initial_msg_count == final_msg_count:
 | 
				
			||||||
                    if stream_messages[-1]["id"] <= log_entry.event_last_message_id:
 | 
					                    if stream_messages[-1]["id"] <= event_last_message_id:
 | 
				
			||||||
                        stream_messages = []
 | 
					                        stream_messages = []
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                raise AssertionError(f"{log_entry.event_type} is not a subscription event.")
 | 
					                raise AssertionError(f"{log_entry.event_type} is not a subscription event.")
 | 
				
			||||||
@@ -172,7 +175,7 @@ def add_missing_messages(user_profile: UserProfile) -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list)
 | 
					    all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list)
 | 
				
			||||||
    for log in subscription_logs:
 | 
					    for log in subscription_logs:
 | 
				
			||||||
        all_stream_subscription_logs[log.modified_stream_id].append(log)
 | 
					        all_stream_subscription_logs[assert_is_not_none(log.modified_stream_id)].append(log)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    recipient_ids = []
 | 
					    recipient_ids = []
 | 
				
			||||||
    for sub in all_stream_subs:
 | 
					    for sub in all_stream_subs:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -31,6 +31,7 @@ from PIL.Image import DecompressionBombError
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from zerver.lib.avatar_hash import user_avatar_path
 | 
					from zerver.lib.avatar_hash import user_avatar_path
 | 
				
			||||||
from zerver.lib.exceptions import ErrorCode, JsonableError
 | 
					from zerver.lib.exceptions import ErrorCode, JsonableError
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile
 | 
					from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEFAULT_AVATAR_SIZE = 100
 | 
					DEFAULT_AVATAR_SIZE = 100
 | 
				
			||||||
@@ -729,7 +730,7 @@ class S3UploadBackend(ZulipUploadBackend):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def write_local_file(type: str, path: str, file_data: bytes) -> None:
 | 
					def write_local_file(type: str, path: str, file_data: bytes) -> None:
 | 
				
			||||||
    file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
 | 
					    file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
 | 
					    os.makedirs(os.path.dirname(file_path), exist_ok=True)
 | 
				
			||||||
    with open(file_path, "wb") as f:
 | 
					    with open(file_path, "wb") as f:
 | 
				
			||||||
@@ -737,13 +738,13 @@ def write_local_file(type: str, path: str, file_data: bytes) -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_local_file(type: str, path: str) -> bytes:
 | 
					def read_local_file(type: str, path: str) -> bytes:
 | 
				
			||||||
    file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
 | 
					    file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
 | 
				
			||||||
    with open(file_path, "rb") as f:
 | 
					    with open(file_path, "rb") as f:
 | 
				
			||||||
        return f.read()
 | 
					        return f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def delete_local_file(type: str, path: str) -> bool:
 | 
					def delete_local_file(type: str, path: str) -> bool:
 | 
				
			||||||
    file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
 | 
					    file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
 | 
				
			||||||
    if os.path.isfile(file_path):
 | 
					    if os.path.isfile(file_path):
 | 
				
			||||||
        # This removes the file but the empty folders still remain.
 | 
					        # This removes the file but the empty folders still remain.
 | 
				
			||||||
        os.remove(file_path)
 | 
					        os.remove(file_path)
 | 
				
			||||||
@@ -754,7 +755,7 @@ def delete_local_file(type: str, path: str) -> bool:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_local_file_path(path_id: str) -> Optional[str]:
 | 
					def get_local_file_path(path_id: str) -> Optional[str]:
 | 
				
			||||||
    local_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "files", path_id)
 | 
					    local_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "files", path_id)
 | 
				
			||||||
    if os.path.isfile(local_path):
 | 
					    if os.path.isfile(local_path):
 | 
				
			||||||
        return local_path
 | 
					        return local_path
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
@@ -897,12 +898,16 @@ class LocalUploadBackend(ZulipUploadBackend):
 | 
				
			|||||||
        file_path = user_avatar_path(user_profile)
 | 
					        file_path = user_avatar_path(user_profile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output_path = os.path.join(
 | 
					        output_path = os.path.join(
 | 
				
			||||||
            settings.LOCAL_UPLOADS_DIR, "avatars", file_path + file_extension
 | 
					            assert_is_not_none(settings.LOCAL_UPLOADS_DIR),
 | 
				
			||||||
 | 
					            "avatars",
 | 
				
			||||||
 | 
					            file_path + file_extension,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if os.path.isfile(output_path):
 | 
					        if os.path.isfile(output_path):
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        image_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", file_path + ".original")
 | 
					        image_path = os.path.join(
 | 
				
			||||||
 | 
					            assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", file_path + ".original"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        with open(image_path, "rb") as f:
 | 
					        with open(image_path, "rb") as f:
 | 
				
			||||||
            image_data = f.read()
 | 
					            image_data = f.read()
 | 
				
			||||||
        if is_medium:
 | 
					        if is_medium:
 | 
				
			||||||
@@ -942,7 +947,7 @@ class LocalUploadBackend(ZulipUploadBackend):
 | 
				
			|||||||
            secrets.token_urlsafe(18),
 | 
					            secrets.token_urlsafe(18),
 | 
				
			||||||
            os.path.basename(tarball_path),
 | 
					            os.path.basename(tarball_path),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        abs_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", path)
 | 
					        abs_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", path)
 | 
				
			||||||
        os.makedirs(os.path.dirname(abs_path), exist_ok=True)
 | 
					        os.makedirs(os.path.dirname(abs_path), exist_ok=True)
 | 
				
			||||||
        shutil.copy(tarball_path, abs_path)
 | 
					        shutil.copy(tarball_path, abs_path)
 | 
				
			||||||
        public_url = realm.uri + "/user_avatars/" + path
 | 
					        public_url = realm.uri + "/user_avatars/" + path
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Set
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from zerver.lib.cache import cache_with_key, get_muting_users_cache_key
 | 
					from zerver.lib.cache import cache_with_key, get_muting_users_cache_key
 | 
				
			||||||
from zerver.lib.timestamp import datetime_to_timestamp
 | 
					from zerver.lib.timestamp import datetime_to_timestamp
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import MutedUser, UserProfile
 | 
					from zerver.models import MutedUser, UserProfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -14,7 +15,7 @@ def get_user_mutes(user_profile: UserProfile) -> List[Dict[str, int]]:
 | 
				
			|||||||
    return [
 | 
					    return [
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            "id": row["muted_user_id"],
 | 
					            "id": row["muted_user_id"],
 | 
				
			||||||
            "timestamp": datetime_to_timestamp(row["date_muted"]),
 | 
					            "timestamp": datetime_to_timestamp(assert_is_not_none(row["date_muted"])),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        for row in rows
 | 
					        for row in rows
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,6 +13,7 @@ from zerver.lib.exceptions import JsonableError
 | 
				
			|||||||
from zerver.lib.export import get_realm_exports_serialized
 | 
					from zerver.lib.export import get_realm_exports_serialized
 | 
				
			||||||
from zerver.lib.queue import queue_json_publish
 | 
					from zerver.lib.queue import queue_json_publish
 | 
				
			||||||
from zerver.lib.response import json_success
 | 
					from zerver.lib.response import json_success
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.models import RealmAuditLog, UserProfile
 | 
					from zerver.models import RealmAuditLog, UserProfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -90,7 +91,7 @@ def delete_realm_export(request: HttpRequest, user: UserProfile, export_id: int)
 | 
				
			|||||||
    except RealmAuditLog.DoesNotExist:
 | 
					    except RealmAuditLog.DoesNotExist:
 | 
				
			||||||
        raise JsonableError(_("Invalid data export ID"))
 | 
					        raise JsonableError(_("Invalid data export ID"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    export_data = orjson.loads(audit_log_entry.extra_data)
 | 
					    export_data = orjson.loads(assert_is_not_none(audit_log_entry.extra_data))
 | 
				
			||||||
    if "deleted_timestamp" in export_data:
 | 
					    if "deleted_timestamp" in export_data:
 | 
				
			||||||
        raise JsonableError(_("Export already deleted"))
 | 
					        raise JsonableError(_("Export already deleted"))
 | 
				
			||||||
    do_delete_realm_export(user, audit_log_entry)
 | 
					    do_delete_realm_export(user, audit_log_entry)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -71,6 +71,7 @@ from zerver.lib.topic import (
 | 
				
			|||||||
    messages_for_topic,
 | 
					    messages_for_topic,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from zerver.lib.types import Validator
 | 
					from zerver.lib.types import Validator
 | 
				
			||||||
 | 
					from zerver.lib.utils import assert_is_not_none
 | 
				
			||||||
from zerver.lib.validator import (
 | 
					from zerver.lib.validator import (
 | 
				
			||||||
    check_bool,
 | 
					    check_bool,
 | 
				
			||||||
    check_capped_string,
 | 
					    check_capped_string,
 | 
				
			||||||
@@ -726,7 +727,9 @@ def get_topics_backend(
 | 
				
			|||||||
    if is_web_public_query:
 | 
					    if is_web_public_query:
 | 
				
			||||||
        realm = get_valid_realm_from_request(request)
 | 
					        realm = get_valid_realm_from_request(request)
 | 
				
			||||||
        stream = access_web_public_stream(stream_id, realm)
 | 
					        stream = access_web_public_stream(stream_id, realm)
 | 
				
			||||||
        result = get_topic_history_for_public_stream(recipient_id=stream.recipient_id)
 | 
					        result = get_topic_history_for_public_stream(
 | 
				
			||||||
 | 
					            recipient_id=assert_is_not_none(stream.recipient_id)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        assert user_profile is not None
 | 
					        assert user_profile is not None
 | 
				
			||||||
@@ -753,7 +756,7 @@ def delete_in_topic(
 | 
				
			|||||||
) -> HttpResponse:
 | 
					) -> HttpResponse:
 | 
				
			||||||
    (stream, sub) = access_stream_by_id(user_profile, stream_id)
 | 
					    (stream, sub) = access_stream_by_id(user_profile, stream_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    messages = messages_for_topic(stream.recipient_id, topic_name)
 | 
					    messages = messages_for_topic(assert_is_not_none(stream.recipient_id), topic_name)
 | 
				
			||||||
    if not stream.is_history_public_to_subscribers():
 | 
					    if not stream.is_history_public_to_subscribers():
 | 
				
			||||||
        # Don't allow the user to delete messages that they don't have access to.
 | 
					        # Don't allow the user to delete messages that they don't have access to.
 | 
				
			||||||
        deletable_message_ids = UserMessage.objects.filter(
 | 
					        deletable_message_ids = UserMessage.objects.filter(
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user