mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-03 21:43:21 +00:00 
			
		
		
		
	This commit adds `durable=True` to the short independent db transaction in `UserGroupRaceConditionTestCase.tearDown` method. This is as a part of our plan to explicitly mark all the transaction.atomic calls with either 'savepoint=False' or 'durable=True' as required.
		
			
				
	
	
		
			218 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			218 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import threading
 | 
						|
from typing import Any
 | 
						|
from unittest import mock
 | 
						|
 | 
						|
import orjson
 | 
						|
from django.db import OperationalError, connections, transaction
 | 
						|
from django.http import HttpRequest
 | 
						|
from typing_extensions import override
 | 
						|
 | 
						|
from zerver.actions.user_groups import add_subgroups_to_user_group, check_add_user_group
 | 
						|
from zerver.lib.exceptions import JsonableError
 | 
						|
from zerver.lib.test_classes import ZulipTransactionTestCase
 | 
						|
from zerver.lib.test_helpers import HostRequestMock
 | 
						|
from zerver.lib.user_groups import access_user_group_for_update
 | 
						|
from zerver.models import NamedUserGroup, Realm, UserGroup, UserProfile
 | 
						|
from zerver.models.realms import get_realm
 | 
						|
from zerver.views.user_groups import update_subgroups_of_user_group
 | 
						|
 | 
						|
BARRIER: threading.Barrier | None = None
 | 
						|
 | 
						|
 | 
						|
