refactor: Introduce SubscriptionInfo dataclass.

We use this as the return type for
gather_subscriptions_helper and
get_web_public_subs, instead of tuples.
This commit is contained in:
Steve Howell
2021-01-14 20:44:56 +00:00
committed by Tim Abbott
parent 768117f0ff
commit f2586d2f9b
4 changed files with 87 additions and 51 deletions

View File

@@ -4,6 +4,7 @@ import logging
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
AbstractSet, AbstractSet,
@@ -242,6 +243,12 @@ if settings.BILLING_ENABLED:
update_license_ledger_if_needed, update_license_ledger_if_needed,
) )
@dataclass
class SubscriptionInfo:
subscriptions: List[Dict[str, Any]]
unsubscribed: List[Dict[str, Any]]
never_subscribed: List[Dict[str, Any]]
# This will be used to type annotate parameters in a function if the function # This will be used to type annotate parameters in a function if the function
# works on both str and unicode in python 2 but in python 3 it only works on str. # works on both str and unicode in python 2 but in python 3 it only works on str.
SizedTextIterable = Union[Sequence[str], AbstractSet[str]] SizedTextIterable = Union[Sequence[str], AbstractSet[str]]
@@ -4985,9 +4992,7 @@ def get_average_weekly_stream_traffic(stream_id: int, stream_date_created: datet
return round_to_2_significant_digits(average_weekly_traffic) return round_to_2_significant_digits(average_weekly_traffic)
SubHelperT = Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]] def get_web_public_subs(realm: Realm) -> SubscriptionInfo:
def get_web_public_subs(realm: Realm) -> SubHelperT:
color_idx = 0 color_idx = 0
def get_next_color() -> str: def get_next_color() -> str:
@@ -5016,7 +5021,11 @@ def get_web_public_subs(realm: Realm) -> SubHelperT:
stream_dict['email_address'] = '' stream_dict['email_address'] = ''
subscribed.append(stream_dict) subscribed.append(stream_dict)
return (subscribed, [], []) return SubscriptionInfo(
subscriptions=subscribed,
unsubscribed=[],
never_subscribed=[],
)
def build_stream_dict_for_sub( def build_stream_dict_for_sub(
user: UserProfile, user: UserProfile,
@@ -5087,8 +5096,10 @@ def build_stream_dict_for_never_sub(
# the code pretty ugly, but in this case, it has significant # the code pretty ugly, but in this case, it has significant
# performance impact for loading / for users with large numbers of # performance impact for loading / for users with large numbers of
# subscriptions, so it's worth optimizing. # subscriptions, so it's worth optimizing.
def gather_subscriptions_helper(user_profile: UserProfile, def gather_subscriptions_helper(
include_subscribers: bool=True) -> SubHelperT: user_profile: UserProfile,
include_subscribers: bool=True,
) -> SubscriptionInfo:
realm = user_profile.realm realm = user_profile.realm
all_streams = get_active_streams(realm).values( all_streams = get_active_streams(realm).values(
*Stream.API_FIELDS, *Stream.API_FIELDS,
@@ -5183,16 +5194,23 @@ def gather_subscriptions_helper(user_profile: UserProfile,
for sub in lst: for sub in lst:
sub["subscribers"] = subscriber_map[sub["stream_id"]] sub["subscribers"] = subscriber_map[sub["stream_id"]]
return (sorted(subscribed, key=lambda x: x['name']), return SubscriptionInfo(
sorted(unsubscribed, key=lambda x: x['name']), subscriptions=sorted(subscribed, key=lambda x: x['name']),
sorted(never_subscribed, key=lambda x: x['name'])) unsubscribed=sorted(unsubscribed, key=lambda x: x['name']),
never_subscribed=sorted(never_subscribed, key=lambda x: x['name']),
)
def gather_subscriptions( def gather_subscriptions(
user_profile: UserProfile, user_profile: UserProfile,
include_subscribers: bool=False, include_subscribers: bool=False,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
subscribed, unsubscribed, _ = gather_subscriptions_helper( helper_result = gather_subscriptions_helper(
user_profile, include_subscribers=include_subscribers) user_profile,
include_subscribers=include_subscribers,
)
subscribed = helper_result.subscriptions
unsubscribed = helper_result.unsubscribed
if include_subscribers: if include_subscribers:
user_ids = set() user_ids = set()

View File

@@ -336,13 +336,16 @@ def fetch_initial_state_data(
if want('subscription'): if want('subscription'):
if user_profile is not None: if user_profile is not None:
subscriptions, unsubscribed, never_subscribed = gather_subscriptions_helper( sub_info = gather_subscriptions_helper(
user_profile, include_subscribers=include_subscribers) user_profile,
include_subscribers=include_subscribers,
)
else: else:
subscriptions, unsubscribed, never_subscribed = get_web_public_subs(realm) sub_info = get_web_public_subs(realm)
state['subscriptions'] = subscriptions
state['unsubscribed'] = unsubscribed state['subscriptions'] = sub_info.subscriptions
state['never_subscribed'] = never_subscribed state['unsubscribed'] = sub_info.unsubscribed
state['never_subscribed'] = sub_info.never_subscribed
if want('update_message_flags') and want('message'): if want('update_message_flags') and want('message'):
# Keeping unread_msgs updated requires both message flag updates and # Keeping unread_msgs updated requires both message flag updates and
@@ -516,11 +519,14 @@ def apply_event(state: Dict[str, Any],
# current user changing roles, we should just do a # current user changing roles, we should just do a
# full refetch. # full refetch.
if 'never_subscribed' in state: if 'never_subscribed' in state:
subscriptions, unsubscribed, never_subscribed = gather_subscriptions_helper( sub_info = gather_subscriptions_helper(
user_profile, include_subscribers=include_subscribers) user_profile,
state['subscriptions'] = subscriptions include_subscribers=include_subscribers,
state['unsubscribed'] = unsubscribed )
state['never_subscribed'] = never_subscribed state['subscriptions'] = sub_info.subscriptions
state['unsubscribed'] = sub_info.unsubscribed
state['never_subscribed'] = sub_info.never_subscribed
if 'streams' in state: if 'streams' in state:
state['streams'] = do_get_streams(user_profile) state['streams'] = do_get_streams(user_profile)

View File

@@ -59,31 +59,30 @@ class GlobalPublicStreamTest(ZulipTestCase):
public_stream = public_streams[0] public_stream = public_streams[0]
self.assertEqual(public_stream['name'], "Rome") self.assertEqual(public_stream['name'], "Rome")
public_subs, public_unsubs, public_neversubs = get_web_public_subs(realm) info = get_web_public_subs(realm)
self.assert_length(public_subs, 1) self.assert_length(info.subscriptions, 1)
public_sub = public_subs[0] self.assertEqual(info.subscriptions[0]['name'], "Rome")
self.assertEqual(public_sub['name'], "Rome") self.assert_length(info.unsubscribed, 0)
self.assert_length(public_unsubs, 0) self.assert_length(info.never_subscribed, 0)
self.assert_length(public_neversubs, 0)
# Now add a second public stream # Now add a second public stream
test_stream = self.make_stream('Test Public Archives') test_stream = self.make_stream('Test Public Archives')
do_change_stream_web_public(test_stream, True) do_change_stream_web_public(test_stream, True)
public_streams = get_web_public_streams(realm) public_streams = get_web_public_streams(realm)
self.assert_length(public_streams, 2) self.assert_length(public_streams, 2)
public_subs, public_unsubs, public_neversubs = get_web_public_subs(realm) info = get_web_public_subs(realm)
self.assert_length(public_subs, 2) self.assert_length(info.subscriptions, 2)
self.assert_length(public_unsubs, 0) self.assert_length(info.unsubscribed, 0)
self.assert_length(public_neversubs, 0) self.assert_length(info.never_subscribed, 0)
self.assertNotEqual(public_subs[0]['color'], public_subs[1]['color']) self.assertNotEqual(info.subscriptions[0]['color'], info.subscriptions[1]['color'])
do_deactivate_stream(test_stream) do_deactivate_stream(test_stream)
public_streams = get_web_public_streams(realm) public_streams = get_web_public_streams(realm)
self.assert_length(public_streams, 1) self.assert_length(public_streams, 1)
public_subs, public_unsubs, public_neversubs = get_web_public_subs(realm) info = get_web_public_subs(realm)
self.assert_length(public_subs, 1) self.assert_length(info.subscriptions, 1)
self.assert_length(public_unsubs, 0) self.assert_length(info.unsubscribed, 0)
self.assert_length(public_neversubs, 0) self.assert_length(info.never_subscribed, 0)
class WebPublicTopicHistoryTest(ZulipTestCase): class WebPublicTopicHistoryTest(ZulipTestCase):
def test_non_existant_stream_id(self) -> None: def test_non_existant_stream_id(self) -> None:

View File

@@ -1800,7 +1800,12 @@ class DefaultStreamTest(ZulipTestCase):
# Get all the streams that Polonius has access to (subscribed + web public streams) # Get all the streams that Polonius has access to (subscribed + web public streams)
result = self.client_get("/json/streams", {"include_web_public": "true"}) result = self.client_get("/json/streams", {"include_web_public": "true"})
streams = result.json()['streams'] streams = result.json()['streams']
subscribed, unsubscribed, never_subscribed = gather_subscriptions_helper(user_profile) sub_info = gather_subscriptions_helper(user_profile)
subscribed = sub_info.subscriptions
unsubscribed = sub_info.unsubscribed
never_subscribed = sub_info.never_subscribed
self.assertEqual(len(streams), self.assertEqual(len(streams),
len(subscribed) + len(unsubscribed) + len(never_subscribed)) len(subscribed) + len(unsubscribed) + len(never_subscribed))
expected_streams = subscribed + unsubscribed + never_subscribed expected_streams = subscribed + unsubscribed + never_subscribed
@@ -2138,8 +2143,10 @@ class SubscriptionPropertiesTest(ZulipTestCase):
test_user = self.example_user("hamlet") test_user = self.example_user("hamlet")
self.login_user(test_user) self.login_user(test_user)
subscribed, unsubscribed, never_subscribed = gather_subscriptions_helper(test_user) sub_info = gather_subscriptions_helper(test_user)
not_subbed = unsubscribed + never_subscribed
not_subbed = sub_info.never_subscribed
result = self.api_post(test_user, "/api/v1/users/me/subscriptions/properties", result = self.api_post(test_user, "/api/v1/users/me/subscriptions/properties",
{"subscription_data": orjson.dumps([{"property": "color", {"subscription_data": orjson.dumps([{"property": "color",
"stream_id": not_subbed[0]["stream_id"], "stream_id": not_subbed[0]["stream_id"],
@@ -3767,10 +3774,10 @@ class SubscriptionAPITest(ZulipTestCase):
# Compare results - should be 1 stream less # Compare results - should be 1 stream less
self.assertTrue( self.assertTrue(
len(admin_before_delete[0]) == len(admin_after_delete[0]) + 1, len(admin_before_delete.subscriptions) == len(admin_after_delete.subscriptions) + 1,
'Expected exactly 1 less stream from gather_subscriptions_helper') 'Expected exactly 1 less stream from gather_subscriptions_helper')
self.assertTrue( self.assertTrue(
len(non_admin_before_delete[0]) == len(non_admin_after_delete[0]) + 1, len(non_admin_before_delete.subscriptions) == len(non_admin_after_delete.subscriptions) + 1,
'Expected exactly 1 less stream from gather_subscriptions_helper') 'Expected exactly 1 less stream from gather_subscriptions_helper')
def test_validate_user_access_to_subscribers_helper(self) -> None: def test_validate_user_access_to_subscribers_helper(self) -> None:
@@ -4303,7 +4310,7 @@ class GetSubscribersTest(ZulipTestCase):
def get_never_subscribed() -> List[Dict[str, Any]]: def get_never_subscribed() -> List[Dict[str, Any]]:
with queries_captured() as queries: with queries_captured() as queries:
sub_data = gather_subscriptions_helper(self.user_profile) sub_data = gather_subscriptions_helper(self.user_profile)
never_subscribed = sub_data[2] never_subscribed = sub_data.never_subscribed
self.assert_length(queries, 4) self.assert_length(queries, 4)
# Ignore old streams. # Ignore old streams.
@@ -4339,7 +4346,10 @@ class GetSubscribersTest(ZulipTestCase):
def test_guest_user_case() -> None: def test_guest_user_case() -> None:
self.user_profile.role = UserProfile.ROLE_GUEST self.user_profile.role = UserProfile.ROLE_GUEST
sub, unsub, never_sub = gather_subscriptions_helper(self.user_profile) helper_result = gather_subscriptions_helper(self.user_profile)
sub = helper_result.subscriptions
unsub = helper_result.unsubscribed
never_sub = helper_result.never_subscribed
# It's +1 because of the stream Rome. # It's +1 because of the stream Rome.
self.assertEqual(len(never_sub), len(web_public_streams) + 1) self.assertEqual(len(never_sub), len(web_public_streams) + 1)
@@ -4381,7 +4391,9 @@ class GetSubscribersTest(ZulipTestCase):
self.subscribe(normal_user, stream_name_unsub) self.subscribe(normal_user, stream_name_unsub)
self.subscribe(normal_user, stream_name_unsub) self.subscribe(normal_user, stream_name_unsub)
subs, unsubs, neversubs = gather_subscriptions_helper(guest_user) helper_result = gather_subscriptions_helper(guest_user)
subs = helper_result.subscriptions
neversubs = helper_result.never_subscribed
# Guest users get info about subscribed public stream's subscribers # Guest users get info about subscribed public stream's subscribers
expected_stream_exists = False expected_stream_exists = False
@@ -4417,18 +4429,18 @@ class GetSubscribersTest(ZulipTestCase):
# Test admin user gets previously subscribed private stream's subscribers. # Test admin user gets previously subscribed private stream's subscribers.
sub_data = gather_subscriptions_helper(admin_user) sub_data = gather_subscriptions_helper(admin_user)
unsubscribed_streams = sub_data[1] unsubscribed_streams = sub_data.unsubscribed
self.assertEqual(len(unsubscribed_streams), 1) self.assertEqual(len(unsubscribed_streams), 1)
self.assertEqual(len(unsubscribed_streams[0]["subscribers"]), 1) self.assertEqual(len(unsubscribed_streams[0]["subscribers"]), 1)
# Test non admin users cannot get previously subscribed private stream's subscribers. # Test non admin users cannot get previously subscribed private stream's subscribers.
sub_data = gather_subscriptions_helper(non_admin_user) sub_data = gather_subscriptions_helper(non_admin_user)
unsubscribed_streams = sub_data[1] unsubscribed_streams = sub_data.unsubscribed
self.assertEqual(len(unsubscribed_streams), 1) self.assertEqual(len(unsubscribed_streams), 1)
self.assertEqual(unsubscribed_streams[0]['subscribers'], []) self.assertEqual(unsubscribed_streams[0]['subscribers'], [])
sub_data = gather_subscriptions_helper(guest_user) sub_data = gather_subscriptions_helper(guest_user)
unsubscribed_streams = sub_data[1] unsubscribed_streams = sub_data.unsubscribed
self.assertEqual(len(unsubscribed_streams), 1) self.assertEqual(len(unsubscribed_streams), 1)
self.assertEqual(unsubscribed_streams[0]['subscribers'], []) self.assertEqual(unsubscribed_streams[0]['subscribers'], [])
@@ -4533,7 +4545,8 @@ class GetSubscribersTest(ZulipTestCase):
if they aren't subscribed or have never subscribed to that stream. if they aren't subscribed or have never subscribed to that stream.
""" """
guest_user = self.example_user("polonius") guest_user = self.example_user("polonius")
_, _, never_subscribed = gather_subscriptions_helper(guest_user, True) never_subscribed = gather_subscriptions_helper(guest_user, True).never_subscribed
# A guest user can only see never subscribed streams that are web-public. # A guest user can only see never subscribed streams that are web-public.
# For Polonius, the only web public stream that he is not subscribed at # For Polonius, the only web public stream that he is not subscribed at
# this point is Rome. # this point is Rome.
@@ -4708,10 +4721,10 @@ class NoRecipientIDsTest(ZulipTestCase):
user_profile = self.example_user('cordelia') user_profile = self.example_user('cordelia')
Subscription.objects.filter(user_profile=user_profile, recipient__type=Recipient.STREAM).delete() Subscription.objects.filter(user_profile=user_profile, recipient__type=Recipient.STREAM).delete()
subs = gather_subscriptions_helper(user_profile) subs = gather_subscriptions_helper(user_profile).subscriptions
# Checks that gather_subscriptions_helper will not return anything # Checks that gather_subscriptions_helper will not return anything
# since there will not be any recipients, without crashing. # since there will not be any recipients, without crashing.
# #
# This covers a rare corner case. # This covers a rare corner case.
self.assertEqual(len(subs[0]), 0) self.assertEqual(len(subs), 0)