mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	get_realm is better in two key ways: * It uses memcached to fetch the data from the cache and thus is faster. * It does a case-insensitive query and thus is more safe.
		
			
				
	
	
		
			353 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			353 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from django.test import TestCase
 | 
						|
 | 
						|
from zerver.lib.initial_password import initial_password
 | 
						|
from zerver.lib.db import TimeTrackingCursor
 | 
						|
from zerver.lib import cache
 | 
						|
from zerver.lib import event_queue
 | 
						|
from zerver.worker import queue_processors
 | 
						|
 | 
						|
from zerver.lib.actions import (
 | 
						|
    check_send_message, create_stream_if_needed, do_add_subscription,
 | 
						|
    get_display_recipient,
 | 
						|
)
 | 
						|
 | 
						|
from zerver.models import (
 | 
						|
    get_realm,
 | 
						|
    get_user_profile_by_email,
 | 
						|
    resolve_email_to_domain,
 | 
						|
    Client,
 | 
						|
    Message,
 | 
						|
    Realm,
 | 
						|
    Recipient,
 | 
						|
    Stream,
 | 
						|
    Subscription,
 | 
						|
    UserMessage,
 | 
						|
)
 | 
						|
 | 
						|
import base64
 | 
						|
import os
 | 
						|
import re
 | 
						|
import time
 | 
						|
import ujson
 | 
						|
import urllib
 | 
						|
 | 
						|
from contextlib import contextmanager
 | 
						|
 | 
						|
API_KEYS = {}
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def stub(obj, name, f):
 | 
						|
    old_f = getattr(obj, name)
 | 
						|
    setattr(obj, name, f)
 | 
						|
    yield
 | 
						|
    setattr(obj, name, old_f)
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def simulated_queue_client(client):
 | 
						|
    real_SimpleQueueClient = queue_processors.SimpleQueueClient
 | 
						|
    queue_processors.SimpleQueueClient = client
 | 
						|
    yield
 | 
						|
    queue_processors.SimpleQueueClient = real_SimpleQueueClient
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def tornado_redirected_to_list(lst):
 | 
						|
    real_event_queue_process_notification = event_queue.process_notification
 | 
						|
    event_queue.process_notification = lst.append
 | 
						|
    yield
 | 
						|
    event_queue.process_notification = real_event_queue_process_notification
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def simulated_empty_cache():
 | 
						|
    cache_queries = []
 | 
						|
    def my_cache_get(key, cache_name=None):
 | 
						|
        cache_queries.append(('get', key, cache_name))
 | 
						|
        return None
 | 
						|
 | 
						|
    def my_cache_get_many(keys, cache_name=None):
 | 
						|
        cache_queries.append(('getmany', keys, cache_name))
 | 
						|
        return None
 | 
						|
 | 
						|
    old_get = cache.cache_get
 | 
						|
    old_get_many = cache.cache_get_many
 | 
						|
    cache.cache_get = my_cache_get
 | 
						|
    cache.cache_get_many = my_cache_get_many
 | 
						|
    yield cache_queries
 | 
						|
    cache.cache_get = old_get
 | 
						|
    cache.cache_get_many = old_get_many
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def queries_captured():
 | 
						|
    '''
 | 
						|
    Allow a user to capture just the queries executed during
 | 
						|
    the with statement.
 | 
						|
    '''
 | 
						|
 | 
						|
    queries = []
 | 
						|
 | 
						|
    def wrapper_execute(self, action, sql, params=()):
 | 
						|
        start = time.time()
 | 
						|
        try:
 | 
						|
            return action(sql, params)
 | 
						|
        finally:
 | 
						|
            stop = time.time()
 | 
						|
            duration = stop - start
 | 
						|
            queries.append({
 | 
						|
                    'sql': self.mogrify(sql, params),
 | 
						|
                    'time': "%.3f" % duration,
 | 
						|
                    })
 | 
						|
 | 
						|
    old_execute = TimeTrackingCursor.execute
 | 
						|
    old_executemany = TimeTrackingCursor.executemany
 | 
						|
 | 
						|
    def cursor_execute(self, sql, params=()):
 | 
						|
        return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
 | 
						|
    TimeTrackingCursor.execute = cursor_execute
 | 
						|
 | 
						|
    def cursor_executemany(self, sql, params=()):
 | 
						|
        return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params)
 | 
						|
    TimeTrackingCursor.executemany = cursor_executemany
 | 
						|
 | 
						|
    yield queries
 | 
						|
 | 
						|
    TimeTrackingCursor.execute = old_execute
 | 
						|
    TimeTrackingCursor.executemany = old_executemany
 | 
						|
 | 
						|
 | 
						|
