diff --git a/analytics/lib/counts.py b/analytics/lib/counts.py index cedc613e89..0b9b9b2b7e 100644 --- a/analytics/lib/counts.py +++ b/analytics/lib/counts.py @@ -851,6 +851,9 @@ def get_count_stats(realm: Realm | None = None) -> dict[str, CountStat]: ), CountStat.DAY, ), + # AI credit usage stats for users, in units of $1/10^9, which is safe for + # aggregation because we're using bigints for the values. + LoggingCountStat("ai_credit_usage::day", UserCount, CountStat.DAY), # Counts the number of active users in the UserProfile.is_active sense. # Important that this stay a daily stat, so that 'realm_active_humans::day' works as expected. CountStat( diff --git a/tools/test-backend b/tools/test-backend index 81e399c442..23b9cc4de6 100755 --- a/tools/test-backend +++ b/tools/test-backend @@ -130,6 +130,8 @@ not_yet_fully_covered = [ "zerver/webhooks/zapier/view.py", # This is hard to get test coverage for, and low value to do so "zerver/views/sentry.py", + # TODO: Add tests when the details are finalized. + "zerver/views/message_summary.py", # Cannot have coverage, as tests run in a transaction "zerver/lib/safe_session_cached_db.py", "zerver/lib/singleton_bmemcached.py", diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index 45a417ac1f..da0aa0a11a 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -526,6 +526,9 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0} exempt_patterns = { + # TODO: Add tests when we are sure about what needs to be + # served to the user as summary. + "api/v1/messages/summary", # We exempt some patterns that are called via Tornado. "api/v1/events", "api/v1/events/internal", diff --git a/zerver/views/message_summary.py b/zerver/views/message_summary.py new file mode 100644 index 0000000000..181a28abae --- /dev/null +++ b/zerver/views/message_summary.py @@ -0,0 +1,204 @@ +import json +from typing import Any + +from django.conf import settings +from django.http import HttpRequest, HttpResponse +from django.utils.timezone import now as timezone_now +from django.utils.translation import gettext as _ +from pydantic import Json + +from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat +from zerver.lib.exceptions import JsonableError +from zerver.lib.message import messages_for_ids +from zerver.lib.narrow import ( + LARGER_THAN_MAX_MESSAGE_ID, + NarrowParameter, + clean_narrow_for_message_fetch, + fetch_messages, +) +from zerver.lib.response import json_success +from zerver.lib.typed_endpoint import typed_endpoint +from zerver.models import UserProfile + +# Maximum number of messages that can be summarized in a single request. +MAX_MESSAGES_SUMMARIZED = 100 +# Price per token for input and output tokens. +# These values are based on the pricing of the Bedrock API +# for Llama 3.3 Instruct (70B). +# https://aws.amazon.com/bedrock/pricing/ +# Unit: USD per 1 billion tokens. +# +# These values likely will want to be declared in configuration, +# rather than here in the code. +OUTPUT_COST_PER_GIGATOKEN = 720 +INPUT_COST_PER_GIGATOKEN = 720 + + +def format_zulip_messages_for_model(zulip_messages: list[dict[str, Any]]) -> str: + # Format the Zulip messages for processing by the model. + # + # - We don't need to encode the recipient, since that's the same for + # every message in the conversation. + # - We use full names to reference senders, since we want the + # model to refer to users by name. We may want to experiment + # with using silent-mention syntax for users if we move to + # Markdown-rendering what the model returns. + # - We don't include timestamps, since experiments with current models + # suggest they do not make relevant use of them. + # - We haven't figured out a useful way to include reaction metadata (either + # the emoji themselves or just the counter). + # - Polls/TODO widgets are currently sent to the model as empty messages, + # since this logic doesn't inspect SubMessage objects. + zulip_messages_list = [ + {"sender": message["sender_full_name"], "content": message["content"]} + for message in zulip_messages + ] + return json.dumps(zulip_messages_list) + + +def make_message(content: str, role: str = "user") -> dict[str, str]: + return {"content": content, "role": role} + + +def get_max_summary_length(conversation_length: int) -> int: + return min(6, 4 + int((conversation_length - 10) / 10)) + + +@typed_endpoint +def get_messages_summary( + request: HttpRequest, + user_profile: UserProfile, + *, + narrow: Json[list[NarrowParameter] | None] = None, +) -> HttpResponse: + if settings.TOPIC_SUMMARIZATION_MODEL is None: + raise JsonableError(_("AI features are not enabled on this server.")) + + if not (user_profile.is_moderator or user_profile.is_realm_admin): + return json_success(request, {"summary": "Feature limited to moderators for now."}) + + # TODO: This implementation does not attempt to make use of + # caching previous summaries of the same conversation or rolling + # summaries. Doing so correctly will require careful work around + # invalidation of caches when messages are edited, moved, or sent. + narrow = clean_narrow_for_message_fetch(narrow, user_profile.realm, user_profile) + query_info = fetch_messages( + narrow=narrow, + user_profile=user_profile, + realm=user_profile.realm, + is_web_public_query=False, + anchor=LARGER_THAN_MAX_MESSAGE_ID, + include_anchor=True, + num_before=MAX_MESSAGES_SUMMARIZED, + num_after=0, + ) + + if len(query_info.rows) == 0: + return json_success(request, {"summary": "No messages in conversation to summarize"}) + + result_message_ids: list[int] = [] + user_message_flags: dict[int, list[str]] = {} + for row in query_info.rows: + message_id = row[0] + result_message_ids.append(message_id) + # We skip populating flags, since they would be ignored below anyway. + user_message_flags[message_id] = [] + + message_list = messages_for_ids( + message_ids=result_message_ids, + user_message_flags=user_message_flags, + search_fields={}, + # We currently prefer the plain-text content of messages to + apply_markdown=False, + # Avoid wasting resources computing gravatars. + client_gravatar=True, + allow_empty_topic_name=False, + # Avoid fetching edit history, which won't be passed to the model. + allow_edit_history=False, + user_profile=user_profile, + realm=user_profile.realm, + ) + + # IDEA: We could consider translating input and output text to + # English to improve results when using a summarization model that + # is primarily trained on English. + model = settings.TOPIC_SUMMARIZATION_MODEL + litellm_params: dict[str, Any] = {} + if model.startswith("huggingface"): + assert settings.HUGGINGFACE_API_KEY is not None + litellm_params["api_key"] = settings.HUGGINGFACE_API_KEY + else: + assert model.startswith("bedrock") + litellm_params["aws_access_key_id"] = settings.AWS_ACCESS_KEY_ID + litellm_params["aws_secret_access_key"] = settings.AWS_SECRET_ACCESS_KEY + litellm_params["aws_region_name"] = settings.AWS_REGION_NAME + + conversation_length = len(message_list) + max_summary_length = get_max_summary_length(conversation_length) + intro = "The following is a chat conversation in the Zulip team chat app." + topic: str | None = None + channel: str | None = None + if narrow and len(narrow) == 2: + for term in narrow: + assert not term.negated + if term.operator == "channel": + channel = term.operand + if term.operator == "topic": + topic = term.operand + if channel: + intro += f" channel: {channel}" + if topic: + intro += f", topic: {topic}" + + formatted_conversation = format_zulip_messages_for_model(message_list) + prompt = ( + f"Succinctly summarize this conversation based only on the information provided, " + f"in up to {max_summary_length} sentences, for someone who is familiar with the context. " + f"Mention key conclusions and actions, if any. Refer to specific people as appropriate. " + f"Don't use an intro phrase." + ) + messages = [ + make_message(intro, "system"), + make_message(formatted_conversation), + make_message(prompt), + ] + + # We import litellm here to avoid a DeprecationWarning. + # See these issues for more info: + # https://github.com/BerriAI/litellm/issues/6232 + # https://github.com/BerriAI/litellm/issues/5647 + import litellm + + # Token counter is recommended by LiteLLM but mypy says it's not explicitly exported. + # https://docs.litellm.ai/docs/completion/token_usage#3-token_counter + input_tokens = litellm.token_counter(model=model, messages=messages) # type: ignore[attr-defined] # Explained above + + # TODO when implementing user plans: + # - Before querying the model, check whether we've enough tokens left using + # the estimated token count. + # - Then increase the `LoggingCountStat` using the estimated token count. + # (These first two steps should be a short database transaction that + # locks the `LoggingCountStat` row). + # - Then query the model. + # - Then adjust the `LoggingCountStat` by `(actual - estimated)`, + # being careful to avoid doing this to the next day if the query + # happened milliseconds before midnight; changing the + # `LoggingCountStat` we added the estimate to. + # That way, you can't easily get extra tokens by sending + # 25 requests all at once when you're just below the limit. + + response = litellm.completion( + model=model, + messages=messages, + **litellm_params, + ) + output_tokens = response["usage"]["completion_tokens"] + + credits_used = (output_tokens * OUTPUT_COST_PER_GIGATOKEN) + ( + input_tokens * INPUT_COST_PER_GIGATOKEN + ) + do_increment_logging_stat( + user_profile, COUNT_STATS["ai_credit_usage::day"], None, timezone_now(), credits_used + ) + + return json_success(request, {"summary": response["choices"][0]["message"]["content"]}) diff --git a/zproject/computed_settings.py b/zproject/computed_settings.py index e628f6157b..a82b745d15 100644 --- a/zproject/computed_settings.py +++ b/zproject/computed_settings.py @@ -1244,6 +1244,31 @@ EMAIL_HOST_PASSWORD = get_secret("email_password") EMAIL_GATEWAY_PASSWORD = get_secret("email_gateway_password") AUTH_LDAP_BIND_PASSWORD = get_secret("auth_ldap_bind_password", "") +######################################################################## +# LiteLLM SETTINGS +######################################################################## + +# The model name that will used by LiteLLM library to configure +# parameters to be sent to API. +# The Llama-3-8B-instruct model is free to use and only requires submitting +# a small form on the HuggingFace page for the model to gain access. +# We only support HuggingFace and AWS Bedrock for LLM API requests. +DEFAULT_TOPIC_SUMMARIZATION_MODEL: str | None = "huggingface/meta-llama/Meta-Llama-3-8B-Instruct" +if PRODUCTION: + DEFAULT_TOPIC_SUMMARIZATION_MODEL = None +TOPIC_SUMMARIZATION_MODEL = get_secret( + "topic_summarization_model", DEFAULT_TOPIC_SUMMARIZATION_MODEL +) + +# Which API key to use will be determined based on TOPIC_SUMMARIZATION_MODEL. +# HuggingFace access credentials +HUGGINGFACE_API_KEY = get_secret("huggingface_api_key", None) + +# AWS Bedrock access credentials +AWS_ACCESS_KEY_ID = get_secret("aws_access_key_id", None) +AWS_SECRET_ACCESS_KEY = get_secret("aws_secret_access_key", None) +AWS_REGION_NAME = get_secret("aws_region_name", None) + ######################################################################## # MISC SETTINGS ######################################################################## diff --git a/zproject/urls.py b/zproject/urls.py index e68e696a77..6e84e2e1db 100644 --- a/zproject/urls.py +++ b/zproject/urls.py @@ -84,6 +84,7 @@ from zerver.views.message_flags import ( update_message_flags_for_narrow, ) from zerver.views.message_send import render_message_backend, send_message_backend, zcommand_backend +from zerver.views.message_summary import get_messages_summary from zerver.views.muted_users import mute_user, unmute_user from zerver.views.onboarding_steps import mark_onboarding_step_as_read from zerver.views.presence import ( @@ -367,6 +368,14 @@ v1_api_and_json_patterns = [ PATCH=update_message_backend, DELETE=delete_message_backend, ), + rest_path( + "messages/summary", + GET=( + get_messages_summary, + # Not documented since the API details haven't been finalized yet. + {"intentionally_undocumented"}, + ), + ), rest_path("messages/render", POST=render_message_backend), rest_path("messages/flags", POST=update_message_flags), rest_path("messages/flags/narrow", POST=update_message_flags_for_narrow),