typing: Use assertions for function arguments.

Utilize the assert_is_not_None helper to eliminate errors of
'Argument x to "Foo" has incompatible type "Optional[Bar]"...'
This commit is contained in:
PIG208
2021-07-25 22:31:12 +08:00
committed by Tim Abbott
parent 8a91d1c2b1
commit 7d1c475f69
9 changed files with 80 additions and 48 deletions

View File

@@ -51,6 +51,7 @@ from zerver.lib.exceptions import InvitationError
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day
from zerver.lib.topic import DB_TOPIC_NAME from zerver.lib.topic import DB_TOPIC_NAME
from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
Client, Client,
Huddle, Huddle,
@@ -1388,11 +1389,13 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(5) assertInviteCountEquals(5)
# Revoking invite should not give you credit # Revoking invite should not give you credit
do_revoke_user_invite(PreregistrationUser.objects.filter(realm=user.realm).first()) do_revoke_user_invite(
assert_is_not_none(PreregistrationUser.objects.filter(realm=user.realm).first())
)
assertInviteCountEquals(5) assertInviteCountEquals(5)
# Resending invite should cost you # Resending invite should cost you
do_resend_user_invite_email(PreregistrationUser.objects.first()) do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first()))
assertInviteCountEquals(6) assertInviteCountEquals(6)
def test_messages_read_hour(self) -> None: def test_messages_read_hour(self) -> None:
@@ -1423,7 +1426,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
self.send_stream_message(user1, stream.name) self.send_stream_message(user1, stream.name)
self.send_stream_message(user1, stream.name) self.send_stream_message(user1, stream.name)
do_mark_stream_messages_as_read(user2, stream.recipient_id) do_mark_stream_messages_as_read(user2, assert_is_not_none(stream.recipient_id))
self.assertEqual( self.assertEqual(
3, 3,
UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[ UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[

View File

@@ -28,6 +28,7 @@ from zerver.lib.actions import (
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.realm_icon import realm_icon_url from zerver.lib.realm_icon import realm_icon_url
from zerver.lib.subdomains import get_subdomain_from_hostname from zerver.lib.subdomains import get_subdomain_from_hostname
from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
MultiuseInvite, MultiuseInvite,
PreregistrationUser, PreregistrationUser,
@@ -119,24 +120,24 @@ def support(request: HttpRequest) -> HttpResponse:
if len(keys) != 2: if len(keys) != 2:
raise JsonableError(_("Invalid parameters")) raise JsonableError(_("Invalid parameters"))
realm_id = request.POST.get("realm_id") realm_id: str = assert_is_not_none(request.POST.get("realm_id"))
realm = Realm.objects.get(id=realm_id) realm = Realm.objects.get(id=realm_id)
if request.POST.get("plan_type", None) is not None: if request.POST.get("plan_type", None) is not None:
new_plan_type = int(request.POST.get("plan_type")) new_plan_type = int(assert_is_not_none(request.POST.get("plan_type")))
current_plan_type = realm.plan_type current_plan_type = realm.plan_type
do_change_plan_type(realm, new_plan_type, acting_user=request.user) do_change_plan_type(realm, new_plan_type, acting_user=request.user)
msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} " msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} "
context["success_message"] = msg context["success_message"] = msg
elif request.POST.get("discount", None) is not None: elif request.POST.get("discount", None) is not None:
new_discount = Decimal(request.POST.get("discount")) new_discount = Decimal(assert_is_not_none(request.POST.get("discount")))
current_discount = get_discount_for_realm(realm) or 0 current_discount = get_discount_for_realm(realm) or 0
attach_discount_to_realm(realm, new_discount, acting_user=request.user) attach_discount_to_realm(realm, new_discount, acting_user=request.user)
context[ context[
"success_message" "success_message"
] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%." ] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%."
elif request.POST.get("new_subdomain", None) is not None: elif request.POST.get("new_subdomain", None) is not None:
new_subdomain = request.POST.get("new_subdomain") new_subdomain: str = assert_is_not_none(request.POST.get("new_subdomain"))
old_subdomain = realm.string_id old_subdomain = realm.string_id
try: try:
check_subdomain_available(new_subdomain) check_subdomain_available(new_subdomain)

View File

@@ -5,7 +5,7 @@ import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
from functools import wraps from functools import wraps
from typing import Callable, Dict, Generator, Optional, Tuple, TypeVar, cast from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, cast
import orjson import orjson
import stripe import stripe
@@ -30,6 +30,7 @@ from zerver.lib.exceptions import JsonableError
from zerver.lib.logging_util import log_to_file from zerver.lib.logging_util import log_to_file
from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none
from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
from zproject.config import get_secret from zproject.config import get_secret
@@ -516,7 +517,9 @@ def compute_plan_parameters(
if automanage_licenses: if automanage_licenses:
next_invoice_date = add_months(billing_cycle_anchor, 1) next_invoice_date = add_months(billing_cycle_anchor, 1)
if free_trial: if free_trial:
period_end = billing_cycle_anchor + timedelta(days=settings.FREE_TRIAL_DAYS) period_end = billing_cycle_anchor + timedelta(
days=assert_is_not_none(settings.FREE_TRIAL_DAYS)
)
next_invoice_date = period_end next_invoice_date = period_end
return billing_cycle_anchor, next_invoice_date, period_end, price_per_license return billing_cycle_anchor, next_invoice_date, period_end, price_per_license
@@ -943,10 +946,12 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag
def get_realms_to_default_discount_dict() -> Dict[str, Decimal]: def get_realms_to_default_discount_dict() -> Dict[str, Decimal]:
realms_to_default_discount = {} realms_to_default_discount: Dict[str, Any] = {}
customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0) customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0)
for customer in customers: for customer in customers:
realms_to_default_discount[customer.realm.string_id] = customer.default_discount realms_to_default_discount[customer.realm.string_id] = assert_is_not_none(
customer.default_discount
)
return realms_to_default_discount return realms_to_default_discount

View File

@@ -90,6 +90,7 @@ from zerver.lib.actions import (
) )
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
Message, Message,
Realm, Realm,
@@ -532,7 +533,7 @@ class StripeTest(StripeTestCase):
# Check that we correctly created a Customer object in Stripe # Check that we correctly created a Customer object in Stripe
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
self.assertEqual(stripe_customer.default_source.id[:5], "card_") self.assertEqual(stripe_customer.default_source.id[:5], "card_")
self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer)) self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer))
@@ -642,9 +643,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual( self.assertEqual(
orjson.loads( orjson.loads(
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True) .values_list("extra_data", flat=True)
.first() .first()
)
)["automanage_licenses"], )["automanage_licenses"],
True, True,
) )
@@ -694,7 +697,7 @@ class StripeTest(StripeTestCase):
self.upgrade(invoice=True) self.upgrade(invoice=True)
# Check that we correctly created a Customer in Stripe # Check that we correctly created a Customer in Stripe
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer)) self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer))
# It can take a second for Stripe to attach the source to the customer, and in # It can take a second for Stripe to attach the source to the customer, and in
@@ -781,9 +784,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual( self.assertEqual(
orjson.loads( orjson.loads(
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True) .values_list("extra_data", flat=True)
.first() .first()
)
)["automanage_licenses"], )["automanage_licenses"],
False, False,
) )
@@ -834,7 +839,7 @@ class StripeTest(StripeTestCase):
self.upgrade() self.upgrade()
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
self.assertEqual(stripe_customer.default_source.id[:5], "card_") self.assertEqual(stripe_customer.default_source.id[:5], "card_")
self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)") self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)")
@@ -894,9 +899,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual( self.assertEqual(
orjson.loads( orjson.loads(
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True) .values_list("extra_data", flat=True)
.first() .first()
)
)["automanage_licenses"], )["automanage_licenses"],
True, True,
) )
@@ -1040,7 +1047,7 @@ class StripeTest(StripeTestCase):
self.upgrade(invoice=True) self.upgrade(invoice=True)
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
self.assertEqual(stripe_customer.discount, None) self.assertEqual(stripe_customer.discount, None)
self.assertEqual(stripe_customer.email, user.delivery_email) self.assertEqual(stripe_customer.email, user.delivery_email)
@@ -1093,9 +1100,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual( self.assertEqual(
orjson.loads( orjson.loads(
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True) .values_list("extra_data", flat=True)
.first() .first()
)
)["automanage_licenses"], )["automanage_licenses"],
False, False,
) )
@@ -1218,7 +1227,7 @@ class StripeTest(StripeTestCase):
self.upgrade() self.upgrade()
customer = Customer.objects.first() customer = Customer.objects.first()
assert customer is not None assert customer is not None
stripe_customer_id = customer.stripe_customer_id stripe_customer_id: str = assert_is_not_none(customer.stripe_customer_id)
# Check that the Charge used the old quantity, not new_seat_count # Check that the Charge used the old quantity, not new_seat_count
[charge] = stripe.Charge.list(customer=stripe_customer_id) [charge] = stripe.Charge.list(customer=stripe_customer_id)
self.assertEqual(8000 * self.seat_count, charge.amount) self.assertEqual(8000 * self.seat_count, charge.amount)
@@ -2037,9 +2046,10 @@ class StripeTest(StripeTestCase):
audit_log = RealmAuditLog.objects.get( audit_log = RealmAuditLog.objects.get(
event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN
) )
extra_data: str = assert_is_not_none(audit_log.extra_data)
self.assertEqual(audit_log.realm, user.realm) self.assertEqual(audit_log.realm, user.realm)
self.assertEqual(orjson.loads(audit_log.extra_data)["monthly_plan_id"], monthly_plan.id) self.assertEqual(orjson.loads(extra_data)["monthly_plan_id"], monthly_plan.id)
self.assertEqual(orjson.loads(audit_log.extra_data)["annual_plan_id"], annual_plan.id) self.assertEqual(orjson.loads(extra_data)["annual_plan_id"], annual_plan.id)
invoice_plans_as_needed(self.next_month) invoice_plans_as_needed(self.next_month)
@@ -2468,7 +2478,7 @@ class StripeTest(StripeTestCase):
self.assert_json_success(result) self.assert_json_success(result)
invoice_plans_as_needed(self.next_year) invoice_plans_as_needed(self.next_year)
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
[invoice, _] = stripe.Invoice.list(customer=stripe_customer.id) [invoice, _] = stripe.Invoice.list(customer=stripe_customer.id)
invoice_params = { invoice_params = {
@@ -2518,7 +2528,7 @@ class StripeTest(StripeTestCase):
self.assert_json_success(result) self.assert_json_success(result)
invoice_plans_as_needed(self.next_year + timedelta(days=365)) invoice_plans_as_needed(self.next_year + timedelta(days=365))
stripe_customer = stripe_get_customer( stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
) )
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id) [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id)
invoice_params = { invoice_params = {
@@ -3460,7 +3470,7 @@ class InvoiceTest(StripeTestCase):
plan.invoicing_status = CustomerPlan.STARTED plan.invoicing_status = CustomerPlan.STARTED
plan.save(update_fields=["invoicing_status"]) plan.save(update_fields=["invoicing_status"])
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
invoice_plan(CustomerPlan.objects.first(), self.now) invoice_plan(assert_is_not_none(CustomerPlan.objects.first()), self.now)
def test_invoice_plan_without_stripe_customer(self) -> None: def test_invoice_plan_without_stripe_customer(self) -> None:
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL) self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL)