def find_key_by_email(address):
 | 
						|
    from django.core.mail import outbox
 | 
						|
    key_regex = re.compile("accounts/do_confirm/([a-f0-9]{40})>")
 | 
						|
    for message in reversed(outbox):
 | 
						|
        if address in message.to:
 | 
						|
            return key_regex.search(message.body).groups()[0]
 | 
						|
 | 
						|
def message_ids(result):
 | 
						|
    return set(message['id'] for message in result['messages'])
 | 
						|
 | 
						|
def message_stream_count(user_profile):
 | 
						|
    return UserMessage.objects. \
 | 
						|
        select_related("message"). \
 | 
						|
        filter(user_profile=user_profile). \
 | 
						|
        count()
 | 
						|
 | 
						|
def most_recent_usermessage(user_profile):
 | 
						|
    query = UserMessage.objects. \
 | 
						|
        select_related("message"). \
 | 
						|
        filter(user_profile=user_profile). \
 | 
						|
        order_by('-message')
 | 
						|
    return query[0] # Django does LIMIT here
 | 
						|
 | 
						|
def most_recent_message(user_profile):
 | 
						|
    usermessage = most_recent_usermessage(user_profile)
 | 
						|
    return usermessage.message
 | 
						|
 | 
						|
def get_user_messages(user_profile):
 | 
						|
    query = UserMessage.objects. \
 | 
						|
        select_related("message"). \
 | 
						|
        filter(user_profile=user_profile). \
 | 
						|
        order_by('message')
 | 
						|
    return [um.message for um in query]
 | 
						|
 | 
						|
class DummyObject:
 | 
						|
    pass
 | 
						|
 | 
						|
class DummyTornadoRequest:
 | 
						|
    def __init__(self):
 | 
						|
        self.connection = DummyObject()
 | 
						|
        self.connection.stream = DummyStream()
 | 
						|
 | 
						|
class DummyHandler(object):
 | 
						|
    def __init__(self, assert_callback):
 | 
						|
        self.assert_callback = assert_callback
 | 
						|
        self.request = DummyTornadoRequest()
 | 
						|
 | 
						|
    # Mocks RequestHandler.async_callback, which wraps a callback to
 | 
						|
    # handle exceptions.  We return the callback as-is.
 | 
						|
    def async_callback(self, cb):
 | 
						|
        return cb
 | 
						|
 | 
						|
    def write(self, response):
 | 
						|
        raise NotImplemented
 | 
						|
 | 
						|
    def zulip_finish(self, response, *ignore):
 | 
						|
        if self.assert_callback:
 | 
						|
            self.assert_callback(response)
 | 
						|
 | 
						|
 | 
						|
class DummySession(object):
 | 
						|
    session_key = "0"
 | 
						|
 | 
						|
class DummyStream:
 | 
						|
    def closed(self):
 | 
						|
        return False
 | 
						|
 | 
						|
class POSTRequestMock(object):
 | 
						|
    method = "POST"
 | 
						|
 | 
						|
    def __init__(self, post_data, user_profile, assert_callback=None):
 | 
						|
        self.REQUEST = self.POST = post_data
 | 
						|
        self.user = user_profile
 | 
						|
        self._tornado_handler = DummyHandler(assert_callback)
 | 
						|
        self.session = DummySession()
 | 
						|
        self._log_data = {}
 | 
						|
        self.META = {'PATH_INFO': 'test'}
 | 
						|
        self._log_data = {}
 | 
						|
 | 
						|
