diff --git a/tools/test-backend b/tools/test-backend index 23b9cc4de6..0220d14671 100755 --- a/tools/test-backend +++ b/tools/test-backend @@ -130,8 +130,6 @@ not_yet_fully_covered = [ "zerver/webhooks/zapier/view.py", # This is hard to get test coverage for, and low value to do so "zerver/views/sentry.py", - # TODO: Add tests when the details are finalized. - "zerver/views/message_summary.py", # Cannot have coverage, as tests run in a transaction "zerver/lib/safe_session_cached_db.py", "zerver/lib/singleton_bmemcached.py", @@ -282,6 +280,11 @@ def main() -> None: action="store_true", help="Generate Stripe test fixtures by making requests to Stripe test network", ) + parser.add_argument( + "--generate-litellm-fixtures", + action="store_true", + help="Generate litellm test fixtures using credentials in zproject/dev-secrets.conf", + ) parser.add_argument("args", nargs="*") parser.add_argument( "--ban-console-output", @@ -390,6 +393,16 @@ def main() -> None: default_parallel = 1 os.environ["GENERATE_STRIPE_FIXTURES"] = "1" + if options.generate_litellm_fixtures: + if full_suite: + suites = [ + "zerver.tests.test_message_summary", + ] + full_suite = False + print("-- Forcing serial mode for generating litellm fixtures.", flush=True) + default_parallel = 1 + os.environ["GENERATE_LITELLM_FIXTURES"] = "1" + assert_provisioning_status_ok(options.skip_provision_check) if options.coverage: diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index da0aa0a11a..45a417ac1f 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -526,9 +526,6 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0} exempt_patterns = { - # TODO: Add tests when we are sure about what needs to be - # served to the user as summary. - "api/v1/messages/summary", # We exempt some patterns that are called via Tornado. "api/v1/events", "api/v1/events/internal", diff --git a/zerver/tests/fixtures/litellm/summary.json b/zerver/tests/fixtures/litellm/summary.json new file mode 100644 index 0000000000..698e03d7df --- /dev/null +++ b/zerver/tests/fixtures/litellm/summary.json @@ -0,0 +1,43 @@ +{ + "model": "bedrock/meta.llama3-8b-instruct-v1:0", + "messages": [ + { + "content": "The following is a chat conversation in the Zulip team chat app. channel: Zulip features, topic: New feature launch", + "role": "system" + }, + { + "content": "[{\"sender\": \"Iago\", \"content\": \"Zulip just launched a feature to generate summary of messages.\"}, {\"sender\": \"Iago\", \"content\": \"Sounds awesome! This will greatly help me when catching up.\"}]", + "role": "user" + }, + { + "content": "Succinctly summarize this conversation based only on the information provided, in up to 4 sentences, for someone who is familiar with the context. Mention key conclusions and actions, if any. Refer to specific people as appropriate. Don't use an intro phrase.", + "role": "user" + } + ], + "response": { + "id": "chatcmpl-a86e270f-a634-40f3-92f4-da786ccb263b", + "created": 1737832810, + "model": "meta.llama3-8b-instruct-v1:0", + "object": "chat.completion", + "system_fingerprint": null, + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "\n\nIago announced the launch of a new feature in Zulip, which generates summaries of messages. He expressed enthusiasm for the feature, stating it will greatly help him when catching up.", + "role": "assistant", + "tool_calls": null, + "function_call": null + } + } + ], + "usage": { + "completion_tokens": 39, + "prompt_tokens": 144, + "total_tokens": 183, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } +} diff --git a/zerver/tests/test_message_summary.py b/zerver/tests/test_message_summary.py new file mode 100644 index 0000000000..ed9411fb91 --- /dev/null +++ b/zerver/tests/test_message_summary.py @@ -0,0 +1,108 @@ +import os +import warnings +from unittest import mock + +import orjson +from django.conf import settings +from typing_extensions import override + +from analytics.models import UserCount +from zerver.lib.test_classes import ZulipTestCase +from zerver.views.message_summary import INPUT_COST_PER_GIGATOKEN, OUTPUT_COST_PER_GIGATOKEN + +warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm") +# Avoid network query to fetch the model cost map. +os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" +import litellm + +# Fixture file to store recorded responses +LLM_FIXTURES_FILE = "zerver/tests/fixtures/litellm/summary.json" + + +class MessagesSummaryTestCase(ZulipTestCase): + @override + def setUp(self) -> None: + super().setUp() + self.user = self.example_user("iago") + self.topic_name = "New feature launch" + self.channel_name = "Zulip features" + + self.login_user(self.user) + self.subscribe(self.user, self.channel_name) + content = "Zulip just launched a feature to generate summary of messages." + self.send_stream_message( + self.user, self.channel_name, content=content, topic_name=self.topic_name + ) + + content = "Sounds awesome! This will greatly help me when catching up." + self.send_stream_message( + self.user, self.channel_name, content=content, topic_name=self.topic_name + ) + + if settings.GENERATE_LITELLM_FIXTURES: # nocoverage + self.patcher = mock.patch("litellm.completion", wraps=litellm.completion) + self.mocked_completion = self.patcher.start() + + @override + def tearDown(self) -> None: + if settings.GENERATE_LITELLM_FIXTURES: # nocoverage + self.patcher.stop() + super().tearDown() + + def test_summarize_messages_in_topic(self) -> None: + narrow = orjson.dumps([["channel", self.channel_name], ["topic", self.topic_name]]).decode() + + if settings.GENERATE_LITELLM_FIXTURES: # nocoverage + # NOTE: You need have proper credentials in zproject/dev-secrets.conf + # to generate the fixtures. (Tested using aws bedrock.) + # Trigger the API call to extract the arguments. + self.client_get("/json/messages/summary", dict(narrow=narrow)) + call_args = self.mocked_completion.call_args + + # Once we have the arguments, call the original method and save its response. + response = self.mocked_completion(**call_args.kwargs).json() + with open(LLM_FIXTURES_FILE, "wb") as f: + fixture_data = { + # Only store model and messages. + # We don't want to store any secrets. + "model": call_args.kwargs["model"], + "messages": call_args.kwargs["messages"], + "response": response, + } + f.write(orjson.dumps(fixture_data, option=orjson.OPT_INDENT_2) + b"\n") + return + + # In this code path, we test using the fixtures. + with open(LLM_FIXTURES_FILE, "rb") as f: + fixture_data = orjson.loads(f.read()) + + # Fake credentials to ensure we crash if actual network + # requests occur, which would reflect a problem with how the + # fixtures were set up. + with self.settings( + TOPIC_SUMMARIZATION_MODEL="bedrock/meta.llama3-8b-instruct-v1:0", + AWS_ACCESS_KEY_ID="fakeKeyID", + AWS_SECRET_ACCESS_KEY="fakeAccessKey", + AWS_REGION_NAME="ap-south-1", + ): + input_tokens = fixture_data["response"]["usage"]["prompt_tokens"] + output_tokens = fixture_data["response"]["usage"]["completion_tokens"] + credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + ( + input_tokens * INPUT_COST_PER_GIGATOKEN + ) + self.assertFalse( + UserCount.objects.filter( + property="ai_credit_usage::day", value=credits_used, user_id=self.user.id + ).exists() + ) + with mock.patch("litellm.completion", return_value=fixture_data["response"]): + payload = self.client_get("/json/messages/summary", dict(narrow=narrow)) + self.assertEqual(payload.status_code, 200) + # Check that we recorded this usage. + self.assertTrue( + UserCount.objects.filter( + property="ai_credit_usage::day", value=credits_used, user_id=self.user.id + ).exists() + ) diff --git a/zerver/views/message_summary.py b/zerver/views/message_summary.py index fc3934d18a..1b1fe5c842 100644 --- a/zerver/views/message_summary.py +++ b/zerver/views/message_summary.py @@ -1,3 +1,6 @@ +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") import json from typing import Any @@ -73,10 +76,10 @@ def get_messages_summary( *, narrow: Json[list[NarrowParameter] | None] = None, ) -> HttpResponse: - if settings.TOPIC_SUMMARIZATION_MODEL is None: + if settings.TOPIC_SUMMARIZATION_MODEL is None: # nocoverage raise JsonableError(_("AI features are not enabled on this server.")) - if not (user_profile.is_moderator or user_profile.is_realm_admin): + if not (user_profile.is_moderator or user_profile.is_realm_admin): # nocoverage return json_success(request, {"summary": "Feature limited to moderators for now."}) # TODO: This implementation does not attempt to make use of @@ -95,7 +98,7 @@ def get_messages_summary( num_after=0, ) - if len(query_info.rows) == 0: + if len(query_info.rows) == 0: # nocoverage return json_success(request, {"summary": "No messages in conversation to summarize"}) result_message_ids: list[int] = [] @@ -126,7 +129,7 @@ def get_messages_summary( # is primarily trained on English. model = settings.TOPIC_SUMMARIZATION_MODEL litellm_params: dict[str, Any] = {} - if model.startswith("huggingface"): + if model.startswith("huggingface"): # nocoverage assert settings.HUGGINGFACE_API_KEY is not None litellm_params["api_key"] = settings.HUGGINGFACE_API_KEY else: @@ -194,6 +197,7 @@ def get_messages_summary( messages=messages, **litellm_params, ) + input_tokens = response["usage"]["prompt_tokens"] output_tokens = response["usage"]["completion_tokens"] credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + ( diff --git a/zproject/computed_settings.py b/zproject/computed_settings.py index a82b745d15..2e822f6463 100644 --- a/zproject/computed_settings.py +++ b/zproject/computed_settings.py @@ -179,6 +179,8 @@ RUNNING_OPENAPI_CURL_TEST = False # This is overridden in test_settings.py for the test suites GENERATE_STRIPE_FIXTURES = False # This is overridden in test_settings.py for the test suites +GENERATE_LITELLM_FIXTURES = False +# This is overridden in test_settings.py for the test suites BAN_CONSOLE_OUTPUT = False # This is overridden in test_settings.py for the test suites TEST_WORKER_DIR = "" diff --git a/zproject/test_extra_settings.py b/zproject/test_extra_settings.py index f0951ef364..9d4fc05145 100644 --- a/zproject/test_extra_settings.py +++ b/zproject/test_extra_settings.py @@ -55,6 +55,9 @@ if "RUNNING_OPENAPI_CURL_TEST" in os.environ: if "GENERATE_STRIPE_FIXTURES" in os.environ: GENERATE_STRIPE_FIXTURES = True +if "GENERATE_LITELLM_FIXTURES" in os.environ: + GENERATE_LITELLM_FIXTURES = True + if "BAN_CONSOLE_OUTPUT" in os.environ: BAN_CONSOLE_OUTPUT = True