This is a new tool as per the design document “design-ssh-setup”. It receives a JSON data structure on its standard input and configures the SSH daemon and root's SSH keys accordingly. Unit tests are included.
Signed-off-by: Michael Hanselmann <[email protected]> --- .gitignore | 1 + Makefile.am | 8 +- lib/constants.py | 12 + lib/ssh.py | 5 +- lib/tools/prepare_node_join.py | 369 +++++++++++++++++++++++ lib/utils/io.py | 5 +- test/data/cert2.pem | 22 ++ test/ganeti.tools.prepare_node_join_unittest.py | 309 +++++++++++++++++++ 8 files changed, 724 insertions(+), 7 deletions(-) create mode 100644 lib/tools/prepare_node_join.py create mode 100644 test/data/cert2.pem create mode 100755 test/ganeti.tools.prepare_node_join_unittest.py diff --git a/.gitignore b/.gitignore index 3d5cb0f..2759bc8 100644 --- a/.gitignore +++ b/.gitignore @@ -94,6 +94,7 @@ /tools/kvm-ifup /tools/ensure-dirs /tools/vcluster-setup +/tools/prepare-node-join # scripts /scripts/gnt-backup diff --git a/Makefile.am b/Makefile.am index f9b4a94..78282a5 100644 --- a/Makefile.am +++ b/Makefile.am @@ -578,7 +578,8 @@ PYTHON_BOOTSTRAP_SBIN = \ PYTHON_BOOTSTRAP = \ $(PYTHON_BOOTSTRAP_SBIN) \ - tools/ensure-dirs + tools/ensure-dirs \ + tools/prepare-node-join qa_scripts = \ qa/__init__.py \ @@ -690,7 +691,8 @@ pkglib_python_scripts = \ tools/check-cert-expired nodist_pkglib_python_scripts = \ - tools/ensure-dirs + tools/ensure-dirs \ + tools/prepare-node-join myexeclib_SCRIPTS = \ daemons/daemon-util \ @@ -926,6 +928,7 @@ python_tests = \ test/ganeti.ssh_unittest.py \ test/ganeti.storage_unittest.py \ test/ganeti.tools.ensure_dirs_unittest.py \ + test/ganeti.tools.prepare_node_join_unittest.py \ test/ganeti.uidpool_unittest.py \ test/ganeti.utils.algo_unittest.py \ test/ganeti.utils.filelock_unittest.py \ @@ -1327,6 +1330,7 @@ daemons/ganeti-%: MODULE = ganeti.server.$(patsubst ganeti-%,%,$(notdir $@)) daemons/ganeti-watcher: MODULE = ganeti.watcher scripts/%: MODULE = ganeti.client.$(subst -,_,$(notdir $@)) tools/ensure-dirs: MODULE = ganeti.tools.ensure_dirs +tools/prepare-node-join: MODULE = ganeti.tools.prepare_node_join $(HS_BUILT_TEST_HELPERS): TESTROLE = $(patsubst htest/%,%,$@) $(PYTHON_BOOTSTRAP): Makefile | stamp-directories diff --git a/lib/constants.py b/lib/constants.py index 39b6895..35cf53b 100644 --- a/lib/constants.py +++ b/lib/constants.py @@ -2049,5 +2049,17 @@ SSHK_RSA = "rsa" SSHK_DSA = "dsa" SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA]) +# SSH authorized key types +SSHAK_RSA = "ssh-rsa" +SSHAK_DSS = "ssh-dss" +SSHAK_ALL = frozenset([SSHAK_RSA, SSHAK_DSS]) + +# SSH setup +SSHS_CLUSTER_NAME = "cluster_name" +SSHS_FORCE = "force" +SSHS_SSH_HOST_KEY = "ssh_host_key" +SSHS_SSH_ROOT_KEY = "ssh_root_key" +SSHS_NODE_DAEMON_CERTIFICATE = "node_daemon_certificate" + # Do not re-export imported modules del re, _vcsversion, _autoconf, socket, pathutils diff --git a/lib/ssh.py b/lib/ssh.py index 1e35feb..27c197f 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -50,7 +50,7 @@ def FormatParamikoFingerprint(fingerprint): def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, - _homedir_fn=utils.GetHomeDir): + _homedir_fn=None): """Return the paths of a user's SSH files. @type user: string @@ -68,6 +68,9 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, exception is raised if C{~$user/.ssh} is not a directory """ + if _homedir_fn is None: + _homedir_fn = utils.GetHomeDir + user_dir = _homedir_fn(user) if not user_dir: raise errors.OpExecError("Cannot resolve home of user '%s'" % user) diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py new file mode 100644 index 0000000..4531e8c --- /dev/null +++ b/lib/tools/prepare_node_join.py @@ -0,0 +1,369 @@ +# +# + +# Copyright (C) 2012 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +"""Script to prepare a node for joining a cluster. + +""" + +import os +import os.path +import optparse +import sys +import logging +import errno +import OpenSSL + +from ganeti import cli +from ganeti import constants +from ganeti import errors +from ganeti import pathutils +from ganeti import utils +from ganeti import serializer +from ganeti import ht +from ganeti import ssh +from ganeti import ssconf + + +_SSH_KEY_LIST = \ + ht.TListOf(ht.TAnd(ht.TIsLength(3), + ht.TItems([ + ht.TElemOf(constants.SSHK_ALL), + ht.Comment("public")(ht.TNonEmptyString), + ht.Comment("private")(ht.TNonEmptyString), + ]))) + +_DATA_CHECK = ht.TStrictDict(False, True, { + constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString, + constants.SSHS_FORCE: ht.TBool, + constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString, + constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST, + constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST, + }) + +_SSHK_TO_SSHAK = { + constants.SSHK_RSA: constants.SSHAK_RSA, + constants.SSHK_DSA: constants.SSHAK_DSS, + } + +_SSH_DAEMON_KEYFILES = { + constants.SSHK_RSA: + (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV), + constants.SSHK_DSA: + (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV), + } + +assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL +assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL + + +class JoinError(errors.GenericError): + """Local class for reporting errors. + + """ + + +def ParseOptions(): + """Parses the options passed to the program. + + @return: Options and arguments + + """ + program = os.path.basename(sys.argv[0]) + + parser = optparse.OptionParser(usage="%prog [--dry-run]", + prog=program) + parser.add_option(cli.DEBUG_OPT) + parser.add_option(cli.VERBOSE_OPT) + parser.add_option(cli.DRY_RUN_OPT) + + (opts, args) = parser.parse_args() + + return VerifyOptions(parser, opts, args) + + +def VerifyOptions(parser, opts, args): + """Verifies options and arguments for correctness. + + """ + if args: + parser.error("No arguments are expected") + + return opts + + +def SetupLogging(opts): + """Configures the logging module. + + """ + formatter = logging.Formatter("%(asctime)s: %(message)s") + + stderr_handler = logging.StreamHandler() + stderr_handler.setFormatter(formatter) + if opts.debug: + stderr_handler.setLevel(logging.NOTSET) + elif opts.verbose: + stderr_handler.setLevel(logging.INFO) + else: + stderr_handler.setLevel(logging.WARNING) + + root_logger = logging.getLogger("") + root_logger.setLevel(logging.NOTSET) + root_logger.addHandler(stderr_handler) + + +def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE): + """Verifies a certificate against the local node daemon certificate. + + @type cert: string + @param cert: Certificate in PEM format (no key) + + """ + try: + OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert) + except OpenSSL.crypto.Error, err: + pass + else: + raise JoinError("No private key may be given") + + try: + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + except Exception, err: + raise errors.X509CertError("(stdin)", + "Unable to load certificate: %s" % err) + + try: + noded_pem = utils.ReadFile(_noded_cert_file) + except EnvironmentError, err: + if err.errno != errno.ENOENT: + raise + + logging.debug("Local node certificate was not found (file %s)", + _noded_cert_file) + return + + try: + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem) + except Exception, err: + raise errors.X509CertError(_noded_cert_file, + "Unable to load private key: %s" % err) + + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) + ctx.use_privatekey(key) + ctx.use_certificate(cert) + try: + ctx.check_privatekey() + except OpenSSL.SSL.Error: + raise JoinError("Given cluster certificate does not match local key") + + +def VerifyCertificate(data, _verify_fn=_VerifyCertificate): + """Verifies cluster certificate. + + @type data: dict + + """ + try: + cert = data[constants.SSHS_NODE_DAEMON_CERTIFICATE] + except KeyError: + pass + else: + _verify_fn(cert) + + +def _VerifyClusterName(name, _ss_cluster_name_file=None): + """Verifies cluster name against a local cluster name. + + @type name: string + @param name: Cluster name + + """ + if _ss_cluster_name_file is None: + _ss_cluster_name_file = \ + ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME) + + try: + local_name = utils.ReadOneLineFile(_ss_cluster_name_file) + except EnvironmentError, err: + if err.errno != errno.ENOENT: + raise + + logging.debug("Local cluster name was not found (file %s)", + _ss_cluster_name_file) + else: + if name != local_name: + raise JoinError("Current cluster name is '%s'" % local_name) + + +def VerifyClusterName(data, _verify_fn=_VerifyClusterName): + """Verifies cluster name. + + @type data: dict + + """ + try: + name = data[constants.SSHS_CLUSTER_NAME] + except KeyError: + raise JoinError("Cluster name must be specified") + else: + _verify_fn(name) + + +def _UpdateKeyFiles(keys, dry_run, keyfiles): + """Updates SSH key files. + + @type keys: sequence of tuple; (string, string, string) + @param keys: Keys to write, tuples consist of key type + (L{constants.SSHK_ALL}), public and private key + @type dry_run: boolean + @param dry_run: Whether to perform a dry run + @type keyfiles: dict; (string as key, tuple with (string, string) as values) + @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file + names; value tuples consist of public key filename and private key filename + + """ + assert set(keyfiles) == constants.SSHK_ALL + + for (kind, public_key, private_key) in keys: + (public_file, private_file) = keyfiles[kind] + + logging.debug("Writing %s ...", public_file) + utils.WriteFile(public_file, data=public_key, mode=0644, + backup=True, dry_run=dry_run) + + logging.debug("Writing %s ...", private_file) + utils.WriteFile(private_file, data=private_key, mode=0600, + backup=True, dry_run=dry_run) + + +def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd, + _keyfiles=None): + """Updates SSH daemon's keys. + + Unless C{dry_run} is set, the daemon is restarted at the end. + + @type data: dict + @param data: Input data + @type dry_run: boolean + @param dry_run: Whether to perform a dry run + + """ + keys = data.get(constants.SSHS_SSH_HOST_KEY) + if not keys: + return + + if _keyfiles is None: + _keyfiles = _SSH_DAEMON_KEYFILES + + logging.info("Updating SSH daemon key files") + _UpdateKeyFiles(keys, dry_run, _keyfiles) + + if dry_run: + logging.info("This is a dry run, not restarting SSH daemon") + else: + result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"], + interactive=True) + if result.failed: + raise JoinError("Could not reload SSH keys, command '%s'" + " had exitcode %s and error %s" % + (result.cmd, result.exit_code, result.output)) + + +def UpdateSshRoot(data, dry_run, _homedir_fn=None): + """Updates root's SSH keys. + + Root's C{authorized_keys} file is also updated with new public keys. + + @type data: dict + @param data: Input data + @type dry_run: boolean + @param dry_run: Whether to perform a dry run + + """ + keys = data.get(constants.SSHS_SSH_ROOT_KEY) + if not keys: + return + + (dsa_private_file, dsa_public_file, auth_keys_file) = \ + ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True, + kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn) + (rsa_private_file, rsa_public_file, _) = \ + ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True, + kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn) + + _UpdateKeyFiles(keys, dry_run, { + constants.SSHK_RSA: (rsa_public_file, rsa_private_file), + constants.SSHK_DSA: (dsa_public_file, dsa_private_file), + }) + + if dry_run: + logging.info("This is a dry run, not modifying %s", auth_keys_file) + else: + for (kind, public_key, _) in keys: + line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key) + utils.AddAuthorizedKey(auth_keys_file, line) + + +def LoadData(raw): + """Parses and verifies input data. + + @rtype: dict + + """ + try: + data = serializer.LoadJson(raw) + except Exception, err: + raise errors.ParseError("Can't parse input data: %s" % err) + + if not _DATA_CHECK(data): + raise errors.ParseError("Input data does not match expected format: %s" % + _DATA_CHECK) + + return data + + +def Main(): + """Main routine. + + """ + opts = ParseOptions() + + SetupLogging(opts) + + try: + data = LoadData(sys.stdin.read()) + + # Check if input data is correct + VerifyClusterName(data) + VerifyCertificate(data) + + # Update SSH files + UpdateSshDaemon(data, opts.dry_run) + UpdateSshRoot(data, opts.dry_run) + + logging.info("Setup finished successfully") + except Exception, err: # pylint: disable=W0703 + logging.debug("Caught unhandled exception", exc_info=True) + + (retcode, message) = cli.FormatError(err) + logging.error(message) + + return retcode + else: + return constants.EXIT_SUCCESS diff --git a/lib/utils/io.py b/lib/utils/io.py index a10ba84..bb34b11 100644 --- a/lib/utils/io.py +++ b/lib/utils/io.py @@ -828,9 +828,6 @@ def ReadLockedPidFile(path): return None -_SSH_KEYS_WITH_TWO_PARTS = frozenset(["ssh-dss", "ssh-rsa"]) - - def _SplitSshKey(key): """Splits a line for SSH's C{authorized_keys} file. @@ -845,7 +842,7 @@ def _SplitSshKey(key): """ parts = key.split() - if parts and parts[0] in _SSH_KEYS_WITH_TWO_PARTS: + if parts and parts[0] in constants.SSHAK_ALL: # If the key has no options in front of it, we only want the significant # fields return (False, parts[:2]) diff --git a/test/data/cert2.pem b/test/data/cert2.pem new file mode 100644 index 0000000..0d15d55 --- /dev/null +++ b/test/data/cert2.pem @@ -0,0 +1,22 @@ +-----BEGIN PRIVATE KEY----- +MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAt8OZYvvi8noVPlpR +/SrHcya9ne7RG5DjvMssksUqyGriUs/WGnpZlL4nz+BcLFGwNNntoxqR30Tjk47S +cmSBRQIDAQABAkAqTP5MCMuPIYcuWUAyVNygpzRS3JyKCepClUpnZreYdo4sUQE3 +/AM7xeb92R06iZ3f9/MPrbaMKTWRh3uCyfKBAiEA5TxdacnVxdS8+ZLyys4p/C1s +iajrarBb/j+NIAnsdnECIQDNOCDO7Jq/iN5qE4Vbi/3zmnP1Ca5aBo+KJ/hhSjRq +FQIgIBpWEqybbXsfg+waaGB67MAHxTeM0IImP/LydpwtK2ECIB3SrlHj6Ik1Jr1b +oOGw8nLYW0mc4o2KrolxTZM16XARAiBKW3aSjY5UrnoEqa8pAeiO8LJaRj73Epmr +zC89IuLZfg== +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIB0zCCAX2gAwIBAgIJAKrAqGX6UolVMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTIxMDE5MTQ1NjA4WhcNMTIxMDIwMTQ1NjA4WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALfD +mWL74vJ6FT5aUf0qx3MmvZ3u0RuQ47zLLJLFKshq4lLP1hp6WZS+J8/gXCxRsDTZ +7aMakd9E45OO0nJkgUUCAwEAAaNQME4wHQYDVR0OBBYEFA1Fc/GIVtd6nMocrSsA +e5bxmVhMMB8GA1UdIwQYMBaAFA1Fc/GIVtd6nMocrSsAe5bxmVhMMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADQQCTUwzDGU+IJTQ3PIJrA3fHMyKbBvc4Rkvi +ZNFsmgsidWhb+5APlPjtlS7rXlonNHBzDoGb4RNArtxhEx+rBcAE +-----END CERTIFICATE----- diff --git a/test/ganeti.tools.prepare_node_join_unittest.py b/test/ganeti.tools.prepare_node_join_unittest.py new file mode 100755 index 0000000..b629f77 --- /dev/null +++ b/test/ganeti.tools.prepare_node_join_unittest.py @@ -0,0 +1,309 @@ +#!/usr/bin/python +# + +# Copyright (C) 2012 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""Script for testing ganeti.tools.prepare_node_join""" + +import unittest +import shutil +import tempfile +import os.path +import OpenSSL + +from ganeti import errors +from ganeti import constants +from ganeti import serializer +from ganeti import pathutils +from ganeti import compat +from ganeti import utils +from ganeti.tools import prepare_node_join + +import testutils + + +_JoinError = prepare_node_join.JoinError + + +class TestLoadData(unittest.TestCase): + def testNoJson(self): + self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "") + self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}") + + def testInvalidDataStructure(self): + raw = serializer.DumpJson({ + "some other thing": False, + }) + self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw) + + raw = serializer.DumpJson([]) + self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw) + + def testValidData(self): + raw = serializer.DumpJson({}) + self.assertEqual(prepare_node_join.LoadData(raw), {}) + + +class TestVerifyCertificate(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def testNoCert(self): + prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented) + + def testMismatchingKey(self): + other_cert = self._TestDataFilename("cert1.pem") + node_cert = self._TestDataFilename("cert2.pem") + + self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate, + utils.ReadFile(other_cert), _noded_cert_file=node_cert) + + def testGivenPrivateKey(self): + cert_filename = self._TestDataFilename("cert2.pem") + cert_pem = utils.ReadFile(cert_filename) + + self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate, + cert_pem, _noded_cert_file=cert_filename) + + def testMatchingKey(self): + cert_filename = self._TestDataFilename("cert2.pem") + + # Extract certificate + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + utils.ReadFile(cert_filename)) + cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert) + + prepare_node_join._VerifyCertificate(cert_pem, + _noded_cert_file=cert_filename) + + def testMissingFile(self): + cert = self._TestDataFilename("cert1.pem") + nodecert = utils.PathJoin(self.tmpdir, "does-not-exist") + prepare_node_join._VerifyCertificate(utils.ReadFile(cert), + _noded_cert_file=nodecert) + + def testInvalidCertificate(self): + self.assertRaises(errors.X509CertError, + prepare_node_join._VerifyCertificate, + "Something that's not a certificate", + _noded_cert_file=NotImplemented) + + def testNoPrivateKey(self): + cert = self._TestDataFilename("cert1.pem") + self.assertRaises(errors.X509CertError, + prepare_node_join._VerifyCertificate, + utils.ReadFile(cert), _noded_cert_file=cert) + + +class TestVerifyClusterName(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + unittest.TestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def testNoName(self): + self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName, + {}, _verify_fn=NotImplemented) + + def testMissingFile(self): + tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist") + prepare_node_join._VerifyClusterName(NotImplemented, + _ss_cluster_name_file=tmpfile) + + def testMatchingName(self): + tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") + + for content in ["cluster.example.com", "cluster.example.com\n\n"]: + utils.WriteFile(tmpfile, data=content) + prepare_node_join._VerifyClusterName("cluster.example.com", + _ss_cluster_name_file=tmpfile) + + def testNameMismatch(self): + tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") + + for content in ["something.example.com", "foobar\n\ncluster.example.com"]: + utils.WriteFile(tmpfile, data=content) + self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName, + "cluster.example.com", _ss_cluster_name_file=tmpfile) + + +class TestUpdateSshDaemon(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + + self.keyfiles = { + constants.SSHK_RSA: + (utils.PathJoin(self.tmpdir, "rsa.public"), + utils.PathJoin(self.tmpdir, "rsa.private")), + constants.SSHK_DSA: + (utils.PathJoin(self.tmpdir, "dsa.public"), + utils.PathJoin(self.tmpdir, "dsa.private")), + } + + def tearDown(self): + unittest.TestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def testNoKeys(self): + data_empty_keys = { + constants.SSHS_SSH_HOST_KEY: [], + } + + for data in [{}, data_empty_keys]: + for dry_run in [False, True]: + prepare_node_join.UpdateSshDaemon(data, dry_run, + _runcmd_fn=NotImplemented, + _keyfiles=NotImplemented) + self.assertEqual(os.listdir(self.tmpdir), []) + + def _TestDryRun(self, data): + prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented, + _keyfiles=self.keyfiles) + self.assertEqual(os.listdir(self.tmpdir), []) + + def testDryRunRsa(self): + self._TestDryRun({ + constants.SSHS_SSH_HOST_KEY: [ + (constants.SSHK_RSA, "rsapub", "rsapriv"), + ], + }) + + def testDryRunDsa(self): + self._TestDryRun({ + constants.SSHS_SSH_HOST_KEY: [ + (constants.SSHK_DSA, "dsapub", "dsapriv"), + ], + }) + + def _RunCmd(self, fail, cmd, interactive=NotImplemented): + self.assertTrue(interactive) + self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"]) + if fail: + exit_code = constants.EXIT_FAILURE + else: + exit_code = constants.EXIT_SUCCESS + return utils.RunResult(exit_code, None, "stdout", "stderr", + utils.ShellQuoteArgs(cmd), + NotImplemented, NotImplemented) + + def _TestUpdate(self, failcmd): + data = { + constants.SSHS_SSH_HOST_KEY: [ + (constants.SSHK_DSA, "dsapub", "dsapriv"), + (constants.SSHK_RSA, "rsapub", "rsapriv"), + ], + } + runcmd_fn = compat.partial(self._RunCmd, failcmd) + if failcmd: + self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon, + data, False, _runcmd_fn=runcmd_fn, + _keyfiles=self.keyfiles) + else: + prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn, + _keyfiles=self.keyfiles) + self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([ + "rsa.private", "rsa.public", + "dsa.private", "dsa.public", + ])) + self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")), + "rsapub") + self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")), + "rsapriv") + self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")), + "dsapub") + self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")), + "dsapriv") + + def testSuccess(self): + self._TestUpdate(False) + + def testFailure(self): + self._TestUpdate(True) + + +class TestUpdateSshRoot(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + self.sshdir = utils.PathJoin(self.tmpdir, ".ssh") + + def tearDown(self): + unittest.TestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def _GetHomeDir(self, user): + self.assertEqual(user, constants.SSH_LOGIN_USER) + return self.tmpdir + + def testNoKeys(self): + data_empty_keys = { + constants.SSHS_SSH_ROOT_KEY: [], + } + + for data in [{}, data_empty_keys]: + for dry_run in [False, True]: + prepare_node_join.UpdateSshRoot(data, dry_run, + _homedir_fn=NotImplemented) + self.assertEqual(os.listdir(self.tmpdir), []) + + def testDryRun(self): + data = { + constants.SSHS_SSH_ROOT_KEY: [ + (constants.SSHK_RSA, "aaa", "bbb"), + ] + } + + prepare_node_join.UpdateSshRoot(data, True, + _homedir_fn=self._GetHomeDir) + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + self.assertEqual(os.listdir(self.sshdir), []) + + def testUpdate(self): + data = { + constants.SSHS_SSH_ROOT_KEY: [ + (constants.SSHK_DSA, "pubdsa", "privatedsa"), + ] + } + + prepare_node_join.UpdateSshRoot(data, False, + _homedir_fn=self._GetHomeDir) + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + self.assertEqual(sorted(os.listdir(self.sshdir)), + sorted(["authorized_keys", "id_dsa", "id_dsa.pub"])) + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")), + "privatedsa") + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")), + "pubdsa") + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, + "authorized_keys")), + "ssh-dss pubdsa\n") + + +if __name__ == "__main__": + testutils.GanetiTestProgram() -- 1.7.7.3
