From eb23c6fa6cdd3094bd4074b127d25709a23f53ec Mon Sep 17 00:00:00 2001 From: Hashir Sarwar Date: Mon, 10 Feb 2020 18:22:58 +0500 Subject: [PATCH] test_fixtures: Clean up interface for `template_database_status()`. 1) Created a new class `DatabaseType` and access its objects inside `template_database_status()` instead of sending five arguments with default values. 2) Made `check_files` and `setting_name` local variables instead of function parameters since they had same value(None) for every call. Fixes #13845. --- tools/lib/provision_inner.py | 9 +---- zerver/lib/test_fixtures.py | 72 +++++++++++++++++++----------------- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/tools/lib/provision_inner.py b/tools/lib/provision_inner.py index 771c3c5601..84dbec307b 100755 --- a/tools/lib/provision_inner.py +++ b/tools/lib/provision_inner.py @@ -172,12 +172,7 @@ def main(options: argparse.Namespace) -> int: else: print("RabbitMQ is already configured.") - migration_status_path = os.path.join(UUID_VAR_PATH, "migration_status_dev") - dev_template_db_status = template_database_status( - migration_status=migration_status_path, - settings="zproject.settings", - database_name="zulip", - ) + dev_template_db_status = template_database_status('dev') if options.is_force or dev_template_db_status == 'needs_rebuild': run(["tools/setup/postgres-init-dev-db"]) run(["tools/do-destroy-rebuild-database"]) @@ -186,7 +181,7 @@ def main(options: argparse.Namespace) -> int: elif dev_template_db_status == 'current': print("No need to regenerate the dev DB.") - test_template_db_status = template_database_status() + test_template_db_status = template_database_status('test') if options.is_force or test_template_db_status == 'needs_rebuild': run(["tools/setup/postgres-init-test-db"]) run(["tools/do-destroy-rebuild-test-database"]) diff --git a/zerver/lib/test_fixtures.py b/zerver/lib/test_fixtures.py index 189fe01498..86303da2f9 100644 --- a/zerver/lib/test_fixtures.py +++ b/zerver/lib/test_fixtures.py @@ -5,7 +5,7 @@ import re import hashlib import subprocess import sys -from typing import Any, List, Optional, Set +from typing import Any, List, Set from importlib import import_module from io import StringIO import glob @@ -24,9 +24,23 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from scripts.lib.zulip_tools import get_dev_uuid_var_path, run, \ file_or_package_hash_updated, TEMPLATE_DATABASE_DIR +class DatabaseType: + def __init__(self, database_name: str, settings: str, migration_status: str): + self.database_name = database_name + self.settings = settings + self.migration_status = migration_status + UUID_VAR_DIR = get_dev_uuid_var_path() FILENAME_SPLITTER = re.compile(r'[\W\-_]') +DEV_DATABASE_TYPE = DatabaseType(database_name='zulip', + settings='zproject.settings', + migration_status=os.path.join(UUID_VAR_DIR, "migration_status_dev")) + +TEST_DATABASE_TYPE = DatabaseType(database_name='zulip_test_template', + settings='zproject.test_settings', + migration_status=os.path.join(UUID_VAR_DIR, 'migration_status_test')) + def run_db_migrations(platform: str) -> None: if platform == 'dev': migration_status_file = 'migration_status_dev' @@ -70,7 +84,7 @@ def update_test_databases_if_required(use_force: bool=False, If use_force is specified, it will always do a full rebuild. """ generate_fixtures_command = ['tools/setup/generate-fixtures'] - test_template_db_status = template_database_status() + test_template_db_status = template_database_status('test') if use_force or test_template_db_status == 'needs_rebuild': generate_fixtures_command.append('--force') elif test_template_db_status == 'run_migrations': @@ -186,43 +200,33 @@ def check_setting_hash(setting_name: str, status_dir: str) -> bool: return _check_hash(source_hash_file, target_content) -def template_database_status( - database_name: str='zulip_test_template', - migration_status: Optional[str]=None, - settings: str='zproject.test_settings', - check_files: Optional[List[str]]=None, - check_settings: Optional[List[str]]=None) -> str: +def template_database_status(database_type: str) -> str: # This function returns a status string specifying the type of # state the template db is in and thus the kind of action required. - if check_files is None: - check_files = [ - 'zilencer/management/commands/populate_db.py', - 'zerver/lib/bulk_create.py', - 'zerver/lib/generate_test_data.py', - 'zerver/lib/server_initialization.py', - 'tools/setup/postgres-init-test-db', - 'tools/setup/postgres-init-dev-db', - 'zerver/migrations/0258_enable_online_push_notifications_default.py', - ] - if check_settings is None: - check_settings = [ - 'REALM_INTERNAL_BOTS', - ] + if database_type == 'dev': + database = DEV_DATABASE_TYPE + elif database_type == 'test': + database = TEST_DATABASE_TYPE + + check_files = [ + 'zilencer/management/commands/populate_db.py', + 'zerver/lib/bulk_create.py', + 'zerver/lib/generate_test_data.py', + 'zerver/lib/server_initialization.py', + 'tools/setup/postgres-init-test-db', + 'tools/setup/postgres-init-dev-db', + 'zerver/migrations/0258_enable_online_push_notifications_default.py', + ] + check_settings = [ + 'REALM_INTERNAL_BOTS', + ] # Construct a directory to store hashes named after the target database. - status_dir = os.path.join(UUID_VAR_DIR, database_name + '_db_status') + status_dir = os.path.join(UUID_VAR_DIR, database.database_name + '_db_status') if not os.path.exists(status_dir): os.mkdir(status_dir) - # Arguably we should move this inside status_dir, but it'd require - # a bit of work since generate_fixtures expects to also know the - # path, and make the directory. We may also want to refactor this - # logic to be inside a couple class objects for the two databases, - # rather than a random-feeling set of option flags. - if migration_status is None: - migration_status = os.path.join(UUID_VAR_DIR, 'migration_status_test') - - if database_exists(database_name): + if database_exists(database.database_name): # To ensure Python evaluates all the hash tests (and thus creates the # hash files about the current state), we evaluate them in a # list and then process the result @@ -239,12 +243,12 @@ def template_database_status( # migrations without spending a few 100ms parsing all the # Python migration code. paths = glob.glob('*/migrations/*.py') - check_migrations = file_or_package_hash_updated(paths, "migrations_hash_" + database_name, + check_migrations = file_or_package_hash_updated(paths, "migrations_hash_" + database.database_name, is_force=False) if not check_migrations: return 'current' - migration_op = what_to_do_with_migrations(migration_status, settings=settings) + migration_op = what_to_do_with_migrations(database.migration_status, settings=database.settings) if migration_op == 'scrap': return 'needs_rebuild'