refactor: Simplify call to bulk_get_subscriber_user_ids.

The way we were computing the dictionary was very
convoluted--all we need is a set of subscribed user
ids.
This commit is contained in:
Steve Howell
2020-10-18 15:08:51 +00:00
committed by Tim Abbott
parent b58152abda
commit 4dce34ab8b
2 changed files with 11 additions and 11 deletions

View File

@@ -2628,16 +2628,19 @@ def validate_user_access_to_subscribers_helper(
def bulk_get_subscriber_user_ids(
stream_dicts: Iterable[Mapping[str, Any]],
user_profile: UserProfile,
sub_dict: Mapping[int, bool],
subscribed_stream_ids: Set[int],
) -> Dict[int, List[int]]:
"""sub_dict maps stream_id => whether the user is subscribed to that stream."""
target_stream_dicts = []
for stream_dict in stream_dicts:
stream_id = stream_dict["id"]
is_subscribed = stream_id in subscribed_stream_ids
try:
validate_user_access_to_subscribers_helper(
user_profile,
stream_dict,
lambda user_profile: sub_dict[stream_dict["id"]],
lambda user_profile: is_subscribed,
)
except JsonableError:
continue
@@ -5023,19 +5026,16 @@ def gather_subscriptions_helper(user_profile: UserProfile,
unsubscribed = []
never_subscribed = []
# Deactivated streams aren't in stream_hash.
streams = [stream_hash[sub["stream_id"]] for sub in sub_dicts
if sub["stream_id"] in stream_hash]
streams_subscribed_map = {sub["stream_id"]: sub["active"] for sub in sub_dicts}
# Add never subscribed streams to streams_subscribed_map
streams_subscribed_map.update({stream['id']: False for stream in all_streams if stream not in streams})
# The highly optimized bulk_get_subscriber_user_ids wants to know which
# streams we are subscribed to, for validation purposes, and it uses that
# info to know if it's allowed to find OTHER subscribers.
subscribed_stream_ids = {sub["stream_id"] for sub in sub_dicts if sub["active"]}
if include_subscribers:
subscriber_map: Mapping[int, Optional[List[int]]] = bulk_get_subscriber_user_ids(
all_streams,
user_profile,
streams_subscribed_map,
subscribed_stream_ids,
)
else:
# If we're not including subscribers, always return None,

View File

@@ -101,7 +101,7 @@ class TestMiscStuff(ZulipTestCase):
result = bulk_get_subscriber_user_ids(
stream_dicts=[],
user_profile=user_profile,
sub_dict={},
subscribed_stream_ids=set(),
)
self.assertEqual(result, {})