python: Avoid relying on Collection supertype of QuerySet.

QuerySet doesn’t implement __contains__, so it can’t be a subtype of
Container or Collection (https://code.djangoproject.com/ticket/35154).
This incorrect subtyping annotation was removed in
https://github.com/typeddjango/django-stubs/pull/1925, so we need to
stop relying on it before upgrading to django-stubs 5.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg
2024-04-16 20:28:33 -07:00
committed by Tim Abbott
parent 5654d051f7
commit f31579a220
9 changed files with 28 additions and 24 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Tuple,
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import transaction from django.db import transaction
from django.db.models import Q, Sum from django.db.models import Q, QuerySet, Sum
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from zxcvbn import zxcvbn from zxcvbn import zxcvbn
@@ -70,7 +70,7 @@ def do_send_confirmation_email(
return activation_url return activation_url
def estimate_recent_invites(realms: Collection[Realm], *, days: int) -> int: def estimate_recent_invites(realms: Collection[Realm] | QuerySet[Realm], *, days: int) -> int:
"""An upper bound on the number of invites sent in the last `days` days""" """An upper bound on the number of invites sent in the last `days` days"""
recent_invites = RealmCount.objects.filter( recent_invites = RealmCount.objects.filter(
realm__in=realms, realm__in=realms,

View File

@@ -238,8 +238,8 @@ def get_recipient_info(
if recipient.type == Recipient.PERSONAL: if recipient.type == Recipient.PERSONAL:
# The sender and recipient may be the same id, so # The sender and recipient may be the same id, so
# de-duplicate using a set. # de-duplicate using a set.
message_to_user_ids: Collection[int] = list({recipient.type_id, sender_id}) message_to_user_id_set = {recipient.type_id, sender_id}
assert len(message_to_user_ids) in [1, 2] assert len(message_to_user_id_set) in [1, 2]
elif recipient.type == Recipient.STREAM: elif recipient.type == Recipient.STREAM:
# Anybody calling us w/r/t a stream message needs to supply # Anybody calling us w/r/t a stream message needs to supply
@@ -302,9 +302,9 @@ def get_recipient_info(
.order_by("user_profile_id") .order_by("user_profile_id")
) )
message_to_user_ids = list() message_to_user_id_set = set()
for row in subscription_rows: for row in subscription_rows:
message_to_user_ids.append(row["user_profile_id"]) message_to_user_id_set.add(row["user_profile_id"])
# We store the 'sender_muted_stream' information here to avoid db query at # We store the 'sender_muted_stream' information here to avoid db query at
# a later stage when we perform automatically unmute topic in muted stream operation. # a later stage when we perform automatically unmute topic in muted stream operation.
if row["user_profile_id"] == sender_id: if row["user_profile_id"] == sender_id:
@@ -373,21 +373,18 @@ def get_recipient_info(
) )
elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP: elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP:
message_to_user_ids = get_huddle_user_ids(recipient) message_to_user_id_set = set(get_huddle_user_ids(recipient))
else: else:
raise ValueError("Bad recipient type") raise ValueError("Bad recipient type")
message_to_user_id_set = set(message_to_user_ids)
user_ids = set(message_to_user_id_set)
# Important note: Because we haven't rendered Markdown yet, we # Important note: Because we haven't rendered Markdown yet, we
# don't yet know which of these possibly-mentioned users was # don't yet know which of these possibly-mentioned users was
# actually mentioned in the message (in other words, the # actually mentioned in the message (in other words, the
# mention syntax might have been in a code block or otherwise # mention syntax might have been in a code block or otherwise
# escaped). `get_ids_for` will filter these extra user rows # escaped). `get_ids_for` will filter these extra user rows
# for our data structures not related to bots # for our data structures not related to bots
user_ids |= possibly_mentioned_user_ids user_ids = message_to_user_id_set | possibly_mentioned_user_ids
if user_ids: if user_ids:
query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter( query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter(

View File

@@ -1,6 +1,6 @@
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from django.db.models import Model from django.db.models import Model, QuerySet
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from zerver.lib.create_user import create_user_profile, get_display_email_address from zerver.lib.create_user import create_user_profile, get_display_email_address
@@ -163,7 +163,9 @@ def bulk_create_users(
def bulk_set_users_or_streams_recipient_fields( def bulk_set_users_or_streams_recipient_fields(
model: Type[Model], model: Type[Model],
objects: Union[Collection[UserProfile], Collection[Stream]], objects: Union[
Collection[UserProfile], QuerySet[UserProfile], Collection[Stream], QuerySet[Stream]
],
recipients: Optional[Iterable[Recipient]] = None, recipients: Optional[Iterable[Recipient]] = None,
) -> None: ) -> None:
assert model in [UserProfile, Stream] assert model in [UserProfile, Stream]

View File

@@ -7,7 +7,7 @@ from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.db.models import Exists, OuterRef from django.db.models import Exists, OuterRef, QuerySet
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@@ -331,7 +331,7 @@ def get_slim_stream_id_map(realm: Realm) -> Dict[int, Stream]:
def bulk_get_digest_context( def bulk_get_digest_context(
users: Collection[UserProfile], cutoff: float users: Collection[UserProfile] | QuerySet[UserProfile], cutoff: float
) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]: ) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]:
# We expect a non-empty list of users all from the same realm. # We expect a non-empty list of users all from the same realm.
assert users assert users

View File

@@ -420,7 +420,10 @@ def has_message_access(
def bulk_access_messages( def bulk_access_messages(
user_profile: UserProfile, messages: Collection[Message], *, stream: Optional[Stream] = None user_profile: UserProfile,
messages: Collection[Message] | QuerySet[Message],
*,
stream: Optional[Stream] = None,
) -> List[Message]: ) -> List[Message]:
"""This function does the full has_message_access check for each """This function does the full has_message_access check for each
message. If stream is provided, it is used to avoid unnecessary message. If stream is provided, it is used to avoid unnecessary

View File

@@ -2,7 +2,7 @@ import copy
import zlib import zlib
from datetime import datetime from datetime import datetime
from email.headerregistry import Address from email.headerregistry import Address
from typing import Any, Collection, Dict, List, Optional, TypedDict from typing import Any, Dict, Iterable, List, Optional, TypedDict
import orjson import orjson
@@ -78,7 +78,7 @@ def message_to_encoded_cache(message: Message, realm_id: Optional[int] = None) -
def update_message_cache( def update_message_cache(
changed_messages: Collection[Message], realm_id: Optional[int] = None changed_messages: Iterable[Message], realm_id: Optional[int] = None
) -> List[int]: ) -> List[int]:
"""Updates the message as stored in the to_dict cache (for serving """Updates the message as stored in the to_dict cache (for serving
messages).""" messages)."""
@@ -273,7 +273,7 @@ class MessageDict:
@staticmethod @staticmethod
def messages_to_encoded_cache( def messages_to_encoded_cache(
messages: Collection[Message], realm_id: Optional[int] = None messages: Iterable[Message], realm_id: Optional[int] = None
) -> Dict[int, bytes]: ) -> Dict[int, bytes]:
messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id) messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id)
encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict} encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict}
@@ -281,7 +281,7 @@ class MessageDict:
@staticmethod @staticmethod
def messages_to_encoded_cache_helper( def messages_to_encoded_cache_helper(
messages: Collection[Message], realm_id: Optional[int] = None messages: Iterable[Message], realm_id: Optional[int] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
# Near duplicate of the build_message_dict + get_raw_db_rows # Near duplicate of the build_message_dict + get_raw_db_rows
# code path that accepts already fetched Message objects # code path that accepts already fetched Message objects

View File

@@ -180,7 +180,7 @@ def get_users_for_streams(stream_ids: Set[int]) -> Dict[int, Set[UserProfile]]:
def bulk_get_subscriber_peer_info( def bulk_get_subscriber_peer_info(
realm: Realm, realm: Realm,
streams: Collection[Stream], streams: Collection[Stream] | QuerySet[Stream],
) -> SubscriberPeerInfo: ) -> SubscriberPeerInfo:
""" """
Glossary: Glossary:

View File

@@ -47,6 +47,7 @@ from django.urls import resolve
from django.utils import translation from django.utils import translation
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django_stubs_ext import ValuesQuerySet
from fakeldap import MockLDAP from fakeldap import MockLDAP
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
from requests import PreparedRequest from requests import PreparedRequest
@@ -1245,7 +1246,7 @@ Output:
""" """
self.assertEqual(self.get_json_error(result, status_code=status_code), msg) self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
def assert_length(self, items: Collection[Any], count: int) -> None: def assert_length(self, items: Collection[Any] | ValuesQuerySet[Any, Any], count: int) -> None:
actual_count = len(items) actual_count = len(items)
if actual_count != count: # nocoverage if actual_count != count: # nocoverage
print("\nITEMS:\n") print("\nITEMS:\n")

View File

@@ -1,10 +1,11 @@
import logging import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import Any, Collection from typing import Any
from django.conf import settings from django.conf import settings
from django.core.management.base import CommandError from django.core.management.base import CommandError
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet
from typing_extensions import override from typing_extensions import override
from zerver.lib.logging_util import log_to_file from zerver.lib.logging_util import log_to_file
@@ -20,7 +21,7 @@ log_to_file(logger, settings.LDAP_SYNC_LOG_PATH)
# Run this on a cron job to pick up on name changes. # Run this on a cron job to pick up on name changes.
@transaction.atomic @transaction.atomic
def sync_ldap_user_data( def sync_ldap_user_data(
user_profiles: Collection[UserProfile], deactivation_protection: bool = True user_profiles: QuerySet[UserProfile], deactivation_protection: bool = True
) -> None: ) -> None:
logger.info("Starting update.") logger.info("Starting update.")
try: try: