message: Refactor has_message_access parameters.

This commit is contained in:
Tim Abbott
2021-05-12 14:07:07 -07:00
committed by Tim Abbott
parent 7ef0d21fc2
commit c84ea01869
2 changed files with 41 additions and 16 deletions

View File

@@ -671,7 +671,7 @@ def access_message(
user_message = get_usermessage_by_message_id(user_profile, message_id) user_message = get_usermessage_by_message_id(user_profile, message_id)
if has_message_access(user_profile, message, user_message): if has_message_access(user_profile, message, has_user_message=user_message is not None):
return (message, user_message) return (message, user_message)
raise JsonableError(_("Invalid message(s)")) raise JsonableError(_("Invalid message(s)"))
@@ -679,8 +679,8 @@ def access_message(
def has_message_access( def has_message_access(
user_profile: UserProfile, user_profile: UserProfile,
message: Message, message: Message,
user_message: Optional[UserMessage],
*, *,
has_user_message: bool,
stream: Optional[Stream] = None, stream: Optional[Stream] = None,
is_subscribed: Optional[bool] = None, is_subscribed: Optional[bool] = None,
) -> bool: ) -> bool:
@@ -693,7 +693,7 @@ def has_message_access(
""" """
# If you have a user_message object, you have access. # If you have a user_message object, you have access.
if user_message is not None: if has_user_message:
return True return True
if message.recipient.type != Recipient.STREAM: if message.recipient.type != Recipient.STREAM:
@@ -731,7 +731,7 @@ def bulk_access_messages(user_profile: UserProfile, messages: Sequence[Message])
for message in messages: for message in messages:
user_message = get_usermessage_by_message_id(user_profile, message.id) user_message = get_usermessage_by_message_id(user_profile, message.id)
if has_message_access(user_profile, message, user_message): if has_message_access(user_profile, message, has_user_message=user_message is not None):
filtered_messages.append(message) filtered_messages.append(message)
return filtered_messages return filtered_messages

View File

@@ -1359,18 +1359,27 @@ class EditMessageTest(ZulipTestCase):
user_profile, old_stream.name, topic_name="test", content="fourth" user_profile, old_stream.name, topic_name="test", content="fourth"
) )
self.assertEqual(
has_message_access(guest_user, Message.objects.get(id=msg_id_to_test_acesss), None),
True,
)
self.assertEqual( self.assertEqual(
has_message_access( has_message_access(
guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream guest_user, Message.objects.get(id=msg_id_to_test_acesss), has_user_message=False
), ),
True, True,
) )
self.assertEqual( self.assertEqual(
has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), has_message_access(
guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
stream=old_stream,
),
True,
)
self.assertEqual(
has_message_access(
non_guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
),
True, True,
) )
@@ -1386,11 +1395,19 @@ class EditMessageTest(ZulipTestCase):
self.assert_json_success(result) self.assert_json_success(result)
self.assertEqual( self.assertEqual(
has_message_access(guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), has_message_access(
guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
),
False, False,
) )
self.assertEqual( self.assertEqual(
has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), has_message_access(
non_guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
),
True, True,
) )
self.assertEqual( self.assertEqual(
@@ -1400,7 +1417,7 @@ class EditMessageTest(ZulipTestCase):
has_message_access( has_message_access(
guest_user, guest_user,
Message.objects.get(id=msg_id_to_test_acesss), Message.objects.get(id=msg_id_to_test_acesss),
None, has_user_message=False,
stream=new_stream, stream=new_stream,
is_subscribed=True, is_subscribed=True,
), ),
@@ -1409,14 +1426,20 @@ class EditMessageTest(ZulipTestCase):
self.assertEqual( self.assertEqual(
has_message_access( has_message_access(
guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=new_stream guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
stream=new_stream,
), ),
False, False,
) )
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# Raises assertion if you pass an invalid stream. # Raises assertion if you pass an invalid stream.
has_message_access( has_message_access(
guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream guest_user,
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
stream=old_stream,
) )
self.assertEqual( self.assertEqual(
@@ -1428,7 +1451,9 @@ class EditMessageTest(ZulipTestCase):
) )
self.assertEqual( self.assertEqual(
has_message_access( has_message_access(
self.example_user("iago"), Message.objects.get(id=msg_id_to_test_acesss), None self.example_user("iago"),
Message.objects.get(id=msg_id_to_test_acesss),
has_user_message=False,
), ),
True, True,
) )