mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-03 21:43:21 +00:00 
			
		
		
		
	Add type annotations to zerver.lib.test_helpers.
This commit is contained in:
		@@ -1,8 +1,9 @@
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from typing import Any, Callable, Generator, Iterable, Tuple
 | 
			
		||||
from typing import Any, Callable, Generator, Iterable, Tuple, Sized, Union, Optional
 | 
			
		||||
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.template import loader
 | 
			
		||||
from django.http import HttpResponse
 | 
			
		||||
 | 
			
		||||
from zerver.lib.initial_password import initial_password
 | 
			
		||||
from zerver.lib.db import TimeTrackingCursor
 | 
			
		||||
@@ -16,8 +17,6 @@ from zerver.lib.actions import (
 | 
			
		||||
    get_display_recipient,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from zerver.lib.handlers import allocate_handler_id
 | 
			
		||||
 | 
			
		||||
from zerver.models import (
 | 
			
		||||
    get_realm,
 | 
			
		||||
    get_stream,
 | 
			
		||||
@@ -30,6 +29,7 @@ from zerver.models import (
 | 
			
		||||
    Stream,
 | 
			
		||||
    Subscription,
 | 
			
		||||
    UserMessage,
 | 
			
		||||
    UserProfile,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import base64
 | 
			
		||||
@@ -70,13 +70,15 @@ def tornado_redirected_to_list(lst):
 | 
			
		||||
 | 
			
		||||
@contextmanager
 | 
			
		||||
def simulated_empty_cache():
 | 
			
		||||
    # type: () -> Generator[List[Tuple[str, str, str]], None, None]
 | 
			
		||||
    cache_queries = []
 | 
			
		||||
    # type: () -> Generator[List[Tuple[str, Union[str, List[str]], str]], None, None]
 | 
			
		||||
    cache_queries = [] # type: List[Tuple[str, Union[str, List[str]], str]]
 | 
			
		||||
    def my_cache_get(key, cache_name=None):
 | 
			
		||||
        # type: (str, Optional[str]) -> Any
 | 
			
		||||
        cache_queries.append(('get', key, cache_name))
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def my_cache_get_many(keys, cache_name=None):
 | 
			
		||||
        # type: (List[str], Optional[str]) -> Dict[str, Any]
 | 
			
		||||
        cache_queries.append(('getmany', keys, cache_name))
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
@@ -99,6 +101,7 @@ def queries_captured():
 | 
			
		||||
    queries = []
 | 
			
		||||
 | 
			
		||||
    def wrapper_execute(self, action, sql, params=()):
 | 
			
		||||
        # type: (TimeTrackingCursor, Callable, str, Iterable[Any]) -> None
 | 
			
		||||
        start = time.time()
 | 
			
		||||
        try:
 | 
			
		||||
            return action(sql, params)
 | 
			
		||||
@@ -106,19 +109,21 @@ def queries_captured():
 | 
			
		||||
            stop = time.time()
 | 
			
		||||
            duration = stop - start
 | 
			
		||||
            queries.append({
 | 
			
		||||
                    'sql': self.mogrify(sql, params),
 | 
			
		||||
                    'time': "%.3f" % duration,
 | 
			
		||||
                    })
 | 
			
		||||
                'sql': self.mogrify(sql, params),
 | 
			
		||||
                'time': "%.3f" % duration,
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
    old_execute = TimeTrackingCursor.execute
 | 
			
		||||
    old_executemany = TimeTrackingCursor.executemany
 | 
			
		||||
 | 
			
		||||
    def cursor_execute(self, sql, params=()):
 | 
			
		||||
        return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)  # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
        # type: (TimeTrackingCursor, str, Iterable[Any]) -> None
 | 
			
		||||
        return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
    TimeTrackingCursor.execute = cursor_execute # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
 | 
			
		||||
    def cursor_executemany(self, sql, params=()):
 | 
			
		||||
        return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params)  # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
        # type: (TimeTrackingCursor, str, Iterable[Any]) -> None
 | 
			
		||||
        return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
    TimeTrackingCursor.executemany = cursor_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
 | 
			
		||||
 | 
			
		||||
    yield queries
 | 
			
		||||
@@ -128,6 +133,7 @@ def queries_captured():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_key_by_email(address):
 | 
			
		||||
    # type: (str) -> str
 | 
			
		||||
    from django.core.mail import outbox
 | 
			
		||||
    key_regex = re.compile("accounts/do_confirm/([a-f0-9]{40})>")
 | 
			
		||||
    for message in reversed(outbox):
 | 
			
		||||
@@ -135,15 +141,18 @@ def find_key_by_email(address):
 | 
			
		||||
            return key_regex.search(message.body).groups()[0]
 | 
			
		||||
 | 
			
		||||
def message_ids(result):
 | 
			
		||||
    # type: (Dict[str, Any]) -> Set[int]
 | 
			
		||||
    return set(message['id'] for message in result['messages'])
 | 
			
		||||
 | 
			
		||||
def message_stream_count(user_profile):
 | 
			
		||||
    # type: (UserProfile) -> int
 | 
			
		||||
    return UserMessage.objects. \
 | 
			
		||||
        select_related("message"). \
 | 
			
		||||
        filter(user_profile=user_profile). \
 | 
			
		||||
        count()
 | 
			
		||||
 | 
			
		||||
def most_recent_usermessage(user_profile):
 | 
			
		||||
    # type: (UserProfile) -> UserMessage
 | 
			
		||||
    query = UserMessage.objects. \
 | 
			
		||||
        select_related("message"). \
 | 
			
		||||
        filter(user_profile=user_profile). \
 | 
			
		||||
@@ -151,10 +160,12 @@ def most_recent_usermessage(user_profile):
 | 
			
		||||
    return query[0] # Django does LIMIT here
 | 
			
		||||
 | 
			
		||||
def most_recent_message(user_profile):
 | 
			
		||||
    # type: (UserProfile) -> Message
 | 
			
		||||
    usermessage = most_recent_usermessage(user_profile)
 | 
			
		||||
    return usermessage.message
 | 
			
		||||
 | 
			
		||||
def get_user_messages(user_profile):
 | 
			
		||||
    # type: (UserProfile) -> List[Message]
 | 
			
		||||
    query = UserMessage.objects. \
 | 
			
		||||
        select_related("message"). \
 | 
			
		||||
        filter(user_profile=user_profile). \
 | 
			
		||||
@@ -166,11 +177,13 @@ class DummyObject(object):
 | 
			
		||||
 | 
			
		||||
class DummyTornadoRequest(object):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # type: () -> None
 | 
			
		||||
        self.connection = DummyObject()
 | 
			
		||||
        self.connection.stream = DummyStream() # type: ignore # monkey-patching here
 | 
			
		||||
 | 
			
		||||
class DummyHandler(object):
 | 
			
		||||
    def __init__(self, assert_callback):
 | 
			
		||||
        # type: (Any) -> None
 | 
			
		||||
        self.assert_callback = assert_callback
 | 
			
		||||
        self.request = DummyTornadoRequest()
 | 
			
		||||
        allocate_handler_id(self)
 | 
			
		||||
@@ -178,12 +191,15 @@ class DummyHandler(object):
 | 
			
		||||
    # Mocks RequestHandler.async_callback, which wraps a callback to
 | 
			
		||||
    # handle exceptions.  We return the callback as-is.
 | 
			
		||||
    def async_callback(self, cb):
 | 
			
		||||
        # type: (Callable) -> Callable
 | 
			
		||||
        return cb
 | 
			
		||||
 | 
			
		||||
    def write(self, response):
 | 
			
		||||
        # type: (str) -> None
 | 
			
		||||
        raise NotImplemented
 | 
			
		||||
 | 
			
		||||
    def zulip_finish(self, response, *ignore):
 | 
			
		||||
        # type: (HttpResponse, *Any) -> None
 | 
			
		||||
        if self.assert_callback:
 | 
			
		||||
            self.assert_callback(response)
 | 
			
		||||
 | 
			
		||||
@@ -193,12 +209,14 @@ class DummySession(object):
 | 
			
		||||
 | 
			
		||||
class DummyStream(object):
 | 
			
		||||
    def closed(self):
 | 
			
		||||
        # type: () -> bool
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
class POSTRequestMock(object):
 | 
			
		||||
    method = "POST"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, post_data, user_profile, assert_callback=None):
 | 
			
		||||
        # type: (Dict[str, Any], UserProfile, Optional[Callable]) -> None
 | 
			
		||||
        self.REQUEST = self.POST = post_data
 | 
			
		||||
        self.user = user_profile
 | 
			
		||||
        self._tornado_handler = DummyHandler(assert_callback)
 | 
			
		||||