class AuthedTestCase(TestCase):
 | 
						|
    # Helper because self.client.patch annoying requires you to urlencode
 | 
						|
    def client_patch(self, url, info={}, **kwargs):
 | 
						|
        info = urllib.urlencode(info)
 | 
						|
        return self.client.patch(url, info, **kwargs)
 | 
						|
    def client_put(self, url, info={}, **kwargs):
 | 
						|
        info = urllib.urlencode(info)
 | 
						|
        return self.client.put(url, info, **kwargs)
 | 
						|
    def client_delete(self, url, info={}, **kwargs):
 | 
						|
        info = urllib.urlencode(info)
 | 
						|
        return self.client.delete(url, info, **kwargs)
 | 
						|
 | 
						|
    def login(self, email, password=None):
 | 
						|
        if password is None:
 | 
						|
            password = initial_password(email)
 | 
						|
        return self.client.post('/accounts/login/',
 | 
						|
                                {'username':email, 'password':password})
 | 
						|
 | 
						|
    def register(self, username, password, domain="zulip.com"):
 | 
						|
        self.client.post('/accounts/home/',
 | 
						|
                         {'email': username + "@" + domain})
 | 
						|
        return self.submit_reg_form_for_user(username, password, domain=domain)
 | 
						|
 | 
						|
    def submit_reg_form_for_user(self, username, password, domain="zulip.com"):
 | 
						|
        """
 | 
						|
        Stage two of the two-step registration process.
 | 
						|
 | 
						|
        If things are working correctly the account should be fully
 | 
						|
        registered after this call.
 | 
						|
        """
 | 
						|
        return self.client.post('/accounts/register/',
 | 
						|
                                {'full_name': username, 'password': password,
 | 
						|
                                 'key': find_key_by_email(username + '@' + domain),
 | 
						|
                                 'terms': True})
 | 
						|
 | 
						|
    def get_api_key(self, email):
 | 
						|
        if email not in API_KEYS:
 | 
						|
            API_KEYS[email] =  get_user_profile_by_email(email).api_key
 | 
						|
        return API_KEYS[email]
 | 
						|
 | 
						|
    def api_auth(self, email):
 | 
						|
        credentials = "%s:%s" % (email, self.get_api_key(email))
 | 
						|
        return {
 | 
						|
            'HTTP_AUTHORIZATION': 'Basic ' + base64.b64encode(credentials)
 | 
						|
            }
 | 
						|
 | 
						|
    def get_streams(self, email):
 | 
						|
        """
 | 
						|
        Helper function to get the stream names for a user
 | 
						|
        """
 | 
						|
        user_profile = get_user_profile_by_email(email)
 | 
						|
        subs = Subscription.objects.filter(
 | 
						|
            user_profile    = user_profile,
 | 
						|
            active          = True,
 | 
						|
            recipient__type = Recipient.STREAM)
 | 
						|
        return [get_display_recipient(sub.recipient) for sub in subs]
 | 
						|
 | 
						|
    def send_message(self, sender_name, recipient_list, message_type,
 | 
						|
                     content="test content", subject="test", **kwargs):
 | 
						|
        sender = get_user_profile_by_email(sender_name)
 | 
						|
        if message_type == Recipient.PERSONAL:
 | 
						|
            message_type_name = "private"
 | 
						|
        else:
 | 
						|
            message_type_name = "stream"
 | 
						|
        if isinstance(recipient_list, basestring):
 | 
						|
            recipient_list = [recipient_list]
 | 
						|
        (sending_client, _) = Client.objects.get_or_create(name="test suite")
 | 
						|
 | 
						|
        return check_send_message(
 | 
						|
            sender, sending_client, message_type_name, recipient_list, subject,
 | 
						|
            content, forged=False, forged_timestamp=None,
 | 
						|
            forwarder_user_profile=sender, realm=sender.realm, **kwargs)
 | 
						|
 | 
						|
    def get_old_messages(self, anchor=1, num_before=100, num_after=100):
 | 
						|
        post_params = {"anchor": anchor, "num_before": num_before,
 | 
						|
                       "num_after": num_after}
 | 
						|
        result = self.client.post("/json/get_old_messages", dict(post_params))
 | 
						|
        data = ujson.loads(result.content)
 | 
						|
        return data['messages']
 | 
						|
 | 
						|
    def users_subscribed_to_stream(self, stream_name, realm_domain):
 | 
						|
        realm = get_realm(realm_domain)
 | 
						|
        stream = Stream.objects.get(name=stream_name, realm=realm)
 | 
						|
        recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)
 | 
						|
        subscriptions = Subscription.objects.filter(recipient=recipient, active=True)
 | 
						|
 | 
						|
        return [subscription.user_profile for subscription in subscriptions]
 | 
						|
 | 
						|
    def assert_json_success(self, result):
 | 
						|
        """
 | 
						|
        Successful POSTs return a 200 and JSON of the form {"result": "success",
 | 
						|
        "msg": ""}.
 | 
						|
        """
 | 
						|
        self.assertEqual(result.status_code, 200, result)
 | 
						|
        json = ujson.loads(result.content)
 | 
						|
        self.assertEqual(json.get("result"), "success")
 | 
						|
        # We have a msg key for consistency with errors, but it typically has an
 | 
						|
        # empty value.
 | 
						|
        self.assertIn("msg", json)
 | 
						|
        return json
 | 
						|
 | 
						|
    def get_json_error(self, result, status_code=400):
 | 
						|
        self.assertEqual(result.status_code, status_code)
 | 
						|
        json = ujson.loads(result.content)
 | 
						|
        self.assertEqual(json.get("result"), "error")
 | 
						|
        return json['msg']
 | 
						|
 | 
						|
    def assert_json_error(self, result, msg, status_code=400):
 | 
						|
        """
 | 
						|
        Invalid POSTs return an error status code and JSON of the form
 | 
						|
        {"result": "error", "msg": "reason"}.
 | 
						|
        """
 | 
						|
        self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
 | 
						|
 | 
						|
    def assert_length(self, queries, count, exact=False):
 | 
						|
        actual_count = len(queries)
 | 
						|
        if exact:
 | 
						|
            return self.assertTrue(actual_count == count,
 | 
						|
                                   "len(%s) == %s, != %s" % (queries, actual_count, count))
 | 
						|
        return self.assertTrue(actual_count <= count,
 | 
						|
                               "len(%s) == %s, > %s" % (queries, actual_count, count))
 | 
						|
 | 
						|
    def assert_json_error_contains(self, result, msg_substring):
 | 
						|
        self.assertIn(msg_substring, self.get_json_error(result))
 | 
						|
 | 
						|
    def fixture_data(self, type, action, file_type='json'):
 | 
						|
        return open(os.path.join(os.path.dirname(__file__),
 | 
						|
                                 "../fixtures/%s/%s_%s.%s" % (type, type, action,file_type))).read()
 | 
						|
 | 
						|
    # Subscribe to a stream directly
 | 
						|
    def subscribe_to_stream(self, email, stream_name, realm=None):
 | 
						|
        realm = get_realm(resolve_email_to_domain(email))
 | 
						|
        stream, _ = create_stream_if_needed(realm, stream_name)
 | 
						|
        user_profile = get_user_profile_by_email(email)
 | 
						|
        do_add_subscription(user_profile, stream, no_log=True)
 | 
						|
 | 
						|
    # Subscribe to a stream by making an API request
 | 
						|
    def common_subscribe_to_streams(self, email, streams, extra_post_data = {}, invite_only=False):
 | 
						|
        post_data = {'subscriptions': ujson.dumps([{"name": stream} for stream in streams]),
 | 
						|
                     'invite_only': ujson.dumps(invite_only)}
 | 
						|
        post_data.update(extra_post_data)
 | 
						|
        result = self.client.post("/api/v1/users/me/subscriptions", post_data, **self.api_auth(email))
 | 
						|
        return result
 | 
						|
 | 
						|
    def send_json_payload(self, email, url, payload, stream_name=None, **post_params):
 | 
						|
        if stream_name != None:
 | 
						|
            self.subscribe_to_stream(email, stream_name)
 | 
						|
 | 
						|
        result = self.client.post(url, payload, **post_params)
 | 
						|
        self.assert_json_success(result)
 | 
						|
 | 
						|
        # Check the correct message was sent
 | 
						|
        msg = Message.objects.filter().order_by('-id')[0]
 | 
						|
        self.assertEqual(msg.sender.email, email)
 | 
						|
        self.assertEqual(get_display_recipient(msg.recipient), stream_name)
 | 
						|
 | 
						|
        return msg
 | 
						|
 |