types: Better types for API fields.

Signed-off-by: Zixuan James Li <359101898@qq.com>
This commit is contained in:
Zixuan James Li
2022-05-25 20:51:35 -04:00
committed by Tim Abbott
parent e6e975b470
commit 44ecd66eae
6 changed files with 135 additions and 26 deletions

View File

@@ -4,6 +4,7 @@ from django.db import transaction
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.types import APIStreamDict
from zerver.models import ( from zerver.models import (
DefaultStream, DefaultStream,
DefaultStreamGroup, DefaultStreamGroup,
@@ -183,7 +184,7 @@ def get_default_streams_for_realm(realm_id: int) -> List[Stream]:
# returns default streams in JSON serializable format # returns default streams in JSON serializable format
def streams_to_dicts_sorted(streams: List[Stream]) -> List[Dict[str, Any]]: def streams_to_dicts_sorted(streams: List[Stream]) -> List[APIStreamDict]:
return sorted((stream.to_dict() for stream in streams), key=lambda elt: elt["name"]) return sorted((stream.to_dict() for stream in streams), key=lambda elt: elt["name"])

View File

@@ -46,6 +46,7 @@ from zerver.lib.streams import (
send_stream_creation_event, send_stream_creation_event,
) )
from zerver.lib.subscription_info import get_subscribers_query from zerver.lib.subscription_info import get_subscribers_query
from zerver.lib.types import APISubscriptionDict
from zerver.models import ( from zerver.models import (
ArchivedAttachment, ArchivedAttachment,
Attachment, Attachment,
@@ -177,19 +178,47 @@ def send_subscription_add_events(
) )
for user_id, sub_infos in info_by_user.items(): for user_id, sub_infos in info_by_user.items():
sub_dicts = [] sub_dicts: List[APISubscriptionDict] = []
for sub_info in sub_infos: for sub_info in sub_infos:
stream = sub_info.stream stream = sub_info.stream
stream_info = stream_info_dict[stream.id] stream_info = stream_info_dict[stream.id]
subscription = sub_info.sub subscription = sub_info.sub
sub_dict = stream.to_dict() stream_dict = stream.to_dict()
for field_name in Subscription.API_FIELDS: # This is verbose as we cannot unpack existing TypedDict
sub_dict[field_name] = getattr(subscription, field_name) # to initialize another TypedDict while making mypy happy.
# https://github.com/python/mypy/issues/5382
sub_dict = APISubscriptionDict(
# Fields from Subscription.API_FIELDS
audible_notifications=subscription.audible_notifications,
color=subscription.color,
desktop_notifications=subscription.desktop_notifications,
email_notifications=subscription.email_notifications,
is_muted=subscription.is_muted,
pin_to_top=subscription.pin_to_top,
push_notifications=subscription.push_notifications,
role=subscription.role,
wildcard_mentions_notify=subscription.wildcard_mentions_notify,
# Computed fields not present in Subscription.API_FIELDS
email_address=stream_info.email_address,
in_home_view=not subscription.is_muted,
stream_weekly_traffic=stream_info.stream_weekly_traffic,
subscribers=stream_info.subscribers,
# Fields from Stream.API_FIELDS
date_created=stream_dict["date_created"],
description=stream_dict["description"],
first_message_id=stream_dict["first_message_id"],
history_public_to_subscribers=stream_dict["history_public_to_subscribers"],
invite_only=stream_dict["invite_only"],
is_web_public=stream_dict["is_web_public"],
message_retention_days=stream_dict["message_retention_days"],
name=stream_dict["name"],
rendered_description=stream_dict["rendered_description"],
stream_id=stream_dict["stream_id"],
stream_post_policy=stream_dict["stream_post_policy"],
# Computed fields not present in Stream.API_FIELDS
is_announcement_only=stream_dict["is_announcement_only"],
)
sub_dict["in_home_view"] = not subscription.is_muted
sub_dict["email_address"] = stream_info.email_address
sub_dict["stream_weekly_traffic"] = stream_info.stream_weekly_traffic
sub_dict["subscribers"] = stream_info.subscribers
sub_dicts.append(sub_dict) sub_dicts.append(sub_dict)
# Send a notification to the user who subscribed. # Send a notification to the user who subscribed.

View File

@@ -1,4 +1,4 @@
from typing import Any, Collection, Dict, List, Optional, Set, Tuple, TypedDict, Union from typing import Collection, List, Optional, Set, Tuple, TypedDict, Union
from django.db import transaction from django.db import transaction
from django.db.models import Exists, OuterRef, Q from django.db.models import Exists, OuterRef, Q
@@ -18,6 +18,7 @@ from zerver.lib.stream_subscription import (
get_subscribed_stream_ids_for_user, get_subscribed_stream_ids_for_user,
) )
from zerver.lib.string_validation import check_stream_name from zerver.lib.string_validation import check_stream_name
from zerver.lib.types import APIStreamDict
from zerver.models import ( from zerver.models import (
DefaultStreamGroup, DefaultStreamGroup,
Realm, Realm,
@@ -785,7 +786,7 @@ def get_occupied_streams(realm: Realm) -> QuerySet:
return occupied_streams return occupied_streams
def get_web_public_streams(realm: Realm) -> List[Dict[str, Any]]: # nocoverage def get_web_public_streams(realm: Realm) -> List[APIStreamDict]: # nocoverage
query = get_web_public_streams_queryset(realm) query = get_web_public_streams_queryset(realm)
streams = Stream.get_client_data(query) streams = Stream.get_client_data(query)
return streams return streams
@@ -799,7 +800,7 @@ def do_get_streams(
include_all_active: bool = False, include_all_active: bool = False,
include_default: bool = False, include_default: bool = False,
include_owner_subscribed: bool = False, include_owner_subscribed: bool = False,
) -> List[Dict[str, Any]]: ) -> List[APIStreamDict]:
# This function is only used by API clients now. # This function is only used by API clients now.
if include_all_active and not user_profile.is_realm_admin: if include_all_active and not user_profile.is_realm_admin:

View File

@@ -221,6 +221,49 @@ class NeverSubscribedStreamDict(TypedDict):
subscribers: NotRequired[List[int]] subscribers: NotRequired[List[int]]
class APIStreamDict(TypedDict):
"""Stream information provided to Zulip clients as a dictionary via API.
It should contain all the fields specified in `zerver.models.Stream.API_FIELDS`
with few exceptions and possible additional fields.
"""
date_created: int
description: str
first_message_id: Optional[int]
history_public_to_subscribers: bool
invite_only: bool
is_web_public: bool
message_retention_days: Optional[int]
name: str
rendered_description: str
stream_id: int # `stream_id`` represents `id` of the `Stream` object in `API_FIELDS`
stream_post_policy: int
# Computed fields not specified in `Stream.API_FIELDS`
is_announcement_only: bool
is_default: NotRequired[bool]
class APISubscriptionDict(APIStreamDict):
"""Similar to StreamClientDict, it should contain all the fields specified in
`zerver.models.Subscription.API_FIELDS` and several additional fields.
"""
audible_notifications: Optional[bool]
color: str
desktop_notifications: Optional[bool]
email_notifications: Optional[bool]
is_muted: bool
pin_to_top: bool
push_notifications: Optional[bool]
role: int
wildcard_mentions_notify: Optional[bool]
# Computed fields not specified in `Subscription.API_FIELDS`
email_address: str
in_home_view: bool
stream_weekly_traffic: Optional[int]
subscribers: List[int]
@dataclass @dataclass
class SubscriptionInfo: class SubscriptionInfo:
subscriptions: List[SubscriptionStreamDict] subscriptions: List[SubscriptionStreamDict]

View File

@@ -83,6 +83,7 @@ from zerver.lib.exceptions import JsonableError, RateLimited
from zerver.lib.pysa import mark_sanitized from zerver.lib.pysa import mark_sanitized
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import ( from zerver.lib.types import (
APIStreamDict,
DisplayRecipientT, DisplayRecipientT,
ExtendedFieldElement, ExtendedFieldElement,
ExtendedValidator, ExtendedValidator,
@@ -2483,22 +2484,25 @@ class Stream(models.Model):
] ]
@staticmethod @staticmethod
def get_client_data(query: QuerySet) -> List[Dict[str, Any]]: def get_client_data(query: QuerySet) -> List[APIStreamDict]:
query = query.only(*Stream.API_FIELDS) query = query.only(*Stream.API_FIELDS)
return [row.to_dict() for row in query] return [row.to_dict() for row in query]
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> APIStreamDict:
result = {} return APIStreamDict(
for field_name in self.API_FIELDS: date_created=datetime_to_timestamp(self.date_created),
if field_name == "id": description=self.description,
result["stream_id"] = self.id first_message_id=self.first_message_id,
continue history_public_to_subscribers=self.history_public_to_subscribers,
elif field_name == "date_created": invite_only=self.invite_only,
result["date_created"] = datetime_to_timestamp(self.date_created) is_web_public=self.is_web_public,
continue message_retention_days=self.message_retention_days,
result[field_name] = getattr(self, field_name) name=self.name,
result["is_announcement_only"] = self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS rendered_description=self.rendered_description,
return result stream_id=self.id,
stream_post_policy=self.stream_post_policy,
is_announcement_only=self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS,
)
class Meta: class Meta:
indexes = [ indexes = [

View File

@@ -75,7 +75,12 @@ from zerver.lib.test_helpers import (
queries_captured, queries_captured,
reset_emails_in_zulip_realm, reset_emails_in_zulip_realm,
) )
from zerver.lib.types import NeverSubscribedStreamDict, SubscriptionInfo from zerver.lib.types import (
APIStreamDict,
APISubscriptionDict,
NeverSubscribedStreamDict,
SubscriptionInfo,
)
from zerver.models import ( from zerver.models import (
Attachment, Attachment,
DefaultStream, DefaultStream,
@@ -205,6 +210,32 @@ class TestMiscStuff(ZulipTestCase):
) )
self.assertEqual(streams, []) self.assertEqual(streams, [])
def test_api_fields(self) -> None:
"""Verify that all the fields from `Stream.API_FIELDS` and `Subscription.API_FIELDS` present
in `APIStreamDict` and `APISubscriptionDict`, respectively.
"""
expected_fields = set(Stream.API_FIELDS) | {"stream_id"}
expected_fields -= {"id"}
stream_dict_fields = set(APIStreamDict.__annotations__.keys())
computed_fields = set(["is_announcement_only", "is_default"])
self.assertEqual(stream_dict_fields - computed_fields, expected_fields)
expected_fields = set(Subscription.API_FIELDS)
subscription_dict_fields = set(APISubscriptionDict.__annotations__.keys())
computed_fields = set(
["in_home_view", "email_address", "stream_weekly_traffic", "subscribers"]
)
# `APISubscriptionDict` is a subclass of `APIStreamDict`, therefore having all the
# fields in addition to the computed fields and `Subscription.API_FIELDS` that
# need to be excluded here.
self.assertEqual(
subscription_dict_fields - computed_fields - stream_dict_fields,
expected_fields,
)
class TestCreateStreams(ZulipTestCase): class TestCreateStreams(ZulipTestCase):
def test_creating_streams(self) -> None: def test_creating_streams(self) -> None: