mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# -*- coding: utf-8 -*-
 | 
						|
from __future__ import absolute_import
 | 
						|
from __future__ import print_function
 | 
						|
 | 
						|
from django.conf import settings
 | 
						|
from django.test import TestCase
 | 
						|
 | 
						|
import os
 | 
						|
import shutil
 | 
						|
import ujson
 | 
						|
 | 
						|
from mock import patch, MagicMock
 | 
						|
from six.moves import range
 | 
						|
from typing import Any
 | 
						|
 | 
						|
from zerver.lib.actions import (
 | 
						|
    do_claim_attachments,
 | 
						|
)
 | 
						|
 | 
						|
from zerver.lib.export import (
 | 
						|
    do_export_realm,
 | 
						|
    export_usermessages_batch,
 | 
						|
)
 | 
						|
from zerver.lib.upload import (
 | 
						|
    claim_attachment,
 | 
						|
    upload_message_image,
 | 
						|
)
 | 
						|
from zerver.lib.utils import (
 | 
						|
    mkdir_p,
 | 
						|
    query_chunker,
 | 
						|
)
 | 
						|
from zerver.lib.test_classes import (
 | 
						|
    ZulipTestCase,
 | 
						|
)
 | 
						|
 | 
						|
from zerver.lib.test_runner import slow
 | 
						|
 | 
						|
from zerver.models import (
 | 
						|
    get_user_profile_by_email,
 | 
						|
    Message,
 | 
						|
    Realm,
 | 
						|
    Recipient,
 | 
						|
    UserMessage,
 | 
						|
)
 | 
						|
 | 
						|
def rm_tree(path):
 | 
						|
    # type: (str) -> None
 | 
						|
    if os.path.exists(path):
 | 
						|
        shutil.rmtree(path)
 | 
						|
 | 
						|
class QueryUtilTest(ZulipTestCase):
 | 
						|
    def _create_messages(self):
 | 
						|
        # type: () -> None
 | 
						|
        for email in ['cordelia@zulip.com', 'hamlet@zulip.com', 'iago@zulip.com']:
 | 
						|
            for _ in range(5):
 | 
						|
                self.send_message(email, 'othello@zulip.com', Recipient.PERSONAL)
 | 
						|
 | 
						|
    @slow('creates lots of data')
 | 
						|
    def test_query_chunker(self):
 | 
						|
        # type: () -> None
 | 
						|
        self._create_messages()
 | 
						|
 | 
						|
        cordelia = get_user_profile_by_email('cordelia@zulip.com')
 | 
						|
        hamlet = get_user_profile_by_email('hamlet@zulip.com')
 | 
						|
 | 
						|
        def get_queries():
 | 
						|
            # type: () -> List[Any]
 | 
						|
            queries = [
 | 
						|
                Message.objects.filter(sender_id=cordelia.id),
 | 
						|
                Message.objects.filter(sender_id=hamlet.id),
 | 
						|
                Message.objects.exclude(sender_id__in=[cordelia.id, hamlet.id])
 | 
						|
            ]
 | 
						|
            return queries
 | 
						|
 | 
						|
        for query in get_queries():
 | 
						|
            # For our test to be meaningful, we want non-empty queries
 | 
						|
            # at first
 | 
						|
            assert len(list(query)) > 0
 | 
						|
 | 
						|
        queries = get_queries()
 | 
						|
 | 
						|
        all_msg_ids = set() # type: Set[int]
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=20,
 | 
						|
        )
 | 
						|
 | 
						|
        all_row_ids = []
 | 
						|
        for chunk in chunker:
 | 
						|
            for row in chunk:
 | 
						|
                all_row_ids.append(row.id)
 | 
						|
 | 
						|
        self.assertEqual(all_row_ids, sorted(all_row_ids))
 | 
						|
        self.assertEqual(len(all_msg_ids), len(Message.objects.all()))
 | 
						|
 | 
						|
        # Now just search for cordelia/hamlet.  Note that we don't really
 | 
						|
        # need the order_by here, but it should be harmless.
 | 
						|
        queries = [
 | 
						|
            Message.objects.filter(sender_id=cordelia.id).order_by('id'),
 | 
						|
            Message.objects.filter(sender_id=hamlet.id),
 | 
						|
        ]
 | 
						|
        all_msg_ids = set()
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=7, # use a different size
 | 
						|
        )
 | 
						|
        list(chunker) # exhaust the iterator
 | 
						|
        self.assertEqual(
 | 
						|
            len(all_msg_ids),
 | 
						|
            len(Message.objects.filter(sender_id__in=[cordelia.id, hamlet.id]))
 | 
						|
        )
 | 
						|
 | 
						|
        # Try just a single query to validate chunking.
 | 
						|
        queries = [
 | 
						|
            Message.objects.exclude(sender_id=cordelia.id),
 | 
						|
        ]
 | 
						|
        all_msg_ids = set()
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=11, # use a different size each time
 | 
						|
        )
 | 
						|
        list(chunker) # exhaust the iterator
 | 
						|
        self.assertEqual(
 | 
						|
            len(all_msg_ids),
 | 
						|
            len(Message.objects.exclude(sender_id=cordelia.id))
 | 
						|
        )
 | 
						|
        self.assertTrue(len(all_msg_ids) > 15)
 | 
						|
 | 
						|
        # Verify assertions about disjoint-ness.
 | 
						|
        queries = [
 | 
						|
            Message.objects.exclude(sender_id=cordelia.id),
 | 
						|
            Message.objects.filter(sender_id=hamlet.id),
 | 
						|
        ]
 | 
						|
        all_msg_ids = set()
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=13, # use a different size each time
 | 
						|
        )
 | 
						|
        with self.assertRaises(AssertionError):
 | 
						|
            list(chunker) # exercise the iterator
 | 
						|
 | 
						|
        # Try to confuse things with ids part of the query...
 | 
						|
        queries = [
 | 
						|
            Message.objects.filter(id__lte=10),
 | 
						|
            Message.objects.filter(id__gt=10),
 | 
						|
        ]
 | 
						|
        all_msg_ids = set()
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=11, # use a different size each time
 | 
						|
        )
 | 
						|
        self.assertEqual(len(all_msg_ids), 0) # until we actually use the iterator
 | 
						|
        list(chunker) # exhaust the iterator
 | 
						|
        self.assertEqual(len(all_msg_ids), len(Message.objects.all()))
 | 
						|
 | 
						|
        # Verify that we can just get the first chunk with a next() call.
 | 
						|
        queries = [
 | 
						|
            Message.objects.all(),
 | 
						|
        ]
 | 
						|
        all_msg_ids = set()
 | 
						|
        chunker = query_chunker(
 | 
						|
            queries=queries,
 | 
						|
            id_collector=all_msg_ids,
 | 
						|
            chunk_size=10, # use a different size each time
 | 
						|
        )
 | 
						|
        first_chunk = next(chunker) # type: ignore
 | 
						|
        self.assertEqual(len(first_chunk), 10)
 | 
						|
        self.assertEqual(len(all_msg_ids), 10)
 | 
						|
        expected_msg = Message.objects.all()[0:10][5]
 | 
						|
        actual_msg = first_chunk[5]
 | 
						|
        self.assertEqual(actual_msg.content, expected_msg.content)
 | 
						|
        self.assertEqual(actual_msg.sender_id, expected_msg.sender_id)
 | 
						|
 | 
						|
 | 
						|
class ExportTest(TestCase):
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        # type: () -> None
 | 
						|
        rm_tree(settings.LOCAL_UPLOADS_DIR)
 | 
						|
 | 
						|
    def _make_output_dir(self):
 | 
						|
        # type: () -> str
 | 
						|
        output_dir = 'var/test-export'
 | 
						|
        rm_tree(output_dir)
 | 
						|
        mkdir_p(output_dir)
 | 
						|
        return output_dir
 | 
						|
 | 
						|
    def _export_realm(self, realm, exportable_user_ids=None):
 | 
						|
        # type: (Realm, Set[int]) -> Dict[str, Any]
 | 
						|
        output_dir = self._make_output_dir()
 | 
						|
        with patch('logging.info'), patch('zerver.lib.export.create_soft_link'):
 | 
						|
            do_export_realm(
 | 
						|
                realm=realm,
 | 
						|
                output_dir=output_dir,
 | 
						|
                threads=0,
 | 
						|
                exportable_user_ids=exportable_user_ids,
 | 
						|
            )
 | 
						|
            # TODO: Process the second partial file, which can be created
 | 
						|
            #       for certain edge cases.
 | 
						|
            export_usermessages_batch(
 | 
						|
                input_path=os.path.join(output_dir, 'messages-000001.json.partial'),
 | 
						|
                output_path=os.path.join(output_dir, 'message.json')
 | 
						|
            )
 | 
						|
 | 
						|
        def read_file(fn):
 | 
						|
            # type: (str) -> Any
 | 
						|
            full_fn = os.path.join(output_dir, fn)
 | 
						|
            with open(full_fn) as f:
 | 
						|
                return ujson.load(f)
 | 
						|
 | 
						|
        result = {}
 | 
						|
        result['realm'] = read_file('realm.json')
 | 
						|
        result['attachment'] = read_file('attachment.json')
 | 
						|
        result['message'] = read_file('message.json')
 | 
						|
        result['uploads_dir'] = os.path.join(output_dir, 'uploads')
 | 
						|
        return result
 | 
						|
 | 
						|
    def test_attachment(self):
 | 
						|
        # type: () -> None
 | 
						|
        message = Message.objects.all()[0]
 | 
						|
        user_profile = message.sender
 | 
						|
        url = upload_message_image(u'dummy.txt', u'text/plain', b'zulip!', user_profile)
 | 
						|
        path_id = url.replace('/user_uploads/', '')
 | 
						|
        claim_attachment(
 | 
						|
            user_profile=user_profile,
 | 
						|
            path_id=path_id,
 | 
						|
            message=message,
 | 
						|
            is_message_realm_public=True
 | 
						|
        )
 | 
						|
 | 
						|
        realm = Realm.objects.get(string_id='zulip')
 | 
						|
        full_data = self._export_realm(realm)
 | 
						|
 | 
						|
        data = full_data['attachment']
 | 
						|
        self.assertEqual(len(data['zerver_attachment']), 1)
 | 
						|
        record = data['zerver_attachment'][0]
 | 
						|
        self.assertEqual(record['path_id'], path_id)
 | 
						|
 | 
						|
        fn = os.path.join(full_data['uploads_dir'], path_id)
 | 
						|
        with open(fn) as f:
 | 
						|
            self.assertEqual(f.read(), 'zulip!')
 | 
						|
 | 
						|
    def test_zulip_realm(self):
 | 
						|
        # type: () -> None
 | 
						|
        realm = Realm.objects.get(string_id='zulip')
 | 
						|
        full_data = self._export_realm(realm)
 | 
						|
 | 
						|
        data = full_data['realm']
 | 
						|
        self.assertEqual(len(data['zerver_userprofile_crossrealm']), 0)
 | 
						|
        self.assertEqual(len(data['zerver_userprofile_mirrordummy']), 0)
 | 
						|
 | 
						|
        def get_set(table, field):
 | 
						|
            # type: (str, str) -> Set[str]
 | 
						|
            values = set(r[field] for r in data[table])
 | 
						|
            # print('set(%s)' % sorted(values))
 | 
						|
            return values
 | 
						|
 | 
						|
        def find_by_id(table, db_id):
 | 
						|
            # type: (str, int) -> Dict[str, Any]
 | 
						|
            return [
 | 
						|
                r for r in data[table]
 | 
						|
                if r['id'] == db_id][0]
 | 
						|
 | 
						|
        exported_user_emails = get_set('zerver_userprofile', 'email')
 | 
						|
        self.assertIn('cordelia@zulip.com', exported_user_emails)
 | 
						|
        self.assertIn('default-bot@zulip.com', exported_user_emails)
 | 
						|
        self.assertIn('emailgateway@zulip.com', exported_user_emails)
 | 
						|
 | 
						|
        exported_streams = get_set('zerver_stream', 'name')
 | 
						|
        self.assertEqual(
 | 
						|
            exported_streams,
 | 
						|
            set([u'Denmark', u'Rome', u'Scotland', u'Venice', u'Verona'])
 | 
						|
        )
 | 
						|
 | 
						|
        data = full_data['message']
 | 
						|
        um = UserMessage.objects.all()[0]
 | 
						|
        exported_um = find_by_id('zerver_usermessage', um.id)
 | 
						|
        self.assertEqual(exported_um['message'], um.message_id)
 | 
						|
        self.assertEqual(exported_um['user_profile'], um.user_profile_id)
 | 
						|
 | 
						|
        exported_message = find_by_id('zerver_message', um.message_id)
 | 
						|
        self.assertEqual(exported_message['content'], um.message.content)
 | 
						|
 | 
						|
        # TODO, extract get_set/find_by_id, so we can split this test up
 | 
						|
 | 
						|
        # Now, restrict users
 | 
						|
        cordelia = get_user_profile_by_email('cordelia@zulip.com')
 | 
						|
        hamlet = get_user_profile_by_email('hamlet@zulip.com')
 | 
						|
        user_ids = set([cordelia.id, hamlet.id])
 | 
						|
 | 
						|
        full_data = self._export_realm(realm, exportable_user_ids=user_ids)
 | 
						|
        data = full_data['realm']
 | 
						|
        exported_user_emails = get_set('zerver_userprofile', 'email')
 | 
						|
        self.assertIn('cordelia@zulip.com', exported_user_emails)
 | 
						|
        self.assertIn('hamlet@zulip.com', exported_user_emails)
 | 
						|
        self.assertNotIn('default-bot@zulip.com', exported_user_emails)
 | 
						|
        self.assertNotIn('iago@zulip.com', exported_user_emails)
 | 
						|
 | 
						|
        dummy_user_emails = get_set('zerver_userprofile_mirrordummy', 'email')
 | 
						|
        self.assertIn('iago@zulip.com', dummy_user_emails)
 | 
						|
        self.assertNotIn('cordelia@zulip.com', dummy_user_emails)
 |