Use topic_match_sa() for topic searches.

Note this introduce literal(), which makes the way
we handle topic mutes more consistent with general
topic searches.
This commit is contained in:
Steve Howell
2018-11-01 21:15:43 +00:00
committed by Tim Abbott
parent 79d5e36ca3
commit ff60055fa4
3 changed files with 32 additions and 28 deletions

View File

@@ -6,6 +6,7 @@ from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import ( from sqlalchemy.sql import (
column, column,
literal,
func, func,
) )
@@ -26,7 +27,7 @@ PREV_TOPIC = "prev_subject"
def topic_match_sa(topic_name: str) -> Any: def topic_match_sa(topic_name: str) -> Any:
# _sa is short for Sql Alchemy, which we use mostly for # _sa is short for Sql Alchemy, which we use mostly for
# queries that search messages # queries that search messages
topic_cond = func.upper(column("subject")) == func.upper(topic_name) topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name))
return topic_cond return topic_cond
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet: def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:

View File

@@ -2410,13 +2410,13 @@ class GetOldMessagesTest(ZulipTestCase):
expected_query = ''' expected_query = '''
SELECT id AS message_id SELECT id AS message_id
FROM zerver_message FROM zerver_message
WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:upper_1)) WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:param_1))
''' '''
self.assertEqual(fix_ws(query), fix_ws(expected_query)) self.assertEqual(fix_ws(query), fix_ws(expected_query))
params = get_sqlalchemy_query_params(query) params = get_sqlalchemy_query_params(query)
self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Scotland')) self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Scotland'))
self.assertEqual(params['upper_1'], 'golf') self.assertEqual(params['param_1'], 'golf')
mute_stream(realm, user_profile, 'Verona') mute_stream(realm, user_profile, 'Verona')
@@ -2435,15 +2435,15 @@ class GetOldMessagesTest(ZulipTestCase):
FROM zerver_message FROM zerver_message
WHERE recipient_id NOT IN (:recipient_id_1) WHERE recipient_id NOT IN (:recipient_id_1)
AND NOT AND NOT
(recipient_id = :recipient_id_2 AND upper(subject) = upper(:upper_1) OR (recipient_id = :recipient_id_2 AND upper(subject) = upper(:param_1) OR
recipient_id = :recipient_id_3 AND upper(subject) = upper(:upper_2))''' recipient_id = :recipient_id_3 AND upper(subject) = upper(:param_2))'''
self.assertEqual(fix_ws(query), fix_ws(expected_query)) self.assertEqual(fix_ws(query), fix_ws(expected_query))
params = get_sqlalchemy_query_params(query) params = get_sqlalchemy_query_params(query)
self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Verona')) self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Verona'))
self.assertEqual(params['recipient_id_2'], get_recipient_id_for_stream_name(realm, 'Scotland')) self.assertEqual(params['recipient_id_2'], get_recipient_id_for_stream_name(realm, 'Scotland'))
self.assertEqual(params['upper_1'], 'golf') self.assertEqual(params['param_1'], 'golf')
self.assertEqual(params['recipient_id_3'], get_recipient_id_for_stream_name(realm, 'web stuff')) self.assertEqual(params['recipient_id_3'], get_recipient_id_for_stream_name(realm, 'web stuff'))
self.assertEqual(params['upper_2'], 'css') self.assertEqual(params['param_2'], 'css')
def test_get_messages_queries(self) -> None: def test_get_messages_queries(self) -> None:
query_ids = self.get_query_ids() query_ids = self.get_query_ids()

View File

@@ -33,6 +33,9 @@ from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.streams import access_stream_by_id, can_access_stream_history_by_name from zerver.lib.streams import access_stream_by_id, can_access_stream_history_by_name
from zerver.lib.timestamp import datetime_to_timestamp, convert_to_UTC from zerver.lib.timestamp import datetime_to_timestamp, convert_to_UTC
from zerver.lib.timezone import get_timezone from zerver.lib.timezone import get_timezone
from zerver.lib.topic import (
topic_match_sa,
)
from zerver.lib.topic_mutes import exclude_topic_mutes from zerver.lib.topic_mutes import exclude_topic_mutes
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
from zerver.lib.validator import \ from zerver.lib.validator import \
@@ -241,36 +244,36 @@ class NarrowBuilder:
# instance "personal" to be the same. # instance "personal" to be the same.
if base_topic in ('', 'personal', '(instance "")'): if base_topic in ('', 'personal', '(instance "")'):
cond = or_( cond = or_(
func.upper(column("subject")) == func.upper(literal("")), topic_match_sa(""),
func.upper(column("subject")) == func.upper(literal(".d")), topic_match_sa(".d"),
func.upper(column("subject")) == func.upper(literal(".d.d")), topic_match_sa(".d.d"),
func.upper(column("subject")) == func.upper(literal(".d.d.d")), topic_match_sa(".d.d.d"),
func.upper(column("subject")) == func.upper(literal(".d.d.d.d")), topic_match_sa(".d.d.d.d"),
func.upper(column("subject")) == func.upper(literal("personal")), topic_match_sa("personal"),
func.upper(column("subject")) == func.upper(literal("personal.d")), topic_match_sa("personal.d"),
func.upper(column("subject")) == func.upper(literal("personal.d.d")), topic_match_sa("personal.d.d"),
func.upper(column("subject")) == func.upper(literal("personal.d.d.d")), topic_match_sa("personal.d.d.d"),
func.upper(column("subject")) == func.upper(literal("personal.d.d.d.d")), topic_match_sa("personal.d.d.d.d"),
func.upper(column("subject")) == func.upper(literal('(instance "")')), topic_match_sa('(instance "")'),
func.upper(column("subject")) == func.upper(literal('(instance "").d')), topic_match_sa('(instance "").d'),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d')), topic_match_sa('(instance "").d.d'),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d')), topic_match_sa('(instance "").d.d.d'),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d.d')), topic_match_sa('(instance "").d.d.d.d'),
) )
else: else:
# We limit `.d` counts, since postgres has much better # We limit `.d` counts, since postgres has much better
# query planning for this than they do for a regular # query planning for this than they do for a regular
# expression (which would sometimes table scan). # expression (which would sometimes table scan).
cond = or_( cond = or_(
func.upper(column("subject")) == func.upper(literal(base_topic)), topic_match_sa(base_topic),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d")), topic_match_sa(base_topic + ".d"),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d")), topic_match_sa(base_topic + ".d.d"),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d")), topic_match_sa(base_topic + ".d.d.d"),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d.d")), topic_match_sa(base_topic + ".d.d.d.d"),
) )
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
cond = func.upper(column("subject")) == func.upper(literal(operand)) cond = topic_match_sa(operand)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_sender(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_sender(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: