typing: Apply trivial none-checks with assertions as necessary.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-06-14 23:17:23 -04:00
committed by Tim Abbott
parent 58e95cbfae
commit fd9a0f4274
7 changed files with 10 additions and 3 deletions

View File

@@ -165,6 +165,7 @@ def support(
if len(keys) != 2: if len(keys) != 2:
raise JsonableError(_("Invalid parameters")) raise JsonableError(_("Invalid parameters"))
assert realm_id is not None
realm = Realm.objects.get(id=realm_id) realm = Realm.objects.get(id=realm_id)
acting_user = request.user acting_user = request.user

View File

@@ -1178,6 +1178,7 @@ def switch_realm_from_standard_to_plus_plan(realm: Realm) -> None:
LicenseLedger.objects.filter(is_renewal=True, plan=standard_plan).order_by("id").last() LicenseLedger.objects.filter(is_renewal=True, plan=standard_plan).order_by("id").last()
) )
assert standard_plan_last_renewal_ledger is not None assert standard_plan_last_renewal_ledger is not None
assert standard_plan.price_per_license is not None
standard_plan_last_renewal_amount = ( standard_plan_last_renewal_amount = (
standard_plan_last_renewal_ledger.licenses * standard_plan.price_per_license standard_plan_last_renewal_ledger.licenses * standard_plan.price_per_license
) )

View File

@@ -266,6 +266,7 @@ def do_get_invites_controlled_by_user(user_profile: UserProfile) -> List[Dict[st
invites = [] invites = []
for invitee in prereg_users: for invitee in prereg_users:
assert invitee.referred_by is not None
invites.append( invites.append(
dict( dict(
email=invitee.email, email=invitee.email,

View File

@@ -6,6 +6,7 @@ from django.conf import settings
def generate_camo_url(url: str) -> str: def generate_camo_url(url: str) -> str:
encoded_url = url.encode() encoded_url = url.encode()
assert settings.CAMO_KEY is not None
encoded_camo_key = settings.CAMO_KEY.encode() encoded_camo_key = settings.CAMO_KEY.encode()
digest = hmac.new(encoded_camo_key, encoded_url, hashlib.sha1).hexdigest() digest = hmac.new(encoded_camo_key, encoded_url, hashlib.sha1).hexdigest()
return f"{digest}/{encoded_url.hex()}" return f"{digest}/{encoded_url.hex()}"

View File

@@ -31,6 +31,7 @@ from scripts.lib.zulip_tools import overwrite_symlink
from zerver.lib.avatar_hash import user_avatar_path_from_ids from zerver.lib.avatar_hash import user_avatar_path_from_ids
from zerver.lib.pysa import mark_sanitized from zerver.lib.pysa import mark_sanitized
from zerver.lib.upload import get_bucket from zerver.lib.upload import get_bucket
from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
AlertWord, AlertWord,
Attachment, Attachment,
@@ -435,7 +436,7 @@ def floatify_datetime_fields(data: TableData, table: TableName) -> None:
dt = timezone_make_aware(orig_dt) dt = timezone_make_aware(orig_dt)
else: else:
dt = orig_dt dt = orig_dt
utc_naive = dt.replace(tzinfo=None) - dt.utcoffset() utc_naive = dt.replace(tzinfo=None) - assert_is_not_none(dt.utcoffset())
item[field] = (utc_naive - datetime.datetime(1970, 1, 1)).total_seconds() item[field] = (utc_naive - datetime.datetime(1970, 1, 1)).total_seconds()

View File

@@ -5,7 +5,7 @@ import threading
import traceback import traceback
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from logging import Logger from logging import Logger
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import orjson import orjson
from django.conf import settings from django.conf import settings
@@ -255,10 +255,11 @@ class ZulipWebhookFormatter(ZulipFormatter):
return super().format(record) return super().format(record)
if request.content_type == "application/json": if request.content_type == "application/json":
payload = request.body payload: Union[str, bytes, None] = request.body
else: else:
payload = request.POST.get("payload") payload = request.POST.get("payload")
assert payload is not None
try: try:
payload = orjson.dumps(orjson.loads(payload), option=orjson.OPT_INDENT_2).decode() payload = orjson.dumps(orjson.loads(payload), option=orjson.OPT_INDENT_2).decode()
except orjson.JSONDecodeError: except orjson.JSONDecodeError:

View File

@@ -93,6 +93,7 @@ def get_used_colors_for_user_ids(user_ids: List[int]) -> Dict[int, Set[str]]:
result: Dict[int, Set[str]] = defaultdict(set) result: Dict[int, Set[str]] = defaultdict(set)
for row in list(query): for row in list(query):
assert row["color"] is not None
result[row["user_profile_id"]].add(row["color"]) result[row["user_profile_id"]].add(row["color"])
return result return result