diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index e04a139b03..7f3889dc59 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -283,14 +283,21 @@ class HostRequestMock: return self.host class MockPythonResponse: - def __init__(self, text: str, status_code: int) -> None: + def __init__(self, text: str, status_code: int, headers: Optional[Dict[str, str]]=None) -> None: self.text = text self.status_code = status_code + if headers is None: + headers = {'content-type': 'text/html'} + self.headers = headers @property def ok(self) -> bool: return self.status_code == 200 + def iter_content(self, n: int) -> Generator[str, Any, None]: + yield self.text[:n] + + INSTRUMENTING = os.environ.get('TEST_INSTRUMENT_URL_COVERAGE', '') == 'TRUE' INSTRUMENTED_CALLS = [] # type: List[Dict[str, Any]] diff --git a/zerver/lib/url_preview/preview.py b/zerver/lib/url_preview/preview.py index 31577b2a3e..251cd69579 100644 --- a/zerver/lib/url_preview/preview.py +++ b/zerver/lib/url_preview/preview.py @@ -3,6 +3,7 @@ import requests from django.conf import settings from django.utils.encoding import smart_text +import magic from typing import Any, Optional, Dict from typing.re import Match @@ -28,12 +29,40 @@ def is_link(url: str) -> Match[str]: return link_regex.match(smart_text(url)) +def guess_mimetype_from_content(response: requests.Response) -> str: + mime_magic = magic.Magic(mime=True) + try: + content = next(response.iter_content(1000)) + except StopIteration: + content = '' + return mime_magic.from_buffer(content) + +def valid_content_type(url: str) -> bool: + try: + response = requests.get(url, stream=True) + except requests.RequestException: + return False + + if not response.ok: + return False + + content_type = response.headers.get('content-type') + # Be accommodating of bad servers: assume content may be html if no content-type header + if not content_type or content_type.startswith('text/html'): + # Verify that the content is actually HTML if the server claims it is + content_type = guess_mimetype_from_content(response) + return content_type.startswith('text/html') + @cache_with_key(preview_url_cache_key, cache_name=CACHE_NAME, with_statsd_key="urlpreview_data") def get_link_embed_data(url: str, maxwidth: Optional[int]=640, maxheight: Optional[int]=480) -> Optional[Dict[str, Any]]: if not is_link(url): return None + + if not valid_content_type(url): + return None + # Fetch information from URL. # We are using three sources in next order: # 1. OEmbed @@ -47,7 +76,7 @@ def get_link_embed_data(url: str, # open graph data. return None data = data or {} - response = requests.get(url) + response = requests.get(url, stream=True) if response.ok: og_data = OpenGraphParser(response.text).extract_data() if og_data: diff --git a/zerver/tests/test_link_embed.py b/zerver/tests/test_link_embed.py index 31833ec916..74cd2718de 100644 --- a/zerver/tests/test_link_embed.py +++ b/zerver/tests/test_link_embed.py @@ -2,7 +2,7 @@ import mock import ujson -from typing import Any, Callable +from typing import Any, Callable, Dict, Optional from requests.exceptions import ConnectionError from django.test import override_settings @@ -171,12 +171,13 @@ class PreviewTestCase(ZulipTestCase): """ @classmethod - def create_mock_response(cls, url: str, relative_url: bool=False) -> Callable[..., MockPythonResponse]: + def create_mock_response(cls, url: str, relative_url: bool=False, + headers: Optional[Dict[str, str]]=None) -> Callable[..., MockPythonResponse]: html = cls.open_graph_html if relative_url is True: html = html.replace('http://ia.media-imdb.com', '') - response = MockPythonResponse(html, 200) - return lambda k: {url: response}.get(k, MockPythonResponse('', 404)) + response = MockPythonResponse(html, 200, headers) + return lambda k, **kwargs: {url: response}.get(k, MockPythonResponse('', 404, headers)) @override_settings(INLINE_URL_EMBED_PREVIEW=True) def test_edit_message_history(self) -> None: @@ -374,3 +375,116 @@ class PreviewTestCase(ZulipTestCase): key = preview_url_cache_key(url) cache_set(key, link_embed_data, 'database') self.assertEqual(link_embed_data, link_embed_data_from_cache(url)) + + @override_settings(INLINE_URL_EMBED_PREVIEW=True) + def test_link_preview_non_html_data(self) -> None: + email = self.example_email('hamlet') + self.login(email) + url = 'http://test.org/audio.mp3' + with mock.patch('zerver.lib.actions.queue_json_publish') as patched: + msg_id = self.send_stream_message(email, "Scotland", topic_name="foo", content=url) + patched.assert_called_once() + queue = patched.call_args[0][0] + self.assertEqual(queue, "embed_links") + event = patched.call_args[0][1] + + headers = {'content-type': 'application/octet-stream'} + mocked_response = mock.Mock(side_effect=self.create_mock_response(url, headers=headers)) + + with self.settings(TEST_SUITE=False, CACHES=TEST_CACHES): + with mock.patch('requests.get', mocked_response): + FetchLinksEmbedData().consume(event) + + cached_data = link_embed_data_from_cache(url) + + self.assertIsNone(cached_data) + msg = Message.objects.select_related("sender").get(id=msg_id) + self.assertEqual( + ('

