typing: Import ValuesQuerySet alias from django_stubs_ext.

This saves us from using a conditional import.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-09-19 15:48:53 -04:00
committed by Tim Abbott
parent a4eaa770f0
commit 7fd8d77ce0
8 changed files with 24 additions and 51 deletions

View File

@@ -3,7 +3,6 @@ import logging
from collections import defaultdict from collections import defaultdict
from email.headerregistry import Address from email.headerregistry import Address
from typing import ( from typing import (
TYPE_CHECKING,
AbstractSet, AbstractSet,
Any, Any,
Callable, Callable,
@@ -27,6 +26,7 @@ from django.utils.html import escape
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils.translation import override as override_language from django.utils.translation import override as override_language
from django_stubs_ext import ValuesQuerySet
from zerver.actions.uploads import do_claim_attachments from zerver.actions.uploads import do_claim_attachments
from zerver.lib.addressee import Addressee from zerver.lib.addressee import Addressee
@@ -89,9 +89,6 @@ from zerver.models import (
) )
from zerver.tornado.django_api import send_event from zerver.tornado.django_api import send_event
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
def compute_irc_user_fullname(email: str) -> str: def compute_irc_user_fullname(email: str) -> str:
return Address(addr_spec=email).username + " (IRC)" return Address(addr_spec=email).username + " (IRC)"

View File

@@ -1,18 +1,7 @@
import hashlib import hashlib
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Set, Tuple
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
import orjson import orjson
from django.conf import settings from django.conf import settings
@@ -20,6 +9,7 @@ from django.db import transaction
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils.translation import override as override_language from django.utils.translation import override as override_language
from django_stubs_ext import ValuesQuerySet
from zerver.actions.default_streams import ( from zerver.actions.default_streams import (
do_remove_default_stream, do_remove_default_stream,
@@ -75,9 +65,6 @@ from zerver.models import (
) )
from zerver.tornado.django_api import send_event from zerver.tornado.django_api import send_event
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
@transaction.atomic(savepoint=False) @transaction.atomic(savepoint=False)
def do_deactivate_stream( def do_deactivate_stream(
@@ -220,7 +207,7 @@ def merge_streams(
def get_subscriber_ids( def get_subscriber_ids(
stream: Stream, requesting_user: Optional[UserProfile] = None stream: Stream, requesting_user: Optional[UserProfile] = None
) -> "ValuesQuerySet[Subscription, int]": ) -> ValuesQuerySet[Subscription, int]:
subscriptions_query = get_subscribers_query(stream, requesting_user) subscriptions_query = get_subscribers_query(stream, requesting_user)
return subscriptions_query.values_list("user_profile_id", flat=True) return subscriptions_query.values_list("user_profile_id", flat=True)

View File

@@ -1,12 +1,13 @@
# See https://zulip.readthedocs.io/en/latest/subsystems/caching.html for docs # See https://zulip.readthedocs.io/en/latest/subsystems/caching.html for docs
import datetime import datetime
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Tuple from typing import Any, Callable, Dict, Iterable, Tuple
from django.conf import settings from django.conf import settings
from django.contrib.sessions.models import Session from django.contrib.sessions.models import Session
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django_stubs_ext import ValuesQuerySet
# This file needs to be different from cache.py because cache.py # This file needs to be different from cache.py because cache.py
# cannot import anything from zerver.models or we'd have an import # cannot import anything from zerver.models or we'd have an import
@@ -32,9 +33,6 @@ from zerver.models import (
huddle_hash_cache_key, huddle_hash_cache_key,
) )
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
def user_cache_items( def user_cache_items(
items_for_remote_cache: Dict[str, Tuple[UserProfile]], user_profile: UserProfile items_for_remote_cache: Dict[str, Tuple[UserProfile]], user_profile: UserProfile
@@ -73,7 +71,7 @@ def session_cache_items(
items_for_remote_cache[store.cache_key] = store.decode(session.session_data) items_for_remote_cache[store.cache_key] = store.decode(session.session_data)
def get_active_realm_ids() -> "ValuesQuerySet[RealmCount, int]": def get_active_realm_ids() -> ValuesQuerySet[RealmCount, int]:
"""For installations like Zulip Cloud hosting a lot of realms, it only makes """For installations like Zulip Cloud hosting a lot of realms, it only makes
sense to do cache-filling work for realms that have any currently sense to do cache-filling work for realms that have any currently
active users/clients. Otherwise, we end up with every single-user active users/clients. Otherwise, we end up with every single-user

View File

@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, TypedDict from typing import Dict, List, Optional, Set, Tuple, TypedDict
from django_stubs_ext import ValuesQuerySet
from zerver.lib.cache import ( from zerver.lib.cache import (
bulk_cached_fetch, bulk_cached_fetch,
@@ -10,9 +12,6 @@ from zerver.lib.cache import (
from zerver.lib.types import DisplayRecipientT, UserDisplayRecipient from zerver.lib.types import DisplayRecipientT, UserDisplayRecipient
from zerver.models import Recipient, Stream, UserProfile, bulk_get_huddle_user_ids from zerver.models import Recipient, Stream, UserProfile, bulk_get_huddle_user_ids
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
display_recipient_fields = [ display_recipient_fields = [
"id", "id",
"email", "email",
@@ -101,7 +100,7 @@ def bulk_fetch_display_recipients(
def stream_query_function( def stream_query_function(
recipient_ids: List[int], recipient_ids: List[int],
) -> "ValuesQuerySet[Stream, TinyStreamResult]": ) -> ValuesQuerySet[Stream, TinyStreamResult]:
stream_ids = [ stream_ids = [
recipient_id_to_type_pair_dict[recipient_id][1] for recipient_id in recipient_ids recipient_id_to_type_pair_dict[recipient_id][1] for recipient_id in recipient_ids
] ]

View File

@@ -3,7 +3,6 @@ import datetime
import zlib import zlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Collection, Collection,
Dict, Dict,
@@ -24,6 +23,7 @@ from django.db import connection
from django.db.models import Max, Sum from django.db.models import Max, Sum
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_stubs_ext import ValuesQuerySet
from psycopg2.sql import SQL from psycopg2.sql import SQL
from analytics.lib.counts import COUNT_STATS from analytics.lib.counts import COUNT_STATS
@@ -69,9 +69,6 @@ from zerver.models import (
query_for_ids, query_for_ids,
) )
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
class MessageDetailsDict(TypedDict, total=False): class MessageDetailsDict(TypedDict, total=False):
type: str type: str
@@ -891,7 +888,7 @@ def bulk_access_messages(
def bulk_access_messages_expect_usermessage( def bulk_access_messages_expect_usermessage(
user_profile_id: int, message_ids: Sequence[int] user_profile_id: int, message_ids: Sequence[int]
) -> "ValuesQuerySet[UserMessage, int]": ) -> ValuesQuerySet[UserMessage, int]:
""" """
Like bulk_access_messages, but faster and potentially stricter. Like bulk_access_messages, but faster and potentially stricter.

View File

@@ -2,12 +2,10 @@ import itertools
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
from typing import TYPE_CHECKING, AbstractSet, Any, Collection, Dict, List, Optional, Set from typing import AbstractSet, Any, Collection, Dict, List, Optional, Set
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django_stubs_ext import ValuesQuerySet
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
from zerver.models import AlertWord, Realm, Recipient, Stream, Subscription, UserProfile from zerver.models import AlertWord, Realm, Recipient, Stream, Subscription, UserProfile
@@ -53,7 +51,7 @@ def get_active_subscriptions_for_stream_ids(stream_ids: Set[int]) -> QuerySet[Su
def get_subscribed_stream_ids_for_user( def get_subscribed_stream_ids_for_user(
user_profile: UserProfile, user_profile: UserProfile,
) -> "ValuesQuerySet[Subscription, int]": ) -> ValuesQuerySet[Subscription, int]:
return Subscription.objects.filter( return Subscription.objects.filter(
user_profile_id=user_profile, user_profile_id=user_profile,
recipient__type=Recipient.STREAM, recipient__type=Recipient.STREAM,
@@ -63,7 +61,7 @@ def get_subscribed_stream_ids_for_user(
def get_subscribed_stream_recipient_ids_for_user( def get_subscribed_stream_recipient_ids_for_user(
user_profile: UserProfile, user_profile: UserProfile,
) -> "ValuesQuerySet[Subscription, int]": ) -> ValuesQuerySet[Subscription, int]:
return Subscription.objects.filter( return Subscription.objects.filter(
user_profile_id=user_profile, user_profile_id=user_profile,
recipient__type=Recipient.STREAM, recipient__type=Recipient.STREAM,

View File

@@ -1,16 +1,14 @@
from typing import TYPE_CHECKING, Dict, Iterable, List, Sequence, TypedDict from typing import Dict, Iterable, List, Sequence, TypedDict
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_cte import With from django_cte import With
from django_stubs_ext import ValuesQuerySet
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.models import GroupGroupMembership, Realm, UserGroup, UserGroupMembership, UserProfile from zerver.models import GroupGroupMembership, Realm, UserGroup, UserGroupMembership, UserProfile
if TYPE_CHECKING:
from django.db.models.query import _QuerySet as ValuesQuerySet
class UserGroupDict(TypedDict): class UserGroupDict(TypedDict):
id: int id: int
@@ -128,7 +126,7 @@ def create_user_group(
def get_user_group_direct_member_ids( def get_user_group_direct_member_ids(
user_group: UserGroup, user_group: UserGroup,
) -> "ValuesQuerySet[UserGroupMembership, int]": ) -> ValuesQuerySet[UserGroupMembership, int]:
return UserGroupMembership.objects.filter(user_group=user_group).values_list( return UserGroupMembership.objects.filter(user_group=user_group).values_list(
"user_profile_id", flat=True "user_profile_id", flat=True
) )

View File

@@ -50,7 +50,7 @@ from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from django_cte import CTEManager from django_cte import CTEManager
from django_stubs_ext import StrPromise from django_stubs_ext import StrPromise, ValuesQuerySet
from confirmation import settings as confirmation_settings from confirmation import settings as confirmation_settings
from zerver.lib import cache from zerver.lib import cache
@@ -119,7 +119,6 @@ STREAM_NAMES = TypeVar("STREAM_NAMES", Sequence[str], AbstractSet[str])
if TYPE_CHECKING: if TYPE_CHECKING:
# We use ModelBackend only for typing. Importing it otherwise causes circular dependency. # We use ModelBackend only for typing. Importing it otherwise causes circular dependency.
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
from django.db.models.query import _QuerySet as ValuesQuerySet
class EmojiInfo(TypedDict): class EmojiInfo(TypedDict):
@@ -160,10 +159,10 @@ RowT = TypeVar("RowT")
def query_for_ids( def query_for_ids(
query: "ValuesQuerySet[ModelT, RowT]", query: ValuesQuerySet[ModelT, RowT],
user_ids: List[int], user_ids: List[int],
field: str, field: str,
) -> "ValuesQuerySet[ModelT, RowT]": ) -> ValuesQuerySet[ModelT, RowT]:
""" """
This function optimizes searches of the form This function optimizes searches of the form
`user_profile_id in (1, 2, 3, 4)` by quickly `user_profile_id in (1, 2, 3, 4)` by quickly
@@ -2789,7 +2788,7 @@ def get_huddle_recipient(user_profile_ids: Set[int]) -> Recipient:
return huddle.recipient return huddle.recipient
def get_huddle_user_ids(recipient: Recipient) -> "ValuesQuerySet[Subscription, int]": def get_huddle_user_ids(recipient: Recipient) -> ValuesQuerySet["Subscription", int]:
assert recipient.type == Recipient.HUDDLE assert recipient.type == Recipient.HUDDLE
return ( return (