View File

@@ -10,6 +10,7 @@ from django.utils.timezone import now as timezone_now
from sentry_sdk import capture_exception from sentry_sdk import capture_exception
from zerver.lib.logging_util import log_to_file from zerver.lib.logging_util import log_to_file
from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
Message, Message,
Realm, Realm,
@@ -63,12 +64,14 @@ def filter_by_subscription_history(
# check belongs in this inner loop, not the outer loop. # check belongs in this inner loop, not the outer loop.
break break
event_last_message_id = assert_is_not_none(log_entry.event_last_message_id)
if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED: if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED:
# If the event shows the user was unsubscribed after # If the event shows the user was unsubscribed after
# event_last_message_id, we know they must have been # event_last_message_id, we know they must have been
# subscribed immediately before the event. # subscribed immediately before the event.
for stream_message in stream_messages: for stream_message in stream_messages:
if stream_message["id"] <= log_entry.event_last_message_id: if stream_message["id"] <= event_last_message_id:
store_user_message_to_insert(stream_message) store_user_message_to_insert(stream_message)
else: else:
break break
@@ -78,12 +81,12 @@ def filter_by_subscription_history(
): ):
initial_msg_count = len(stream_messages) initial_msg_count = len(stream_messages)
for i, stream_message in enumerate(stream_messages): for i, stream_message in enumerate(stream_messages):
if stream_message["id"] > log_entry.event_last_message_id: if stream_message["id"] > event_last_message_id:
stream_messages = stream_messages[i:] stream_messages = stream_messages[i:]
break break
final_msg_count = len(stream_messages) final_msg_count = len(stream_messages)
if initial_msg_count == final_msg_count: if initial_msg_count == final_msg_count:
if stream_messages[-1]["id"] <= log_entry.event_last_message_id: if stream_messages[-1]["id"] <= event_last_message_id:
stream_messages = [] stream_messages = []
else: else:
raise AssertionError(f"{log_entry.event_type} is not a subscription event.") raise AssertionError(f"{log_entry.event_type} is not a subscription event.")
@@ -172,7 +175,7 @@ def add_missing_messages(user_profile: UserProfile) -> None:
all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list) all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list)
for log in subscription_logs: for log in subscription_logs:
all_stream_subscription_logs[log.modified_stream_id].append(log) all_stream_subscription_logs[assert_is_not_none(log.modified_stream_id)].append(log)
recipient_ids = [] recipient_ids = []
for sub in all_stream_subs: for sub in all_stream_subs:

