message_flags: Filter msgs having (or not) the flag before updating.

We were blindly adding / removing flag from UserMessages without
check if they even need to be updated.

This caused server to repeatedly update flags for messages which
already had been updated, creating a confusion for other clients
like mobile.

Fixes #22164
This commit is contained in:
Aman Agrawal
2022-06-02 11:36:42 +00:00
committed by Tim Abbott
parent 0ad282c11e
commit 40fcf5a633
3 changed files with 44 additions and 13 deletions

View File

@@ -1,7 +1,8 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import List, Optional, Set from typing import List, Optional, Set, Tuple
from django.db import transaction
from django.db.models import F from django.db.models import F
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@@ -229,7 +230,7 @@ def do_clear_mobile_push_notifications_for_ids(
def do_update_message_flags( def do_update_message_flags(
user_profile: UserProfile, operation: str, flag: str, messages: List[int] user_profile: UserProfile, operation: str, flag: str, messages: List[int]
) -> int: ) -> Tuple[int, List[int]]:
valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS] valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS]
if flag not in valid_flags: if flag not in valid_flags:
raise JsonableError(_("Invalid flag: '{}'").format(flag)) raise JsonableError(_("Invalid flag: '{}'").format(flag))
@@ -250,18 +251,33 @@ def do_update_message_flags(
# And then create historical UserMessage records. See the called function for more context. # And then create historical UserMessage records. See the called function for more context.
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids) create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
with transaction.atomic():
if operation == "add":
msgs = (
msgs.select_for_update()
.order_by("message_id")
.extra(where=[UserMessage.where_flag_is_absent(flagattr)])
)
updated_message_ids = [um.message_id for um in msgs]
msgs.filter(message_id__in=updated_message_ids).update(flags=F("flags").bitor(flagattr))
elif operation == "remove":
msgs = (
msgs.select_for_update()
.order_by("message_id")
.extra(where=[UserMessage.where_flag_is_present(flagattr)])
)
updated_message_ids = [um.message_id for um in msgs]
msgs.filter(message_id__in=updated_message_ids).update(
flags=F("flags").bitand(~flagattr)
)
if operation == "add": count = len(updated_message_ids)
count = msgs.update(flags=F("flags").bitor(flagattr))
elif operation == "remove":
count = msgs.update(flags=F("flags").bitand(~flagattr))
event = { event = {
"type": "update_message_flags", "type": "update_message_flags",
"op": operation, "op": operation,
"operation": operation, "operation": operation,
"flag": flag, "flag": flag,
"messages": messages, "messages": updated_message_ids,
"all": False, "all": False,
} }
@@ -270,14 +286,14 @@ def do_update_message_flags(
# unread), extend the event with an additional object with # unread), extend the event with an additional object with
# details on the messages required to update the client's # details on the messages required to update the client's
# `unread_msgs` data structure. # `unread_msgs` data structure.
raw_unread_data = get_raw_unread_data(user_profile, messages) raw_unread_data = get_raw_unread_data(user_profile, updated_message_ids)
event["message_details"] = format_unread_message_details(user_profile.id, raw_unread_data) event["message_details"] = format_unread_message_details(user_profile.id, raw_unread_data)
send_event(user_profile.realm, event, [user_profile.id]) send_event(user_profile.realm, event, [user_profile.id])
if flag == "read" and operation == "add": if flag == "read" and operation == "add":
event_time = timezone_now() event_time = timezone_now()
do_clear_mobile_push_notifications_for_ids([user_profile.id], messages) do_clear_mobile_push_notifications_for_ids([user_profile.id], updated_message_ids)
do_increment_logging_stat( do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
@@ -290,4 +306,4 @@ def do_update_message_flags(
increment=min(1, count), increment=min(1, count),
) )
return count return count, updated_message_ids

View File

@@ -651,12 +651,27 @@ class NormalActionsTest(BaseAction):
state_change_expected=True, state_change_expected=True,
) )
check_update_message_flags_add("events[0]", events[0]) check_update_message_flags_add("events[0]", events[0])
self.assert_length(events[0]["messages"], 1)
# No message_id is returned from the server if the flag is already preset.
events = self.verify_action(
lambda: do_update_message_flags(user_profile, "add", "starred", [message]),
state_change_expected=False,
)
self.assert_length(events[0]["messages"], 0)
events = self.verify_action( events = self.verify_action(
lambda: do_update_message_flags(user_profile, "remove", "starred", [message]), lambda: do_update_message_flags(user_profile, "remove", "starred", [message]),
state_change_expected=True, state_change_expected=True,
) )
check_update_message_flags_remove("events[0]", events[0]) check_update_message_flags_remove("events[0]", events[0])
self.assert_length(events[0]["messages"], 1)
# No message_id is returned from the server if the flag is already absent.
events = self.verify_action(
lambda: do_update_message_flags(user_profile, "remove", "starred", [message]),
state_change_expected=False,
)
self.assert_length(events[0]["messages"], 0)
def test_update_read_flag_removes_unread_msg_ids(self) -> None: def test_update_read_flag_removes_unread_msg_ids(self) -> None:

View File

@@ -38,13 +38,13 @@ def update_message_flags(
request_notes = RequestNotes.get_notes(request) request_notes = RequestNotes.get_notes(request)
assert request_notes.log_data is not None assert request_notes.log_data is not None
count = do_update_message_flags(user_profile, operation, flag, messages) count, updated_message_ids = do_update_message_flags(user_profile, operation, flag, messages)
target_count_str = str(len(messages)) target_count_str = str(len(messages))
log_data_str = f"[{operation} {flag}/{target_count_str}] actually {count}" log_data_str = f"[{operation} {flag}/{target_count_str}] actually {count}"
request_notes.log_data["extra"] = log_data_str request_notes.log_data["extra"] = log_data_str
return json_success(request, data={"messages": messages}) return json_success(request, data={"messages": updated_message_ids})
@has_request_variables @has_request_variables