diff --git a/zerver/management/commands/change_auth_backends.py b/zerver/management/commands/change_auth_backends.py new file mode 100644 index 0000000000..2fca0c8d6f --- /dev/null +++ b/zerver/management/commands/change_auth_backends.py @@ -0,0 +1,85 @@ +from argparse import ArgumentParser +from typing import Any + +from django.core.management.base import CommandError +from typing_extensions import override + +from zerver.actions.realm_settings import do_set_realm_authentication_methods +from zerver.lib.management import ZulipBaseCommand + + +class Command(ZulipBaseCommand): + help = """Enable or disable an authentication backend for a realm""" + + @override + def add_arguments(self, parser: ArgumentParser) -> None: + self.add_realm_args(parser, required=True) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--enable", + type=str, + help="Name of the authentication backend to enable", + ) + group.add_argument( + "--disable", + type=str, + help="Name of the authentication backend to disable", + ) + group.add_argument( + "--show", + action="store_true", + help="Show current authentication backends for the realm", + ) + + @override + def handle(self, *args: Any, **options: str) -> None: + realm = self.get_realm(options) + assert realm is not None # Should be ensured by parser + + auth_methods = realm.authentication_methods_dict() + + if options["show"]: + print("Current authentication backends for the realm:") + print_auth_methods_dict(auth_methods) + return + + if options["enable"]: + backend_name = options["enable"] + check_backend_name_valid(backend_name, auth_methods) + + auth_methods[backend_name] = True + print(f"Enabling {backend_name} backend for realm {realm.name}") + elif options["disable"]: + backend_name = options["disable"] + check_backend_name_valid(backend_name, auth_methods) + + auth_methods[backend_name] = False + print(f"Disabling {backend_name} backend for realm {realm.name}") + + do_set_realm_authentication_methods(realm, auth_methods, acting_user=None) + + print("Updated authentication backends for the realm:") + print_auth_methods_dict(realm.authentication_methods_dict()) + print("Done!") + + +def check_backend_name_valid(backend_name: str, auth_methods_dict: dict[str, bool]) -> None: + if backend_name not in auth_methods_dict: + raise CommandError( + f"Backend {backend_name} is not a valid authentication backend. Valid backends: {list(auth_methods_dict.keys())}" + ) + + +def print_auth_methods_dict(auth_methods: dict[str, bool]) -> None: + enabled_backends = [backend for backend, enabled in auth_methods.items() if enabled] + disabled_backends = [backend for backend, enabled in auth_methods.items() if not enabled] + + if enabled_backends: + print("Enabled backends:") + for backend in enabled_backends: + print(f" {backend}") + + if disabled_backends: + print("Disabled backends:") + for backend in disabled_backends: + print(f" {backend}")