mirror of
https://github.com/zulip/zulip.git
synced 2025-11-05 14:35:27 +00:00
Adds a new field org_type to Realm. Defaults for restricted_to_domain and invite_required are now controlled by org_type at time of realm creation (see zerver.lib.actions.do_create_realm), rather than at the database level. Note that the backend defaults are all org_type=corporate, since that matches the current assumptions in the codebase, whereas the frontend default is org_type=community, since if a user isn't sure they probably want community. Since we will likely in the future enable/disable various administrative features based on whether an organization is corporate or community, we discuss those issues in the realm creation form. Before we actually implement any such features, we'll want to make sure users understand what type of organization they are a member of. Choice of org_type (via radio button) has been added to the realm creation flow and the realm creation management command, and the open-realm option removed. The database defaults have not been changed, which allows our testing code to work unchanged. [includes some HTML/CSS work by Brock Whittaker to make it look nice]
686 lines
27 KiB
Python
686 lines
27 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from contextlib import contextmanager
|
|
from typing import (cast, Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional,
|
|
Sized, Tuple, Union)
|
|
|
|
from django.conf import settings
|
|
from django.test import TestCase
|
|
from django.test.client import (
|
|
BOUNDARY, MULTIPART_CONTENT, encode_multipart,
|
|
)
|
|
from django.template import loader
|
|
from django.http import HttpResponse
|
|
from django.utils.translation import ugettext as _
|
|
|
|
from zerver.lib.initial_password import initial_password
|
|
from zerver.lib.db import TimeTrackingCursor
|
|
from zerver.lib.handlers import allocate_handler_id
|
|
from zerver.lib.str_utils import force_text
|
|
from zerver.lib import cache
|
|
from zerver.lib import event_queue
|
|
from zerver.worker import queue_processors
|
|
|
|
from zerver.lib.actions import (
|
|
check_send_message, create_stream_if_needed, do_add_subscription,
|
|
get_display_recipient,
|
|
)
|
|
|
|
from zerver.models import (
|
|
get_realm,
|
|
get_stream,
|
|
get_user_profile_by_email,
|
|
resolve_email_to_domain,
|
|
Client,
|
|
Message,
|
|
Realm,
|
|
Recipient,
|
|
Stream,
|
|
Subscription,
|
|
UserMessage,
|
|
UserProfile,
|
|
)
|
|
|
|
from zerver.lib.request import JsonableError
|
|
|
|
|
|
import base64
|
|
import mock
|
|
import os
|
|
import re
|
|
import time
|
|
import ujson
|
|
import unittest
|
|
from six.moves import urllib
|
|
from six import text_type, binary_type
|
|
from zerver.lib.str_utils import NonBinaryStr
|
|
|
|
from contextlib import contextmanager
|
|
import six
|
|
|
|
API_KEYS = {} # type: Dict[text_type, text_type]
|
|
|
|
skip_py3 = unittest.skipIf(six.PY3, "Expected failure on Python 3")
|
|
|
|
|
|
@contextmanager
|
|
def simulated_queue_client(client):
|
|
# type: (type) -> Generator[None, None, None]
|
|
real_SimpleQueueClient = queue_processors.SimpleQueueClient
|
|
queue_processors.SimpleQueueClient = client # type: ignore # https://github.com/JukkaL/mypy/issues/1152
|
|
yield
|
|
queue_processors.SimpleQueueClient = real_SimpleQueueClient # type: ignore # https://github.com/JukkaL/mypy/issues/1152
|
|
|
|
@contextmanager
|
|
def tornado_redirected_to_list(lst):
|
|
# type: (List[Mapping[str, Any]]) -> Generator[None, None, None]
|
|
real_event_queue_process_notification = event_queue.process_notification
|
|
event_queue.process_notification = lst.append
|
|
yield
|
|
event_queue.process_notification = real_event_queue_process_notification
|
|
|
|
@contextmanager
|
|
def simulated_empty_cache():
|
|
# type: () -> Generator[List[Tuple[str, Union[text_type, List[text_type]], text_type]], None, None]
|
|
cache_queries = [] # type: List[Tuple[str, Union[text_type, List[text_type]], text_type]]
|
|
def my_cache_get(key, cache_name=None):
|
|
# type: (text_type, Optional[str]) -> Any
|
|
cache_queries.append(('get', key, cache_name))
|
|
return None
|
|
|
|
def my_cache_get_many(keys, cache_name=None):
|
|
# type: (List[text_type], Optional[str]) -> Dict[text_type, Any]
|
|
cache_queries.append(('getmany', keys, cache_name))
|
|
return None
|
|
|
|
old_get = cache.cache_get
|
|
old_get_many = cache.cache_get_many
|
|
cache.cache_get = my_cache_get
|
|
cache.cache_get_many = my_cache_get_many
|
|
yield cache_queries
|
|
cache.cache_get = old_get
|
|
cache.cache_get_many = old_get_many
|
|
|
|
@contextmanager
|
|
def queries_captured():
|
|
# type: () -> Generator[List[Dict[str, Union[str, binary_type]]], None, None]
|
|
'''
|
|
Allow a user to capture just the queries executed during
|
|
the with statement.
|
|
'''
|
|
|
|
queries = [] # type: List[Dict[str, Union[str, binary_type]]]
|
|
|
|
def wrapper_execute(self, action, sql, params=()):
|
|
# type: (TimeTrackingCursor, Callable, NonBinaryStr, Iterable[Any]) -> None
|
|
start = time.time()
|
|
try:
|
|
return action(sql, params)
|
|
finally:
|
|
stop = time.time()
|
|
duration = stop - start
|
|
queries.append({
|
|
'sql': self.mogrify(sql, params).decode('utf-8'),
|
|
'time': "%.3f" % duration,
|
|
})
|
|
|
|
old_execute = TimeTrackingCursor.execute
|
|
old_executemany = TimeTrackingCursor.executemany
|
|
|
|
def cursor_execute(self, sql, params=()):
|
|
# type: (TimeTrackingCursor, NonBinaryStr, Iterable[Any]) -> None
|
|
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
TimeTrackingCursor.execute = cursor_execute # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
|
|
def cursor_executemany(self, sql, params=()):
|
|
# type: (TimeTrackingCursor, NonBinaryStr, Iterable[Any]) -> None
|
|
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
TimeTrackingCursor.executemany = cursor_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
|
|
yield queries
|
|
|
|
TimeTrackingCursor.execute = old_execute # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
TimeTrackingCursor.executemany = old_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
|
|
|
|
|
|
def make_client(name):
|
|
# type: (str) -> Client
|
|
client, _ = Client.objects.get_or_create(name=name)
|
|
return client
|
|
|
|
def find_key_by_email(address):
|
|
# type: (text_type) -> text_type
|
|
from django.core.mail import outbox
|
|
key_regex = re.compile("accounts/do_confirm/([a-f0-9]{40})>")
|
|
for message in reversed(outbox):
|
|
if address in message.to:
|
|
return key_regex.search(message.body).groups()[0]
|
|
|
|
def message_ids(result):
|
|
# type: (Dict[str, Any]) -> Set[int]
|
|
return set(message['id'] for message in result['messages'])
|
|
|
|
def message_stream_count(user_profile):
|
|
# type: (UserProfile) -> int
|
|
return UserMessage.objects. \
|
|
select_related("message"). \
|
|
filter(user_profile=user_profile). \
|
|
count()
|
|
|
|
def most_recent_usermessage(user_profile):
|
|
# type: (UserProfile) -> UserMessage
|
|
query = UserMessage.objects. \
|
|
select_related("message"). \
|
|
filter(user_profile=user_profile). \
|
|
order_by('-message')
|
|
return query[0] # Django does LIMIT here
|
|
|
|
def most_recent_message(user_profile):
|
|
# type: (UserProfile) -> Message
|
|
usermessage = most_recent_usermessage(user_profile)
|
|
return usermessage.message
|
|
|
|
def get_user_messages(user_profile):
|
|
# type: (UserProfile) -> List[Message]
|
|
query = UserMessage.objects. \
|
|
select_related("message"). \
|
|
filter(user_profile=user_profile). \
|
|
order_by('message')
|
|
return [um.message for um in query]
|
|
|
|
class DummyHandler(object):
|
|
def __init__(self):
|
|
# type: (Callable) -> None
|
|
allocate_handler_id(self)
|
|
|
|
class POSTRequestMock(object):
|
|
method = "POST"
|
|
|
|
def __init__(self, post_data, user_profile):
|
|
# type: (Dict[str, Any], UserProfile) -> None
|
|
self.REQUEST = self.POST = post_data
|
|
self.user = user_profile
|
|
self._tornado_handler = DummyHandler()
|
|
self._log_data = {} # type: Dict[str, Any]
|
|
self.META = {'PATH_INFO': 'test'}
|
|
|
|
class HostRequestMock(object):
|
|
"""A mock request object where get_host() works. Useful for testing
|
|
routes that use Zulip's subdomains feature"""
|
|
def __init__(self, host=settings.EXTERNAL_HOST):
|
|
# type: (text_type) -> None
|
|
self.host = host
|
|
|
|
def get_host(self):
|
|
# type: () -> text_type
|
|
return self.host
|
|
|
|
INSTRUMENTING = os.environ.get('TEST_INSTRUMENT_URL_COVERAGE', '') == 'TRUE'
|
|
INSTRUMENTED_CALLS = [] # type: List[Dict[str, Any]]
|
|
|
|
UrlFuncT = Callable[..., HttpResponse] # TODO: make more specific
|
|
|
|
def instrument_url(f):
|
|
# type: (UrlFuncT) -> UrlFuncT
|
|
if not INSTRUMENTING:
|
|
return f
|
|
else:
|
|
def wrapper(self, url, info={}, **kwargs):
|
|
# type: (Any, text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
start = time.time()
|
|
result = f(self, url, info, **kwargs)
|
|
delay = time.time() - start
|
|
test_name = self.id()
|
|
if '?' in url:
|
|
url, extra_info = url.split('?', 1)
|
|
else:
|
|
extra_info = ''
|
|
|
|
INSTRUMENTED_CALLS.append(dict(
|
|
url=url,
|
|
status_code=result.status_code,
|
|
method=f.__name__,
|
|
delay=delay,
|
|
extra_info=extra_info,
|
|
info=info,
|
|
test_name=test_name,
|
|
kwargs=kwargs))
|
|
return result
|
|
return wrapper
|
|
|
|
def write_instrumentation_reports():
|
|
# type: () -> None
|
|
if INSTRUMENTING:
|
|
calls = INSTRUMENTED_CALLS
|
|
var_dir = 'var' # TODO make sure path is robust here
|
|
fn = os.path.join(var_dir, 'url_coverage.txt')
|
|
with open(fn, 'w') as f:
|
|
for call in calls:
|
|
try:
|
|
line = ujson.dumps(call)
|
|
f.write(line + '\n')
|
|
except OverflowError:
|
|
print('''
|
|
A JSON overflow error was encountered while
|
|
producing the URL coverage report. Sometimes
|
|
this indicates that a test is passing objects
|
|
into methods like client_post(), which is
|
|
unnecessary and leads to false positives.
|
|
''')
|
|
print(call)
|
|
|
|
print('URL coverage report is in %s' % (fn,))
|
|
print('Try running: ./tools/analyze-url-coverage')
|
|
|
|
# Find our untested urls.
|
|
from zproject.urls import urlpatterns
|
|
untested_patterns = []
|
|
for pattern in urlpatterns:
|
|
for call in calls:
|
|
url = call['url']
|
|
if url.startswith('/'):
|
|
url = url[1:]
|
|
if pattern.regex.match(url):
|
|
break
|
|
else:
|
|
untested_patterns.append(pattern.regex.pattern)
|
|
|
|
fn = os.path.join(var_dir, 'untested_url_report.txt')
|
|
with open(fn, 'w') as f:
|
|
f.write('untested urls\n')
|
|
for untested_pattern in sorted(untested_patterns):
|
|
f.write(' %s\n' % (untested_pattern,))
|
|
print('Untested-url report is in %s' % (fn,))
|
|
|
|
|
|
class ZulipTestCase(TestCase):
|
|
'''
|
|
WRAPPER_COMMENT:
|
|
|
|
We wrap calls to self.client.{patch,put,get,post,delete} for various
|
|
reasons. Some of this has to do with fixing encodings before calling
|
|
into the Django code. Some of this has to do with providing a future
|
|
path for instrumentation. Some of it's just consistency.
|
|
|
|
The linter will prevent direct calls to self.client.foo, so the wrapper
|
|
functions have to fake out the linter by using a local variable called
|
|
django_client to fool the regext.
|
|
'''
|
|
|
|
@instrument_url
|
|
def client_patch(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
"""
|
|
We need to urlencode, since Django's function won't do it for us.
|
|
"""
|
|
encoded = urllib.parse.urlencode(info)
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.patch(url, encoded, **kwargs)
|
|
|
|
@instrument_url
|
|
def client_patch_multipart(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
"""
|
|
Use this for patch requests that have file uploads or
|
|
that need some sort of multi-part content. In the future
|
|
Django's test client may become a bit more flexible,
|
|
so we can hopefully eliminate this. (When you post
|
|
with the Django test client, it deals with MULTIPART_CONTENT
|
|
automatically, but not patch.)
|
|
"""
|
|
encoded = encode_multipart(BOUNDARY, info)
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.patch(
|
|
url,
|
|
encoded,
|
|
content_type=MULTIPART_CONTENT,
|
|
**kwargs)
|
|
|
|
@instrument_url
|
|
def client_put(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
encoded = urllib.parse.urlencode(info)
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.put(url, encoded, **kwargs)
|
|
|
|
@instrument_url
|
|
def client_delete(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
encoded = urllib.parse.urlencode(info)
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.delete(url, encoded, **kwargs)
|
|
|
|
@instrument_url
|
|
def client_post(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.post(url, info, **kwargs)
|
|
|
|
@instrument_url
|
|
def client_get(self, url, info={}, **kwargs):
|
|
# type: (text_type, Dict[str, Any], **Any) -> HttpResponse
|
|
django_client = self.client # see WRAPPER_COMMENT
|
|
return django_client.get(url, info, **kwargs)
|
|
|
|
def login_with_return(self, email, password=None):
|
|
# type: (text_type, Optional[text_type]) -> HttpResponse
|
|
if password is None:
|
|
password = initial_password(email)
|
|
return self.client_post('/accounts/login/',
|
|
{'username': email, 'password': password})
|
|
|
|
def login(self, email, password=None, fails=False):
|
|
# type: (text_type, Optional[text_type], bool) -> HttpResponse
|
|
if password is None:
|
|
password = initial_password(email)
|
|
if not fails:
|
|
self.assertTrue(self.client.login(username=email, password=password))
|
|
else:
|
|
self.assertFalse(self.client.login(username=email, password=password))
|
|
|
|
def register(self, username, password, domain="zulip.com"):
|
|
# type: (text_type, text_type, text_type) -> HttpResponse
|
|
self.client_post('/accounts/home/',
|
|
{'email': username + "@" + domain})
|
|
return self.submit_reg_form_for_user(username, password, domain=domain)
|
|
|
|
def submit_reg_form_for_user(self, username, password, domain="zulip.com",
|
|
realm_name=None, realm_subdomain=None,
|
|
realm_org_type=Realm.COMMUNITY, **kwargs):
|
|
# type: (text_type, text_type, text_type, Optional[text_type], Optional[text_type], int, **Any) -> HttpResponse
|
|
"""
|
|
Stage two of the two-step registration process.
|
|
|
|
If things are working correctly the account should be fully
|
|
registered after this call.
|
|
|
|
You can pass the HTTP_HOST variable for subdomains via kwargs.
|
|
"""
|
|
return self.client_post('/accounts/register/',
|
|
{'full_name': username, 'password': password,
|
|
'realm_name': realm_name,
|
|
'realm_subdomain': realm_subdomain,
|
|
'key': find_key_by_email(username + '@' + domain),
|
|
'realm_org_type': realm_org_type,
|
|
'terms': True},
|
|
**kwargs)
|
|
|
|
def get_confirmation_url_from_outbox(self, email_address, path_pattern="(\S+)>"):
|
|
# type: (text_type, text_type) -> text_type
|
|
from django.core.mail import outbox
|
|
for message in reversed(outbox):
|
|
if email_address in message.to:
|
|
return re.search(settings.EXTERNAL_HOST + path_pattern,
|
|
message.body).groups()[0]
|
|
else:
|
|
raise ValueError("Couldn't find a confirmation email.")
|
|
|
|
def get_api_key(self, email):
|
|
# type: (text_type) -> text_type
|
|
if email not in API_KEYS:
|
|
API_KEYS[email] = get_user_profile_by_email(email).api_key
|
|
return API_KEYS[email]
|
|
|
|
def api_auth(self, email):
|
|
# type: (text_type) -> Dict[str, text_type]
|
|
credentials = u"%s:%s" % (email, self.get_api_key(email))
|
|
return {
|
|
'HTTP_AUTHORIZATION': u'Basic ' + base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
|
|
}
|
|
|
|
def get_streams(self, email):
|
|
# type: (text_type) -> List[text_type]
|
|
"""
|
|
Helper function to get the stream names for a user
|
|
"""
|
|
user_profile = get_user_profile_by_email(email)
|
|
subs = Subscription.objects.filter(
|
|
user_profile=user_profile,
|
|
active=True,
|
|
recipient__type=Recipient.STREAM)
|
|
return [cast(text_type, get_display_recipient(sub.recipient)) for sub in subs]
|
|
|
|
def send_message(self, sender_name, raw_recipients, message_type,
|
|
content=u"test content", subject=u"test", **kwargs):
|
|
# type: (text_type, Union[text_type, List[text_type]], int, text_type, text_type, **Any) -> int
|
|
sender = get_user_profile_by_email(sender_name)
|
|
if message_type == Recipient.PERSONAL:
|
|
message_type_name = "private"
|
|
else:
|
|
message_type_name = "stream"
|
|
if isinstance(raw_recipients, six.string_types):
|
|
recipient_list = [raw_recipients]
|
|
else:
|
|
recipient_list = raw_recipients
|
|
(sending_client, _) = Client.objects.get_or_create(name="test suite")
|
|
|
|
return check_send_message(
|
|
sender, sending_client, message_type_name, recipient_list, subject,
|
|
content, forged=False, forged_timestamp=None,
|
|
forwarder_user_profile=sender, realm=sender.realm, **kwargs)
|
|
|
|
def get_old_messages(self, anchor=1, num_before=100, num_after=100):
|
|
# type: (int, int, int) -> List[Dict[str, Any]]
|
|
post_params = {"anchor": anchor, "num_before": num_before,
|
|
"num_after": num_after}
|
|
result = self.client_get("/json/messages", dict(post_params))
|
|
data = ujson.loads(result.content)
|
|
return data['messages']
|
|
|
|
def users_subscribed_to_stream(self, stream_name, realm_domain):
|
|
# type: (text_type, text_type) -> List[UserProfile]
|
|
realm = get_realm(realm_domain)
|
|
stream = Stream.objects.get(name=stream_name, realm=realm)
|
|
recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)
|
|
subscriptions = Subscription.objects.filter(recipient=recipient, active=True)
|
|
|
|
return [subscription.user_profile for subscription in subscriptions]
|
|
|
|
def assert_json_success(self, result):
|
|
# type: (HttpResponse) -> Dict[str, Any]
|
|
"""
|
|
Successful POSTs return a 200 and JSON of the form {"result": "success",
|
|
"msg": ""}.
|
|
"""
|
|
self.assertEqual(result.status_code, 200, result)
|
|
json = ujson.loads(result.content)
|
|
self.assertEqual(json.get("result"), "success")
|
|
# We have a msg key for consistency with errors, but it typically has an
|
|
# empty value.
|
|
self.assertIn("msg", json)
|
|
return json
|
|
|
|
def get_json_error(self, result, status_code=400):
|
|
# type: (HttpResponse, int) -> Dict[str, Any]
|
|
self.assertEqual(result.status_code, status_code)
|
|
json = ujson.loads(result.content)
|
|
self.assertEqual(json.get("result"), "error")
|
|
return json['msg']
|
|
|
|
def assert_json_error(self, result, msg, status_code=400):
|
|
# type: (HttpResponse, text_type, int) -> None
|
|
"""
|
|
Invalid POSTs return an error status code and JSON of the form
|
|
{"result": "error", "msg": "reason"}.
|
|
"""
|
|
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
|
|
|
|
def assert_length(self, queries, count):
|
|
# type: (Sized, int) -> None
|
|
actual_count = len(queries)
|
|
return self.assertTrue(actual_count == count,
|
|
"len(%s) == %s, != %s" % (queries, actual_count, count))
|
|
|
|
def assert_max_length(self, queries, count):
|
|
# type: (Sized, int) -> None
|
|
actual_count = len(queries)
|
|
return self.assertTrue(actual_count <= count,
|
|
"len(%s) == %s, > %s" % (queries, actual_count, count))
|
|
|
|
def assert_json_error_contains(self, result, msg_substring, status_code=400):
|
|
# type: (HttpResponse, text_type, int) -> None
|
|
self.assertIn(msg_substring, self.get_json_error(result, status_code=status_code))
|
|
|
|
def assert_equals_response(self, string, response):
|
|
# type: (text_type, HttpResponse) -> None
|
|
self.assertEqual(string, response.content.decode('utf-8'))
|
|
|
|
def assert_in_response(self, substring, response):
|
|
# type: (text_type, HttpResponse) -> None
|
|
self.assertIn(substring, response.content.decode('utf-8'))
|
|
|
|
def fixture_data(self, type, action, file_type='json'):
|
|
# type: (text_type, text_type, text_type) -> text_type
|
|
return force_text(open(os.path.join(os.path.dirname(__file__),
|
|
"../fixtures/%s/%s_%s.%s" % (type, type, action, file_type))).read())
|
|
|
|
# Subscribe to a stream directly
|
|
def subscribe_to_stream(self, email, stream_name, realm=None):
|
|
# type: (text_type, text_type, Optional[Realm]) -> None
|
|
if realm is None:
|
|
realm = get_realm(resolve_email_to_domain(email))
|
|
stream = get_stream(stream_name, realm)
|
|
if stream is None:
|
|
stream, _ = create_stream_if_needed(realm, stream_name)
|
|
user_profile = get_user_profile_by_email(email)
|
|
do_add_subscription(user_profile, stream, no_log=True)
|
|
|
|
# Subscribe to a stream by making an API request
|
|
def common_subscribe_to_streams(self, email, streams, extra_post_data={}, invite_only=False):
|
|
# type: (text_type, Iterable[text_type], Dict[str, Any], bool) -> HttpResponse
|
|
post_data = {'subscriptions': ujson.dumps([{"name": stream} for stream in streams]),
|
|
'invite_only': ujson.dumps(invite_only)}
|
|
post_data.update(extra_post_data)
|
|
result = self.client_post("/api/v1/users/me/subscriptions", post_data, **self.api_auth(email))
|
|
return result
|
|
|
|
def send_json_payload(self, email, url, payload, stream_name=None, **post_params):
|
|
# type: (text_type, text_type, Union[text_type, Dict[str, Any]], Optional[text_type], **Any) -> Message
|
|
if stream_name is not None:
|
|
self.subscribe_to_stream(email, stream_name)
|
|
|
|
result = self.client_post(url, payload, **post_params)
|
|
self.assert_json_success(result)
|
|
|
|
# Check the correct message was sent
|
|
msg = self.get_last_message()
|
|
self.assertEqual(msg.sender.email, email)
|
|
if stream_name is not None:
|
|
self.assertEqual(get_display_recipient(msg.recipient), stream_name)
|
|
# TODO: should also validate recipient for private messages
|
|
|
|
return msg
|
|
|
|
def get_last_message(self):
|
|
# type: () -> Message
|
|
return Message.objects.latest('id')
|
|
|
|
def get_second_to_last_message(self):
|
|
# type: () -> Message
|
|
return Message.objects.all().order_by('-id')[1]
|
|
|
|
@contextmanager
|
|
def simulated_markdown_failure(self):
|
|
# type: () -> Generator[None, None, None]
|
|
'''
|
|
This raises a failure inside of the try/except block of
|
|
bugdown.__init__.do_convert.
|
|
'''
|
|
with \
|
|
self.settings(ERROR_BOT=None), \
|
|
mock.patch('zerver.lib.bugdown.timeout', side_effect=KeyError('foo')), \
|
|
mock.patch('zerver.lib.bugdown.log_bugdown_error'):
|
|
yield
|
|
|
|
class WebhookTestCase(ZulipTestCase):
|
|
"""
|
|
Common for all webhooks tests
|
|
|
|
Override below class attributes and run send_and_test_message
|
|
If you create your url in uncommon way you can override build_webhook_url method
|
|
In case that you need modify body or create it without using fixture you can also override get_body method
|
|
"""
|
|
STREAM_NAME = None # type: Optional[text_type]
|
|
TEST_USER_EMAIL = 'webhook-bot@zulip.com'
|
|
URL_TEMPLATE = None # type: Optional[text_type]
|
|
FIXTURE_DIR_NAME = None # type: Optional[text_type]
|
|
|
|
def setUp(self):
|
|
# type: () -> None
|
|
self.url = self.build_webhook_url()
|
|
|
|
def send_and_test_stream_message(self, fixture_name, expected_subject=None,
|
|
expected_message=None, content_type="application/json", **kwargs):
|
|
# type: (text_type, Optional[text_type], Optional[text_type], Optional[text_type], **Any) -> Message
|
|
payload = self.get_body(fixture_name)
|
|
if content_type is not None:
|
|
kwargs['content_type'] = content_type
|
|
msg = self.send_json_payload(self.TEST_USER_EMAIL, self.url, payload,
|
|
self.STREAM_NAME, **kwargs)
|
|
self.do_test_subject(msg, expected_subject)
|
|
self.do_test_message(msg, expected_message)
|
|
|
|
return msg
|
|
|
|
def send_and_test_private_message(self, fixture_name, expected_subject=None,
|
|
expected_message=None, content_type="application/json", **kwargs):
|
|
# type: (text_type, text_type, text_type, str, **Any) -> Message
|
|
payload = self.get_body(fixture_name)
|
|
if content_type is not None:
|
|
kwargs['content_type'] = content_type
|
|
|
|
msg = self.send_json_payload(self.TEST_USER_EMAIL, self.url, payload,
|
|
stream_name=None, **kwargs)
|
|
self.do_test_message(msg, expected_message)
|
|
|
|
return msg
|
|
|
|
def build_webhook_url(self):
|
|
# type: () -> text_type
|
|
api_key = self.get_api_key(self.TEST_USER_EMAIL)
|
|
return self.URL_TEMPLATE.format(stream=self.STREAM_NAME, api_key=api_key)
|
|
|
|
def get_body(self, fixture_name):
|
|
# type: (text_type) -> Union[text_type, Dict[str, text_type]]
|
|
"""Can be implemented either as returning a dictionary containing the
|
|
post parameters or as string containing the body of the request."""
|
|
return ujson.dumps(ujson.loads(self.fixture_data(self.FIXTURE_DIR_NAME, fixture_name)))
|
|
|
|
def do_test_subject(self, msg, expected_subject):
|
|
# type: (Message, Optional[text_type]) -> None
|
|
if expected_subject is not None:
|
|
self.assertEqual(msg.topic_name(), expected_subject)
|
|
|
|
def do_test_message(self, msg, expected_message):
|
|
# type: (Message, Optional[text_type]) -> None
|
|
if expected_message is not None:
|
|
self.assertEqual(msg.content, expected_message)
|
|
|
|
def get_all_templates():
|
|
# type: () -> List[str]
|
|
templates = []
|
|
|
|
relpath = os.path.relpath
|
|
isfile = os.path.isfile
|
|
path_exists = os.path.exists
|
|
|
|
def is_valid_template(p, n):
|
|
# type: (text_type, text_type) -> bool
|
|
return not n.startswith('.') and not n.startswith('__init__') and isfile(p)
|
|
|
|
def process(template_dir, dirname, fnames):
|
|
# type: (str, str, Iterable[str]) -> None
|
|
for name in fnames:
|
|
path = os.path.join(dirname, name)
|
|
if is_valid_template(path, name):
|
|
templates.append(relpath(path, template_dir))
|
|
|
|
for engine in loader.engines.all():
|
|
template_dirs = [d for d in engine.template_dirs if path_exists(d)]
|
|
for template_dir in template_dirs:
|
|
template_dir = os.path.normpath(template_dir)
|
|
for dirpath, dirnames, fnames in os.walk(template_dir):
|
|
process(template_dir, dirpath, fnames)
|
|
|
|
return templates
|