On Tue, Sep 02, 2014 at 04:19:33PM +0200, 'Helga Velroyen' via ganeti-devel 
wrote:
In order to update the 'ganeti_pub_keys' and the
'authorized_keys' files of various nodes via SSH, we
introduce the tool 'ssh_update'. It works similar to the
tool 'prepare_node_join', which is also a tool invoked
via SSH on a remote note.

This patch includes some refactoring to reuse code from
the 'prepare_node_join' tool and provides unit tests as
well. Note that the actual invocation of the 'ssh_update'
tool will be done in later patches of this series.

Signed-off-by: Helga Velroyen <[email protected]>
---
.gitignore                                         |   1 +
Makefile.am                                        |  13 +-
lib/pathutils.py                                   |   1 +
lib/tools/common.py                                | 100 +++++++++++++
lib/tools/prepare_node_join.py                     |  79 +----------
lib/tools/ssh_update.py                            | 154 +++++++++++++++++++++
test/py/ganeti.tools.prepare_node_join_unittest.py |  38 ++---
test/py/ganeti.tools.ssh_update_unittest.py        | 123 ++++++++++++++++
8 files changed, 414 insertions(+), 95 deletions(-)
create mode 100644 lib/tools/common.py
create mode 100644 lib/tools/ssh_update.py
create mode 100755 test/py/ganeti.tools.ssh_update_unittest.py

diff --git a/.gitignore b/.gitignore
index bbc5402..7e90385 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,6 +127,7 @@
/tools/node-cleanup
/tools/node-daemon-setup
/tools/prepare-node-join
+/tools/ssh-update

# scripts
/scripts/gnt-backup
diff --git a/Makefile.am b/Makefile.am
index 767ff47..ab24c93 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -312,6 +312,8 @@ CLEANFILES = \
        tools/net-common \
        tools/users-setup \
        tools/vcluster-setup \
+       tools/prepare-node-join \
+       tools/ssh-update \
        stamp-directories \
        stamp-srclinks \
        $(nodist_pkgpython_PYTHON) \
@@ -558,7 +560,9 @@ pytools_PYTHON = \
        lib/tools/ensure_dirs.py \
        lib/tools/node_cleanup.py \
        lib/tools/node_daemon_setup.py \
-       lib/tools/prepare_node_join.py
+       lib/tools/prepare_node_join.py \
+       lib/tools/common.py \
+       lib/tools/ssh_update.py

utils_PYTHON = \
        lib/utils/__init__.py \
@@ -1151,7 +1155,8 @@ PYTHON_BOOTSTRAP = \
        tools/ensure-dirs \
        tools/node-cleanup \
        tools/node-daemon-setup \
-       tools/prepare-node-join
+       tools/prepare-node-join \
+       tools/ssh-update

qa_scripts = \
        qa/__init__.py \
@@ -1324,7 +1329,8 @@ pkglib_python_scripts = \
nodist_pkglib_python_scripts = \
        tools/ensure-dirs \
        tools/node-daemon-setup \
-       tools/prepare-node-join
+       tools/prepare-node-join \
+       tools/ssh-update

