Files
zulip/zerver/tests/test_message_summary.py
Alex Vandiver 4dfb285b5b message_summary: Reorganize settings and secrets.
This allows arbitrary litellm-supported backends.
2025-01-29 14:55:35 -08:00

107 lines
4.5 KiB
Python

import os
import warnings
from unittest import mock
import orjson
from django.conf import settings
from typing_extensions import override
from analytics.models import UserCount
from zerver.actions.message_summary import INPUT_COST_PER_GIGATOKEN, OUTPUT_COST_PER_GIGATOKEN
from zerver.lib.test_classes import ZulipTestCase
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm")
# Avoid network query to fetch the model cost map.
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
import litellm
# Fixture file to store recorded responses
LLM_FIXTURES_FILE = "zerver/tests/fixtures/litellm/summary.json"
class MessagesSummaryTestCase(ZulipTestCase):
@override
def setUp(self) -> None:
super().setUp()
self.user = self.example_user("iago")
self.topic_name = "New feature launch"
self.channel_name = "Zulip features"
self.login_user(self.user)
self.subscribe(self.user, self.channel_name)
content = "Zulip just launched a feature to generate summary of messages."
self.send_stream_message(
self.user, self.channel_name, content=content, topic_name=self.topic_name
)
content = "Sounds awesome! This will greatly help me when catching up."
self.send_stream_message(
self.user, self.channel_name, content=content, topic_name=self.topic_name
)
if settings.GENERATE_LITELLM_FIXTURES: # nocoverage
self.patcher = mock.patch("litellm.completion", wraps=litellm.completion)
self.mocked_completion = self.patcher.start()
@override
def tearDown(self) -> None:
if settings.GENERATE_LITELLM_FIXTURES: # nocoverage
self.patcher.stop()
super().tearDown()
def test_summarize_messages_in_topic(self) -> None:
narrow = orjson.dumps([["channel", self.channel_name], ["topic", self.topic_name]]).decode()
if settings.GENERATE_LITELLM_FIXTURES: # nocoverage
# NOTE: You need have proper credentials in zproject/dev-secrets.conf
# to generate the fixtures. (Tested using aws bedrock.)
# Trigger the API call to extract the arguments.
self.client_get("/json/messages/summary", dict(narrow=narrow))
call_args = self.mocked_completion.call_args
# Once we have the arguments, call the original method and save its response.
response = self.mocked_completion(**call_args.kwargs).json()
with open(LLM_FIXTURES_FILE, "wb") as f:
fixture_data = {
# Only store model and messages.
# We don't want to store any secrets.
"model": call_args.kwargs["model"],
"messages": call_args.kwargs["messages"],
"response": response,
}
f.write(orjson.dumps(fixture_data, option=orjson.OPT_INDENT_2) + b"\n")
return
# In this code path, we test using the fixtures.
with open(LLM_FIXTURES_FILE, "rb") as f:
fixture_data = orjson.loads(f.read())
# Fake credentials to ensure we crash if actual network
# requests occur, which would reflect a problem with how the
# fixtures were set up.
with self.settings(
TOPIC_SUMMARIZATION_MODEL="groq/llama-3.3-70b-versatile",
TOPIC_SUMMARIZATION_API_KEY="test",
):
input_tokens = fixture_data["response"]["usage"]["prompt_tokens"]
output_tokens = fixture_data["response"]["usage"]["completion_tokens"]
credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + (
input_tokens * INPUT_COST_PER_GIGATOKEN
)
self.assertFalse(
UserCount.objects.filter(
property="ai_credit_usage::day", value=credits_used, user_id=self.user.id
).exists()
)
with mock.patch("litellm.completion", return_value=fixture_data["response"]):
payload = self.client_get("/json/messages/summary", dict(narrow=narrow))
self.assertEqual(payload.status_code, 200)
# Check that we recorded this usage.
self.assertTrue(
UserCount.objects.filter(
property="ai_credit_usage::day", value=credits_used, user_id=self.user.id
).exists()
)