From c84ea018699a34a2261e3d4de8ee83c12e55b828 Mon Sep 17 00:00:00 2001 From: Tim Abbott Date: Wed, 12 May 2021 14:07:07 -0700 Subject: [PATCH] message: Refactor has_message_access parameters. --- zerver/lib/message.py | 8 ++--- zerver/tests/test_message_edit.py | 49 +++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index f827c14463..989e8a71a5 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -671,7 +671,7 @@ def access_message( user_message = get_usermessage_by_message_id(user_profile, message_id) - if has_message_access(user_profile, message, user_message): + if has_message_access(user_profile, message, has_user_message=user_message is not None): return (message, user_message) raise JsonableError(_("Invalid message(s)")) @@ -679,8 +679,8 @@ def access_message( def has_message_access( user_profile: UserProfile, message: Message, - user_message: Optional[UserMessage], *, + has_user_message: bool, stream: Optional[Stream] = None, is_subscribed: Optional[bool] = None, ) -> bool: @@ -693,7 +693,7 @@ def has_message_access( """ # If you have a user_message object, you have access. - if user_message is not None: + if has_user_message: return True if message.recipient.type != Recipient.STREAM: @@ -731,7 +731,7 @@ def bulk_access_messages(user_profile: UserProfile, messages: Sequence[Message]) for message in messages: user_message = get_usermessage_by_message_id(user_profile, message.id) - if has_message_access(user_profile, message, user_message): + if has_message_access(user_profile, message, has_user_message=user_message is not None): filtered_messages.append(message) return filtered_messages diff --git a/zerver/tests/test_message_edit.py b/zerver/tests/test_message_edit.py index 383dcc39cb..96fb20c0da 100644 --- a/zerver/tests/test_message_edit.py +++ b/zerver/tests/test_message_edit.py @@ -1359,18 +1359,27 @@ class EditMessageTest(ZulipTestCase): user_profile, old_stream.name, topic_name="test", content="fourth" ) - self.assertEqual( - has_message_access(guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), - True, - ) self.assertEqual( has_message_access( - guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream + guest_user, Message.objects.get(id=msg_id_to_test_acesss), has_user_message=False ), True, ) self.assertEqual( - has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), + has_message_access( + guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + stream=old_stream, + ), + True, + ) + self.assertEqual( + has_message_access( + non_guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + ), True, ) @@ -1386,11 +1395,19 @@ class EditMessageTest(ZulipTestCase): self.assert_json_success(result) self.assertEqual( - has_message_access(guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), + has_message_access( + guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + ), False, ) self.assertEqual( - has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), + has_message_access( + non_guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + ), True, ) self.assertEqual( @@ -1400,7 +1417,7 @@ class EditMessageTest(ZulipTestCase): has_message_access( guest_user, Message.objects.get(id=msg_id_to_test_acesss), - None, + has_user_message=False, stream=new_stream, is_subscribed=True, ), @@ -1409,14 +1426,20 @@ class EditMessageTest(ZulipTestCase): self.assertEqual( has_message_access( - guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=new_stream + guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + stream=new_stream, ), False, ) with self.assertRaises(AssertionError): # Raises assertion if you pass an invalid stream. has_message_access( - guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream + guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, + stream=old_stream, ) self.assertEqual( @@ -1428,7 +1451,9 @@ class EditMessageTest(ZulipTestCase): ) self.assertEqual( has_message_access( - self.example_user("iago"), Message.objects.get(id=msg_id_to_test_acesss), None + self.example_user("iago"), + Message.objects.get(id=msg_id_to_test_acesss), + has_user_message=False, ), True, )