This is a new tool as per the design document “design-ssh-setup”. It
receives a JSON data structure on its standard input and configures the
SSH daemon and root's SSH keys accordingly. Unit tests are included.

Signed-off-by: Michael Hanselmann <[email protected]>
---
 .gitignore                                      |    1 +
 Makefile.am                                     |    8 +-
 lib/constants.py                                |   12 +
 lib/ssh.py                                      |    5 +-
 lib/tools/prepare_node_join.py                  |  369 +++++++++++++++++++++++
 lib/utils/io.py                                 |    5 +-
 test/data/cert2.pem                             |   22 ++
 test/ganeti.tools.prepare_node_join_unittest.py |  309 +++++++++++++++++++
 8 files changed, 724 insertions(+), 7 deletions(-)
 create mode 100644 lib/tools/prepare_node_join.py
 create mode 100644 test/data/cert2.pem
 create mode 100755 test/ganeti.tools.prepare_node_join_unittest.py

diff --git a/.gitignore b/.gitignore
index 3d5cb0f..2759bc8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -94,6 +94,7 @@
 /tools/kvm-ifup
 /tools/ensure-dirs
 /tools/vcluster-setup
+/tools/prepare-node-join
 
 # scripts
 /scripts/gnt-backup
diff --git a/Makefile.am b/Makefile.am
index f9b4a94..78282a5 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -578,7 +578,8 @@ PYTHON_BOOTSTRAP_SBIN = \
 
 PYTHON_BOOTSTRAP = \
        $(PYTHON_BOOTSTRAP_SBIN) \
-       tools/ensure-dirs
+       tools/ensure-dirs \
+       tools/prepare-node-join
 
 qa_scripts = \
        qa/__init__.py \
@@ -690,7 +691,8 @@ pkglib_python_scripts = \
        tools/check-cert-expired
 
 nodist_pkglib_python_scripts = \
-       tools/ensure-dirs
+       tools/ensure-dirs \
+       tools/prepare-node-join
 
 myexeclib_SCRIPTS = \
        daemons/daemon-util \
@@ -926,6 +928,7 @@ python_tests = \
        test/ganeti.ssh_unittest.py \
        test/ganeti.storage_unittest.py \
        test/ganeti.tools.ensure_dirs_unittest.py \
+       test/ganeti.tools.prepare_node_join_unittest.py \
        test/ganeti.uidpool_unittest.py \
        test/ganeti.utils.algo_unittest.py \
        test/ganeti.utils.filelock_unittest.py \
@@ -1327,6 +1330,7 @@ daemons/ganeti-%: MODULE = ganeti.server.$(patsubst 
ganeti-%,%,$(notdir $@))
 daemons/ganeti-watcher: MODULE = ganeti.watcher
 scripts/%: MODULE = ganeti.client.$(subst -,_,$(notdir $@))
 tools/ensure-dirs: MODULE = ganeti.tools.ensure_dirs
+tools/prepare-node-join: MODULE = ganeti.tools.prepare_node_join
 $(HS_BUILT_TEST_HELPERS): TESTROLE = $(patsubst htest/%,%,$@)
 
 $(PYTHON_BOOTSTRAP): Makefile | stamp-directories
diff --git a/lib/constants.py b/lib/constants.py
index 39b6895..35cf53b 100644
--- a/lib/constants.py
+++ b/lib/constants.py
@@ -2049,5 +2049,17 @@ SSHK_RSA = "rsa"
 SSHK_DSA = "dsa"
 SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
 
+# SSH authorized key types
+SSHAK_RSA = "ssh-rsa"
+SSHAK_DSS = "ssh-dss"
+SSHAK_ALL = frozenset([SSHAK_RSA, SSHAK_DSS])
+
+# SSH setup
+SSHS_CLUSTER_NAME = "cluster_name"
+SSHS_FORCE = "force"
+SSHS_SSH_HOST_KEY = "ssh_host_key"
+SSHS_SSH_ROOT_KEY = "ssh_root_key"
+SSHS_NODE_DAEMON_CERTIFICATE = "node_daemon_certificate"
+
 # Do not re-export imported modules
 del re, _vcsversion, _autoconf, socket, pathutils
diff --git a/lib/ssh.py b/lib/ssh.py
index 1e35feb..27c197f 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -50,7 +50,7 @@ def FormatParamikoFingerprint(fingerprint):
 
 
 def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
