Fix most strict-optional issues in export.py.

This commit is contained in:
Christian Hudon
2017-05-24 16:41:24 -07:00
committed by Tim Abbott
parent 1761a3b1c1
commit 8ab6a23a30

View File

@@ -28,7 +28,7 @@ from zerver.models import UserProfile, Realm, Client, Huddle, Stream, \
from zerver.lib.parallel import run_parallel
from zerver.lib.utils import mkdir_p
from six.moves import range
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
# Custom mypy types follow:
Record = Dict[str, Any]
@@ -229,30 +229,34 @@ class Config(object):
self.source_filter = source_filter
self.children = [] # type: List[Config]
if normal_parent:
self.parent = normal_parent
if normal_parent is not None:
self.parent = normal_parent # type: Optional[Config]
else:
self.parent = None
if virtual_parent and normal_parent:
raise Exception('''
if virtual_parent is not None and normal_parent is not None:
raise ValueError('''
If you specify a normal_parent, please
do not create a virtual_parent.
''')
if normal_parent:
if normal_parent is not None:
normal_parent.children.append(self)
elif virtual_parent:
elif virtual_parent is not None:
virtual_parent.children.append(self)
elif not is_seeded:
raise Exception('''
elif is_seeded is None:
raise ValueError('''
You must specify a parent if you are
not using is_seeded.
''')
if self.id_source:
if self.id_source is not None:
if self.virtual_parent is None:
raise ValueError('''
You must specify a virtual_parent if you are
using id_source.''')
if self.id_source[0] != self.virtual_parent.table:
raise Exception('''
raise ValueError('''
Configuration error. To populate %s, you
want data from %s, but that differs from
the table name of your virtual parent (%s),
@@ -277,6 +281,10 @@ def export_from_config(response, config, seed_object=None, context=None):
if table:
exported_tables = [table]
else:
if config.custom_tables is None:
raise ValueError('''
You must specify config.custom_tables if you
are not specifying config.table''')
exported_tables = config.custom_tables
for t in exported_tables:
@@ -306,9 +314,11 @@ def export_from_config(response, config, seed_object=None, context=None):
data += response[t]
del response[t]
logging.info('Deleted temporary %s' % (t,))
assert table is not None
response[table] = data
elif config.use_all:
assert model is not None
query = model.objects.all()
rows = list(query)
@@ -318,10 +328,14 @@ def export_from_config(response, config, seed_object=None, context=None):
# now we just need to get all the articles
# contained by the blogs.
model = config.model
assert parent is not None
assert parent.table is not None
assert config.parent_key is not None
parent_ids = [r['id'] for r in response[parent.table]]
filter_parms = {config.parent_key: parent_ids}
if config.filter_args:
filter_parms = {config.parent_key: parent_ids} # type: Dict[str, Any]
if config.filter_args is not None:
filter_parms.update(config.filter_args)
assert model is not None
query = model.objects.filter(**filter_parms)
rows = list(query)
@@ -330,6 +344,7 @@ def export_from_config(response, config, seed_object=None, context=None):
# need to look at the current response to get all the
# blog ids from the Article rows we fetched previously.
model = config.model
assert model is not None
# This will be a tuple of the form ('zerver_article', 'blog').
(child_table, field) = config.id_source
child_rows = response[child_table]
@@ -344,6 +359,7 @@ def export_from_config(response, config, seed_object=None, context=None):
# Post-process rows (which won't apply to custom fetches/concats)
if rows is not None:
assert table is not None # Hint for mypy
response[table] = make_raw(rows, exclude=config.exclude)
if table in DATE_FIELDS:
floatify_datetime_fields(response, table)
@@ -625,6 +641,8 @@ def fetch_huddle_objects(response, config, context):
# type: (TableData, Config, Context) -> None
realm = context['realm']
assert config.parent is not None
assert config.parent.table is not None
user_profile_ids = set(r['id'] for r in response[config.parent.table])
# First we get all huddles involving someone in the realm.
@@ -1415,7 +1433,7 @@ def do_import_realm(import_dir):
fix_datetime_fields(data, 'zerver_realm')
realm = Realm(**data['zerver_realm'][0])
if realm.notifications_stream_id is not None:
notifications_stream_id = int(realm.notifications_stream_id)
notifications_stream_id = int(realm.notifications_stream_id) # type: Optional[int]
else:
notifications_stream_id = None
realm.notifications_stream_id = None