From e09ef69a8d238d3f192cef8c961cd17832c02cc2 Mon Sep 17 00:00:00 2001 From: Harshit Bansal Date: Fri, 11 Jan 2019 10:25:36 +0000 Subject: [PATCH] management: Extend `sync_ldap_user_data` to allow update of a single user. --- zerver/lib/management.py | 8 +++++-- .../commands/sync_ldap_user_data.py | 24 ++++++++++++++----- zerver/tests/test_management_commands.py | 14 +++++++++-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/zerver/lib/management.py b/zerver/lib/management.py index ebe479f8cc..3fafd5b944 100644 --- a/zerver/lib/management.py +++ b/zerver/lib/management.py @@ -88,7 +88,8 @@ You can use the command list_realms to find ID of the realms in this server.""" raise CommandError("There is no realm with id '%s'. Aborting." % (options["realm_id"],)) - def get_users(self, options: Dict[str, Any], realm: Optional[Realm]) -> List[UserProfile]: + def get_users(self, options: Dict[str, Any], realm: Optional[Realm], + is_bot: Optional[bool]=None) -> List[UserProfile]: if "all_users" in options: all_users = options["all_users"] @@ -102,7 +103,10 @@ You can use the command list_realms to find ID of the realms in this server.""" raise CommandError("The --all-users option requires a realm; please pass --realm.") if all_users: - return UserProfile.objects.filter(realm=realm) + user_profiles = UserProfile.objects.filter(realm=realm) + if is_bot is not None: + return user_profiles.filter(is_bot=is_bot) + return user_profiles if options["users"] is None: return [] diff --git a/zerver/management/commands/sync_ldap_user_data.py b/zerver/management/commands/sync_ldap_user_data.py index 484de1dab5..c8251d93ae 100644 --- a/zerver/management/commands/sync_ldap_user_data.py +++ b/zerver/management/commands/sync_ldap_user_data.py @@ -1,12 +1,15 @@ import logging -from typing import Any + +from argparse import ArgumentParser +from typing import Any, List + from django.conf import settings -from django.core.management.base import BaseCommand from django.db.utils import IntegrityError from zerver.lib.logging_util import log_to_file +from zerver.lib.management import ZulipBaseCommand from zerver.models import UserProfile from zproject.backends import ZulipLDAPUserPopulator, ZulipLDAPException @@ -15,10 +18,10 @@ logger = logging.getLogger(__name__) log_to_file(logger, settings.LDAP_SYNC_LOG_PATH) # Run this on a cronjob to pick up on name changes. -def sync_ldap_user_data() -> None: +def sync_ldap_user_data(user_profiles: List[UserProfile]) -> None: logger.info("Starting update.") backend = ZulipLDAPUserPopulator() - for u in UserProfile.objects.select_related().filter(is_bot=False).all(): + for u in user_profiles: # This will save the user if relevant, and will do nothing if the user # does not exist. try: @@ -31,6 +34,15 @@ def sync_ldap_user_data() -> None: logger.error(e) logger.info("Finished update.") -class Command(BaseCommand): +class Command(ZulipBaseCommand): + def add_arguments(self, parser: ArgumentParser) -> None: + self.add_realm_args(parser) + self.add_user_list_args(parser) + def handle(self, *args: Any, **options: Any) -> None: - sync_ldap_user_data() + if "realm_id" in options: + realm = self.get_realm(options) + user_profiles = self.get_users(options, realm, is_bot=False) + else: + user_profiles = UserProfile.objects.select_related().filter(is_bot=False) + sync_ldap_user_data(user_profiles) diff --git a/zerver/tests/test_management_commands.py b/zerver/tests/test_management_commands.py index c4a7cd49b0..6574ce744c 100644 --- a/zerver/tests/test_management_commands.py +++ b/zerver/tests/test_management_commands.py @@ -80,8 +80,9 @@ class TestZulipBaseCommand(ZulipTestCase): self.assertEqual(get_user_profile_by_email(email), user_profile) - def get_users_sorted(self, options: Dict[str, Any], realm: Optional[Realm]) -> List[UserProfile]: - user_profiles = self.command.get_users(options, realm) + def get_users_sorted(self, options: Dict[str, Any], realm: Optional[Realm], + is_bot: Optional[bool]=None) -> List[UserProfile]: + user_profiles = self.command.get_users(options, realm, is_bot=is_bot) return sorted(user_profiles, key = lambda x: x.email) def test_get_users(self) -> None: @@ -127,6 +128,15 @@ class TestZulipBaseCommand(ZulipTestCase): with self.assertRaisesRegex(CommandError, error_message): self.command.get_users(dict(users=None, all_users=True), None) + def test_get_non_bot_users(self) -> None: + expected_user_profiles = sorted(UserProfile.objects.filter(realm=self.zulip_realm, + is_bot=False), + key = lambda x: x.email) + user_profiles = self.get_users_sorted(dict(users=None, all_users=True), + self.zulip_realm, + is_bot=False) + self.assertEqual(user_profiles, expected_user_profiles) + class TestCommandsCanStart(TestCase): def setUp(self) -> None: