Add type annotations to zerver.lib.test_helpers.

This commit is contained in:
Conrad Dean
2016-06-02 17:10:13 -07:00
committed by Tim Abbott
parent 4a10923bf1
commit 7f61a5e862

View File

@@ -1,8 +1,9 @@
from __future__ import absolute_import
from typing import Any, Callable, Generator, Iterable, Tuple
from typing import Any, Callable, Generator, Iterable, Tuple, Sized, Union, Optional
from django.test import TestCase
from django.template import loader
from django.http import HttpResponse
from zerver.lib.initial_password import initial_password
from zerver.lib.db import TimeTrackingCursor
@@ -16,8 +17,6 @@ from zerver.lib.actions import (
get_display_recipient,
)
from zerver.lib.handlers import allocate_handler_id
from zerver.models import (
get_realm,
get_stream,
@@ -30,6 +29,7 @@ from zerver.models import (
Stream,
Subscription,
UserMessage,
UserProfile,
)
import base64
@@ -70,13 +70,15 @@ def tornado_redirected_to_list(lst):
@contextmanager
def simulated_empty_cache():
# type: () -> Generator[List[Tuple[str, str, str]], None, None]
cache_queries = []
# type: () -> Generator[List[Tuple[str, Union[str, List[str]], str]], None, None]
cache_queries = [] # type: List[Tuple[str, Union[str, List[str]], str]]
def my_cache_get(key, cache_name=None):
# type: (str, Optional[str]) -> Any
cache_queries.append(('get', key, cache_name))
return None
def my_cache_get_many(keys, cache_name=None):
# type: (List[str], Optional[str]) -> Dict[str, Any]
cache_queries.append(('getmany', keys, cache_name))
return None
@@ -99,6 +101,7 @@ def queries_captured():
queries = []
def wrapper_execute(self, action, sql, params=()):
# type: (TimeTrackingCursor, Callable, str, Iterable[Any]) -> None
start = time.time()
try:
return action(sql, params)
@@ -106,19 +109,21 @@ def queries_captured():
stop = time.time()
duration = stop - start
queries.append({
'sql': self.mogrify(sql, params),
'time': "%.3f" % duration,
})
'sql': self.mogrify(sql, params),
'time': "%.3f" % duration,
})
old_execute = TimeTrackingCursor.execute
old_executemany = TimeTrackingCursor.executemany
def cursor_execute(self, sql, params=()):
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
# type: (TimeTrackingCursor, str, Iterable[Any]) -> None
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
TimeTrackingCursor.execute = cursor_execute # type: ignore # https://github.com/JukkaL/mypy/issues/1167
def cursor_executemany(self, sql, params=()):
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
# type: (TimeTrackingCursor, str, Iterable[Any]) -> None
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
TimeTrackingCursor.executemany = cursor_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
yield queries
@@ -128,6 +133,7 @@ def queries_captured():
def find_key_by_email(address):
# type: (str) -> str
from django.core.mail import outbox
key_regex = re.compile("accounts/do_confirm/([a-f0-9]{40})>")
for message in reversed(outbox):
@@ -135,15 +141,18 @@ def find_key_by_email(address):
return key_regex.search(message.body).groups()[0]
def message_ids(result):
# type: (Dict[str, Any]) -> Set[int]
return set(message['id'] for message in result['messages'])
def message_stream_count(user_profile):
# type: (UserProfile) -> int
return UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
count()
def most_recent_usermessage(user_profile):
# type: (UserProfile) -> UserMessage
query = UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
@@ -151,10 +160,12 @@ def most_recent_usermessage(user_profile):
return query[0] # Django does LIMIT here
def most_recent_message(user_profile):
# type: (UserProfile) -> Message
usermessage = most_recent_usermessage(user_profile)
return usermessage.message
def get_user_messages(user_profile):
# type: (UserProfile) -> List[Message]
query = UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
@@ -166,11 +177,13 @@ class DummyObject(object):
class DummyTornadoRequest(object):
def __init__(self):
# type: () -> None
self.connection = DummyObject()
self.connection.stream = DummyStream() # type: ignore # monkey-patching here
class DummyHandler(object):
def __init__(self, assert_callback):
# type: (Any) -> None
self.assert_callback = assert_callback
self.request = DummyTornadoRequest()
allocate_handler_id(self)
@@ -178,12 +191,15 @@ class DummyHandler(object):
# Mocks RequestHandler.async_callback, which wraps a callback to
# handle exceptions. We return the callback as-is.
def async_callback(self, cb):
# type: (Callable) -> Callable
return cb
def write(self, response):
# type: (str) -> None
raise NotImplemented
def zulip_finish(self, response, *ignore):
# type: (HttpResponse, *Any) -> None
if self.assert_callback:
self.assert_callback(response)
@@ -193,12 +209,14 @@ class DummySession(object):
class DummyStream(object):
def closed(self):
# type: () -> bool
return False
class POSTRequestMock(object):
method = "POST"
def __init__(self, post_data, user_profile, assert_callback=None):
# type: (Dict[str, Any], UserProfile, Optional[Callable]) -> None
self.REQUEST = self.POST = post_data
self.user = user_profile
self._tornado_handler = DummyHandler(assert_callback)
@@ -208,28 +226,37 @@ class POSTRequestMock(object):
class AuthedTestCase(TestCase):
# Helper because self.client.patch annoying requires you to urlencode
def client_patch(self, url, info={}, **kwargs):
info = urllib.parse.urlencode(info)
return self.client.patch(url, info, **kwargs)
# type: (str, Dict[str, Any], **Any) -> HttpResponse
encoded = urllib.parse.urlencode(info)
return self.client.patch(url, encoded, **kwargs)
def client_put(self, url, info={}, **kwargs):
info = urllib.parse.urlencode(info)
return self.client.put(url, info, **kwargs)
# type: (str, Dict[str, Any], **Any) -> HttpResponse
encoded = urllib.parse.urlencode(info)
return self.client.put(url, encoded, **kwargs)
def client_delete(self, url, info={}, **kwargs):
info = urllib.parse.urlencode(info)
return self.client.delete(url, info, **kwargs)
# type: (str, Dict[str, Any], **Any) -> HttpResponse
encoded = urllib.parse.urlencode(info)
return self.client.delete(url, encoded, **kwargs)
def login(self, email, password=None):
# type: (str, Optional[str]) -> HttpResponse
if password is None:
password = initial_password(email)
return self.client.post('/accounts/login/',
{'username': email, 'password': password})
def register(self, username, password, domain="zulip.com"):
# type: (str, str, str) -> HttpResponse
self.client.post('/accounts/home/',
{'email': username + "@" + domain})
return self.submit_reg_form_for_user(username, password, domain=domain)
def submit_reg_form_for_user(self, username, password, domain="zulip.com"):
# type: (str, str, str) -> HttpResponse
"""
Stage two of the two-step registration process.
@@ -242,29 +269,33 @@ class AuthedTestCase(TestCase):
'terms': True})
def get_api_key(self, email):
# type: (str) -> str
if email not in API_KEYS:
API_KEYS[email] = get_user_profile_by_email(email).api_key
API_KEYS[email] = get_user_profile_by_email(email).api_key
return API_KEYS[email]
def api_auth(self, email):
# type: (str) -> Dict[str, str]
credentials = "%s:%s" % (email, self.get_api_key(email))
return {
'HTTP_AUTHORIZATION': 'Basic ' + base64.b64encode(credentials)
}
}
def get_streams(self, email):
# type: (str) -> Iterable[Dict[str, Any]]
"""
Helper function to get the stream names for a user
"""
user_profile = get_user_profile_by_email(email)
subs = Subscription.objects.filter(
user_profile = user_profile,
active = True,
recipient__type = Recipient.STREAM)
user_profile=user_profile,
active=True,
recipient__type=Recipient.STREAM)
return [get_display_recipient(sub.recipient) for sub in subs]
def send_message(self, sender_name, recipient_list, message_type,
content="test content", subject="test", **kwargs):
# type: (str, Iterable[str], int, str, str, **Any) -> int
sender = get_user_profile_by_email(sender_name)
if message_type == Recipient.PERSONAL:
message_type_name = "private"
@@ -280,6 +311,7 @@ class AuthedTestCase(TestCase):
forwarder_user_profile=sender, realm=sender.realm, **kwargs)
def get_old_messages(self, anchor=1, num_before=100, num_after=100):
# type: (int, int, int) -> List[Dict[str, Any]]
post_params = {"anchor": anchor, "num_before": num_before,
"num_after": num_after}
result = self.client.get("/json/messages", dict(post_params))
@@ -287,6 +319,7 @@ class AuthedTestCase(TestCase):
return data['messages']
def users_subscribed_to_stream(self, stream_name, realm_domain):
# type: (str, str) -> List[UserProfile]
realm = get_realm(realm_domain)
stream = Stream.objects.get(name=stream_name, realm=realm)
recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)
@@ -295,6 +328,7 @@ class AuthedTestCase(TestCase):
return [subscription.user_profile for subscription in subscriptions]
def assert_json_success(self, result):
# type: (HttpResponse) -> Dict[str, Any]
"""
Successful POSTs return a 200 and JSON of the form {"result": "success",
"msg": ""}.
@@ -308,12 +342,14 @@ class AuthedTestCase(TestCase):
return json
def get_json_error(self, result, status_code=400):
# type: (HttpResponse, int) -> Dict[str, Any]
self.assertEqual(result.status_code, status_code)
json = ujson.loads(result.content)
self.assertEqual(json.get("result"), "error")
return json['msg']
def assert_json_error(self, result, msg, status_code=400):
# type: (HttpResponse, str, int) -> None
"""
Invalid POSTs return an error status code and JSON of the form
{"result": "error", "msg": "reason"}.
@@ -321,6 +357,7 @@ class AuthedTestCase(TestCase):
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
def assert_length(self, queries, count, exact=False):
# type: (Sized, int, bool) -> None
actual_count = len(queries)
if exact:
return self.assertTrue(actual_count == count,
@@ -329,15 +366,19 @@ class AuthedTestCase(TestCase):
"len(%s) == %s, > %s" % (queries, actual_count, count))
def assert_json_error_contains(self, result, msg_substring, status_code=400):
# type: (HttpResponse, str, int) -> None
self.assertIn(msg_substring, self.get_json_error(result, status_code=status_code))
def fixture_data(self, type, action, file_type='json'):
# type: (str, str, str) -> str
return open(os.path.join(os.path.dirname(__file__),
"../fixtures/%s/%s_%s.%s" % (type, type, action, file_type))).read()
# Subscribe to a stream directly
def subscribe_to_stream(self, email, stream_name, realm=None):
realm = get_realm(resolve_email_to_domain(email))
# type: (str, str, Optional[Realm]) -> None
if realm is None:
realm = get_realm(resolve_email_to_domain(email))
stream = get_stream(stream_name, realm)
if stream is None:
stream, _ = create_stream_if_needed(realm, stream_name)
@@ -345,7 +386,8 @@ class AuthedTestCase(TestCase):
do_add_subscription(user_profile, stream, no_log=True)
# Subscribe to a stream by making an API request
def common_subscribe_to_streams(self, email, streams, extra_post_data = {}, invite_only=False):
def common_subscribe_to_streams(self, email, streams, extra_post_data={}, invite_only=False):
# type: (str, Iterable[str], Dict[str, Any], bool) -> HttpResponse
post_data = {'subscriptions': ujson.dumps([{"name": stream} for stream in streams]),
'invite_only': ujson.dumps(invite_only)}
post_data.update(extra_post_data)
@@ -353,7 +395,8 @@ class AuthedTestCase(TestCase):
return result
def send_json_payload(self, email, url, payload, stream_name=None, **post_params):
if stream_name != None:
# type: (str, str, Dict[str, Any], Optional[str], **Any) -> Message
if stream_name is not None:
self.subscribe_to_stream(email, stream_name)
result = self.client.post(url, payload, **post_params)
@@ -367,9 +410,11 @@ class AuthedTestCase(TestCase):
return msg
def get_last_message(self):
# type: () -> Message
return Message.objects.latest('id')
def get_all_templates():
# type: () -> List[str]
templates = []
relpath = os.path.relpath
@@ -379,6 +424,7 @@ def get_all_templates():
is_valid_template = lambda p, n: not n.startswith('.') and isfile(p)
def process(template_dir, dirname, fnames):
# type: (str, str, Iterable[str]) -> None
for name in fnames:
path = os.path.join(dirname, name)
if is_valid_template(path, name):