message_flags: Filter msgs having (or not) the flag before updating.

We were blindly adding / removing flag from UserMessages without
check if they even need to be updated.

This caused server to repeatedly update flags for messages which
already had been updated, creating a confusion for other clients
like mobile.

Fixes #22164
This commit is contained in:
Aman Agrawal
2022-06-02 11:36:42 +00:00
committed by Tim Abbott
parent 0ad282c11e
commit 40fcf5a633
3 changed files with 44 additions and 13 deletions

View File

@@ -1,7 +1,8 @@
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Set
from typing import List, Optional, Set, Tuple
from django.db import transaction
from django.db.models import F
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
@@ -229,7 +230,7 @@ def do_clear_mobile_push_notifications_for_ids(
def do_update_message_flags(
user_profile: UserProfile, operation: str, flag: str, messages: List[int]
) -> int:
) -> Tuple[int, List[int]]:
valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS]
if flag not in valid_flags:
raise JsonableError(_("Invalid flag: '{}'").format(flag))
@@ -250,18 +251,33 @@ def do_update_message_flags(
# And then create historical UserMessage records. See the called function for more context.
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
with transaction.atomic():
if operation == "add":
msgs = (
msgs.select_for_update()
.order_by("message_id")
.extra(where=[UserMessage.where_flag_is_absent(flagattr)])
)
updated_message_ids = [um.message_id for um in msgs]
msgs.filter(message_id__in=updated_message_ids).update(flags=F("flags").bitor(flagattr))
elif operation == "remove":
msgs = (
msgs.select_for_update()
.order_by("message_id")
.extra(where=[UserMessage.where_flag_is_present(flagattr)])
)
updated_message_ids = [um.message_id for um in msgs]
msgs.filter(message_id__in=updated_message_ids).update(
flags=F("flags").bitand(~flagattr)
)
if operation == "add":
count = msgs.update(flags=F("flags").bitor(flagattr))
elif operation == "remove":
count = msgs.update(flags=F("flags").bitand(~flagattr))
count = len(updated_message_ids)
event = {
"type": "update_message_flags",
"op": operation,
"operation": operation,
"flag": flag,
"messages": messages,
"messages": updated_message_ids,
"all": False,
}
@@ -270,14 +286,14 @@ def do_update_message_flags(
# unread), extend the event with an additional object with
# details on the messages required to update the client's
# `unread_msgs` data structure.
raw_unread_data = get_raw_unread_data(user_profile, messages)
raw_unread_data = get_raw_unread_data(user_profile, updated_message_ids)
event["message_details"] = format_unread_message_details(user_profile.id, raw_unread_data)
send_event(user_profile.realm, event, [user_profile.id])
if flag == "read" and operation == "add":
event_time = timezone_now()
do_clear_mobile_push_notifications_for_ids([user_profile.id], messages)
do_clear_mobile_push_notifications_for_ids([user_profile.id], updated_message_ids)
do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
@@ -290,4 +306,4 @@ def do_update_message_flags(
increment=min(1, count),
)
return count
return count, updated_message_ids

View File

@@ -651,12 +651,27 @@ class NormalActionsTest(BaseAction):
state_change_expected=True,
)
check_update_message_flags_add("events[0]", events[0])
self.assert_length(events[0]["messages"], 1)
# No message_id is returned from the server if the flag is already preset.
events = self.verify_action(
lambda: do_update_message_flags(user_profile, "add", "starred", [message]),
state_change_expected=False,
)
self.assert_length(events[0]["messages"], 0)
events = self.verify_action(
lambda: do_update_message_flags(user_profile, "remove", "starred", [message]),
state_change_expected=True,
)
check_update_message_flags_remove("events[0]", events[0])
self.assert_length(events[0]["messages"], 1)
# No message_id is returned from the server if the flag is already absent.
events = self.verify_action(
lambda: do_update_message_flags(user_profile, "remove", "starred", [message]),
state_change_expected=False,
)
self.assert_length(events[0]["messages"], 0)
def test_update_read_flag_removes_unread_msg_ids(self) -> None:

View File

@@ -38,13 +38,13 @@ def update_message_flags(
request_notes = RequestNotes.get_notes(request)
assert request_notes.log_data is not None
count = do_update_message_flags(user_profile, operation, flag, messages)
count, updated_message_ids = do_update_message_flags(user_profile, operation, flag, messages)
target_count_str = str(len(messages))
log_data_str = f"[{operation} {flag}/{target_count_str}] actually {count}"
request_notes.log_data["extra"] = log_data_str
return json_success(request, data={"messages": messages})
return json_success(request, data={"messages": updated_message_ids})
@has_request_variables