-                 _homedir_fn=utils.GetHomeDir):
+                 _homedir_fn=None):
   """Return the paths of a user's SSH files.
 
   @type user: string
@@ -68,6 +68,9 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
     exception is raised if C{~$user/.ssh} is not a directory
 
   """
+  if _homedir_fn is None:
+    _homedir_fn = utils.GetHomeDir
+
   user_dir = _homedir_fn(user)
   if not user_dir:
     raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
new file mode 100644
index 0000000..4531e8c
--- /dev/null
+++ b/lib/tools/prepare_node_join.py
@@ -0,0 +1,369 @@
+#
+#
+
+# Copyright (C) 2012 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+"""Script to prepare a node for joining a cluster.
+
+"""
+
+import os
+import os.path
+import optparse
+import sys
+import logging
+import errno
+import OpenSSL
+
+from ganeti import cli
+from ganeti import constants
+from ganeti import errors
+from ganeti import pathutils
+from ganeti import utils
+from ganeti import serializer
+from ganeti import ht
+from ganeti import ssh
+from ganeti import ssconf
+
+
+_SSH_KEY_LIST = \
+  ht.TListOf(ht.TAnd(ht.TIsLength(3),
+                     ht.TItems([
+                       ht.TElemOf(constants.SSHK_ALL),
+                       ht.Comment("public")(ht.TNonEmptyString),
+                       ht.Comment("private")(ht.TNonEmptyString),
+                       ])))
+
+_DATA_CHECK = ht.TStrictDict(False, True, {
+  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
+  constants.SSHS_FORCE: ht.TBool,
+  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
+  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
+  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
+  })
+
+_SSHK_TO_SSHAK = {
+  constants.SSHK_RSA: constants.SSHAK_RSA,
+  constants.SSHK_DSA: constants.SSHAK_DSS,
+  }
+
+_SSH_DAEMON_KEYFILES = {
+  constants.SSHK_RSA:
+    (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
+  constants.SSHK_DSA:
+    (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
+    }
+
+assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL
+assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL
+
+
+class JoinError(errors.GenericError):
+  """Local class for reporting errors.
+
+  """
+
+
+def ParseOptions():
+  """Parses the options passed to the program.
+
+  @return: Options and arguments
+
+  """
+  program = os.path.basename(sys.argv[0])
+
+  parser = optparse.OptionParser(usage="%prog [--dry-run]",
+                                 prog=program)
+  parser.add_option(cli.DEBUG_OPT)
+  parser.add_option(cli.VERBOSE_OPT)
+  parser.add_option(cli.DRY_RUN_OPT)
+
+  (opts, args) = parser.parse_args()
+
+  return VerifyOptions(parser, opts, args)
+
+
+def VerifyOptions(parser, opts, args):
+  """Verifies options and arguments for correctness.
+
+  """
+  if args:
+    parser.error("No arguments are expected")
+
+  return opts
+
+
+def SetupLogging(opts):
+  """Configures the logging module.
+
+  """
+  formatter = logging.Formatter("%(asctime)s: %(message)s")
+
+  stderr_handler = logging.StreamHandler()
+  stderr_handler.setFormatter(formatter)
+  if opts.debug:
+    stderr_handler.setLevel(logging.NOTSET)
+  elif opts.verbose:
+    stderr_handler.setLevel(logging.INFO)
+  else:
+    stderr_handler.setLevel(logging.WARNING)
+
+  root_logger = logging.getLogger("")
+  root_logger.setLevel(logging.NOTSET)
+  root_logger.addHandler(stderr_handler)
+
+
+def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
+  """Verifies a certificate against the local node daemon certificate.
+
+  @type cert: string
+  @param cert: Certificate in PEM format (no key)
+
+  """
+  try:
+    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
+  except OpenSSL.crypto.Error, err:
+    pass
+  else:
+    raise JoinError("No private key may be given")
+
+  try:
+    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
+  except Exception, err:
+    raise errors.X509CertError("(stdin)",
+                               "Unable to load certificate: %s" % err)
+
+  try:
+    noded_pem = utils.ReadFile(_noded_cert_file)
+  except EnvironmentError, err:
+    if err.errno != errno.ENOENT:
+      raise
+
+    logging.debug("Local node certificate was not found (file %s)",
+                  _noded_cert_file)
+    return
+
+  try:
+    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, 
noded_pem)
+  except Exception, err:
+    raise errors.X509CertError(_noded_cert_file,
+                               "Unable to load private key: %s" % err)
+
+  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
+  ctx.use_privatekey(key)
+  ctx.use_certificate(cert)
+  try:
+    ctx.check_privatekey()
+  except OpenSSL.SSL.Error:
+    raise JoinError("Given cluster certificate does not match local key")
+
+
+def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
+  """Verifies cluster certificate.
+
+  @type data: dict
+
+  """
+  try:
+    cert = data[constants.SSHS_NODE_DAEMON_CERTIFICATE]
+  except KeyError:
+    pass
+  else:
+    _verify_fn(cert)
+
+
+def _VerifyClusterName(name, _ss_cluster_name_file=None):
+  """Verifies cluster name against a local cluster name.
+
+  @type name: string
+  @param name: Cluster name
+
+  """
+  if _ss_cluster_name_file is None:
+    _ss_cluster_name_file = \
+      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
+
+  try:
+    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
+  except EnvironmentError, err:
+    if err.errno != errno.ENOENT:
+      raise
+
+    logging.debug("Local cluster name was not found (file %s)",
+                  _ss_cluster_name_file)
+  else:
+    if name != local_name:
+      raise JoinError("Current cluster name is '%s'" % local_name)
+
+
+def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
+  """Verifies cluster name.
+
+  @type data: dict
+
+  """
+  try:
+    name = data[constants.SSHS_CLUSTER_NAME]
+  except KeyError:
+    raise JoinError("Cluster name must be specified")
+  else:
+    _verify_fn(name)
+
+
+def _UpdateKeyFiles(keys, dry_run, keyfiles):
+  """Updates SSH key files.
+
+  @type keys: sequence of tuple; (string, string, string)
+  @param keys: Keys to write, tuples consist of key type
+    (L{constants.SSHK_ALL}), public and private key
+  @type dry_run: boolean
+  @param dry_run: Whether to perform a dry run
+  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
+  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
+    names; value tuples consist of public key filename and private key filename
+
+  """
+  assert set(keyfiles) == constants.SSHK_ALL
+
+  for (kind, public_key, private_key) in keys:
+    (public_file, private_file) = keyfiles[kind]
+
+    logging.debug("Writing %s ...", public_file)
+    utils.WriteFile(public_file, data=public_key, mode=0644,
+                    backup=True, dry_run=dry_run)
+
+    logging.debug("Writing %s ...", private_file)
+    utils.WriteFile(private_file, data=private_key, mode=0600,
+                    backup=True, dry_run=dry_run)
+
+
+def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
+                    _keyfiles=None):
+  """Updates SSH daemon's keys.
+
+  Unless C{dry_run} is set, the daemon is restarted at the end.
+
+  @type data: dict
+  @param data: Input data
+  @type dry_run: boolean
+  @param dry_run: Whether to perform a dry run
+
+  """
+  keys = data.get(constants.SSHS_SSH_HOST_KEY)
+  if not keys:
+    return
+
+  if _keyfiles is None:
+    _keyfiles = _SSH_DAEMON_KEYFILES
+
+  logging.info("Updating SSH daemon key files")
+  _UpdateKeyFiles(keys, dry_run, _keyfiles)
+
+  if dry_run:
+    logging.info("This is a dry run, not restarting SSH daemon")
+  else:
+    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
+                        interactive=True)
+    if result.failed:
+      raise JoinError("Could not reload SSH keys, command '%s'"
+                      " had exitcode %s and error %s" %
+                       (result.cmd, result.exit_code, result.output))
+
+
+def UpdateSshRoot(data, dry_run, _homedir_fn=None):
+  """Updates root's SSH keys.
+
+  Root's C{authorized_keys} file is also updated with new public keys.
+
+  @type data: dict
+  @param data: Input data
+  @type dry_run: boolean
+  @param dry_run: Whether to perform a dry run
+
+  """
+  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
+  if not keys:
+    return
+
+  (dsa_private_file, dsa_public_file, auth_keys_file) = \
+    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                     kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
+  (rsa_private_file, rsa_public_file, _) = \
+    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                     kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
+
+  _UpdateKeyFiles(keys, dry_run, {
+    constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
+    constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
+    })
+
+  if dry_run:
+    logging.info("This is a dry run, not modifying %s", auth_keys_file)
+  else:
+    for (kind, public_key, _) in keys:
+      line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key)
+      utils.AddAuthorizedKey(auth_keys_file, line)
+
+
+def LoadData(raw):
+  """Parses and verifies input data.
+
+  @rtype: dict
+
+  """
+  try:
+    data = serializer.LoadJson(raw)
+  except Exception, err:
+    raise errors.ParseError("Can't parse input data: %s" % err)
+
+  if not _DATA_CHECK(data):
+    raise errors.ParseError("Input data does not match expected format: %s" %
+                            _DATA_CHECK)
+
+  return data
+
+
+def Main():
+  """Main routine.
+
+  """
+  opts = ParseOptions()
+
+  SetupLogging(opts)
+
+  try:
+    data = LoadData(sys.stdin.read())
+
+    # Check if input data is correct
+    VerifyClusterName(data)
+    VerifyCertificate(data)
+
+    # Update SSH files
+    UpdateSshDaemon(data, opts.dry_run)
+    UpdateSshRoot(data, opts.dry_run)
+
+    logging.info("Setup finished successfully")
+  except Exception, err: # pylint: disable=W0703
+    logging.debug("Caught unhandled exception", exc_info=True)
+
+    (retcode, message) = cli.FormatError(err)
+    logging.error(message)
+
+    return retcode
+  else:
+    return constants.EXIT_SUCCESS
diff --git a/lib/utils/io.py b/lib/utils/io.py
index a10ba84..bb34b11 100644
--- a/lib/utils/io.py
+++ b/lib/utils/io.py
@@ -828,9 +828,6 @@ def ReadLockedPidFile(path):
   return None
 
 
-_SSH_KEYS_WITH_TWO_PARTS = frozenset(["ssh-dss", "ssh-rsa"])
-
-
 def _SplitSshKey(key):
   """Splits a line for SSH's C{authorized_keys} file.
 
