diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index 3cd2d26ac3..1f3f06993d 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -1064,11 +1064,10 @@ def exclude_muting_conditions( def get_base_query_for_search( - realm_id: int, user_profile: UserProfile | None, need_message: bool, need_user_message: bool + realm_id: int, user_profile: UserProfile | None, need_user_message: bool ) -> tuple[Select, ColumnElement[Integer]]: # Handle the simple case where user_message isn't involved first. if not need_user_message: - assert need_message query = ( select(column("id", Integer).label("message_id")) .select_from(table("zerver_message")) @@ -1079,30 +1078,21 @@ def get_base_query_for_search( return (query, inner_msg_id_col) assert user_profile is not None - if need_message: - query = ( - select(column("message_id", Integer)) - # We don't limit by realm_id despite the join to - # zerver_messages, since the user_profile_id limit in - # usermessage is more selective, and the query planner - # can't know about that cross-table correlation. - .where(column("user_profile_id", Integer) == literal(user_profile.id)) - .select_from( - join( - table("zerver_usermessage"), - table("zerver_message"), - literal_column("zerver_usermessage.message_id", Integer) - == literal_column("zerver_message.id", Integer), - ) - ) - ) - inner_msg_id_col = column("message_id", Integer) - return (query, inner_msg_id_col) - query = ( select(column("message_id", Integer)) + # We don't limit by realm_id despite the join to + # zerver_messages, since the user_profile_id limit in + # usermessage is more selective, and the query planner + # can't know about that cross-table correlation. .where(column("user_profile_id", Integer) == literal(user_profile.id)) - .select_from(table("zerver_usermessage")) + .select_from( + join( + table("zerver_usermessage"), + table("zerver_message"), + literal_column("zerver_usermessage.message_id", Integer) + == literal_column("zerver_message.id", Integer), + ) + ) ) inner_msg_id_col = column("message_id", Integer) return (query, inner_msg_id_col) @@ -1160,17 +1150,9 @@ def find_first_unread_anchor( # flag for the user. need_user_message = True - # Because we will need to call exclude_muting_conditions, unless - # the user hasn't muted anything, we will need to include Message - # in our query. It may be worth eventually adding an optimization - # for the case of a user who hasn't muted anything to avoid the - # join in that case, but it's low priority. - need_message = True - query, inner_msg_id_col = get_base_query_for_search( realm_id=user_profile.realm_id, user_profile=user_profile, - need_message=need_message, need_user_message=need_user_message, ) query = query.add_columns(column("flags", Integer)) @@ -1442,15 +1424,8 @@ def fetch_messages( # # Note that is_web_public_query=True goes here, since # include_history is semantically correct for is_web_public_query. - need_message = True need_user_message = False - elif narrow is None: - # We need to limit to messages the user has received, but we don't actually - # need any fields from Message - need_message = False - need_user_message = True else: - need_message = True need_user_message = True # get_base_query_for_search and ok_to_include_history are responsible for ensuring @@ -1459,7 +1434,6 @@ def fetch_messages( query, inner_msg_id_col = get_base_query_for_search( realm_id=realm.id, user_profile=user_profile, - need_message=need_message, need_user_message=need_user_message, ) if need_user_message: diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index 6a70cad458..647025e603 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -4690,27 +4690,64 @@ AND NOT (recipient_id = %(recipient_id_4)s AND upper(subject) = upper(%(param_3) def test_get_messages_queries(self) -> None: query_ids = self.get_query_ids() - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 1, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 2) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\ + LIMIT 2) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 1}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\ + LIMIT 11) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 10}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n\ + LIMIT 11) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 100, "num_before": 10, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM ((SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n LIMIT 10) UNION ALL (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n LIMIT 11)) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM ((SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n\ + LIMIT 10) UNION ALL (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n\ + LIMIT 11)) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query( {"anchor": 100, "num_before": 10, "num_after": 10}, sql diff --git a/zerver/views/message_fetch.py b/zerver/views/message_fetch.py index 928655e61b..6823e45e2d 100644 --- a/zerver/views/message_fetch.py +++ b/zerver/views/message_fetch.py @@ -349,7 +349,7 @@ def messages_in_narrow_backend( # This query is limited to messages the user has access to because they # actually received them, as reflected in `zerver_usermessage`. query, inner_msg_id_col = get_base_query_for_search( - user_profile.realm_id, user_profile, need_message=True, need_user_message=True + user_profile.realm_id, user_profile, need_user_message=True ) query = query.where(column("message_id", Integer).in_(msg_ids))