diff --git a/zerver/tests/test_event_queue.py b/zerver/tests/test_event_queue.py index bbb2320bde..5dd0f2eac7 100644 --- a/zerver/tests/test_event_queue.py +++ b/zerver/tests/test_event_queue.py @@ -13,8 +13,8 @@ from zerver.lib.user_groups import create_user_group, remove_user_from_user_grou from zerver.models import Recipient, Stream, Subscription, UserProfile, get_stream from zerver.tornado.event_queue import ( ClientDescriptor, + access_client_descriptor, allocate_client_descriptor, - get_client_descriptor, maybe_enqueue_notifications, missedmessage_hook, persistent_queue_filename, @@ -173,7 +173,7 @@ class MissedMessageNotificationsTest(ZulipTestCase): ) self.assert_json_success(result) queue_id = orjson.loads(result.content)["queue_id"] - return get_client_descriptor(queue_id) + return access_client_descriptor(user.id, queue_id) def destroy_event_queue(user: UserProfile, queue_id: str) -> None: result = self.tornado_call(cleanup_event_queue, user, {"queue_id": queue_id}) diff --git a/zerver/tests/test_event_system.py b/zerver/tests/test_event_system.py index 433caa2f0a..8a9c3fea78 100644 --- a/zerver/tests/test_event_system.py +++ b/zerver/tests/test_event_system.py @@ -36,6 +36,7 @@ from zerver.tornado.event_queue import ( process_message_event, send_restart_events, ) +from zerver.tornado.exceptions import BadEventQueueIdError from zerver.tornado.views import get_events, get_events_backend from zerver.views.events_register import ( _default_all_public_streams, @@ -418,6 +419,52 @@ class GetEventsTest(ZulipTestCase): self.assertEqual(message["content"], "
hello
") self.assertEqual(message["avatar_url"], None) + def test_bogus_queue_id(self) -> None: + user = self.example_user("hamlet") + + with self.assertRaises(BadEventQueueIdError): + self.tornado_call( + get_events, + user, + { + "queue_id": "hamster", + "user_client": "website", + "last_event_id": -1, + "dont_block": orjson.dumps(True).decode(), + }, + ) + + def test_wrong_user_queue_id(self) -> None: + user = self.example_user("hamlet") + wrong_user = self.example_user("othello") + + result = self.tornado_call( + get_events, + user, + { + "apply_markdown": orjson.dumps(True).decode(), + "client_gravatar": orjson.dumps(True).decode(), + "event_types": orjson.dumps(["message"]).decode(), + "user_client": "website", + "dont_block": orjson.dumps(True).decode(), + }, + ) + self.assert_json_success(result) + queue_id = orjson.loads(result.content)["queue_id"] + + with self.assertLogs(level="WARNING") as cm, self.assertRaises(BadEventQueueIdError): + self.tornado_call( + get_events, + wrong_user, + { + "queue_id": queue_id, + "user_client": "website", + "last_event_id": -1, + "dont_block": orjson.dumps(True).decode(), + }, + ) + self.assertIn("not authorized for queue", cm.output[0]) + class FetchInitialStateDataTest(ZulipTestCase): # Non-admin users don't have access to all bots diff --git a/zerver/tornado/event_queue.py b/zerver/tornado/event_queue.py index cfc650c656..feca3d942a 100644 --- a/zerver/tornado/event_queue.py +++ b/zerver/tornado/event_queue.py @@ -433,11 +433,19 @@ def add_client_gc_hook(hook: Callable[[int, ClientDescriptor, bool], None]) -> N gc_hooks.append(hook) -def get_client_descriptor(queue_id: str) -> ClientDescriptor: - try: - return clients[queue_id] - except KeyError: - raise BadEventQueueIdError(queue_id) +def access_client_descriptor(user_id: int, queue_id: str) -> ClientDescriptor: + client = clients.get(queue_id) + if client is not None: + if user_id == client.user_profile_id: + return client + logging.warning( + "User %d is not authorized for queue %s (%d via %s)", + user_id, + queue_id, + client.user_profile_id, + client.current_client_name, + ) + raise BadEventQueueIdError(queue_id) def get_client_descriptors_for_user(user_profile_id: int) -> List[ClientDescriptor]: @@ -635,9 +643,7 @@ def fetch_events(query: Mapping[str, Any]) -> Dict[str, Any]: else: if last_event_id is None: raise JsonableError(_("Missing 'last_event_id' argument")) - client = get_client_descriptor(queue_id) - if user_profile_id != client.user_profile_id: - raise JsonableError(_("You are not authorized to get events from this queue")) + client = access_client_descriptor(user_profile_id, queue_id) if ( client.event_queue.newest_pruned_id is not None and last_event_id < client.event_queue.newest_pruned_id diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index 49b0ddd523..385a745978 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -18,8 +18,7 @@ from zerver.lib.validator import ( to_non_negative_int, ) from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id -from zerver.tornado.event_queue import fetch_events, get_client_descriptor, process_notification -from zerver.tornado.exceptions import BadEventQueueIdError +from zerver.tornado.event_queue import access_client_descriptor, fetch_events, process_notification T = TypeVar("T") @@ -41,11 +40,7 @@ def notify(request: HttpRequest) -> HttpResponse: def cleanup_event_queue( request: HttpRequest, user_profile: UserProfile, queue_id: str = REQ() ) -> HttpResponse: - client = get_client_descriptor(str(queue_id)) - if client is None: - raise BadEventQueueIdError(queue_id) - if user_profile.id != client.user_profile_id: - raise JsonableError(_("You are not authorized to access this queue")) + client = access_client_descriptor(user_profile.id, queue_id) log_data = RequestNotes.get_notes(request).log_data assert log_data is not None log_data["extra"] = f"[{queue_id}]"