@@ -208,28 +226,37 @@ class POSTRequestMock(object):
 | 
			
		||||
 | 
			
		||||
class AuthedTestCase(TestCase):
 | 
			
		||||
    # Helper because self.client.patch annoying requires you to urlencode
 | 
			
		||||
 | 
			
		||||
    def client_patch(self, url, info={}, **kwargs):
 | 
			
		||||
        info = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.patch(url, info, **kwargs)
 | 
			
		||||
        # type: (str, Dict[str, Any], **Any) -> HttpResponse
 | 
			
		||||
        encoded = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.patch(url, encoded, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def client_put(self, url, info={}, **kwargs):
 | 
			
		||||
        info = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.put(url, info, **kwargs)
 | 
			
		||||
        # type: (str, Dict[str, Any], **Any) -> HttpResponse
 | 
			
		||||
        encoded = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.put(url, encoded, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def client_delete(self, url, info={}, **kwargs):
 | 
			
		||||
        info = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.delete(url, info, **kwargs)
 | 
			
		||||
        # type: (str, Dict[str, Any], **Any) -> HttpResponse
 | 
			
		||||
        encoded = urllib.parse.urlencode(info)
 | 
			
		||||
        return self.client.delete(url, encoded, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def login(self, email, password=None):
 | 
			
		||||
        # type: (str, Optional[str]) -> HttpResponse
 | 
			
		||||
        if password is None:
 | 
			
		||||
            password = initial_password(email)
 | 
			
		||||
        return self.client.post('/accounts/login/',
 | 
			
		||||
                                {'username': email, 'password': password})
 | 
			
		||||
 | 
			
		||||
    def register(self, username, password, domain="zulip.com"):
 | 
			
		||||
        # type: (str, str, str) -> HttpResponse
 | 
			
		||||
        self.client.post('/accounts/home/',
 | 
			
		||||
                         {'email': username + "@" + domain})
 | 
			
		||||
        return self.submit_reg_form_for_user(username, password, domain=domain)
 | 
			
		||||
 | 
			
		||||
    def submit_reg_form_for_user(self, username, password, domain="zulip.com"):
 | 
			
		||||
        # type: (str, str, str) -> HttpResponse
 | 
			
		||||
        """
 | 
			
		||||
        Stage two of the two-step registration process.
 | 
			
		||||
 | 
			
		||||
@@ -242,29 +269,33 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
                                 'terms': True})
 | 
			
		||||
 | 
			
		||||
    def get_api_key(self, email):
 | 
			
		||||
        # type: (str) -> str
 | 
			
		||||
        if email not in API_KEYS:
 | 
			
		||||
            API_KEYS[email] =  get_user_profile_by_email(email).api_key
 | 
			
		||||
            API_KEYS[email] = get_user_profile_by_email(email).api_key
 | 
			
		||||
        return API_KEYS[email]
 | 
			
		||||
 | 
			
		||||
    def api_auth(self, email):
 | 
			
		||||
        # type: (str) -> Dict[str, str]
 | 
			
		||||
        credentials = "%s:%s" % (email, self.get_api_key(email))
 | 
			
		||||
        return {
 | 
			
		||||
            'HTTP_AUTHORIZATION': 'Basic ' + base64.b64encode(credentials)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def get_streams(self, email):
 | 
			
		||||
        # type: (str) -> Iterable[Dict[str, Any]]
 | 
			
		||||
        """
 | 
			
		||||
        Helper function to get the stream names for a user
 | 
			
		||||
        """
 | 
			
		||||
        user_profile = get_user_profile_by_email(email)
 | 
			
		||||
        subs = Subscription.objects.filter(
 | 
			
		||||
            user_profile    = user_profile,
 | 
			
		||||
            active          = True,
 | 
			
		||||
            recipient__type = Recipient.STREAM)
 | 
			
		||||
            user_profile=user_profile,
 | 
			
		||||
            active=True,
 | 
			
		||||
            recipient__type=Recipient.STREAM)
 | 
			
		||||
        return [get_display_recipient(sub.recipient) for sub in subs]
 | 
			
		||||
 | 
			
		||||
    def send_message(self, sender_name, recipient_list, message_type,
 | 
			
		||||
                     content="test content", subject="test", **kwargs):
 | 
			
		||||
        # type: (str, Iterable[str], int, str, str, **Any) -> int
 | 
			
		||||
        sender = get_user_profile_by_email(sender_name)
 | 
			
		||||
        if message_type == Recipient.PERSONAL:
 | 
			
		||||
            message_type_name = "private"
 | 
			
		||||
@@ -280,6 +311,7 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
            forwarder_user_profile=sender, realm=sender.realm, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def get_old_messages(self, anchor=1, num_before=100, num_after=100):
 | 
			
		||||
        # type: (int, int, int) -> List[Dict[str, Any]]
 | 
			
		||||
        post_params = {"anchor": anchor, "num_before": num_before,
 | 
			
		||||
                       "num_after": num_after}
 | 
			
		||||
        result = self.client.get("/json/messages", dict(post_params))
 | 
			
		||||
@@ -287,6 +319,7 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        return data['messages']
 | 
			
		||||
 | 
			
		||||
    def users_subscribed_to_stream(self, stream_name, realm_domain):
 | 
			
		||||
        # type: (str, str) -> List[UserProfile]
 | 
			
		||||
        realm = get_realm(realm_domain)
 | 
			
		||||
        stream = Stream.objects.get(name=stream_name, realm=realm)
 | 
			
		||||
        recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)
 | 
			
		||||
@@ -295,6 +328,7 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        return [subscription.user_profile for subscription in subscriptions]
 | 
			
		||||
 | 
			
		||||
    def assert_json_success(self, result):
 | 
			
		||||
        # type: (HttpResponse) -> Dict[str, Any]
 | 
			
		||||
        """
 | 
			
		||||
        Successful POSTs return a 200 and JSON of the form {"result": "success",
 | 
			
		||||
        "msg": ""}.
 | 
			
		||||
@@ -308,12 +342,14 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        return json
 | 
			
		||||
 | 
			
		||||
    def get_json_error(self, result, status_code=400):
 | 
			
		||||
        # type: (HttpResponse, int) -> Dict[str, Any]
 | 
			
		||||
        self.assertEqual(result.status_code, status_code)
 | 
			
		||||
        json = ujson.loads(result.content)
 | 
			
		||||
        self.assertEqual(json.get("result"), "error")
 | 
			
		||||
        return json['msg']
 | 
			
		||||
 | 
			
		||||
    def assert_json_error(self, result, msg, status_code=400):
 | 
			
		||||
        # type: (HttpResponse, str, int) -> None
 | 
			
		||||
        """
 | 
			
		||||
        Invalid POSTs return an error status code and JSON of the form
 | 
			
		||||
        {"result": "error", "msg": "reason"}.
 | 
			
		||||
@@ -321,6 +357,7 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
 | 
			
		||||
 | 
			
		||||
    def assert_length(self, queries, count, exact=False):
 | 
			
		||||
        # type: (Sized, int, bool) -> None
 | 
			
		||||
        actual_count = len(queries)
 | 
			
		||||
        if exact:
 | 
			
		||||
            return self.assertTrue(actual_count == count,
 | 
			
		||||
@@ -329,15 +366,19 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
                               "len(%s) == %s, > %s" % (queries, actual_count, count))
 | 
			
		||||
 | 
			
		||||
    def assert_json_error_contains(self, result, msg_substring, status_code=400):
 | 
			
		||||
        # type: (HttpResponse, str, int) -> None
 | 
			
		||||
        self.assertIn(msg_substring, self.get_json_error(result, status_code=status_code))
 | 
			
		||||
 | 
			
		||||
    def fixture_data(self, type, action, file_type='json'):
 | 
			
		||||
        # type: (str, str, str) -> str
 | 
			
		||||
        return open(os.path.join(os.path.dirname(__file__),
 | 
			
		||||
                                 "../fixtures/%s/%s_%s.%s" % (type, type, action, file_type))).read()
 | 
			
		||||
 | 
			
		||||
    # Subscribe to a stream directly
 | 
			
		||||
    def subscribe_to_stream(self, email, stream_name, realm=None):
 | 
			
		||||
        realm = get_realm(resolve_email_to_domain(email))
 | 
			
		||||
        # type: (str, str, Optional[Realm]) -> None
 | 
			
		||||
        if realm is None:
 | 
			
		||||
            realm = get_realm(resolve_email_to_domain(email))
 | 
			
		||||
        stream = get_stream(stream_name, realm)
 | 
			
		||||
        if stream is None:
 | 
			
		||||
            stream, _ = create_stream_if_needed(realm, stream_name)
 | 
			
		||||
@@ -345,7 +386,8 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        do_add_subscription(user_profile, stream, no_log=True)
 | 
			
		||||
 | 
			
		||||
    # Subscribe to a stream by making an API request
 | 
			
		||||
    def common_subscribe_to_streams(self, email, streams, extra_post_data = {}, invite_only=False):
 | 
			
		||||
    def common_subscribe_to_streams(self, email, streams, extra_post_data={}, invite_only=False):
 | 
			
		||||
        # type: (str, Iterable[str], Dict[str, Any], bool) -> HttpResponse
 | 
			
		||||
        post_data = {'subscriptions': ujson.dumps([{"name": stream} for stream in streams]),
 | 
			
		||||
                     'invite_only': ujson.dumps(invite_only)}
 | 
			
		||||
        post_data.update(extra_post_data)
 | 
			
		||||
@@ -353,7 +395,8 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def send_json_payload(self, email, url, payload, stream_name=None, **post_params):
 | 
			
		||||
        if stream_name != None:
 | 
			
		||||
        # type: (str, str, Dict[str, Any], Optional[str], **Any) -> Message
 | 
			
		||||
        if stream_name is not None:
 | 
			
		||||
            self.subscribe_to_stream(email, stream_name)
 | 
			
		||||
 | 
			
		||||
        result = self.client.post(url, payload, **post_params)
 | 
			
		||||
@@ -367,9 +410,11 @@ class AuthedTestCase(TestCase):
 | 
			
		||||
        return msg
 | 
			
		||||
 | 
			
		||||
    def get_last_message(self):
 | 
			
		||||
        # type: () -> Message
 | 
			
		||||
        return Message.objects.latest('id')
 | 
			
		||||
 | 
			
		||||
def get_all_templates():
 | 
			
		||||
    # type: () -> List[str]
 | 
			
		||||
    templates = []
 | 
			
		||||
 | 
			
		||||
    relpath = os.path.relpath
 | 
			
		||||
@@ -379,6 +424,7 @@ def get_all_templates():
 | 
			
		||||
    is_valid_template = lambda p, n: not n.startswith('.') and isfile(p)
 | 
			
		||||
 | 
			
		||||
    def process(template_dir, dirname, fnames):
 | 
			
		||||
        # type: (str, str, Iterable[str]) -> None
 | 
			
		||||
        for name in fnames:
 | 
			
		||||
            path = os.path.join(dirname, name)
 | 
			
		||||
            if is_valid_template(path, name):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user