mirror of
https://github.com/zulip/zulip.git
synced 2025-11-09 08:26:11 +00:00
types: Better types for API fields.
Signed-off-by: Zixuan James Li <359101898@qq.com>
This commit is contained in:
committed by
Tim Abbott
parent
e6e975b470
commit
44ecd66eae
@@ -4,6 +4,7 @@ from django.db import transaction
|
|||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
from zerver.lib.exceptions import JsonableError
|
from zerver.lib.exceptions import JsonableError
|
||||||
|
from zerver.lib.types import APIStreamDict
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
DefaultStream,
|
DefaultStream,
|
||||||
DefaultStreamGroup,
|
DefaultStreamGroup,
|
||||||
@@ -183,7 +184,7 @@ def get_default_streams_for_realm(realm_id: int) -> List[Stream]:
|
|||||||
|
|
||||||
|
|
||||||
# returns default streams in JSON serializable format
|
# returns default streams in JSON serializable format
|
||||||
def streams_to_dicts_sorted(streams: List[Stream]) -> List[Dict[str, Any]]:
|
def streams_to_dicts_sorted(streams: List[Stream]) -> List[APIStreamDict]:
|
||||||
return sorted((stream.to_dict() for stream in streams), key=lambda elt: elt["name"])
|
return sorted((stream.to_dict() for stream in streams), key=lambda elt: elt["name"])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from zerver.lib.streams import (
|
|||||||
send_stream_creation_event,
|
send_stream_creation_event,
|
||||||
)
|
)
|
||||||
from zerver.lib.subscription_info import get_subscribers_query
|
from zerver.lib.subscription_info import get_subscribers_query
|
||||||
|
from zerver.lib.types import APISubscriptionDict
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
ArchivedAttachment,
|
ArchivedAttachment,
|
||||||
Attachment,
|
Attachment,
|
||||||
@@ -177,19 +178,47 @@ def send_subscription_add_events(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for user_id, sub_infos in info_by_user.items():
|
for user_id, sub_infos in info_by_user.items():
|
||||||
sub_dicts = []
|
sub_dicts: List[APISubscriptionDict] = []
|
||||||
for sub_info in sub_infos:
|
for sub_info in sub_infos:
|
||||||
stream = sub_info.stream
|
stream = sub_info.stream
|
||||||
stream_info = stream_info_dict[stream.id]
|
stream_info = stream_info_dict[stream.id]
|
||||||
subscription = sub_info.sub
|
subscription = sub_info.sub
|
||||||
sub_dict = stream.to_dict()
|
stream_dict = stream.to_dict()
|
||||||
for field_name in Subscription.API_FIELDS:
|
# This is verbose as we cannot unpack existing TypedDict
|
||||||
sub_dict[field_name] = getattr(subscription, field_name)
|
# to initialize another TypedDict while making mypy happy.
|
||||||
|
# https://github.com/python/mypy/issues/5382
|
||||||
|
sub_dict = APISubscriptionDict(
|
||||||
|
# Fields from Subscription.API_FIELDS
|
||||||
|
audible_notifications=subscription.audible_notifications,
|
||||||
|
color=subscription.color,
|
||||||
|
desktop_notifications=subscription.desktop_notifications,
|
||||||
|
email_notifications=subscription.email_notifications,
|
||||||
|
is_muted=subscription.is_muted,
|
||||||
|
pin_to_top=subscription.pin_to_top,
|
||||||
|
push_notifications=subscription.push_notifications,
|
||||||
|
role=subscription.role,
|
||||||
|
wildcard_mentions_notify=subscription.wildcard_mentions_notify,
|
||||||
|
# Computed fields not present in Subscription.API_FIELDS
|
||||||
|
email_address=stream_info.email_address,
|
||||||
|
in_home_view=not subscription.is_muted,
|
||||||
|
stream_weekly_traffic=stream_info.stream_weekly_traffic,
|
||||||
|
subscribers=stream_info.subscribers,
|
||||||
|
# Fields from Stream.API_FIELDS
|
||||||
|
date_created=stream_dict["date_created"],
|
||||||
|
description=stream_dict["description"],
|
||||||
|
first_message_id=stream_dict["first_message_id"],
|
||||||
|
history_public_to_subscribers=stream_dict["history_public_to_subscribers"],
|
||||||
|
invite_only=stream_dict["invite_only"],
|
||||||
|
is_web_public=stream_dict["is_web_public"],
|
||||||
|
message_retention_days=stream_dict["message_retention_days"],
|
||||||
|
name=stream_dict["name"],
|
||||||
|
rendered_description=stream_dict["rendered_description"],
|
||||||
|
stream_id=stream_dict["stream_id"],
|
||||||
|
stream_post_policy=stream_dict["stream_post_policy"],
|
||||||
|
# Computed fields not present in Stream.API_FIELDS
|
||||||
|
is_announcement_only=stream_dict["is_announcement_only"],
|
||||||
|
)
|
||||||
|
|
||||||
sub_dict["in_home_view"] = not subscription.is_muted
|
|
||||||
sub_dict["email_address"] = stream_info.email_address
|
|
||||||
sub_dict["stream_weekly_traffic"] = stream_info.stream_weekly_traffic
|
|
||||||
sub_dict["subscribers"] = stream_info.subscribers
|
|
||||||
sub_dicts.append(sub_dict)
|
sub_dicts.append(sub_dict)
|
||||||
|
|
||||||
# Send a notification to the user who subscribed.
|
# Send a notification to the user who subscribed.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Collection, Dict, List, Optional, Set, Tuple, TypedDict, Union
|
from typing import Collection, List, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Exists, OuterRef, Q
|
from django.db.models import Exists, OuterRef, Q
|
||||||
@@ -18,6 +18,7 @@ from zerver.lib.stream_subscription import (
|
|||||||
get_subscribed_stream_ids_for_user,
|
get_subscribed_stream_ids_for_user,
|
||||||
)
|
)
|
||||||
from zerver.lib.string_validation import check_stream_name
|
from zerver.lib.string_validation import check_stream_name
|
||||||
|
from zerver.lib.types import APIStreamDict
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
DefaultStreamGroup,
|
DefaultStreamGroup,
|
||||||
Realm,
|
Realm,
|
||||||
@@ -785,7 +786,7 @@ def get_occupied_streams(realm: Realm) -> QuerySet:
|
|||||||
return occupied_streams
|
return occupied_streams
|
||||||
|
|
||||||
|
|
||||||
def get_web_public_streams(realm: Realm) -> List[Dict[str, Any]]: # nocoverage
|
def get_web_public_streams(realm: Realm) -> List[APIStreamDict]: # nocoverage
|
||||||
query = get_web_public_streams_queryset(realm)
|
query = get_web_public_streams_queryset(realm)
|
||||||
streams = Stream.get_client_data(query)
|
streams = Stream.get_client_data(query)
|
||||||
return streams
|
return streams
|
||||||
@@ -799,7 +800,7 @@ def do_get_streams(
|
|||||||
include_all_active: bool = False,
|
include_all_active: bool = False,
|
||||||
include_default: bool = False,
|
include_default: bool = False,
|
||||||
include_owner_subscribed: bool = False,
|
include_owner_subscribed: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[APIStreamDict]:
|
||||||
# This function is only used by API clients now.
|
# This function is only used by API clients now.
|
||||||
|
|
||||||
if include_all_active and not user_profile.is_realm_admin:
|
if include_all_active and not user_profile.is_realm_admin:
|
||||||
|
|||||||
@@ -221,6 +221,49 @@ class NeverSubscribedStreamDict(TypedDict):
|
|||||||
subscribers: NotRequired[List[int]]
|
subscribers: NotRequired[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class APIStreamDict(TypedDict):
|
||||||
|
"""Stream information provided to Zulip clients as a dictionary via API.
|
||||||
|
It should contain all the fields specified in `zerver.models.Stream.API_FIELDS`
|
||||||
|
with few exceptions and possible additional fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
date_created: int
|
||||||
|
description: str
|
||||||
|
first_message_id: Optional[int]
|
||||||
|
history_public_to_subscribers: bool
|
||||||
|
invite_only: bool
|
||||||
|
is_web_public: bool
|
||||||
|
message_retention_days: Optional[int]
|
||||||
|
name: str
|
||||||
|
rendered_description: str
|
||||||
|
stream_id: int # `stream_id`` represents `id` of the `Stream` object in `API_FIELDS`
|
||||||
|
stream_post_policy: int
|
||||||
|
# Computed fields not specified in `Stream.API_FIELDS`
|
||||||
|
is_announcement_only: bool
|
||||||
|
is_default: NotRequired[bool]
|
||||||
|
|
||||||
|
|
||||||
|
class APISubscriptionDict(APIStreamDict):
|
||||||
|
"""Similar to StreamClientDict, it should contain all the fields specified in
|
||||||
|
`zerver.models.Subscription.API_FIELDS` and several additional fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audible_notifications: Optional[bool]
|
||||||
|
color: str
|
||||||
|
desktop_notifications: Optional[bool]
|
||||||
|
email_notifications: Optional[bool]
|
||||||
|
is_muted: bool
|
||||||
|
pin_to_top: bool
|
||||||
|
push_notifications: Optional[bool]
|
||||||
|
role: int
|
||||||
|
wildcard_mentions_notify: Optional[bool]
|
||||||
|
# Computed fields not specified in `Subscription.API_FIELDS`
|
||||||
|
email_address: str
|
||||||
|
in_home_view: bool
|
||||||
|
stream_weekly_traffic: Optional[int]
|
||||||
|
subscribers: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SubscriptionInfo:
|
class SubscriptionInfo:
|
||||||
subscriptions: List[SubscriptionStreamDict]
|
subscriptions: List[SubscriptionStreamDict]
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ from zerver.lib.exceptions import JsonableError, RateLimited
|
|||||||
from zerver.lib.pysa import mark_sanitized
|
from zerver.lib.pysa import mark_sanitized
|
||||||
from zerver.lib.timestamp import datetime_to_timestamp
|
from zerver.lib.timestamp import datetime_to_timestamp
|
||||||
from zerver.lib.types import (
|
from zerver.lib.types import (
|
||||||
|
APIStreamDict,
|
||||||
DisplayRecipientT,
|
DisplayRecipientT,
|
||||||
ExtendedFieldElement,
|
ExtendedFieldElement,
|
||||||
ExtendedValidator,
|
ExtendedValidator,
|
||||||
@@ -2483,22 +2484,25 @@ class Stream(models.Model):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_client_data(query: QuerySet) -> List[Dict[str, Any]]:
|
def get_client_data(query: QuerySet) -> List[APIStreamDict]:
|
||||||
query = query.only(*Stream.API_FIELDS)
|
query = query.only(*Stream.API_FIELDS)
|
||||||
return [row.to_dict() for row in query]
|
return [row.to_dict() for row in query]
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> APIStreamDict:
|
||||||
result = {}
|
return APIStreamDict(
|
||||||
for field_name in self.API_FIELDS:
|
date_created=datetime_to_timestamp(self.date_created),
|
||||||
if field_name == "id":
|
description=self.description,
|
||||||
result["stream_id"] = self.id
|
first_message_id=self.first_message_id,
|
||||||
continue
|
history_public_to_subscribers=self.history_public_to_subscribers,
|
||||||
elif field_name == "date_created":
|
invite_only=self.invite_only,
|
||||||
result["date_created"] = datetime_to_timestamp(self.date_created)
|
is_web_public=self.is_web_public,
|
||||||
continue
|
message_retention_days=self.message_retention_days,
|
||||||
result[field_name] = getattr(self, field_name)
|
name=self.name,
|
||||||
result["is_announcement_only"] = self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS
|
rendered_description=self.rendered_description,
|
||||||
return result
|
stream_id=self.id,
|
||||||
|
stream_post_policy=self.stream_post_policy,
|
||||||
|
is_announcement_only=self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS,
|
||||||
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
indexes = [
|
indexes = [
|
||||||
|
|||||||
@@ -75,7 +75,12 @@ from zerver.lib.test_helpers import (
|
|||||||
queries_captured,
|
queries_captured,
|
||||||
reset_emails_in_zulip_realm,
|
reset_emails_in_zulip_realm,
|
||||||
)
|
)
|
||||||
from zerver.lib.types import NeverSubscribedStreamDict, SubscriptionInfo
|
from zerver.lib.types import (
|
||||||
|
APIStreamDict,
|
||||||
|
APISubscriptionDict,
|
||||||
|
NeverSubscribedStreamDict,
|
||||||
|
SubscriptionInfo,
|
||||||
|
)
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
Attachment,
|
Attachment,
|
||||||
DefaultStream,
|
DefaultStream,
|
||||||
@@ -205,6 +210,32 @@ class TestMiscStuff(ZulipTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(streams, [])
|
self.assertEqual(streams, [])
|
||||||
|
|
||||||
|
def test_api_fields(self) -> None:
|
||||||
|
"""Verify that all the fields from `Stream.API_FIELDS` and `Subscription.API_FIELDS` present
|
||||||
|
in `APIStreamDict` and `APISubscriptionDict`, respectively.
|
||||||
|
"""
|
||||||
|
expected_fields = set(Stream.API_FIELDS) | {"stream_id"}
|
||||||
|
expected_fields -= {"id"}
|
||||||
|
|
||||||
|
stream_dict_fields = set(APIStreamDict.__annotations__.keys())
|
||||||
|
computed_fields = set(["is_announcement_only", "is_default"])
|
||||||
|
|
||||||
|
self.assertEqual(stream_dict_fields - computed_fields, expected_fields)
|
||||||
|
|
||||||
|
expected_fields = set(Subscription.API_FIELDS)
|
||||||
|
|
||||||
|
subscription_dict_fields = set(APISubscriptionDict.__annotations__.keys())
|
||||||
|
computed_fields = set(
|
||||||
|
["in_home_view", "email_address", "stream_weekly_traffic", "subscribers"]
|
||||||
|
)
|
||||||
|
# `APISubscriptionDict` is a subclass of `APIStreamDict`, therefore having all the
|
||||||
|
# fields in addition to the computed fields and `Subscription.API_FIELDS` that
|
||||||
|
# need to be excluded here.
|
||||||
|
self.assertEqual(
|
||||||
|
subscription_dict_fields - computed_fields - stream_dict_fields,
|
||||||
|
expected_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCreateStreams(ZulipTestCase):
|
class TestCreateStreams(ZulipTestCase):
|
||||||
def test_creating_streams(self) -> None:
|
def test_creating_streams(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user