def dev_update_subgroups(
 | 
						|
    request: HttpRequest,
 | 
						|
    user_profile: UserProfile,
 | 
						|
    user_group_id: int,
 | 
						|
) -> str | None:
 | 
						|
    # The test is expected to set up the barrier before accessing this endpoint.
 | 
						|
    assert BARRIER is not None
 | 
						|
    try:
 | 
						|
        with (
 | 
						|
            mock.patch("zerver.lib.user_groups.access_user_group_for_update") as m,
 | 
						|
        ):
 | 
						|
 | 
						|
            def wait_after_recursive_query(*args: Any, **kwargs: Any) -> UserGroup:
 | 
						|
                # When updating the subgroups, we access the supergroup group
 | 
						|
                # only after finishing the recursive query.
 | 
						|
                BARRIER.wait()
 | 
						|
                return access_user_group_for_update(*args, **kwargs)
 | 
						|
 | 
						|
            m.side_effect = wait_after_recursive_query
 | 
						|
 | 
						|
            update_subgroups_of_user_group(request, user_profile, user_group_id=user_group_id)
 | 
						|
    except OperationalError as err:
 | 
						|
        msg = str(err)
 | 
						|
        if "deadlock detected" in msg:
 | 
						|
            return "Deadlock detected"
 | 
						|
        else:
 | 
						|
            assert "could not obtain lock" in msg
 | 
						|
            # This error is possible when nowait is set the True, which only
 | 
						|
            # applies to the recursive query on the subgroups. Because the
 | 
						|
            # recursive query fails, this thread must have not waited on the
 | 
						|
            # barrier yet.
 | 
						|
            BARRIER.wait()
 | 
						|
            return "Busy lock detected"
 | 
						|
    except (
 | 
						|
        threading.BrokenBarrierError
 | 
						|
    ):  # nocoverage # This is only possible when timeout happens or there is a programming error
 | 
						|
        raise JsonableError(
 | 
						|
            "Broken barrier. The tester should make sure that the exact number of parties have waited on the barrier set by the previous immediate set_sync_after_first_lock call"
 | 
						|
        )
 | 
						|
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
class UserGroupRaceConditionTestCase(ZulipTransactionTestCase):
 | 
						|
    created_user_groups: list[NamedUserGroup] = []
 | 
						|
    counter = 0
 | 
						|
    CHAIN_LENGTH = 3
 | 
						|
 | 
						|
    @override
 | 
						|
    def tearDown(self) -> None:
 | 
						|
        # Clean up the user groups created to minimize leakage
 | 
						|
        with transaction.atomic(durable=True):
 | 
						|
            for group in self.created_user_groups:
 | 
						|
                # can_manage_group can be deleted as long as it's the
 | 
						|
                # default group_creator. If we start using non-default
 | 
						|
                # can_manage_group in this test, deleting that group
 | 
						|
                # should be reconsidered.
 | 
						|
                can_manage_group = group.can_manage_group
 | 
						|
                can_add_members_group = group.can_add_members_group
 | 
						|
                group.delete()
 | 
						|
                can_manage_group.delete()
 | 
						|
                can_add_members_group.delete()
 | 
						|
            transaction.on_commit(self.created_user_groups.clear)
 | 
						|
 | 
						|
        super().tearDown()
 | 
						|
 | 
						|
    def create_user_group_chain(self, realm: Realm) -> list[NamedUserGroup]:
 | 
						|
        """Build a user groups forming a chain through group-group memberships
 | 
						|
        returning a list where each group is the supergroup of its subsequent group.
 | 
						|
        """
 | 
						|
        iago = self.example_user("iago")
 | 
						|
        groups = [
 | 
						|
            check_add_user_group(realm, f"chain #{self.counter + i}", [], acting_user=iago)
 | 
						|
            for i in range(self.CHAIN_LENGTH)
 | 
						|
        ]
 | 
						|
        self.counter += self.CHAIN_LENGTH
 | 
						|
        self.created_user_groups.extend(groups)
 | 
						|
        prev_group = groups[0]
 | 
						|
        for group in groups[1:]:
 | 
						|
            add_subgroups_to_user_group(prev_group, [group], acting_user=iago)
 | 
						|
            prev_group = group
 | 
						|
        return groups
 | 
						|
 | 
						|
    def test_lock_subgroups_with_respect_to_supergroup(self) -> None:
 | 
						|
        realm = get_realm("zulip")
 | 
						|
        self.login("iago")
 | 
						|
        iago = self.example_user("iago")
 | 
						|
 | 
						|
        class RacingThread(threading.Thread):
 | 
						|
            def __init__(
 | 
						|
                self,
 | 
						|
                subgroup_ids: list[int],
 | 
						|
                supergroup_id: int,
 | 
						|
            ) -> None:
 | 
						|
                threading.Thread.__init__(self)
 | 
						|
                self.response: str | None = None
 | 
						|
                self.subgroup_ids = subgroup_ids
 | 
						|
                self.supergroup_id = supergroup_id
 | 
						|
 | 
						|
            @override
 | 
						|
            def run(self) -> None:
 | 
						|
                try:
 | 
						|
                    self.response = dev_update_subgroups(
 | 
						|
                        HostRequestMock({"add": orjson.dumps(self.subgroup_ids).decode()}),
 | 
						|
                        iago,
 | 
						|
                        user_group_id=self.supergroup_id,
 | 
						|
                    )
 | 
						|
                finally:
 | 
						|
                    # Close all thread-local database connections
 | 
						|
                    connections.close_all()
 | 
						|
 | 
						|
        def assert_thread_success_count(
 | 
						|
            t1: RacingThread,
 | 
						|
            t2: RacingThread,
 | 
						|
            *,
 | 
						|
            success_count: int,
 | 
						|
            error_message: str = "",
 | 
						|
        ) -> None:
 | 
						|
            help_msg = """We access the test endpoint that wraps around the
 | 
						|
real subgroup update endpoint by synchronizing them after the acquisition of the
 | 
						|
first lock in the critical region. Though unlikely, this test might fail as we
 | 
						|
have no control over the scheduler when the barrier timeouts.
 | 
						|
""".strip()
 | 
						|
            global BARRIER
 | 
						|
            BARRIER = threading.Barrier(parties=2, timeout=3)
 | 
						|
            t1.start()
 | 
						|
            t2.start()
 | 
						|
 | 
						|
            succeeded = 0
 | 
						|
            for t in [t1, t2]:
 | 
						|
                t.join()
 | 
						|
                response = t.response
 | 
						|
                if response is None:
 | 
						|
                    succeeded += 1
 | 
						|
                    continue
 | 
						|
 | 
						|
                self.assertEqual(response, error_message)
 | 
						|
            # Race condition resolution should only allow one thread to succeed
 | 
						|
            self.assertEqual(
 | 
						|
                succeeded,
 | 
						|
                success_count,
 | 
						|
                f"Exactly {success_count} thread(s) should succeed.\n{help_msg}",
 | 
						|
            )
 | 
						|
 | 
						|
        foo_chain = self.create_user_group_chain(realm)
 | 
						|
        bar_chain = self.create_user_group_chain(realm)
 | 
						|
        # These two threads are conflicting because a cycle would be formed if
 | 
						|
        # both of them succeed. There is a deadlock in such circular dependency.
 | 
						|
        assert_thread_success_count(
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[foo_chain[0].id],
 | 
						|
                supergroup_id=bar_chain[-1].id,
 | 
						|
            ),
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[bar_chain[-1].id],
 | 
						|
                supergroup_id=foo_chain[0].id,
 | 
						|
            ),
 | 
						|
            success_count=1,
 | 
						|
            error_message="Deadlock detected",
 | 
						|
        )
 | 
						|
 | 
						|
        foo_chain = self.create_user_group_chain(realm)
 | 
						|
        bar_chain = self.create_user_group_chain(realm)
 | 
						|
        # These two requests would succeed if they didn't race with each other.
 | 
						|
        # However, both threads will attempt to grab a lock on overlapping rows
 | 
						|
        # when they first do the recursive query for subgroups. In this case, we
 | 
						|
        # expect that one of the threads fails due to nowait=True for the
 | 
						|
        # .select_for_update() call.
 | 
						|
        assert_thread_success_count(
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[foo_chain[0].id],
 | 
						|
                supergroup_id=bar_chain[-1].id,
 | 
						|
            ),
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[foo_chain[1].id],
 | 
						|
                supergroup_id=bar_chain[-1].id,
 | 
						|
            ),
 | 
						|
            success_count=1,
 | 
						|
            error_message="Busy lock detected",
 | 
						|
        )
 | 
						|
 | 
						|
        foo_chain = self.create_user_group_chain(realm)
 | 
						|
        bar_chain = self.create_user_group_chain(realm)
 | 
						|
        baz_chain = self.create_user_group_chain(realm)
 | 
						|
        # Adding non-conflicting subgroups should succeed.
 | 
						|
        assert_thread_success_count(
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[foo_chain[1].id, foo_chain[2].id, baz_chain[2].id],
 | 
						|
                supergroup_id=baz_chain[0].id,
 | 
						|
            ),
 | 
						|
            RacingThread(
 | 
						|
                subgroup_ids=[bar_chain[1].id, bar_chain[2].id],
 | 
						|
                supergroup_id=baz_chain[0].id,
 | 
						|
            ),
 | 
						|
            success_count=2,
 | 
						|
        )
 |