message_summary: Add API endpoint to generate narrow summary.

This prototype API is disabled in production through settings not
configuring a default model.
This commit is contained in:
Aman Agrawal
2025-01-06 12:18:17 +05:30
committed by Tim Abbott
parent 6a5c33788d
commit b047c4d322
6 changed files with 246 additions and 0 deletions

View File

@@ -851,6 +851,9 @@ def get_count_stats(realm: Realm | None = None) -> dict[str, CountStat]:
), ),
CountStat.DAY, CountStat.DAY,
), ),
# AI credit usage stats for users, in units of $1/10^9, which is safe for
# aggregation because we're using bigints for the values.
LoggingCountStat("ai_credit_usage::day", UserCount, CountStat.DAY),
# Counts the number of active users in the UserProfile.is_active sense. # Counts the number of active users in the UserProfile.is_active sense.
# Important that this stay a daily stat, so that 'realm_active_humans::day' works as expected. # Important that this stay a daily stat, so that 'realm_active_humans::day' works as expected.
CountStat( CountStat(

View File

@@ -130,6 +130,8 @@ not_yet_fully_covered = [
"zerver/webhooks/zapier/view.py", "zerver/webhooks/zapier/view.py",
# This is hard to get test coverage for, and low value to do so # This is hard to get test coverage for, and low value to do so
"zerver/views/sentry.py", "zerver/views/sentry.py",
# TODO: Add tests when the details are finalized.
"zerver/views/message_summary.py",
# Cannot have coverage, as tests run in a transaction # Cannot have coverage, as tests run in a transaction
"zerver/lib/safe_session_cached_db.py", "zerver/lib/safe_session_cached_db.py",
"zerver/lib/singleton_bmemcached.py", "zerver/lib/singleton_bmemcached.py",

View File

@@ -526,6 +526,9 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N
untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0} untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0}
exempt_patterns = { exempt_patterns = {
# TODO: Add tests when we are sure about what needs to be
# served to the user as summary.
"api/v1/messages/summary",
# We exempt some patterns that are called via Tornado. # We exempt some patterns that are called via Tornado.
"api/v1/events", "api/v1/events",
"api/v1/events/internal", "api/v1/events/internal",

View File

@@ -0,0 +1,204 @@
import json
from typing import Any
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from pydantic import Json
from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat
from zerver.lib.exceptions import JsonableError
from zerver.lib.message import messages_for_ids
from zerver.lib.narrow import (
LARGER_THAN_MAX_MESSAGE_ID,
NarrowParameter,
clean_narrow_for_message_fetch,
fetch_messages,
)
from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint
from zerver.models import UserProfile
# Maximum number of messages that can be summarized in a single request.
MAX_MESSAGES_SUMMARIZED = 100
# Price per token for input and output tokens.
# These values are based on the pricing of the Bedrock API
# for Llama 3.3 Instruct (70B).
# https://aws.amazon.com/bedrock/pricing/
# Unit: USD per 1 billion tokens.
#
# These values likely will want to be declared in configuration,
# rather than here in the code.
OUTPUT_COST_PER_GIGATOKEN = 720
INPUT_COST_PER_GIGATOKEN = 720
def format_zulip_messages_for_model(zulip_messages: list[dict[str, Any]]) -> str:
# Format the Zulip messages for processing by the model.
#
# - We don't need to encode the recipient, since that's the same for
# every message in the conversation.
# - We use full names to reference senders, since we want the
# model to refer to users by name. We may want to experiment
# with using silent-mention syntax for users if we move to
# Markdown-rendering what the model returns.
# - We don't include timestamps, since experiments with current models
# suggest they do not make relevant use of them.
# - We haven't figured out a useful way to include reaction metadata (either
# the emoji themselves or just the counter).
# - Polls/TODO widgets are currently sent to the model as empty messages,
# since this logic doesn't inspect SubMessage objects.
zulip_messages_list = [
{"sender": message["sender_full_name"], "content": message["content"]}
for message in zulip_messages
]
return json.dumps(zulip_messages_list)
def make_message(content: str, role: str = "user") -> dict[str, str]:
return {"content": content, "role": role}
def get_max_summary_length(conversation_length: int) -> int:
return min(6, 4 + int((conversation_length - 10) / 10))
@typed_endpoint
def get_messages_summary(
request: HttpRequest,
user_profile: UserProfile,
*,
narrow: Json[list[NarrowParameter] | None] = None,
) -> HttpResponse:
if settings.TOPIC_SUMMARIZATION_MODEL is None:
raise JsonableError(_("AI features are not enabled on this server."))
if not (user_profile.is_moderator or user_profile.is_realm_admin):
return json_success(request, {"summary": "Feature limited to moderators for now."})
# TODO: This implementation does not attempt to make use of
# caching previous summaries of the same conversation or rolling
# summaries. Doing so correctly will require careful work around
# invalidation of caches when messages are edited, moved, or sent.
narrow = clean_narrow_for_message_fetch(narrow, user_profile.realm, user_profile)
query_info = fetch_messages(
narrow=narrow,
user_profile=user_profile,
realm=user_profile.realm,
is_web_public_query=False,
anchor=LARGER_THAN_MAX_MESSAGE_ID,
include_anchor=True,
num_before=MAX_MESSAGES_SUMMARIZED,
num_after=0,
)
if len(query_info.rows) == 0:
return json_success(request, {"summary": "No messages in conversation to summarize"})
result_message_ids: list[int] = []
user_message_flags: dict[int, list[str]] = {}
for row in query_info.rows:
message_id = row[0]
result_message_ids.append(message_id)
# We skip populating flags, since they would be ignored below anyway.
user_message_flags[message_id] = []
message_list = messages_for_ids(
message_ids=result_message_ids,
user_message_flags=user_message_flags,
search_fields={},
# We currently prefer the plain-text content of messages to
apply_markdown=False,
# Avoid wasting resources computing gravatars.
client_gravatar=True,
allow_empty_topic_name=False,
# Avoid fetching edit history, which won't be passed to the model.
allow_edit_history=False,
user_profile=user_profile,
realm=user_profile.realm,
)
# IDEA: We could consider translating input and output text to
# English to improve results when using a summarization model that
# is primarily trained on English.
model = settings.TOPIC_SUMMARIZATION_MODEL
litellm_params: dict[str, Any] = {}
if model.startswith("huggingface"):
assert settings.HUGGINGFACE_API_KEY is not None
litellm_params["api_key"] = settings.HUGGINGFACE_API_KEY
else:
assert model.startswith("bedrock")
litellm_params["aws_access_key_id"] = settings.AWS_ACCESS_KEY_ID
litellm_params["aws_secret_access_key"] = settings.AWS_SECRET_ACCESS_KEY
litellm_params["aws_region_name"] = settings.AWS_REGION_NAME
conversation_length = len(message_list)
max_summary_length = get_max_summary_length(conversation_length)
intro = "The following is a chat conversation in the Zulip team chat app."
topic: str | None = None
channel: str | None = None
if narrow and len(narrow) == 2:
for term in narrow:
assert not term.negated
if term.operator == "channel":
channel = term.operand
if term.operator == "topic":
topic = term.operand
if channel:
intro += f" channel: {channel}"
if topic:
intro += f", topic: {topic}"
formatted_conversation = format_zulip_messages_for_model(message_list)
prompt = (
f"Succinctly summarize this conversation based only on the information provided, "
f"in up to {max_summary_length} sentences, for someone who is familiar with the context. "
f"Mention key conclusions and actions, if any. Refer to specific people as appropriate. "
f"Don't use an intro phrase."
)
messages = [
make_message(intro, "system"),
make_message(formatted_conversation),
make_message(prompt),
]
# We import litellm here to avoid a DeprecationWarning.
# See these issues for more info:
# https://github.com/BerriAI/litellm/issues/6232
# https://github.com/BerriAI/litellm/issues/5647
import litellm
# Token counter is recommended by LiteLLM but mypy says it's not explicitly exported.
# https://docs.litellm.ai/docs/completion/token_usage#3-token_counter
input_tokens = litellm.token_counter(model=model, messages=messages) # type: ignore[attr-defined] # Explained above
# TODO when implementing user plans:
# - Before querying the model, check whether we've enough tokens left using
# the estimated token count.
# - Then increase the `LoggingCountStat` using the estimated token count.
# (These first two steps should be a short database transaction that
# locks the `LoggingCountStat` row).
# - Then query the model.
# - Then adjust the `LoggingCountStat` by `(actual - estimated)`,
# being careful to avoid doing this to the next day if the query
# happened milliseconds before midnight; changing the
# `LoggingCountStat` we added the estimate to.
# That way, you can't easily get extra tokens by sending
# 25 requests all at once when you're just below the limit.
response = litellm.completion(
model=model,
messages=messages,
**litellm_params,
)
output_tokens = response["usage"]["completion_tokens"]
credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + (
input_tokens * INPUT_COST_PER_GIGATOKEN
)
do_increment_logging_stat(
user_profile, COUNT_STATS["ai_credit_usage::day"], None, timezone_now(), credits_used
)
return json_success(request, {"summary": response["choices"][0]["message"]["content"]})

View File

@@ -1244,6 +1244,31 @@ EMAIL_HOST_PASSWORD = get_secret("email_password")
EMAIL_GATEWAY_PASSWORD = get_secret("email_gateway_password") EMAIL_GATEWAY_PASSWORD = get_secret("email_gateway_password")
AUTH_LDAP_BIND_PASSWORD = get_secret("auth_ldap_bind_password", "") AUTH_LDAP_BIND_PASSWORD = get_secret("auth_ldap_bind_password", "")
########################################################################
# LiteLLM SETTINGS
########################################################################
# The model name that will used by LiteLLM library to configure
# parameters to be sent to API.
# The Llama-3-8B-instruct model is free to use and only requires submitting
# a small form on the HuggingFace page for the model to gain access.
# We only support HuggingFace and AWS Bedrock for LLM API requests.
DEFAULT_TOPIC_SUMMARIZATION_MODEL: str | None = "huggingface/meta-llama/Meta-Llama-3-8B-Instruct"
if PRODUCTION:
DEFAULT_TOPIC_SUMMARIZATION_MODEL = None
TOPIC_SUMMARIZATION_MODEL = get_secret(
"topic_summarization_model", DEFAULT_TOPIC_SUMMARIZATION_MODEL
)
# Which API key to use will be determined based on TOPIC_SUMMARIZATION_MODEL.
# HuggingFace access credentials
HUGGINGFACE_API_KEY = get_secret("huggingface_api_key", None)
# AWS Bedrock access credentials
AWS_ACCESS_KEY_ID = get_secret("aws_access_key_id", None)
AWS_SECRET_ACCESS_KEY = get_secret("aws_secret_access_key", None)
AWS_REGION_NAME = get_secret("aws_region_name", None)
######################################################################## ########################################################################
# MISC SETTINGS # MISC SETTINGS
######################################################################## ########################################################################

View File

@@ -84,6 +84,7 @@ from zerver.views.message_flags import (
update_message_flags_for_narrow, update_message_flags_for_narrow,
) )
from zerver.views.message_send import render_message_backend, send_message_backend, zcommand_backend from zerver.views.message_send import render_message_backend, send_message_backend, zcommand_backend
from zerver.views.message_summary import get_messages_summary
from zerver.views.muted_users import mute_user, unmute_user from zerver.views.muted_users import mute_user, unmute_user
from zerver.views.onboarding_steps import mark_onboarding_step_as_read from zerver.views.onboarding_steps import mark_onboarding_step_as_read
from zerver.views.presence import ( from zerver.views.presence import (
@@ -367,6 +368,14 @@ v1_api_and_json_patterns = [
PATCH=update_message_backend, PATCH=update_message_backend,
DELETE=delete_message_backend, DELETE=delete_message_backend,
), ),
rest_path(
"messages/summary",
GET=(
get_messages_summary,
# Not documented since the API details haven't been finalized yet.
{"intentionally_undocumented"},
),
),
rest_path("messages/render", POST=render_message_backend), rest_path("messages/render", POST=render_message_backend),
rest_path("messages/flags", POST=update_message_flags), rest_path("messages/flags", POST=update_message_flags),
rest_path("messages/flags/narrow", POST=update_message_flags_for_narrow), rest_path("messages/flags/narrow", POST=update_message_flags_for_narrow),