diff --git a/zerver/lib/message.py b/zerver/lib/message.py index d93bf082be..1080ddd075 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -41,7 +41,8 @@ from zerver.models import ( Subscription, UserProfile, UserMessage, - Reaction + Reaction, + get_usermessage_by_message_id, ) from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -475,7 +476,7 @@ class ReactionDict: 'full_name': row['user_profile__full_name']}} -def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, UserMessage]: +def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, Optional[UserMessage]]: """You can access a message by ID in our APIs that either: (1) You received or have previously accessed via starring (aka have a UserMessage row for). @@ -489,11 +490,7 @@ def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, except Message.DoesNotExist: raise JsonableError(_("Invalid message(s)")) - try: - user_message = UserMessage.objects.select_related().get(user_profile=user_profile, - message=message) - except UserMessage.DoesNotExist: - user_message = None + user_message = get_usermessage_by_message_id(user_profile, message_id) if user_message is None: if message.recipient.type != Recipient.STREAM: diff --git a/zerver/models.py b/zerver/models.py index ac7470bf67..c451c76641 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -1558,6 +1558,12 @@ class AbstractUserMessage(models.Model): class UserMessage(AbstractUserMessage): message = models.ForeignKey(Message, on_delete=CASCADE) # type: Message +def get_usermessage_by_message_id(user_profile: UserProfile, message_id: int) -> Optional[UserMessage]: + try: + return UserMessage.objects.select_related().get(user_profile=user_profile, + message__id=message_id) + except UserMessage.DoesNotExist: + return None class ArchivedUserMessage(AbstractUserMessage): """Used as a temporary holding place for deleted UserMessages objects diff --git a/zerver/views/home.py b/zerver/views/home.py index 7acd1e16cb..3335580844 100644 --- a/zerver/views/home.py +++ b/zerver/views/home.py @@ -15,7 +15,7 @@ from zerver.models import Message, UserProfile, Stream, Subscription, Huddle, \ Recipient, Realm, UserMessage, DefaultStream, RealmEmoji, RealmDomain, \ RealmFilter, PreregistrationUser, UserActivity, \ UserPresence, get_stream_recipient, name_changes_disabled, email_to_username, \ - get_realm_domains + get_realm_domains, get_usermessage_by_message_id from zerver.lib.events import do_events_register from zerver.lib.actions import update_user_presence, do_change_tos_version, \ do_update_pointer, realm_user_count @@ -157,13 +157,10 @@ def home_real(request: HttpRequest) -> HttpResponse: if user_profile.pointer == -1: latest_read = None else: - try: - latest_read = UserMessage.objects.get(user_profile=user_profile, - message__id=user_profile.pointer) - except UserMessage.DoesNotExist: + latest_read = get_usermessage_by_message_id(user_profile, user_profile.pointer) + if latest_read is None: # Don't completely fail if your saved pointer ID is invalid logging.warning("%s has invalid pointer %s" % (user_profile.email, user_profile.pointer)) - latest_read = None # We pick a language for the user as follows: # * First priority is the language in the URL, for debugging. diff --git a/zerver/views/pointer.py b/zerver/views/pointer.py index 8e9bcb4de9..19c57e0b30 100644 --- a/zerver/views/pointer.py +++ b/zerver/views/pointer.py @@ -6,7 +6,7 @@ from zerver.decorator import to_non_negative_int from zerver.lib.actions import do_update_pointer from zerver.lib.request import has_request_variables, JsonableError, REQ from zerver.lib.response import json_success -from zerver.models import UserProfile, UserMessage +from zerver.models import UserProfile, UserMessage, get_usermessage_by_message_id def get_pointer_backend(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: return json_success({'pointer': user_profile.pointer}) @@ -17,12 +17,7 @@ def update_pointer_backend(request: HttpRequest, user_profile: UserProfile, if pointer <= user_profile.pointer: return json_success() - try: - UserMessage.objects.get( - user_profile=user_profile, - message__id=pointer - ) - except UserMessage.DoesNotExist: + if get_usermessage_by_message_id(user_profile, pointer) is None: raise JsonableError(_("Invalid message ID")) request._log_data["extra"] = "[%s]" % (pointer,)