View File

@@ -31,6 +31,7 @@ from PIL.Image import DecompressionBombError
from zerver.lib.avatar_hash import user_avatar_path from zerver.lib.avatar_hash import user_avatar_path
from zerver.lib.exceptions import ErrorCode, JsonableError from zerver.lib.exceptions import ErrorCode, JsonableError
from zerver.lib.utils import assert_is_not_none
from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile
DEFAULT_AVATAR_SIZE = 100 DEFAULT_AVATAR_SIZE = 100
@@ -729,7 +730,7 @@ class S3UploadBackend(ZulipUploadBackend):
def write_local_file(type: str, path: str, file_data: bytes) -> None: def write_local_file(type: str, path: str, file_data: bytes) -> None:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@@ -737,13 +738,13 @@ def write_local_file(type: str, path: str, file_data: bytes) -> None:
def read_local_file(type: str, path: str) -> bytes: def read_local_file(type: str, path: str) -> bytes:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
return f.read() return f.read()
def delete_local_file(type: str, path: str) -> bool: def delete_local_file(type: str, path: str) -> bool:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
# This removes the file but the empty folders still remain. # This removes the file but the empty folders still remain.
os.remove(file_path) os.remove(file_path)
@@ -754,7 +755,7 @@ def delete_local_file(type: str, path: str) -> bool:
def get_local_file_path(path_id: str) -> Optional[str]: def get_local_file_path(path_id: str) -> Optional[str]:
local_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "files", path_id) local_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "files", path_id)
if os.path.isfile(local_path): if os.path.isfile(local_path):
return local_path return local_path
else: else:
@@ -897,12 +898,16 @@ class LocalUploadBackend(ZulipUploadBackend):
file_path = user_avatar_path(user_profile) file_path = user_avatar_path(user_profile)
output_path = os.path.join( output_path = os.path.join(
settings.LOCAL_UPLOADS_DIR, "avatars", file_path + file_extension assert_is_not_none(settings.LOCAL_UPLOADS_DIR),
"avatars",
file_path + file_extension,
) )
if os.path.isfile(output_path): if os.path.isfile(output_path):
return return
image_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", file_path + ".original") image_path = os.path.join(
assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", file_path + ".original"
)
with open(image_path, "rb") as f: with open(image_path, "rb") as f:
image_data = f.read() image_data = f.read()
if is_medium: if is_medium:
@@ -942,7 +947,7 @@ class LocalUploadBackend(ZulipUploadBackend):
secrets.token_urlsafe(18), secrets.token_urlsafe(18),
os.path.basename(tarball_path), os.path.basename(tarball_path),
) )
abs_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", path) abs_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True) os.makedirs(os.path.dirname(abs_path), exist_ok=True)
shutil.copy(tarball_path, abs_path) shutil.copy(tarball_path, abs_path)
public_url = realm.uri + "/user_avatars/" + path public_url = realm.uri + "/user_avatars/" + path

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Set
from zerver.lib.cache import cache_with_key, get_muting_users_cache_key from zerver.lib.cache import cache_with_key, get_muting_users_cache_key
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.utils import assert_is_not_none
from zerver.models import MutedUser, UserProfile from zerver.models import MutedUser, UserProfile
@@ -14,7 +15,7 @@ def get_user_mutes(user_profile: UserProfile) -> List[Dict[str, int]]:
return [ return [
{ {
"id": row["muted_user_id"], "id": row["muted_user_id"],
"timestamp": datetime_to_timestamp(row["date_muted"]), "timestamp": datetime_to_timestamp(assert_is_not_none(row["date_muted"])),
} }
for row in rows for row in rows
] ]