' + 'http://test.org/audio.mp3

'), + msg.rendered_content) + + @override_settings(INLINE_URL_EMBED_PREVIEW=True) + def test_link_preview_no_content_type_header(self) -> None: + email = self.example_email('hamlet') + self.login(email) + url = 'http://test.org/' + with mock.patch('zerver.lib.actions.queue_json_publish') as patched: + msg_id = self.send_stream_message(email, "Scotland", topic_name="foo", content=url) + patched.assert_called_once() + queue = patched.call_args[0][0] + self.assertEqual(queue, "embed_links") + event = patched.call_args[0][1] + + headers = {'content-type': ''} # No content type header + mocked_response = mock.Mock(side_effect=self.create_mock_response(url, headers=headers)) + with self.settings(TEST_SUITE=False, CACHES=TEST_CACHES): + with mock.patch('requests.get', mocked_response): + FetchLinksEmbedData().consume(event) + data = link_embed_data_from_cache(url) + + self.assertIn('title', data) + self.assertIn('image', data) + + msg = Message.objects.select_related("sender").get(id=msg_id) + self.assertIn(data['title'], msg.rendered_content) + self.assertIn(data['image'], msg.rendered_content) + + @override_settings(INLINE_URL_EMBED_PREVIEW=True) + def test_valid_content_type_error_get_data(self) -> None: + url = 'http://test.org/' + with mock.patch('zerver.lib.actions.queue_json_publish'): + msg_id = self.send_personal_message( + self.example_email('hamlet'), + self.example_email('cordelia'), + content=url, + ) + msg = Message.objects.select_related("sender").get(id=msg_id) + event = { + 'message_id': msg_id, + 'urls': [url], + 'message_realm_id': msg.sender.realm_id, + 'message_content': url} + + with mock.patch('zerver.lib.url_preview.preview.valid_content_type', side_effect=lambda k: True): + with self.settings(TEST_SUITE=False, CACHES=TEST_CACHES): + with mock.patch('requests.get', mock.Mock(side_effect=ConnectionError())): + FetchLinksEmbedData().consume(event) + cached_data = link_embed_data_from_cache(url) + + # FIXME: Should we really cache this, looks like a network error? + self.assertIsNone(cached_data) + msg.refresh_from_db() + self.assertEqual( + '

http://test.org/

', + msg.rendered_content) + + @override_settings(INLINE_URL_EMBED_PREVIEW=True) + def test_invalid_url(self) -> None: + url = 'http://test.org/' + error_url = 'http://test.org/x' + with mock.patch('zerver.lib.actions.queue_json_publish'): + msg_id = self.send_personal_message( + self.example_email('hamlet'), + self.example_email('cordelia'), + content=error_url, + ) + msg = Message.objects.select_related("sender").get(id=msg_id) + event = { + 'message_id': msg_id, + 'urls': [error_url], + 'message_realm_id': msg.sender.realm_id, + 'message_content': error_url} + + mocked_response = mock.Mock(side_effect=self.create_mock_response(url)) + with self.settings(TEST_SUITE=False, CACHES=TEST_CACHES): + with mock.patch('requests.get', mocked_response): + FetchLinksEmbedData().consume(event) + cached_data = link_embed_data_from_cache(error_url) + + # FIXME: Should we really cache this, especially without cache invalidation? + self.assertIsNone(cached_data) + msg.refresh_from_db() + self.assertEqual( + '

http://test.org/x

', + msg.rendered_content)