@@ -845,7 +842,7 @@ def _SplitSshKey(key):
   """
   parts = key.split()
 
-  if parts and parts[0] in _SSH_KEYS_WITH_TWO_PARTS:
+  if parts and parts[0] in constants.SSHAK_ALL:
     # If the key has no options in front of it, we only want the significant
     # fields
     return (False, parts[:2])
diff --git a/test/data/cert2.pem b/test/data/cert2.pem
new file mode 100644
index 0000000..0d15d55
--- /dev/null
+++ b/test/data/cert2.pem
@@ -0,0 +1,22 @@
+-----BEGIN PRIVATE KEY-----
+MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAt8OZYvvi8noVPlpR
+/SrHcya9ne7RG5DjvMssksUqyGriUs/WGnpZlL4nz+BcLFGwNNntoxqR30Tjk47S
+cmSBRQIDAQABAkAqTP5MCMuPIYcuWUAyVNygpzRS3JyKCepClUpnZreYdo4sUQE3
+/AM7xeb92R06iZ3f9/MPrbaMKTWRh3uCyfKBAiEA5TxdacnVxdS8+ZLyys4p/C1s
+iajrarBb/j+NIAnsdnECIQDNOCDO7Jq/iN5qE4Vbi/3zmnP1Ca5aBo+KJ/hhSjRq
+FQIgIBpWEqybbXsfg+waaGB67MAHxTeM0IImP/LydpwtK2ECIB3SrlHj6Ik1Jr1b
+oOGw8nLYW0mc4o2KrolxTZM16XARAiBKW3aSjY5UrnoEqa8pAeiO8LJaRj73Epmr
+zC89IuLZfg==
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIIB0zCCAX2gAwIBAgIJAKrAqGX6UolVMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
+BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
+aWRnaXRzIFB0eSBMdGQwHhcNMTIxMDE5MTQ1NjA4WhcNMTIxMDIwMTQ1NjA4WjBF
+MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
+ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALfD
+mWL74vJ6FT5aUf0qx3MmvZ3u0RuQ47zLLJLFKshq4lLP1hp6WZS+J8/gXCxRsDTZ
+7aMakd9E45OO0nJkgUUCAwEAAaNQME4wHQYDVR0OBBYEFA1Fc/GIVtd6nMocrSsA
+e5bxmVhMMB8GA1UdIwQYMBaAFA1Fc/GIVtd6nMocrSsAe5bxmVhMMAwGA1UdEwQF
+MAMBAf8wDQYJKoZIhvcNAQEFBQADQQCTUwzDGU+IJTQ3PIJrA3fHMyKbBvc4Rkvi
+ZNFsmgsidWhb+5APlPjtlS7rXlonNHBzDoGb4RNArtxhEx+rBcAE
+-----END CERTIFICATE-----
diff --git a/test/ganeti.tools.prepare_node_join_unittest.py 
b/test/ganeti.tools.prepare_node_join_unittest.py
new file mode 100755
index 0000000..b629f77
--- /dev/null
+++ b/test/ganeti.tools.prepare_node_join_unittest.py
@@ -0,0 +1,309 @@
+#!/usr/bin/python
+#
+
+# Copyright (C) 2012 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+
+"""Script for testing ganeti.tools.prepare_node_join"""
+
+import unittest
+import shutil
+import tempfile
+import os.path
+import OpenSSL
+
+from ganeti import errors
+from ganeti import constants
+from ganeti import serializer
+from ganeti import pathutils
+from ganeti import compat
+from ganeti import utils
+from ganeti.tools import prepare_node_join
+
+import testutils
+
+
+_JoinError = prepare_node_join.JoinError
+
+
+class TestLoadData(unittest.TestCase):
+  def testNoJson(self):
+    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
+    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
+
+  def testInvalidDataStructure(self):
+    raw = serializer.DumpJson({
+      "some other thing": False,
+      })
+    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
+
+    raw = serializer.DumpJson([])
+    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
+
+  def testValidData(self):
+    raw = serializer.DumpJson({})
+    self.assertEqual(prepare_node_join.LoadData(raw), {})
+
+
+class TestVerifyCertificate(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    testutils.GanetiTestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoCert(self):
+    prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
+
+  def testMismatchingKey(self):
+    other_cert = self._TestDataFilename("cert1.pem")
+    node_cert = self._TestDataFilename("cert2.pem")
+
+    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
+                      utils.ReadFile(other_cert), _noded_cert_file=node_cert)
+
+  def testGivenPrivateKey(self):
+    cert_filename = self._TestDataFilename("cert2.pem")
+    cert_pem = utils.ReadFile(cert_filename)
+
+    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
+                      cert_pem, _noded_cert_file=cert_filename)
+
+  def testMatchingKey(self):
+    cert_filename = self._TestDataFilename("cert2.pem")
+
+    # Extract certificate
+    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                           utils.ReadFile(cert_filename))
+    cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                               cert)
+
+    prepare_node_join._VerifyCertificate(cert_pem,
+                                         _noded_cert_file=cert_filename)
+
+  def testMissingFile(self):
+    cert = self._TestDataFilename("cert1.pem")
+    nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
+    prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
+                                         _noded_cert_file=nodecert)
+
+  def testInvalidCertificate(self):
+    self.assertRaises(errors.X509CertError,
+                      prepare_node_join._VerifyCertificate,
+                      "Something that's not a certificate",
+                      _noded_cert_file=NotImplemented)
+
+  def testNoPrivateKey(self):
+    cert = self._TestDataFilename("cert1.pem")
+    self.assertRaises(errors.X509CertError,
+                      prepare_node_join._VerifyCertificate,
+                      utils.ReadFile(cert), _noded_cert_file=cert)
+
+
+class TestVerifyClusterName(unittest.TestCase):
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoName(self):
+    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
+                      {}, _verify_fn=NotImplemented)
+
+  def testMissingFile(self):
+    tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
+    prepare_node_join._VerifyClusterName(NotImplemented,
+                                         _ss_cluster_name_file=tmpfile)
+
+  def testMatchingName(self):
+    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
+
+    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
+      utils.WriteFile(tmpfile, data=content)
+      prepare_node_join._VerifyClusterName("cluster.example.com",
+                                           _ss_cluster_name_file=tmpfile)
+
+  def testNameMismatch(self):
+    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
+
+    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
+      utils.WriteFile(tmpfile, data=content)
+      self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
+                        "cluster.example.com", _ss_cluster_name_file=tmpfile)
+
+
+class TestUpdateSshDaemon(unittest.TestCase):
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+    self.keyfiles = {
+      constants.SSHK_RSA:
+        (utils.PathJoin(self.tmpdir, "rsa.public"),
+         utils.PathJoin(self.tmpdir, "rsa.private")),
+      constants.SSHK_DSA:
+        (utils.PathJoin(self.tmpdir, "dsa.public"),
+         utils.PathJoin(self.tmpdir, "dsa.private")),
+      }
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoKeys(self):
+    data_empty_keys = {
+      constants.SSHS_SSH_HOST_KEY: [],
+      }
+
+    for data in [{}, data_empty_keys]:
+      for dry_run in [False, True]:
+        prepare_node_join.UpdateSshDaemon(data, dry_run,
+                                          _runcmd_fn=NotImplemented,
+                                          _keyfiles=NotImplemented)
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def _TestDryRun(self, data):
+    prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
+                                      _keyfiles=self.keyfiles)
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testDryRunRsa(self):
+    self._TestDryRun({
+      constants.SSHS_SSH_HOST_KEY: [
+        (constants.SSHK_RSA, "rsapub", "rsapriv"),
+        ],
+      })
+
+  def testDryRunDsa(self):
+    self._TestDryRun({
+      constants.SSHS_SSH_HOST_KEY: [
+        (constants.SSHK_DSA, "dsapub", "dsapriv"),
+        ],
+      })
+
+  def _RunCmd(self, fail, cmd, interactive=NotImplemented):
+    self.assertTrue(interactive)
+    self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
+    if fail:
+      exit_code = constants.EXIT_FAILURE
+    else:
+      exit_code = constants.EXIT_SUCCESS
+    return utils.RunResult(exit_code, None, "stdout", "stderr",
+                           utils.ShellQuoteArgs(cmd),
+                           NotImplemented, NotImplemented)
+
+  def _TestUpdate(self, failcmd):
+    data = {
+      constants.SSHS_SSH_HOST_KEY: [
+        (constants.SSHK_DSA, "dsapub", "dsapriv"),
+        (constants.SSHK_RSA, "rsapub", "rsapriv"),
+        ],
+      }
+    runcmd_fn = compat.partial(self._RunCmd, failcmd)
+    if failcmd:
+      self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
+                        data, False, _runcmd_fn=runcmd_fn,
+                        _keyfiles=self.keyfiles)
+    else:
+      prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
+                                        _keyfiles=self.keyfiles)
+    self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
+      "rsa.private", "rsa.public",
+      "dsa.private", "dsa.public",
+      ]))
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
+                     "rsapub")
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, 
"rsa.private")),
+                     "rsapriv")
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
+                     "dsapub")
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, 
"dsa.private")),
+                     "dsapriv")
+
+  def testSuccess(self):
+    self._TestUpdate(False)
+
+  def testFailure(self):
+    self._TestUpdate(True)
+
+
+class TestUpdateSshRoot(unittest.TestCase):
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+    self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def _GetHomeDir(self, user):
+    self.assertEqual(user, constants.SSH_LOGIN_USER)
+    return self.tmpdir
+
+  def testNoKeys(self):
+    data_empty_keys = {
+      constants.SSHS_SSH_ROOT_KEY: [],
+      }
+
+    for data in [{}, data_empty_keys]:
+      for dry_run in [False, True]:
+        prepare_node_join.UpdateSshRoot(data, dry_run,
+                                        _homedir_fn=NotImplemented)
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testDryRun(self):
+    data = {
+      constants.SSHS_SSH_ROOT_KEY: [
+        (constants.SSHK_RSA, "aaa", "bbb"),
+        ]
+      }
+
+    prepare_node_join.UpdateSshRoot(data, True,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(os.listdir(self.sshdir), [])
+
+  def testUpdate(self):
+    data = {
+      constants.SSHS_SSH_ROOT_KEY: [
+        (constants.SSHK_DSA, "pubdsa", "privatedsa"),
+        ]
+      }
+
+    prepare_node_join.UpdateSshRoot(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(sorted(os.listdir(self.sshdir)),
+                     sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
+                     "privatedsa")
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
+                     "pubdsa")
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "ssh-dss pubdsa\n")
+
+
+if __name__ == "__main__":
+  testutils.GanetiTestProgram()
-- 
1.7.7.3

Reply via email to