mirror of
https://github.com/zulip/zulip.git
synced 2025-10-23 04:52:12 +00:00
message_summary: Add API endpoint to generate narrow summary.
This prototype API is disabled in production through settings not configuring a default model.
This commit is contained in:
@@ -851,6 +851,9 @@ def get_count_stats(realm: Realm | None = None) -> dict[str, CountStat]:
|
|||||||
),
|
),
|
||||||
CountStat.DAY,
|
CountStat.DAY,
|
||||||
),
|
),
|
||||||
|
# AI credit usage stats for users, in units of $1/10^9, which is safe for
|
||||||
|
# aggregation because we're using bigints for the values.
|
||||||
|
LoggingCountStat("ai_credit_usage::day", UserCount, CountStat.DAY),
|
||||||
# Counts the number of active users in the UserProfile.is_active sense.
|
# Counts the number of active users in the UserProfile.is_active sense.
|
||||||
# Important that this stay a daily stat, so that 'realm_active_humans::day' works as expected.
|
# Important that this stay a daily stat, so that 'realm_active_humans::day' works as expected.
|
||||||
CountStat(
|
CountStat(
|
||||||
|
@@ -130,6 +130,8 @@ not_yet_fully_covered = [
|
|||||||
"zerver/webhooks/zapier/view.py",
|
"zerver/webhooks/zapier/view.py",
|
||||||
# This is hard to get test coverage for, and low value to do so
|
# This is hard to get test coverage for, and low value to do so
|
||||||
"zerver/views/sentry.py",
|
"zerver/views/sentry.py",
|
||||||
|
# TODO: Add tests when the details are finalized.
|
||||||
|
"zerver/views/message_summary.py",
|
||||||
# Cannot have coverage, as tests run in a transaction
|
# Cannot have coverage, as tests run in a transaction
|
||||||
"zerver/lib/safe_session_cached_db.py",
|
"zerver/lib/safe_session_cached_db.py",
|
||||||
"zerver/lib/singleton_bmemcached.py",
|
"zerver/lib/singleton_bmemcached.py",
|
||||||
|
@@ -526,6 +526,9 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N
|
|||||||
untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0}
|
untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0}
|
||||||
|
|
||||||
exempt_patterns = {
|
exempt_patterns = {
|
||||||
|
# TODO: Add tests when we are sure about what needs to be
|
||||||
|
# served to the user as summary.
|
||||||
|
"api/v1/messages/summary",
|
||||||
# We exempt some patterns that are called via Tornado.
|
# We exempt some patterns that are called via Tornado.
|
||||||
"api/v1/events",
|
"api/v1/events",
|
||||||
"api/v1/events/internal",
|
"api/v1/events/internal",
|
||||||
|
204
zerver/views/message_summary.py
Normal file
204
zerver/views/message_summary.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.http import HttpRequest, HttpResponse
|
||||||
|
from django.utils.timezone import now as timezone_now
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
from pydantic import Json
|
||||||
|
|
||||||
|
from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat
|
||||||
|
from zerver.lib.exceptions import JsonableError
|
||||||
|
from zerver.lib.message import messages_for_ids
|
||||||
|
from zerver.lib.narrow import (
|
||||||
|
LARGER_THAN_MAX_MESSAGE_ID,
|
||||||
|
NarrowParameter,
|
||||||
|
clean_narrow_for_message_fetch,
|
||||||
|
fetch_messages,
|
||||||
|
)
|
||||||
|
from zerver.lib.response import json_success
|
||||||
|
from zerver.lib.typed_endpoint import typed_endpoint
|
||||||
|
from zerver.models import UserProfile
|
||||||
|
|
||||||
|
# Maximum number of messages that can be summarized in a single request.
|
||||||
|
MAX_MESSAGES_SUMMARIZED = 100
|
||||||
|
# Price per token for input and output tokens.
|
||||||
|
# These values are based on the pricing of the Bedrock API
|
||||||
|
# for Llama 3.3 Instruct (70B).
|
||||||
|
# https://aws.amazon.com/bedrock/pricing/
|
||||||
|
# Unit: USD per 1 billion tokens.
|
||||||
|
#
|
||||||
|
# These values likely will want to be declared in configuration,
|
||||||
|
# rather than here in the code.
|
||||||
|
OUTPUT_COST_PER_GIGATOKEN = 720
|
||||||
|
INPUT_COST_PER_GIGATOKEN = 720
|
||||||
|
|
||||||
|
|
||||||
|
def format_zulip_messages_for_model(zulip_messages: list[dict[str, Any]]) -> str:
|
||||||
|
# Format the Zulip messages for processing by the model.
|
||||||
|
#
|
||||||
|
# - We don't need to encode the recipient, since that's the same for
|
||||||
|
# every message in the conversation.
|
||||||
|
# - We use full names to reference senders, since we want the
|
||||||
|
# model to refer to users by name. We may want to experiment
|
||||||
|
# with using silent-mention syntax for users if we move to
|
||||||
|
# Markdown-rendering what the model returns.
|
||||||
|
# - We don't include timestamps, since experiments with current models
|
||||||
|
# suggest they do not make relevant use of them.
|
||||||
|
# - We haven't figured out a useful way to include reaction metadata (either
|
||||||
|
# the emoji themselves or just the counter).
|
||||||
|
# - Polls/TODO widgets are currently sent to the model as empty messages,
|
||||||
|
# since this logic doesn't inspect SubMessage objects.
|
||||||
|
zulip_messages_list = [
|
||||||
|
{"sender": message["sender_full_name"], "content": message["content"]}
|
||||||
|
for message in zulip_messages
|
||||||
|
]
|
||||||
|
return json.dumps(zulip_messages_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_message(content: str, role: str = "user") -> dict[str, str]:
|
||||||
|
return {"content": content, "role": role}
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_summary_length(conversation_length: int) -> int:
|
||||||
|
return min(6, 4 + int((conversation_length - 10) / 10))
|
||||||
|
|
||||||
|
|
||||||
|
@typed_endpoint
|
||||||
|
def get_messages_summary(
|
||||||
|
request: HttpRequest,
|
||||||
|
user_profile: UserProfile,
|
||||||
|
*,
|
||||||
|
narrow: Json[list[NarrowParameter] | None] = None,
|
||||||
|
) -> HttpResponse:
|
||||||
|
if settings.TOPIC_SUMMARIZATION_MODEL is None:
|
||||||
|
raise JsonableError(_("AI features are not enabled on this server."))
|
||||||
|
|
||||||
|
if not (user_profile.is_moderator or user_profile.is_realm_admin):
|
||||||
|
return json_success(request, {"summary": "Feature limited to moderators for now."})
|
||||||
|
|
||||||
|
# TODO: This implementation does not attempt to make use of
|
||||||
|
# caching previous summaries of the same conversation or rolling
|
||||||
|
# summaries. Doing so correctly will require careful work around
|
||||||
|
# invalidation of caches when messages are edited, moved, or sent.
|
||||||
|
narrow = clean_narrow_for_message_fetch(narrow, user_profile.realm, user_profile)
|
||||||
|
query_info = fetch_messages(
|
||||||
|
narrow=narrow,
|
||||||
|
user_profile=user_profile,
|
||||||
|
realm=user_profile.realm,
|
||||||
|
is_web_public_query=False,
|
||||||
|
anchor=LARGER_THAN_MAX_MESSAGE_ID,
|
||||||
|
include_anchor=True,
|
||||||
|
num_before=MAX_MESSAGES_SUMMARIZED,
|
||||||
|
num_after=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(query_info.rows) == 0:
|
||||||
|
return json_success(request, {"summary": "No messages in conversation to summarize"})
|
||||||
|
|
||||||
|
result_message_ids: list[int] = []
|
||||||
|
user_message_flags: dict[int, list[str]] = {}
|
||||||
|
for row in query_info.rows:
|
||||||
|
message_id = row[0]
|
||||||
|
result_message_ids.append(message_id)
|
||||||
|
# We skip populating flags, since they would be ignored below anyway.
|
||||||
|
user_message_flags[message_id] = []
|
||||||
|
|
||||||
|
message_list = messages_for_ids(
|
||||||
|
message_ids=result_message_ids,
|
||||||
|
user_message_flags=user_message_flags,
|
||||||
|
search_fields={},
|
||||||
|
# We currently prefer the plain-text content of messages to
|
||||||
|
apply_markdown=False,
|
||||||
|
# Avoid wasting resources computing gravatars.
|
||||||
|
client_gravatar=True,
|
||||||
|
allow_empty_topic_name=False,
|
||||||
|
# Avoid fetching edit history, which won't be passed to the model.
|
||||||
|
allow_edit_history=False,
|
||||||
|
user_profile=user_profile,
|
||||||
|
realm=user_profile.realm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# IDEA: We could consider translating input and output text to
|
||||||
|
# English to improve results when using a summarization model that
|
||||||
|
# is primarily trained on English.
|
||||||
|
model = settings.TOPIC_SUMMARIZATION_MODEL
|
||||||
|
litellm_params: dict[str, Any] = {}
|
||||||
|
if model.startswith("huggingface"):
|
||||||
|
assert settings.HUGGINGFACE_API_KEY is not None
|
||||||
|
litellm_params["api_key"] = settings.HUGGINGFACE_API_KEY
|
||||||
|
else:
|
||||||
|
assert model.startswith("bedrock")
|
||||||
|
litellm_params["aws_access_key_id"] = settings.AWS_ACCESS_KEY_ID
|
||||||
|
litellm_params["aws_secret_access_key"] = settings.AWS_SECRET_ACCESS_KEY
|
||||||
|
litellm_params["aws_region_name"] = settings.AWS_REGION_NAME
|
||||||
|
|
||||||
|
conversation_length = len(message_list)
|
||||||
|
max_summary_length = get_max_summary_length(conversation_length)
|
||||||
|
intro = "The following is a chat conversation in the Zulip team chat app."
|
||||||
|
topic: str | None = None
|
||||||
|
channel: str | None = None
|
||||||
|
if narrow and len(narrow) == 2:
|
||||||
|
for term in narrow:
|
||||||
|
assert not term.negated
|
||||||
|
if term.operator == "channel":
|
||||||
|
channel = term.operand
|
||||||
|
if term.operator == "topic":
|
||||||
|
topic = term.operand
|
||||||
|
if channel:
|
||||||
|
intro += f" channel: {channel}"
|
||||||
|
if topic:
|
||||||
|
intro += f", topic: {topic}"
|
||||||
|
|
||||||
|
formatted_conversation = format_zulip_messages_for_model(message_list)
|
||||||
|
prompt = (
|
||||||
|
f"Succinctly summarize this conversation based only on the information provided, "
|
||||||
|
f"in up to {max_summary_length} sentences, for someone who is familiar with the context. "
|
||||||
|
f"Mention key conclusions and actions, if any. Refer to specific people as appropriate. "
|
||||||
|
f"Don't use an intro phrase."
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
make_message(intro, "system"),
|
||||||
|
make_message(formatted_conversation),
|
||||||
|
make_message(prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
# We import litellm here to avoid a DeprecationWarning.
|
||||||
|
# See these issues for more info:
|
||||||
|
# https://github.com/BerriAI/litellm/issues/6232
|
||||||
|
# https://github.com/BerriAI/litellm/issues/5647
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
# Token counter is recommended by LiteLLM but mypy says it's not explicitly exported.
|
||||||
|
# https://docs.litellm.ai/docs/completion/token_usage#3-token_counter
|
||||||
|
input_tokens = litellm.token_counter(model=model, messages=messages) # type: ignore[attr-defined] # Explained above
|
||||||
|
|
||||||
|
# TODO when implementing user plans:
|
||||||
|
# - Before querying the model, check whether we've enough tokens left using
|
||||||
|
# the estimated token count.
|
||||||
|
# - Then increase the `LoggingCountStat` using the estimated token count.
|
||||||
|
# (These first two steps should be a short database transaction that
|
||||||
|
# locks the `LoggingCountStat` row).
|
||||||
|
# - Then query the model.
|
||||||
|
# - Then adjust the `LoggingCountStat` by `(actual - estimated)`,
|
||||||
|
# being careful to avoid doing this to the next day if the query
|
||||||
|
# happened milliseconds before midnight; changing the
|
||||||
|
# `LoggingCountStat` we added the estimate to.
|
||||||
|
# That way, you can't easily get extra tokens by sending
|
||||||
|
# 25 requests all at once when you're just below the limit.
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
**litellm_params,
|
||||||
|
)
|
||||||
|
output_tokens = response["usage"]["completion_tokens"]
|
||||||
|
|
||||||
|
credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + (
|
||||||
|
input_tokens * INPUT_COST_PER_GIGATOKEN
|
||||||
|
)
|
||||||
|
do_increment_logging_stat(
|
||||||
|
user_profile, COUNT_STATS["ai_credit_usage::day"], None, timezone_now(), credits_used
|
||||||
|
)
|
||||||
|
|
||||||
|
return json_success(request, {"summary": response["choices"][0]["message"]["content"]})
|
@@ -1244,6 +1244,31 @@ EMAIL_HOST_PASSWORD = get_secret("email_password")
|
|||||||
EMAIL_GATEWAY_PASSWORD = get_secret("email_gateway_password")
|
EMAIL_GATEWAY_PASSWORD = get_secret("email_gateway_password")
|
||||||
AUTH_LDAP_BIND_PASSWORD = get_secret("auth_ldap_bind_password", "")
|
AUTH_LDAP_BIND_PASSWORD = get_secret("auth_ldap_bind_password", "")
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# LiteLLM SETTINGS
|
||||||
|
########################################################################
|
||||||
|
|
||||||
|
# The model name that will used by LiteLLM library to configure
|
||||||
|
# parameters to be sent to API.
|
||||||
|
# The Llama-3-8B-instruct model is free to use and only requires submitting
|
||||||
|
# a small form on the HuggingFace page for the model to gain access.
|
||||||
|
# We only support HuggingFace and AWS Bedrock for LLM API requests.
|
||||||
|
DEFAULT_TOPIC_SUMMARIZATION_MODEL: str | None = "huggingface/meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
if PRODUCTION:
|
||||||
|
DEFAULT_TOPIC_SUMMARIZATION_MODEL = None
|
||||||
|
TOPIC_SUMMARIZATION_MODEL = get_secret(
|
||||||
|
"topic_summarization_model", DEFAULT_TOPIC_SUMMARIZATION_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Which API key to use will be determined based on TOPIC_SUMMARIZATION_MODEL.
|
||||||
|
# HuggingFace access credentials
|
||||||
|
HUGGINGFACE_API_KEY = get_secret("huggingface_api_key", None)
|
||||||
|
|
||||||
|
# AWS Bedrock access credentials
|
||||||
|
AWS_ACCESS_KEY_ID = get_secret("aws_access_key_id", None)
|
||||||
|
AWS_SECRET_ACCESS_KEY = get_secret("aws_secret_access_key", None)
|
||||||
|
AWS_REGION_NAME = get_secret("aws_region_name", None)
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# MISC SETTINGS
|
# MISC SETTINGS
|
||||||
########################################################################
|
########################################################################
|
||||||
|
@@ -84,6 +84,7 @@ from zerver.views.message_flags import (
|
|||||||
update_message_flags_for_narrow,
|
update_message_flags_for_narrow,
|
||||||
)
|
)
|
||||||
from zerver.views.message_send import render_message_backend, send_message_backend, zcommand_backend
|
from zerver.views.message_send import render_message_backend, send_message_backend, zcommand_backend
|
||||||
|
from zerver.views.message_summary import get_messages_summary
|
||||||
from zerver.views.muted_users import mute_user, unmute_user
|
from zerver.views.muted_users import mute_user, unmute_user
|
||||||
from zerver.views.onboarding_steps import mark_onboarding_step_as_read
|
from zerver.views.onboarding_steps import mark_onboarding_step_as_read
|
||||||
from zerver.views.presence import (
|
from zerver.views.presence import (
|
||||||
@@ -367,6 +368,14 @@ v1_api_and_json_patterns = [
|
|||||||
PATCH=update_message_backend,
|
PATCH=update_message_backend,
|
||||||
DELETE=delete_message_backend,
|
DELETE=delete_message_backend,
|
||||||
),
|
),
|
||||||
|
rest_path(
|
||||||
|
"messages/summary",
|
||||||
|
GET=(
|
||||||
|
get_messages_summary,
|
||||||
|
# Not documented since the API details haven't been finalized yet.
|
||||||
|
{"intentionally_undocumented"},
|
||||||
|
),
|
||||||
|
),
|
||||||
rest_path("messages/render", POST=render_message_backend),
|
rest_path("messages/render", POST=render_message_backend),
|
||||||
rest_path("messages/flags", POST=update_message_flags),
|
rest_path("messages/flags", POST=update_message_flags),
|
||||||
rest_path("messages/flags/narrow", POST=update_message_flags_for_narrow),
|
rest_path("messages/flags/narrow", POST=update_message_flags_for_narrow),
|
||||||
|
Reference in New Issue
Block a user