models: Tighten function signatures with generic QuerySet.

TODO: For now, we import `_QuerySet` as `ValuesQuerySet`. But there
is a convenient reexport of `ValuesQuerySet` in `django_stubs_ext`
that does the same thing. Once we get django-stubs integrated,
we should import `ValuesQuerySet` from `django_stubs_ext` instead.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li
2022-06-23 16:47:50 -04:00
committed by Tim Abbott
parent 81ab0b3615
commit 9d18845be5

View File

@@ -118,6 +118,7 @@ STREAM_NAMES = TypeVar("STREAM_NAMES", Sequence[str], AbstractSet[str])
if TYPE_CHECKING:
# We use ModelBackend only for typing. Importing it otherwise causes circular dependency.
from django.contrib.auth.backends import ModelBackend
from django.db.models.query import _QuerySet as ValuesQuerySet
class EmojiInfo(TypedDict):
@@ -799,7 +800,7 @@ class Realm(models.Model):
def get_admin_users_and_bots(
self, include_realm_owners: bool = True
) -> Sequence["UserProfile"]:
) -> QuerySet["UserProfile"]:
"""Use this in contexts where we want administrative users as well as
bots with administrator privileges, like send_event calls for
notifications to all administrator users.
@@ -809,14 +810,13 @@ class Realm(models.Model):
else:
roles = [UserProfile.ROLE_REALM_ADMINISTRATOR]
# TODO: Change return type to QuerySet[UserProfile]
return UserProfile.objects.filter(
realm=self,
is_active=True,
role__in=roles,
)
def get_human_admin_users(self, include_realm_owners: bool = True) -> QuerySet:
def get_human_admin_users(self, include_realm_owners: bool = True) -> QuerySet["UserProfile"]:
"""Use this in contexts where we want only human users with
administrative privileges, like sending an email to all of a
realm's administrators (bots don't have real email addresses).
@@ -826,7 +826,6 @@ class Realm(models.Model):
else:
roles = [UserProfile.ROLE_REALM_ADMINISTRATOR]
# TODO: Change return type to QuerySet[UserProfile]
return UserProfile.objects.filter(
realm=self,
is_bot=False,
@@ -834,7 +833,7 @@ class Realm(models.Model):
role__in=roles,
)
def get_human_billing_admin_and_realm_owner_users(self) -> QuerySet:
def get_human_billing_admin_and_realm_owner_users(self) -> QuerySet["UserProfile"]:
return UserProfile.objects.filter(
Q(role=UserProfile.ROLE_REALM_OWNER) | Q(is_billing_admin=True),
realm=self,
@@ -842,8 +841,7 @@ class Realm(models.Model):
is_active=True,
)
def get_active_users(self) -> Sequence["UserProfile"]:
# TODO: Change return type to QuerySet[UserProfile]
def get_active_users(self) -> QuerySet["UserProfile"]:
return UserProfile.objects.filter(realm=self, is_active=True).select_related()
def get_first_human_user(self) -> Optional["UserProfile"]:
@@ -857,7 +855,7 @@ class Realm(models.Model):
"""
return UserProfile.objects.filter(realm=self, is_bot=False).order_by("id").first()
def get_human_owner_users(self) -> QuerySet:
def get_human_owner_users(self) -> QuerySet["UserProfile"]:
return UserProfile.objects.filter(
realm=self, is_bot=False, role=UserProfile.ROLE_REALM_OWNER, is_active=True
)
@@ -2254,9 +2252,9 @@ class PreregistrationUser(models.Model):
def filter_to_valid_prereg_users(
query: QuerySet,
query: QuerySet[PreregistrationUser],
invite_expires_in_minutes: Union[Optional[int], UnspecifiedValue] = UnspecifiedValue(),
) -> QuerySet:
) -> QuerySet[PreregistrationUser]:
"""
If invite_expires_in_days is specified, we return only those PreregistrationUser
objects that were created at most that many days in the past.
@@ -2494,7 +2492,7 @@ class Stream(models.Model):
]
@staticmethod
def get_client_data(query: QuerySet) -> List[APIStreamDict]:
def get_client_data(query: QuerySet["Stream"]) -> List[APIStreamDict]:
query = query.only(*Stream.API_FIELDS)
return [row.to_dict() for row in query]
@@ -2641,16 +2639,14 @@ def get_realm_stream(stream_name: str, realm_id: int) -> Stream:
return Stream.objects.select_related().get(name__iexact=stream_name.strip(), realm_id=realm_id)
def get_active_streams(realm: Realm) -> QuerySet:
# TODO: Change return type to QuerySet[Stream]
# NOTE: Return value is used as a QuerySet, so cannot currently be Sequence[QuerySet]
def get_active_streams(realm: Realm) -> QuerySet[Stream]:
"""
Return all streams (including invite-only streams) that have not been deactivated.
"""
return Stream.objects.filter(realm=realm, deactivated=False)
def get_linkable_streams(realm_id: int) -> QuerySet:
def get_linkable_streams(realm_id: int) -> QuerySet[Stream]:
"""
This returns the streams that we are allowed to linkify using
something like "#frontend" in our markup. For now the business
@@ -2674,7 +2670,7 @@ def get_stream_by_id_in_realm(stream_id: int, realm: Realm) -> Stream:
def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any]:
def fetch_streams_by_name(stream_names: List[str]) -> Sequence[Stream]:
def fetch_streams_by_name(stream_names: List[str]) -> QuerySet[Stream]:
#
# This should be just
#
@@ -2717,7 +2713,7 @@ def get_huddle_recipient(user_profile_ids: Set[int]) -> Recipient:
return huddle.recipient
def get_huddle_user_ids(recipient: Recipient) -> List[int]:
def get_huddle_user_ids(recipient: Recipient) -> "ValuesQuerySet[Subscription, int]":
assert recipient.type == Recipient.HUDDLE
return (
@@ -2926,8 +2922,7 @@ class Message(AbstractMessage):
]
def get_context_for_message(message: Message) -> Sequence[Message]:
# TODO: Change return type to QuerySet[Message]
def get_context_for_message(message: Message) -> QuerySet[Message]:
return Message.objects.filter(
recipient_id=message.recipient_id,
subject=message.subject,
@@ -3675,7 +3670,7 @@ def get_user_by_delivery_email(email: str, realm: Realm) -> UserProfile:
)
def get_users_by_delivery_email(emails: Set[str], realm: Realm) -> QuerySet:
def get_users_by_delivery_email(emails: Set[str], realm: Realm) -> QuerySet[UserProfile]:
"""This is similar to get_user_by_delivery_email, and
it has the same security caveats. It gets multiple
users and returns a QuerySet, since most callers
@@ -4059,7 +4054,7 @@ class DefaultStreamGroup(models.Model):
)
def get_default_stream_groups(realm: Realm) -> List[DefaultStreamGroup]:
def get_default_stream_groups(realm: Realm) -> QuerySet[DefaultStreamGroup]:
return DefaultStreamGroup.objects.filter(realm=realm)
@@ -4504,7 +4499,7 @@ class CustomProfileField(models.Model):
return f"<CustomProfileField: {self.realm} {self.name} {self.field_type} {self.order}>"
def custom_profile_fields_for_realm(realm_id: int) -> List[CustomProfileField]:
def custom_profile_fields_for_realm(realm_id: int) -> QuerySet[CustomProfileField]:
return CustomProfileField.objects.filter(realm=realm_id).order_by("order")