View File

@@ -13,6 +13,7 @@ from zerver.lib.exceptions import JsonableError
from zerver.lib.export import get_realm_exports_serialized from zerver.lib.export import get_realm_exports_serialized
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_json_publish
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.utils import assert_is_not_none
from zerver.models import RealmAuditLog, UserProfile from zerver.models import RealmAuditLog, UserProfile
@@ -90,7 +91,7 @@ def delete_realm_export(request: HttpRequest, user: UserProfile, export_id: int)
except RealmAuditLog.DoesNotExist: except RealmAuditLog.DoesNotExist:
raise JsonableError(_("Invalid data export ID")) raise JsonableError(_("Invalid data export ID"))
export_data = orjson.loads(audit_log_entry.extra_data) export_data = orjson.loads(assert_is_not_none(audit_log_entry.extra_data))
if "deleted_timestamp" in export_data: if "deleted_timestamp" in export_data:
raise JsonableError(_("Export already deleted")) raise JsonableError(_("Export already deleted"))
do_delete_realm_export(user, audit_log_entry) do_delete_realm_export(user, audit_log_entry)

View File

@@ -71,6 +71,7 @@ from zerver.lib.topic import (
messages_for_topic, messages_for_topic,
) )
from zerver.lib.types import Validator from zerver.lib.types import Validator
from zerver.lib.utils import assert_is_not_none
from zerver.lib.validator import ( from zerver.lib.validator import (
check_bool, check_bool,
check_capped_string, check_capped_string,
@@ -726,7 +727,9 @@ def get_topics_backend(
if is_web_public_query: if is_web_public_query:
realm = get_valid_realm_from_request(request) realm = get_valid_realm_from_request(request)
stream = access_web_public_stream(stream_id, realm) stream = access_web_public_stream(stream_id, realm)
result = get_topic_history_for_public_stream(recipient_id=stream.recipient_id) result = get_topic_history_for_public_stream(
recipient_id=assert_is_not_none(stream.recipient_id)
)
else: else:
assert user_profile is not None assert user_profile is not None
@@ -753,7 +756,7 @@ def delete_in_topic(
) -> HttpResponse: ) -> HttpResponse:
(stream, sub) = access_stream_by_id(user_profile, stream_id) (stream, sub) = access_stream_by_id(user_profile, stream_id)
messages = messages_for_topic(stream.recipient_id, topic_name) messages = messages_for_topic(assert_is_not_none(stream.recipient_id), topic_name)
if not stream.is_history_public_to_subscribers(): if not stream.is_history_public_to_subscribers():
# Don't allow the user to delete messages that they don't have access to. # Don't allow the user to delete messages that they don't have access to.
deletable_message_ids = UserMessage.objects.filter( deletable_message_ids = UserMessage.objects.filter(