mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	user_message: Use INSERT ... ON CONFLICT for historical UM creation.
Rather than use a bulk insert via Django, use the faster `bulk_insert_all_ums` that we already have. This also adds a `ON CONFLICT` clause, to make the insert resilient to race conditions. There are currently two callsites, with different desired `ON CONFLICT` behaviours: - For `notify_reaction_update`, if the `UserMessage` had already been created, we would have done nothing to change it. - For `do_update_message_flags`, we would have ensured a specific bit was (un)set. Extend `create_historical_user_messages` and `bulk_insert_all_ums` to support `ON CONFLICT (...) UPDATE SET flags = ...`.
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							52e3c8e1b2
						
					
				
				
					commit
					7988aad159
				
			@@ -338,8 +338,9 @@ def do_update_message_flags(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            create_historical_user_messages(
 | 
					            create_historical_user_messages(
 | 
				
			||||||
                user_id=user_profile.id,
 | 
					                user_id=user_profile.id,
 | 
				
			||||||
                message_ids=historical_message_ids,
 | 
					                message_ids=list(historical_message_ids),
 | 
				
			||||||
                flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target,
 | 
					                flagattr=flagattr,
 | 
				
			||||||
 | 
					                flag_target=flag_target,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        to_update = UserMessage.objects.filter(
 | 
					        to_update = UserMessage.objects.filter(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,8 +1,8 @@
 | 
				
			|||||||
from typing import Iterable, List
 | 
					from typing import List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.db import connection
 | 
					from django.db import connection
 | 
				
			||||||
from psycopg2.extras import execute_values
 | 
					from psycopg2.extras import execute_values
 | 
				
			||||||
from psycopg2.sql import SQL
 | 
					from psycopg2.sql import SQL, Composable, Literal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from zerver.models import UserMessage
 | 
					from zerver.models import UserMessage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_historical_user_messages(
 | 
					def create_historical_user_messages(
 | 
				
			||||||
    *, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS
 | 
					    *,
 | 
				
			||||||
 | 
					    user_id: int,
 | 
				
			||||||
 | 
					    message_ids: List[int],
 | 
				
			||||||
 | 
					    flagattr: Optional[int] = None,
 | 
				
			||||||
 | 
					    flag_target: Optional[int] = None,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    # Users can see and interact with messages sent to streams with
 | 
					    # Users can see and interact with messages sent to streams with
 | 
				
			||||||
    # public history for which they do not have a UserMessage because
 | 
					    # public history for which they do not have a UserMessage because
 | 
				
			||||||
@@ -36,10 +40,15 @@ def create_historical_user_messages(
 | 
				
			|||||||
    # those messages, we create UserMessage objects for those messages;
 | 
					    # those messages, we create UserMessage objects for those messages;
 | 
				
			||||||
    # these have the special historical flag which keeps track of the
 | 
					    # these have the special historical flag which keeps track of the
 | 
				
			||||||
    # fact that the user did not receive the message at the time it was sent.
 | 
					    # fact that the user did not receive the message at the time it was sent.
 | 
				
			||||||
    UserMessage.objects.bulk_create(
 | 
					    if flagattr is not None and flag_target is not None:
 | 
				
			||||||
        UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags)
 | 
					        conflict = SQL(
 | 
				
			||||||
        for message_id in message_ids
 | 
					            "(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
 | 
				
			||||||
    )
 | 
					        ).format(mask=Literal(flagattr), attr=Literal(flag_target))
 | 
				
			||||||
 | 
					        flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        conflict = None
 | 
				
			||||||
 | 
					        flags = DEFAULT_HISTORICAL_FLAGS
 | 
				
			||||||
 | 
					    bulk_insert_all_ums([user_id], message_ids, flags, conflict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
 | 
					def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
 | 
				
			||||||
@@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
 | 
				
			|||||||
        execute_values(cursor.cursor, query, vals)
 | 
					        execute_values(cursor.cursor, query, vals)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None:
 | 
					def bulk_insert_all_ums(
 | 
				
			||||||
 | 
					    user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
    if not user_ids or not message_ids:
 | 
					    if not user_ids or not message_ids:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int)
 | 
				
			|||||||
        SELECT user_profile_id, message_id, %s AS flags
 | 
					        SELECT user_profile_id, message_id, %s AS flags
 | 
				
			||||||
          FROM UNNEST(%s) user_profile_id
 | 
					          FROM UNNEST(%s) user_profile_id
 | 
				
			||||||
          CROSS JOIN UNNEST(%s) message_id
 | 
					          CROSS JOIN UNNEST(%s) message_id
 | 
				
			||||||
        ON CONFLICT DO NOTHING
 | 
					        ON CONFLICT {conflict}
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
    )
 | 
					    ).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with connection.cursor() as cursor:
 | 
					    with connection.cursor() as cursor:
 | 
				
			||||||
        cursor.execute(query, [flags, user_ids, message_ids])
 | 
					        cursor.execute(query, [flags, user_ids, message_ids])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
from typing import TYPE_CHECKING, Any, List, Set
 | 
					from typing import TYPE_CHECKING, Any, List, Optional, Set
 | 
				
			||||||
from unittest import mock
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import orjson
 | 
					import orjson
 | 
				
			||||||
@@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict
 | 
				
			|||||||
from zerver.lib.test_classes import ZulipTestCase
 | 
					from zerver.lib.test_classes import ZulipTestCase
 | 
				
			||||||
from zerver.lib.test_helpers import get_subscription, timeout_mock
 | 
					from zerver.lib.test_helpers import get_subscription, timeout_mock
 | 
				
			||||||
from zerver.lib.timeout import TimeoutExpiredError
 | 
					from zerver.lib.timeout import TimeoutExpiredError
 | 
				
			||||||
 | 
					from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
 | 
				
			||||||
from zerver.models import (
 | 
					from zerver.models import (
 | 
				
			||||||
    Message,
 | 
					    Message,
 | 
				
			||||||
    Recipient,
 | 
					    Recipient,
 | 
				
			||||||
@@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase):
 | 
				
			|||||||
            elif msg["id"] == self.unread_msg_ids[1]:
 | 
					            elif msg["id"] == self.unread_msg_ids[1]:
 | 
				
			||||||
                check_flags(msg["flags"], set())
 | 
					                check_flags(msg["flags"], set())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_update_flags_race(self) -> None:
 | 
				
			||||||
 | 
					        user = self.example_user("hamlet")
 | 
				
			||||||
 | 
					        self.login_user(user)
 | 
				
			||||||
 | 
					        self.unsubscribe(user, "Verona")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        first_message_id = self.send_stream_message(
 | 
				
			||||||
 | 
					            self.example_user("cordelia"),
 | 
				
			||||||
 | 
					            "Verona",
 | 
				
			||||||
 | 
					            topic_name="testing",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            UserMessage.objects.filter(
 | 
				
			||||||
 | 
					                user_profile_id=user.id, message_id=first_message_id
 | 
				
			||||||
 | 
					            ).exists()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # When adjusting flags of messages that we did not receive, we
 | 
				
			||||||
 | 
					        # create UserMessage rows.
 | 
				
			||||||
 | 
					        with mock.patch(
 | 
				
			||||||
 | 
					            "zerver.actions.message_flags.create_historical_user_messages",
 | 
				
			||||||
 | 
					            wraps=create_historical_user_messages,
 | 
				
			||||||
 | 
					        ) as mock_backfill:
 | 
				
			||||||
 | 
					            result = self.client_post(
 | 
				
			||||||
 | 
					                "/json/messages/flags",
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "messages": orjson.dumps([first_message_id]).decode(),
 | 
				
			||||||
 | 
					                    "op": "add",
 | 
				
			||||||
 | 
					                    "flag": "starred",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assert_json_success(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            mock_backfill.assert_called_once_with(
 | 
				
			||||||
 | 
					                user_id=user.id,
 | 
				
			||||||
 | 
					                message_ids=[first_message_id],
 | 
				
			||||||
 | 
					                flagattr=UserMessage.flags.starred,
 | 
				
			||||||
 | 
					                flag_target=UserMessage.flags.starred,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id)
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                int(um_row.flags),
 | 
				
			||||||
 | 
					                UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # That creation may race with other things which also create
 | 
				
			||||||
 | 
					        # the UserMessage rows (e.g. reactions); ensure the end result
 | 
				
			||||||
 | 
					        # is correct still.
 | 
				
			||||||
 | 
					        def race_creation(
 | 
				
			||||||
 | 
					            *,
 | 
				
			||||||
 | 
					            user_id: int,
 | 
				
			||||||
 | 
					            message_ids: List[int],
 | 
				
			||||||
 | 
					            flagattr: Optional[int] = None,
 | 
				
			||||||
 | 
					            flag_target: Optional[int] = None,
 | 
				
			||||||
 | 
					        ) -> None:
 | 
				
			||||||
 | 
					            UserMessage.objects.create(
 | 
				
			||||||
 | 
					                user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            create_historical_user_messages(
 | 
				
			||||||
 | 
					                user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        second_message_id = self.send_stream_message(
 | 
				
			||||||
 | 
					            self.example_user("cordelia"),
 | 
				
			||||||
 | 
					            "Verona",
 | 
				
			||||||
 | 
					            topic_name="testing",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            UserMessage.objects.filter(
 | 
				
			||||||
 | 
					                user_profile_id=user.id, message_id=second_message_id
 | 
				
			||||||
 | 
					            ).exists()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with mock.patch(
 | 
				
			||||||
 | 
					            "zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation
 | 
				
			||||||
 | 
					        ) as mock_backfill:
 | 
				
			||||||
 | 
					            result = self.client_post(
 | 
				
			||||||
 | 
					                "/json/messages/flags",
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "messages": orjson.dumps([second_message_id]).decode(),
 | 
				
			||||||
 | 
					                    "op": "add",
 | 
				
			||||||
 | 
					                    "flag": "starred",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assert_json_success(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            mock_backfill.assert_called_once_with(
 | 
				
			||||||
 | 
					                user_id=user.id,
 | 
				
			||||||
 | 
					                message_ids=[second_message_id],
 | 
				
			||||||
 | 
					                flagattr=UserMessage.flags.starred,
 | 
				
			||||||
 | 
					                flag_target=UserMessage.flags.starred,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id)
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                int(um_row.flags),
 | 
				
			||||||
 | 
					                UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_update_flags_for_narrow(self) -> None:
 | 
					    def test_update_flags_for_narrow(self) -> None:
 | 
				
			||||||
        user = self.example_user("hamlet")
 | 
					        user = self.example_user("hamlet")
 | 
				
			||||||
        self.login_user(user)
 | 
					        self.login_user(user)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user