mirror of
https://github.com/zulip/zulip.git
synced 2025-11-09 16:37:23 +00:00
python: Avoid relying on Collection supertype of QuerySet.
QuerySet doesn’t implement __contains__, so it can’t be a subtype of Container or Collection (https://code.djangoproject.com/ticket/35154). This incorrect subtyping annotation was removed in https://github.com/typeddjango/django-stubs/pull/1925, so we need to stop relying on it before upgrading to django-stubs 5. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
committed by
Tim Abbott
parent
5654d051f7
commit
f31579a220
@@ -5,7 +5,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Tuple,
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Q, Sum
|
from django.db.models import Q, QuerySet, Sum
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from zxcvbn import zxcvbn
|
from zxcvbn import zxcvbn
|
||||||
@@ -70,7 +70,7 @@ def do_send_confirmation_email(
|
|||||||
return activation_url
|
return activation_url
|
||||||
|
|
||||||
|
|
||||||
def estimate_recent_invites(realms: Collection[Realm], *, days: int) -> int:
|
def estimate_recent_invites(realms: Collection[Realm] | QuerySet[Realm], *, days: int) -> int:
|
||||||
"""An upper bound on the number of invites sent in the last `days` days"""
|
"""An upper bound on the number of invites sent in the last `days` days"""
|
||||||
recent_invites = RealmCount.objects.filter(
|
recent_invites = RealmCount.objects.filter(
|
||||||
realm__in=realms,
|
realm__in=realms,
|
||||||
|
|||||||
@@ -238,8 +238,8 @@ def get_recipient_info(
|
|||||||
if recipient.type == Recipient.PERSONAL:
|
if recipient.type == Recipient.PERSONAL:
|
||||||
# The sender and recipient may be the same id, so
|
# The sender and recipient may be the same id, so
|
||||||
# de-duplicate using a set.
|
# de-duplicate using a set.
|
||||||
message_to_user_ids: Collection[int] = list({recipient.type_id, sender_id})
|
message_to_user_id_set = {recipient.type_id, sender_id}
|
||||||
assert len(message_to_user_ids) in [1, 2]
|
assert len(message_to_user_id_set) in [1, 2]
|
||||||
|
|
||||||
elif recipient.type == Recipient.STREAM:
|
elif recipient.type == Recipient.STREAM:
|
||||||
# Anybody calling us w/r/t a stream message needs to supply
|
# Anybody calling us w/r/t a stream message needs to supply
|
||||||
@@ -302,9 +302,9 @@ def get_recipient_info(
|
|||||||
.order_by("user_profile_id")
|
.order_by("user_profile_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
message_to_user_ids = list()
|
message_to_user_id_set = set()
|
||||||
for row in subscription_rows:
|
for row in subscription_rows:
|
||||||
message_to_user_ids.append(row["user_profile_id"])
|
message_to_user_id_set.add(row["user_profile_id"])
|
||||||
# We store the 'sender_muted_stream' information here to avoid db query at
|
# We store the 'sender_muted_stream' information here to avoid db query at
|
||||||
# a later stage when we perform automatically unmute topic in muted stream operation.
|
# a later stage when we perform automatically unmute topic in muted stream operation.
|
||||||
if row["user_profile_id"] == sender_id:
|
if row["user_profile_id"] == sender_id:
|
||||||
@@ -373,21 +373,18 @@ def get_recipient_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP:
|
elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP:
|
||||||
message_to_user_ids = get_huddle_user_ids(recipient)
|
message_to_user_id_set = set(get_huddle_user_ids(recipient))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Bad recipient type")
|
raise ValueError("Bad recipient type")
|
||||||
|
|
||||||
message_to_user_id_set = set(message_to_user_ids)
|
|
||||||
|
|
||||||
user_ids = set(message_to_user_id_set)
|
|
||||||
# Important note: Because we haven't rendered Markdown yet, we
|
# Important note: Because we haven't rendered Markdown yet, we
|
||||||
# don't yet know which of these possibly-mentioned users was
|
# don't yet know which of these possibly-mentioned users was
|
||||||
# actually mentioned in the message (in other words, the
|
# actually mentioned in the message (in other words, the
|
||||||
# mention syntax might have been in a code block or otherwise
|
# mention syntax might have been in a code block or otherwise
|
||||||
# escaped). `get_ids_for` will filter these extra user rows
|
# escaped). `get_ids_for` will filter these extra user rows
|
||||||
# for our data structures not related to bots
|
# for our data structures not related to bots
|
||||||
user_ids |= possibly_mentioned_user_ids
|
user_ids = message_to_user_id_set | possibly_mentioned_user_ids
|
||||||
|
|
||||||
if user_ids:
|
if user_ids:
|
||||||
query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter(
|
query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model, QuerySet
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
|
|
||||||
from zerver.lib.create_user import create_user_profile, get_display_email_address
|
from zerver.lib.create_user import create_user_profile, get_display_email_address
|
||||||
@@ -163,7 +163,9 @@ def bulk_create_users(
|
|||||||
|
|
||||||
def bulk_set_users_or_streams_recipient_fields(
|
def bulk_set_users_or_streams_recipient_fields(
|
||||||
model: Type[Model],
|
model: Type[Model],
|
||||||
objects: Union[Collection[UserProfile], Collection[Stream]],
|
objects: Union[
|
||||||
|
Collection[UserProfile], QuerySet[UserProfile], Collection[Stream], QuerySet[Stream]
|
||||||
|
],
|
||||||
recipients: Optional[Iterable[Recipient]] = None,
|
recipients: Optional[Iterable[Recipient]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert model in [UserProfile, Stream]
|
assert model in [UserProfile, Stream]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple
|
|||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Exists, OuterRef
|
from django.db.models import Exists, OuterRef, QuerySet
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
@@ -331,7 +331,7 @@ def get_slim_stream_id_map(realm: Realm) -> Dict[int, Stream]:
|
|||||||
|
|
||||||
|
|
||||||
def bulk_get_digest_context(
|
def bulk_get_digest_context(
|
||||||
users: Collection[UserProfile], cutoff: float
|
users: Collection[UserProfile] | QuerySet[UserProfile], cutoff: float
|
||||||
) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]:
|
) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]:
|
||||||
# We expect a non-empty list of users all from the same realm.
|
# We expect a non-empty list of users all from the same realm.
|
||||||
assert users
|
assert users
|
||||||
|
|||||||
@@ -420,7 +420,10 @@ def has_message_access(
|
|||||||
|
|
||||||
|
|
||||||
def bulk_access_messages(
|
def bulk_access_messages(
|
||||||
user_profile: UserProfile, messages: Collection[Message], *, stream: Optional[Stream] = None
|
user_profile: UserProfile,
|
||||||
|
messages: Collection[Message] | QuerySet[Message],
|
||||||
|
*,
|
||||||
|
stream: Optional[Stream] = None,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
"""This function does the full has_message_access check for each
|
"""This function does the full has_message_access check for each
|
||||||
message. If stream is provided, it is used to avoid unnecessary
|
message. If stream is provided, it is used to avoid unnecessary
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import copy
|
|||||||
import zlib
|
import zlib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from email.headerregistry import Address
|
from email.headerregistry import Address
|
||||||
from typing import Any, Collection, Dict, List, Optional, TypedDict
|
from typing import Any, Dict, Iterable, List, Optional, TypedDict
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ def message_to_encoded_cache(message: Message, realm_id: Optional[int] = None) -
|
|||||||
|
|
||||||
|
|
||||||
def update_message_cache(
|
def update_message_cache(
|
||||||
changed_messages: Collection[Message], realm_id: Optional[int] = None
|
changed_messages: Iterable[Message], realm_id: Optional[int] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Updates the message as stored in the to_dict cache (for serving
|
"""Updates the message as stored in the to_dict cache (for serving
|
||||||
messages)."""
|
messages)."""
|
||||||
@@ -273,7 +273,7 @@ class MessageDict:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def messages_to_encoded_cache(
|
def messages_to_encoded_cache(
|
||||||
messages: Collection[Message], realm_id: Optional[int] = None
|
messages: Iterable[Message], realm_id: Optional[int] = None
|
||||||
) -> Dict[int, bytes]:
|
) -> Dict[int, bytes]:
|
||||||
messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id)
|
messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id)
|
||||||
encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict}
|
encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict}
|
||||||
@@ -281,7 +281,7 @@ class MessageDict:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def messages_to_encoded_cache_helper(
|
def messages_to_encoded_cache_helper(
|
||||||
messages: Collection[Message], realm_id: Optional[int] = None
|
messages: Iterable[Message], realm_id: Optional[int] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
# Near duplicate of the build_message_dict + get_raw_db_rows
|
# Near duplicate of the build_message_dict + get_raw_db_rows
|
||||||
# code path that accepts already fetched Message objects
|
# code path that accepts already fetched Message objects
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ def get_users_for_streams(stream_ids: Set[int]) -> Dict[int, Set[UserProfile]]:
|
|||||||
|
|
||||||
def bulk_get_subscriber_peer_info(
|
def bulk_get_subscriber_peer_info(
|
||||||
realm: Realm,
|
realm: Realm,
|
||||||
streams: Collection[Stream],
|
streams: Collection[Stream] | QuerySet[Stream],
|
||||||
) -> SubscriberPeerInfo:
|
) -> SubscriberPeerInfo:
|
||||||
"""
|
"""
|
||||||
Glossary:
|
Glossary:
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from django.urls import resolve
|
|||||||
from django.utils import translation
|
from django.utils import translation
|
||||||
from django.utils.module_loading import import_string
|
from django.utils.module_loading import import_string
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
|
from django_stubs_ext import ValuesQuerySet
|
||||||
from fakeldap import MockLDAP
|
from fakeldap import MockLDAP
|
||||||
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
|
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
|
||||||
from requests import PreparedRequest
|
from requests import PreparedRequest
|
||||||
@@ -1245,7 +1246,7 @@ Output:
|
|||||||
"""
|
"""
|
||||||
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
|
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
|
||||||
|
|
||||||
def assert_length(self, items: Collection[Any], count: int) -> None:
|
def assert_length(self, items: Collection[Any] | ValuesQuerySet[Any, Any], count: int) -> None:
|
||||||
actual_count = len(items)
|
actual_count = len(items)
|
||||||
if actual_count != count: # nocoverage
|
if actual_count != count: # nocoverage
|
||||||
print("\nITEMS:\n")
|
print("\nITEMS:\n")
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from typing import Any, Collection
|
from typing import Any
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.management.base import CommandError
|
from django.core.management.base import CommandError
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
|
from django.db.models import QuerySet
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from zerver.lib.logging_util import log_to_file
|
from zerver.lib.logging_util import log_to_file
|
||||||
@@ -20,7 +21,7 @@ log_to_file(logger, settings.LDAP_SYNC_LOG_PATH)
|
|||||||
# Run this on a cron job to pick up on name changes.
|
# Run this on a cron job to pick up on name changes.
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def sync_ldap_user_data(
|
def sync_ldap_user_data(
|
||||||
user_profiles: Collection[UserProfile], deactivation_protection: bool = True
|
user_profiles: QuerySet[UserProfile], deactivation_protection: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("Starting update.")
|
logger.info("Starting update.")
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user