pkglib_python_basenames = \
        $(patsubst daemons/%,%,$(patsubst tools/%,%,\
@@ -2221,6 +2227,7 @@ tools/burnin: MODULE = ganeti.tools.burnin
tools/ensure-dirs: MODULE = ganeti.tools.ensure_dirs
tools/node-daemon-setup: MODULE = ganeti.tools.node_daemon_setup
tools/prepare-node-join: MODULE = ganeti.tools.prepare_node_join
+tools/ssh-update: MODULE = ganeti.tools.ssh_update
tools/node-cleanup: MODULE = ganeti.tools.node_cleanup
$(HS_BUILT_TEST_HELPERS): TESTROLE = $(patsubst test/hs/%,%,$@)

diff --git a/lib/pathutils.py b/lib/pathutils.py
index 2715504..1cc02e9 100644
--- a/lib/pathutils.py
+++ b/lib/pathutils.py
@@ -55,6 +55,7 @@ IMPORT_EXPORT_DAEMON = _constants.PKGLIBDIR + "/import-export"
KVM_CONSOLE_WRAPPER = _constants.PKGLIBDIR + "/tools/kvm-console-wrapper"
KVM_IFUP = _constants.PKGLIBDIR + "/kvm-ifup"
PREPARE_NODE_JOIN = _constants.PKGLIBDIR + "/prepare-node-join"
+SSH_UPDATE = _constants.PKGLIBDIR + "/ssh-update"
NODE_DAEMON_SETUP = _constants.PKGLIBDIR + "/node-daemon-setup"
XEN_CONSOLE_WRAPPER = _constants.PKGLIBDIR + "/tools/xen-console-wrapper"
CFGUPGRADE = _constants.PKGLIBDIR + "/tools/cfgupgrade"
diff --git a/lib/tools/common.py b/lib/tools/common.py
new file mode 100644
index 0000000..de042e4
--- /dev/null
+++ b/lib/tools/common.py
@@ -0,0 +1,100 @@
+#
+#
+
+# Copyright (C) 2014 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+"""Common functions for tool scripts.
+
+"""
+
+import OpenSSL
+
+from ganeti import constants
+from ganeti import errors
+from ganeti import utils
+from ganeti import serializer
+from ganeti import ssconf
+
+
+def VerifyOptions(parser, opts, args):
+  """Verifies options and arguments for correctness.
+
+  """
+  if args:
+    parser.error("No arguments are expected")
+
+  return opts
+
+
+def _VerifyCertificate(cert_pem, error_fn,
+                       _check_fn=utils.CheckNodeCertificate):
+  """Verifies a certificate against the local node daemon certificate.
+
+  @type cert_pem: string
+  @param cert_pem: Certificate in PEM format (no key)
+
+  """
+  try:
+    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
+  except OpenSSL.crypto.Error, err:
+    pass
+  else:
+    raise error_fn("No private key may be given")
+
+  try:
+    cert = \
+      OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
+  except Exception, err:
+    raise errors.X509CertError("(stdin)",
+                               "Unable to load certificate: %s" % err)
+
+  _check_fn(cert)
+
+
+def VerifyCertificate(data, error_fn, _verify_fn=_VerifyCertificate):
+  """Verifies cluster certificate.
+
+  @type data: dict
+
+  """
+  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
+  if cert:
+    _verify_fn(cert, error_fn)
+
+
+def VerifyClusterName(data, error_fn,
+                      _verify_fn=ssconf.VerifyClusterName):
+  """Verifies cluster name.
+
+  @type data: dict
+
+  """
+  name = data.get(constants.SSHS_CLUSTER_NAME)
+  if name:
+    _verify_fn(name)
+  else:
+    raise error_fn("Cluster name must be specified")
+
+
+def LoadData(raw, data_check):
+  """Parses and verifies input data.
+
+  @rtype: dict
+
+  """
+  return serializer.LoadAndVerifyJson(raw, data_check)
diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
index 7d96c6a..ed5a227 100644
--- a/lib/tools/prepare_node_join.py
+++ b/lib/tools/prepare_node_join.py
@@ -27,17 +27,15 @@ import os.path
import optparse
import sys
import logging
-import OpenSSL

from ganeti import cli
from ganeti import constants
from ganeti import errors
from ganeti import pathutils
from ganeti import utils
-from ganeti import serializer
from ganeti import ht
from ganeti import ssh
-from ganeti import ssconf
+from ganeti.tools import common


_SSH_KEY_LIST_ITEM = \
@@ -82,65 +80,7 @@ def ParseOptions():

  (opts, args) = parser.parse_args()

-  return VerifyOptions(parser, opts, args)
-
-
-def VerifyOptions(parser, opts, args):
-  """Verifies options and arguments for correctness.
-
-  """
-  if args:
-    parser.error("No arguments are expected")
-
-  return opts
-
-
-def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
-  """Verifies a certificate against the local node daemon certificate.
-
-  @type cert_pem: string
-  @param cert_pem: Certificate in PEM format (no key)
-
-  """
-  try:
-    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
-  except OpenSSL.crypto.Error, err:
-    pass
-  else:
-    raise JoinError("No private key may be given")
-
-  try:
-    cert = \
-      OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
-  except Exception, err:
-    raise errors.X509CertError("(stdin)",
-                               "Unable to load certificate: %s" % err)
-
-  _check_fn(cert)
-
-
-def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
-  """Verifies cluster certificate.
-
-  @type data: dict
-
-  """
-  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
-  if cert:
-    _verify_fn(cert)
-
-
-def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
-  """Verifies cluster name.
-
-  @type data: dict
-
-  """
-  name = data.get(constants.SSHS_CLUSTER_NAME)
-  if name:
-    _verify_fn(name)
-  else:
-    raise JoinError("Cluster name must be specified")
+  return common.VerifyOptions(parser, opts, args)


def _UpdateKeyFiles(keys, dry_run, keyfiles):
@@ -241,15 +181,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
        ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)


-def LoadData(raw):
-  """Parses and verifies input data.
-
-  @rtype: dict
-
-  """
-  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
-
-
def Main():
  """Main routine.

@@ -259,11 +190,11 @@ def Main():
  utils.SetupToolLogging(opts.debug, opts.verbose)

  try:
-    data = LoadData(sys.stdin.read())
+    data = common.LoadData(sys.stdin.read(), _DATA_CHECK)

    # Check if input data is correct
-    VerifyClusterName(data)
-    VerifyCertificate(data)
+    common.VerifyClusterName(data, JoinError)
+    common.VerifyCertificate(data, JoinError)

    # Update SSH files
    UpdateSshDaemon(data, opts.dry_run)
diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py
new file mode 100644
index 0000000..db0f189
--- /dev/null
+++ b/lib/tools/ssh_update.py
@@ -0,0 +1,154 @@
+#
+#
+
+# Copyright (C) 2014 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+"""Script to update a node's SSH key files.
+
+This script is used to update the node's 'authorized_keys' and
+'ganeti_pub_key' files. It will be called via SSH from the master
+node.
+
+"""
+
+import os
+import os.path
+import optparse
+import sys
+import logging
+
+from ganeti import cli
+from ganeti import constants
+from ganeti import errors
+from ganeti import utils
+from ganeti import ht
+from ganeti import ssh
+from ganeti import pathutils
+from ganeti.tools import common
+
+
+_DATA_CHECK = ht.TStrictDict(False, True, {
+  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
+  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
+  constants.SSHS_SSH_PUBLIC_KEYS:
+    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+  constants.SSHS_SSH_AUTHORIZED_KEYS:
+    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+  })
+
+
+class SshUpdateError(errors.GenericError):
+  """Local class for reporting errors.
+
+  """
+
+
+def ParseOptions():
+  """Parses the options passed to the program.
+
+  @return: Options and arguments
+
+  """
+  program = os.path.basename(sys.argv[0])
+
+  parser = optparse.OptionParser(
+    usage="%prog [--dry-run] [--verbose] [--debug]", prog=program)
+  parser.add_option(cli.DEBUG_OPT)
+  parser.add_option(cli.VERBOSE_OPT)
+  parser.add_option(cli.DRY_RUN_OPT)
+
+  (opts, args) = parser.parse_args()
+
+  return common.VerifyOptions(parser, opts, args)
+
+
+def UpdateAuthorizedKeys(data, dry_run, _homedir_fn=None):
+  """Updates root's C{authorized_keys} file.
+
+  @type data: dict
+  @param data: Input data
+  @type dry_run: boolean
+  @param dry_run: Whether to perform a dry run
+
+  """
+  authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+
+  if authorized_keys:
+    (auth_keys_file, _) = \
+      ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                          _homedir_fn=_homedir_fn)
+
+    if dry_run:
+      logging.info("This is a dry run, not modifying %s", auth_keys_file)
+    else:
+      all_authorized_keys = []
+      for keys in authorized_keys.values():
+        all_authorized_keys += keys
+      if not os.path.exists(auth_keys_file):
+        utils.WriteFile(auth_keys_file, mode=0600, data="")
+      ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
+
+
+def UpdatePubKeyFile(data, dry_run, key_file=pathutils.SSH_PUB_KEYS):
+  """Updates the file of public SSH keys.
+
+  @type data: dict
+  @param data: Input data
+  @type dry_run: boolean
+  @param dry_run: Whether to perform a dry run
+
+  """
+  public_keys = data.get(constants.SSHS_SSH_PUBLIC_KEYS)
+  if not public_keys:
+    logging.info("No public keys received. Not modifying"
+                 " the public key file at all.")
+    return
+  if dry_run:
+    logging.info("This is a dry run, not modifying %s", key_file)
+  ssh.OverridePubKeyFile(public_keys, key_file=key_file)

Here 'else:' is missing, but as already explained in the comment in the previous patch series, it's fixed in a later patch in the series, so it's OK.

+
+
+def Main():
+  """Main routine.
+
+  """
+  opts = ParseOptions()
+
+  utils.SetupToolLogging(opts.debug, opts.verbose)
+
+  try:
+    data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
+
+    # Check if input data is correct
+    common.VerifyClusterName(data, SshUpdateError)
+    common.VerifyCertificate(data, SshUpdateError)
+
+    # Update SSH files
+    UpdateAuthorizedKeys(data, opts.dry_run)
+    UpdatePubKeyFile(data, opts.dry_run)
+
+    logging.info("Setup finished successfully")
+  except Exception, err: # pylint: disable=W0703
+    logging.debug("Caught unhandled exception", exc_info=True)
+
+    (retcode, message) = cli.FormatError(err)
+    logging.error(message)
+
+    return retcode
+  else:
+    return constants.EXIT_SUCCESS
diff --git a/test/py/ganeti.tools.prepare_node_join_unittest.py 
b/test/py/ganeti.tools.prepare_node_join_unittest.py
index fe4ff26..ac30f90 100755
--- a/test/py/ganeti.tools.prepare_node_join_unittest.py
+++ b/test/py/ganeti.tools.prepare_node_join_unittest.py
@@ -25,7 +25,6 @@ import unittest
import shutil
import tempfile
import os.path
-import OpenSSL

from ganeti import errors
from ganeti import constants
@@ -34,30 +33,31 @@ from ganeti import pathutils
from ganeti import compat
from ganeti import utils
from ganeti.tools import prepare_node_join
+from ganeti.tools import common

import testutils


_JoinError = prepare_node_join.JoinError
-
+_DATA_CHECK = prepare_node_join._DATA_CHECK

class TestLoadData(unittest.TestCase):
  def testNoJson(self):
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
+    self.assertRaises(errors.ParseError, common.LoadData, "", _DATA_CHECK)
+    self.assertRaises(errors.ParseError, common.LoadData, "}", _DATA_CHECK)

  def testInvalidDataStructure(self):
    raw = serializer.DumpJson({
      "some other thing": False,
      })
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
+    self.assertRaises(errors.ParseError, common.LoadData, raw, _DATA_CHECK)

    raw = serializer.DumpJson([])
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
+    self.assertRaises(errors.ParseError, common.LoadData, raw, _DATA_CHECK)

  def testEmptyDict(self):
    raw = serializer.DumpJson({})
-    self.assertEqual(prepare_node_join.LoadData(raw), {})
+    self.assertEqual(common.LoadData(raw, _DATA_CHECK), {})

  def testValidData(self):
    key_list = [[constants.SSHK_DSA, "private foo", "public bar"]]
@@ -69,7 +69,7 @@ class TestLoadData(unittest.TestCase):
        {"nodeuuid01234": ["foo"],
         "nodeuuid56789": ["bar"]}}
    raw = serializer.DumpJson(data_dict)
-    self.assertEqual(prepare_node_join.LoadData(raw), data_dict)
+    self.assertEqual(common.LoadData(raw, _DATA_CHECK), data_dict)


class TestVerifyCertificate(testutils.GanetiTestCase):
@@ -82,20 +82,21 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
    shutil.rmtree(self.tmpdir)

  def testNoCert(self):
-    prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
+    common.VerifyCertificate({}, error_fn=prepare_node_join.JoinError,
+                             _verify_fn=NotImplemented)

  def testGivenPrivateKey(self):
    cert_filename = testutils.TestDataFilename("cert2.pem")
    cert_pem = utils.ReadFile(cert_filename)

-    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
-                      cert_pem, _check_fn=NotImplemented)
+    self.assertRaises(_JoinError, common._VerifyCertificate,
+                      cert_pem, _JoinError, _check_fn=NotImplemented)

  def testInvalidCertificate(self):
    self.assertRaises(errors.X509CertError,
-                      prepare_node_join._VerifyCertificate,
+                      common._VerifyCertificate,
                      "Something that's not a certificate",
-                      _check_fn=NotImplemented)
+                      _JoinError, _check_fn=NotImplemented)

  @staticmethod
  def _Check(cert):
@@ -104,7 +105,8 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
  def testSuccessfulCheck(self):
    cert_filename = testutils.TestDataFilename("cert1.pem")
    cert_pem = utils.ReadFile(cert_filename)
-    prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
+    common._VerifyCertificate(cert_pem, _JoinError,
+      _check_fn=self._Check)


class TestVerifyClusterName(unittest.TestCase):
@@ -117,8 +119,8 @@ class TestVerifyClusterName(unittest.TestCase):
    shutil.rmtree(self.tmpdir)

  def testNoName(self):
-    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
-                      {}, _verify_fn=NotImplemented)
+    self.assertRaises(_JoinError, common.VerifyClusterName,
+                      {}, _JoinError, _verify_fn=NotImplemented)

  @staticmethod
  def _FailingVerify(name):
@@ -130,8 +132,8 @@ class TestVerifyClusterName(unittest.TestCase):
      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
      }

-    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
-                      data, _verify_fn=self._FailingVerify)
+    self.assertRaises(errors.GenericError, common.VerifyClusterName,
+                      data, _JoinError, _verify_fn=self._FailingVerify)


class TestUpdateSshDaemon(unittest.TestCase):
diff --git a/test/py/ganeti.tools.ssh_update_unittest.py 
b/test/py/ganeti.tools.ssh_update_unittest.py
new file mode 100755
index 0000000..af3205a
--- /dev/null
+++ b/test/py/ganeti.tools.ssh_update_unittest.py
@@ -0,0 +1,123 @@
+#!/usr/bin/python
+#
+
+# Copyright (C) 2014 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+
+"""Script for testing ganeti.tools.ssh_update"""
+
+import unittest
+import shutil
+import tempfile
+import os.path
+
+from ganeti import constants
+from ganeti import utils
+from ganeti.tools import ssh_update
+
+import testutils
+
+
+_JoinError = ssh_update.SshUpdateError
+_DATA_CHECK = ssh_update._DATA_CHECK
+
+
+class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+    self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def _GetHomeDir(self, user):
+    self.assertEqual(user, constants.SSH_LOGIN_USER)
+    return self.tmpdir
+
+  def testNoKeys(self):
+    data_empty_keys = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: {},
+      }
+
+    for data in [{}, data_empty_keys]:
+      for dry_run in [False, True]:
+        ssh_update.UpdateAuthorizedKeys(data, dry_run,
+                                        _homedir_fn=NotImplemented)
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testDryRun(self):
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+        "node1" : ["key11", "key12", "key13"],
+        "node2" : ["key21", "key22"]},
+      }
+
+    ssh_update.UpdateAuthorizedKeys(data, True,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(os.listdir(self.sshdir), [])
+
+  def testUpdate(self):
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+        "node1": ["key11", "key12"],
+        "node2": ["key21"]},
+      }
+
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(sorted(os.listdir(self.sshdir)),
+                     sorted(["authorized_keys"]))
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\nkey12\nkey21\n")
+
+
+class TestUpdatePubKeyFile(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+
+  def testNoKeys(self):
+    pub_key_file = self._CreateTempFile()
+    data_empty_keys = {
+      constants.SSHS_SSH_PUBLIC_KEYS: {},
+      }
+
+    for data in [{}, data_empty_keys]:
+      for dry_run in [False, True]:
+        ssh_update.UpdatePubKeyFile(data, dry_run,
+                                    key_file=pub_key_file)
+    self.assertEqual(utils.ReadFile(pub_key_file), "")
+
+  def testValidKeys(self):
+    pub_key_file = self._CreateTempFile()
+    data = {
+      constants.SSHS_SSH_PUBLIC_KEYS: {
+        "node1": ["key11", "key12"],
+        "node2": ["key21"]},
+      }
+    ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file)
+    self.assertEqual(utils.ReadFile(pub_key_file),
+      "node1 key11\nnode1 key12\nnode2 key21\n")
+
+
+if __name__ == "__main__":
+  testutils.GanetiTestProgram()
--
2.1.0.rc2.206.gedb03e5


LGTM

Reply via email to