diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index cefa044665..eaa9ab8363 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -1187,8 +1187,28 @@ def add_narrow_conditions( query = builder.add_term(query, term) if search_operands: + # This topic escaping logic ensures consistent escaping of topic names throughout + # the system, ensuring accuracy in string highlighting and avoiding any discrepancies. + # + # When a topic name is fetched from the database, it goes through this logic. + # The `func.escape_html()` function is used to escape the topic name, ensuring that + # special characters are properly escaped. This helps to avoid the need to apply other + # escaping logic to the topic name for string highlighting purposes. As a result, the + # highlighted string will accurately match the actual topic name displayed in the UI. + # This approach prevents any inconsistencies or offsets that could occur if different + # escaping functions were used. + # + # It's important to note that the `process_fts_updates` script, responsible for + # updating the relevant columns in the database, also utilizes the same escaping + # logic. This alignment ensures that the escaped topic names stored in the database + # and the topic names used during string highlighting are in sync. Therefore, there + # is no need for any special handling in `process_fts_updates` to align with this + # escaping logic. is_search = True - query = query.add_columns(topic_column_sa(), column("rendered_content", Text)) + query = query.add_columns( + func.escape_html(topic_column_sa(), type_=Text).label("escaped_topic_name"), + column("rendered_content", Text), + ) search_term = NarrowParameter( operator="search", operand=" ".join(search_operands), diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index fefe2e418d..b7454eb27c 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -3208,6 +3208,7 @@ class GetOldMessagesTest(ZulipTestCase): ("日本", "今朝はごはんを食べました。"), ("日本", "昨日、日本 のお菓子を送りました。"), ("english", "I want to go to 日本!"), + ("James' burger", "James' burger"), ] next_message_id = self.get_last_message().id + 1 @@ -3335,6 +3336,28 @@ class GetOldMessagesTest(ZulipTestCase): '

こんに ちは今日は いい 天気ですね。

', ) + # Search operands with HTML special characters + special_search_narrow = [ + dict(operator="search", operand="burger"), + ] + special_search_result = self.get_and_check_messages( + dict( + narrow=orjson.dumps(special_search_narrow).decode(), + anchor=next_message_id, + num_after=10, + num_before=0, + ) + ) + self.assert_length(special_search_result["messages"], 1) + self.assertEqual( + special_search_result["messages"][0][MATCH_TOPIC], + 'James' burger', + ) + self.assertEqual( + special_search_result["messages"][0]["match_content"], + '

James\' burger

', + ) + @override_settings(USING_PGROONGA=False) def test_get_visible_messages_with_search(self) -> None: self.login("hamlet") @@ -5080,8 +5103,8 @@ WHERE zerver_subscription.user_profile_id = {hamlet_id} AND zerver_subscription. query_ids = self.get_query_ids() sql_template = """\ -SELECT anon_1.message_id, anon_1.flags, anon_1.subject, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ -FROM (SELECT message_id, flags, subject, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ +SELECT anon_1.message_id, anon_1.flags, anon_1.escaped_topic_name, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ +FROM (SELECT message_id, flags, escape_html(subject) AS escaped_topic_name, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', rendered_content, plainto_tsquery('zulip.english_us_search', 'jumping'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_3\n\ LIMIT ALL OFFSET 1)) AS content_matches, array((SELECT ARRAY[sum(length(anon_5) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_5, '') - 1] AS anon_4 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', escape_html(subject), plainto_tsquery('zulip.english_us_search', 'jumping'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_5\n\ @@ -5100,8 +5123,8 @@ WHERE zerver_subscription.user_profile_id = {hamlet_id} AND zerver_subscription. ) sql_template = """\ -SELECT anon_1.message_id, anon_1.subject, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ -FROM (SELECT id AS message_id, subject, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ +SELECT anon_1.message_id, anon_1.escaped_topic_name, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ +FROM (SELECT id AS message_id, escape_html(subject) AS escaped_topic_name, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', rendered_content, plainto_tsquery('zulip.english_us_search', 'jumping'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_3\n\ LIMIT ALL OFFSET 1)) AS content_matches, array((SELECT ARRAY[sum(length(anon_5) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_5, '') - 1] AS anon_4 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', escape_html(subject), plainto_tsquery('zulip.english_us_search', 'jumping'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_5\n\ @@ -5122,8 +5145,8 @@ WHERE realm_id = 2 AND recipient_id = {scotland_recipient} AND (search_tsvector ) sql_template = """\ -SELECT anon_1.message_id, anon_1.flags, anon_1.subject, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ -FROM (SELECT message_id, flags, subject, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ +SELECT anon_1.message_id, anon_1.flags, anon_1.escaped_topic_name, anon_1.rendered_content, anon_1.content_matches, anon_1.topic_matches \n\ +FROM (SELECT message_id, flags, escape_html(subject) AS escaped_topic_name, rendered_content, array((SELECT ARRAY[sum(length(anon_3) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_3, '') - 1] AS anon_2 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', rendered_content, plainto_tsquery('zulip.english_us_search', '"jumping" quickly'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_3\n\ LIMIT ALL OFFSET 1)) AS content_matches, array((SELECT ARRAY[sum(length(anon_5) - 11) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) + 11, strpos(anon_5, '') - 1] AS anon_4 \n\ FROM unnest(string_to_array(ts_headline('zulip.english_us_search', escape_html(subject), plainto_tsquery('zulip.english_us_search', '"jumping" quickly'), 'HighlightAll = TRUE, StartSel = , StopSel = '), '')) AS anon_5\n\ diff --git a/zerver/views/message_fetch.py b/zerver/views/message_fetch.py index d104e01063..6fbaff170d 100644 --- a/zerver/views/message_fetch.py +++ b/zerver/views/message_fetch.py @@ -5,10 +5,9 @@ from django.conf import settings from django.contrib.auth.models import AnonymousUser from django.db import connection, transaction from django.http import HttpRequest, HttpResponse -from django.utils.html import escape as escape_html from django.utils.translation import gettext as _ from pydantic import Json, NonNegativeInt -from sqlalchemy.sql import column +from sqlalchemy.sql import column, func from sqlalchemy.types import Integer, Text from zerver.context_processors import get_valid_realm_from_request @@ -32,7 +31,7 @@ from zerver.lib.narrow import ( from zerver.lib.request import RequestNotes from zerver.lib.response import json_success from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection -from zerver.lib.topic import DB_TOPIC_NAME, MATCH_TOPIC +from zerver.lib.topic import MATCH_TOPIC from zerver.lib.topic_sqlalchemy import topic_column_sa from zerver.lib.typed_endpoint import ApiParamConfig, typed_endpoint from zerver.models import UserMessage, UserProfile @@ -79,13 +78,13 @@ def highlight_string(text: str, locs: Iterable[tuple[int, int]]) -> str: def get_search_fields( rendered_content: str, - topic_name: str, + escaped_topic_name: str, content_matches: Iterable[tuple[int, int]], topic_matches: Iterable[tuple[int, int]], ) -> dict[str, str]: return { "match_content": highlight_string(rendered_content, content_matches), - MATCH_TOPIC: highlight_string(escape_html(topic_name), topic_matches), + MATCH_TOPIC: highlight_string(escaped_topic_name, topic_matches), } @@ -294,9 +293,9 @@ def get_messages_backend( if is_search: for row in rows: message_id = row[0] - (topic_name, rendered_content, content_matches, topic_matches) = row[-4:] + (escaped_topic_name, rendered_content, content_matches, topic_matches) = row[-4:] search_fields[message_id] = get_search_fields( - rendered_content, topic_name, content_matches, topic_matches + rendered_content, escaped_topic_name, content_matches, topic_matches ) message_list = messages_for_ids( @@ -365,19 +364,22 @@ def messages_in_narrow_backend( if not is_search: # `add_narrow_conditions` adds the following columns only if narrow has search operands. - query = query.add_columns(topic_column_sa(), column("rendered_content", Text)) + query = query.add_columns( + func.escape_html(topic_column_sa(), type_=Text).label("escaped_topic_name"), + column("rendered_content", Text), + ) search_fields = {} with get_sqlalchemy_connection() as sa_conn: for row in sa_conn.execute(query).mappings(): message_id = row["message_id"] - topic_name: str = row[DB_TOPIC_NAME] + escaped_topic_name: str = row["escaped_topic_name"] rendered_content: str = row["rendered_content"] content_matches = row.get("content_matches", []) topic_matches = row.get("topic_matches", []) search_fields[str(message_id)] = get_search_fields( rendered_content, - topic_name, + escaped_topic_name, content_matches, topic_matches, )