From 0bad785f03f89b248b7ec7d79e78fefb927da99e Mon Sep 17 00:00:00 2001 From: Aman Agrawal Date: Tue, 4 Feb 2025 17:44:59 +0530 Subject: [PATCH] message_summary: Log time to generate summary. Mostly copy pasted things from markdown time logging. --- zerver/actions/message_summary.py | 28 ++++++++++++++++++++++++++++ zerver/middleware.py | 13 ++++++++++++- zerver/views/message_summary.py | 1 + 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/zerver/actions/message_summary.py b/zerver/actions/message_summary.py index faa7e88386..dc1d38f5d6 100644 --- a/zerver/actions/message_summary.py +++ b/zerver/actions/message_summary.py @@ -1,3 +1,4 @@ +import time from typing import Any import orjson @@ -28,6 +29,30 @@ OUTPUT_COST_PER_GIGATOKEN = 720 INPUT_COST_PER_GIGATOKEN = 720 +ai_time_start = 0.0 +ai_total_time = 0.0 +ai_total_requests = 0 + + +def get_ai_time() -> float: + return ai_total_time + + +def ai_stats_start() -> None: + global ai_time_start + ai_time_start = time.time() + + +def get_ai_requests() -> int: + return ai_total_requests + + +def ai_stats_finish() -> None: + global ai_total_time, ai_total_requests + ai_total_requests += 1 + ai_total_time += time.time() - ai_time_start + + def format_zulip_messages_for_model(zulip_messages: list[dict[str, Any]]) -> str: # Format the Zulip messages for processing by the model. # @@ -143,6 +168,8 @@ def do_summarize_narrow( make_message(prompt), ] + # Stats for database queries are tracked separately. + ai_stats_start() # We import litellm here to avoid a DeprecationWarning. # See these issues for more info: # https://github.com/BerriAI/litellm/issues/6232 @@ -185,4 +212,5 @@ def do_summarize_narrow( user_profile, COUNT_STATS["ai_credit_usage::day"], None, timezone_now(), credits_used ) + ai_stats_finish() return response["choices"][0]["message"]["content"] diff --git a/zerver/middleware.py b/zerver/middleware.py index b1bd413a06..50ac23e95c 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -25,6 +25,7 @@ from django_scim.settings import scim_settings from sentry_sdk import set_tag from typing_extensions import ParamSpec, override +from zerver.actions.message_summary import get_ai_requests, get_ai_time from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time from zerver.lib.db_connections import reset_queries from zerver.lib.debug import maybe_tracemalloc_listen @@ -97,6 +98,8 @@ def record_request_start_data(log_data: MutableMapping[str, Any]) -> None: log_data["remote_cache_requests_start"] = get_remote_cache_requests() log_data["markdown_time_start"] = get_markdown_time() log_data["markdown_requests_start"] = get_markdown_requests() + log_data["ai_time_start"] = get_ai_time() + log_data["ai_requests_start"] = get_ai_time() def timedelta_ms(timedelta: float) -> float: @@ -186,6 +189,14 @@ def write_log_line( f" (md: {format_timedelta(markdown_time_delta)}/{markdown_count_delta})" ) + ai_output = "" + if "ai_time_start" in log_data: + ai_time_delta = get_ai_time() - log_data["ai_time_start"] + ai_count_delta = get_ai_requests() - log_data["ai_requests_start"] + + if ai_time_delta > 0.005: + ai_output = f" (ai: {format_timedelta(ai_time_delta)}/{ai_count_delta})" + # Get the amount of time spent doing database queries db_time_output = "" queries = connection.connection.queries if connection.connection is not None else [] @@ -201,7 +212,7 @@ def write_log_line( logger_client = f"({requester_for_logs} via {client_name})" else: logger_client = f"({requester_for_logs} via {client_name}/{client_version})" - logger_timing = f"{format_timedelta(time_delta):>5}{optional_orig_delta}{remote_cache_output}{markdown_output}{db_time_output}{startup_output} {path}" + logger_timing = f"{format_timedelta(time_delta):>5}{optional_orig_delta}{remote_cache_output}{markdown_output}{ai_output}{db_time_output}{startup_output} {path}" logger_line = f"{remote_ip:<15} {method:<7} {status_code:3} {logger_timing}{extra_request_data} {logger_client}" if status_code in [200, 304] and method == "GET" and path.startswith("/static"): logger.debug(logger_line) diff --git a/zerver/views/message_summary.py b/zerver/views/message_summary.py index 0a83a0f2fd..d0d02b73de 100644 --- a/zerver/views/message_summary.py +++ b/zerver/views/message_summary.py @@ -2,6 +2,7 @@ import warnings warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + from django.conf import settings from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _