diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 84a2b5e62c..0e41154c79 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -28,7 +28,7 @@ from zerver.models import UserProfile, Realm, Client, Huddle, Stream, \ from zerver.lib.parallel import run_parallel from zerver.lib.utils import mkdir_p from six.moves import range -from typing import Any, Callable, Dict, List, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple # Custom mypy types follow: Record = Dict[str, Any] @@ -229,30 +229,34 @@ class Config(object): self.source_filter = source_filter self.children = [] # type: List[Config] - if normal_parent: - self.parent = normal_parent + if normal_parent is not None: + self.parent = normal_parent # type: Optional[Config] else: self.parent = None - if virtual_parent and normal_parent: - raise Exception(''' + if virtual_parent is not None and normal_parent is not None: + raise ValueError(''' If you specify a normal_parent, please do not create a virtual_parent. ''') - if normal_parent: + if normal_parent is not None: normal_parent.children.append(self) - elif virtual_parent: + elif virtual_parent is not None: virtual_parent.children.append(self) - elif not is_seeded: - raise Exception(''' + elif is_seeded is None: + raise ValueError(''' You must specify a parent if you are not using is_seeded. ''') - if self.id_source: + if self.id_source is not None: + if self.virtual_parent is None: + raise ValueError(''' + You must specify a virtual_parent if you are + using id_source.''') if self.id_source[0] != self.virtual_parent.table: - raise Exception(''' + raise ValueError(''' Configuration error. To populate %s, you want data from %s, but that differs from the table name of your virtual parent (%s), @@ -277,6 +281,10 @@ def export_from_config(response, config, seed_object=None, context=None): if table: exported_tables = [table] else: + if config.custom_tables is None: + raise ValueError(''' + You must specify config.custom_tables if you + are not specifying config.table''') exported_tables = config.custom_tables for t in exported_tables: @@ -306,9 +314,11 @@ def export_from_config(response, config, seed_object=None, context=None): data += response[t] del response[t] logging.info('Deleted temporary %s' % (t,)) + assert table is not None response[table] = data elif config.use_all: + assert model is not None query = model.objects.all() rows = list(query) @@ -318,10 +328,14 @@ def export_from_config(response, config, seed_object=None, context=None): # now we just need to get all the articles # contained by the blogs. model = config.model + assert parent is not None + assert parent.table is not None + assert config.parent_key is not None parent_ids = [r['id'] for r in response[parent.table]] - filter_parms = {config.parent_key: parent_ids} - if config.filter_args: + filter_parms = {config.parent_key: parent_ids} # type: Dict[str, Any] + if config.filter_args is not None: filter_parms.update(config.filter_args) + assert model is not None query = model.objects.filter(**filter_parms) rows = list(query) @@ -330,6 +344,7 @@ def export_from_config(response, config, seed_object=None, context=None): # need to look at the current response to get all the # blog ids from the Article rows we fetched previously. model = config.model + assert model is not None # This will be a tuple of the form ('zerver_article', 'blog'). (child_table, field) = config.id_source child_rows = response[child_table] @@ -344,6 +359,7 @@ def export_from_config(response, config, seed_object=None, context=None): # Post-process rows (which won't apply to custom fetches/concats) if rows is not None: + assert table is not None # Hint for mypy response[table] = make_raw(rows, exclude=config.exclude) if table in DATE_FIELDS: floatify_datetime_fields(response, table) @@ -625,6 +641,8 @@ def fetch_huddle_objects(response, config, context): # type: (TableData, Config, Context) -> None realm = context['realm'] + assert config.parent is not None + assert config.parent.table is not None user_profile_ids = set(r['id'] for r in response[config.parent.table]) # First we get all huddles involving someone in the realm. @@ -1415,7 +1433,7 @@ def do_import_realm(import_dir): fix_datetime_fields(data, 'zerver_realm') realm = Realm(**data['zerver_realm'][0]) if realm.notifications_stream_id is not None: - notifications_stream_id = int(realm.notifications_stream_id) + notifications_stream_id = int(realm.notifications_stream_id) # type: Optional[int] else: notifications_stream_id = None realm.notifications_stream_id = None