This patch makes the node_daemon_setup tool use some of the recently introduced functions in the tools/common.py. By doing that, this also cleans up the correct usage of cluster name constants.
Signed-off-by: Helga Velroyen <[email protected]> --- lib/tools/common.py | 6 +- lib/tools/node_daemon_setup.py | 100 +-------------- lib/tools/prepare_node_join.py | 2 +- lib/tools/ssh_update.py | 2 +- lib/tools/ssl_update.py | 2 +- test/py/ganeti.tools.common_unittest.py | 85 ++++++++++++- test/py/ganeti.tools.node_daemon_setup_unittest.py | 138 --------------------- 7 files changed, 94 insertions(+), 241 deletions(-) diff --git a/lib/tools/common.py b/lib/tools/common.py index 9478655..a9149f6 100644 --- a/lib/tools/common.py +++ b/lib/tools/common.py @@ -166,19 +166,21 @@ def VerifyCertificateStrong(data, error_fn, return _verify_fn(cert, error_fn) -def VerifyClusterName(data, error_fn, +def VerifyClusterName(data, error_fn, cluster_name_constant, _verify_fn=ssconf.VerifyClusterName): """Verifies cluster name. @type data: dict """ - name = data.get(constants.SSHS_CLUSTER_NAME) + name = data.get(cluster_name_constant) if name: _verify_fn(name) else: raise error_fn("Cluster name must be specified") + return name + def LoadData(raw, data_check): """Parses and verifies input data. diff --git a/lib/tools/node_daemon_setup.py b/lib/tools/node_daemon_setup.py index 89e8a18..e45e2e0 100644 --- a/lib/tools/node_daemon_setup.py +++ b/lib/tools/node_daemon_setup.py @@ -36,15 +36,12 @@ import os.path import optparse import sys import logging -import OpenSSL -from cStringIO import StringIO from ganeti import cli from ganeti import constants from ganeti import errors from ganeti import pathutils from ganeti import utils -from ganeti import serializer from ganeti import runtime from ganeti import ht from ganeti import ssconf @@ -93,87 +90,6 @@ def VerifyOptions(parser, opts, args): return opts -def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate): - """Verifies a certificate against the local node daemon certificate. - - @type cert_pem: string - @param cert_pem: Certificate and key in PEM format - @rtype: string - @return: Formatted key and certificate - - """ - try: - cert = \ - OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem) - except Exception, err: - raise errors.X509CertError("(stdin)", - "Unable to load certificate: %s" % err) - - try: - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem) - except OpenSSL.crypto.Error, err: - raise errors.X509CertError("(stdin)", - "Unable to load private key: %s" % err) - - # Check certificate with given key; this detects cases where the key given on - # stdin doesn't match the certificate also given on stdin - try: - utils.X509CertKeyCheck(cert, key) - except OpenSSL.SSL.Error: - raise errors.X509CertError("(stdin)", - "Certificate is not signed with given key") - - # Standard checks, including check against an existing local certificate - # (no-op if that doesn't exist) - _check_fn(cert) - - key_encoded = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key) - cert_encoded = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, - cert) - complete_cert_encoded = key_encoded + cert_encoded - if not cert_pem == complete_cert_encoded: - logging.error("The certificate differs after being reencoded. Please" - " renew the certificates cluster-wide to prevent future" - " inconsistencies.") - - # Format for storing on disk - buf = StringIO() - buf.write(cert_pem) - return buf.getvalue() - - -def VerifyCertificate(data, _verify_fn=_VerifyCertificate): - """Verifies cluster certificate. - - @type data: dict - @rtype: string - @return: Formatted key and certificate - - """ - cert = data.get(constants.NDS_NODE_DAEMON_CERTIFICATE) - if not cert: - raise SetupError("Node daemon certificate must be specified") - - return _verify_fn(cert) - - -def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName): - """Verifies cluster name. - - @type data: dict - @rtype: string - @return: Cluster name - - """ - name = data.get(constants.NDS_CLUSTER_NAME) - if not name: - raise SetupError("Cluster name must be specified") - - _verify_fn(name) - - return name - - def VerifySsconf(data, cluster_name, _verify_fn=ssconf.VerifyKeys): """Verifies ssconf names. @@ -195,15 +111,6 @@ def VerifySsconf(data, cluster_name, _verify_fn=ssconf.VerifyKeys): return items -def LoadData(raw): - """Parses and verifies input data. - - @rtype: dict - - """ - return serializer.LoadAndVerifyJson(raw, _DATA_CHECK) - - def Main(): """Main routine. @@ -215,10 +122,11 @@ def Main(): try: getent = runtime.GetEnts() - data = LoadData(sys.stdin.read()) + data = common.LoadData(sys.stdin.read(), SetupError) - cluster_name = VerifyClusterName(data) - cert_pem = VerifyCertificate(data) + cluster_name = common.VerifyClusterName(data, SetupError, + constants.NDS_CLUSTER_NAME) + cert_pem = common.VerifyCertificateStrong(data, SetupError) ssdata = VerifySsconf(data, cluster_name) logging.info("Writing ssconf files ...") diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py index 0902cf4..82a35dc 100644 --- a/lib/tools/prepare_node_join.py +++ b/lib/tools/prepare_node_join.py @@ -196,7 +196,7 @@ def Main(): data = common.LoadData(sys.stdin.read(), _DATA_CHECK) # Check if input data is correct - common.VerifyClusterName(data, JoinError) + common.VerifyClusterName(data, JoinError, constants.SSHS_CLUSTER_NAME) common.VerifyCertificateSoft(data, JoinError) # Update SSH files diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py index 904cbd3..f9d1b6d 100644 --- a/lib/tools/ssh_update.py +++ b/lib/tools/ssh_update.py @@ -209,7 +209,7 @@ def Main(): data = common.LoadData(sys.stdin.read(), _DATA_CHECK) # Check if input data is correct - common.VerifyClusterName(data, SshUpdateError) + common.VerifyClusterName(data, SshUpdateError, constants.SSHS_CLUSTER_NAME) common.VerifyCertificateSoft(data, SshUpdateError) # Update / Generate SSH files diff --git a/lib/tools/ssl_update.py b/lib/tools/ssl_update.py index f9c5c19..56e8d6a 100644 --- a/lib/tools/ssl_update.py +++ b/lib/tools/ssl_update.py @@ -119,7 +119,7 @@ def Main(): try: data = common.LoadData(sys.stdin.read(), _DATA_CHECK) - common.VerifyClusterName(data, SslSetupError) + common.VerifyClusterName(data, SslSetupError, constants.NDS_CLUSTER_NAME) # Verifies whether the server certificate of the caller # is the same as on this node. diff --git a/test/py/ganeti.tools.common_unittest.py b/test/py/ganeti.tools.common_unittest.py index 427b851..0eb7e45 100755 --- a/test/py/ganeti.tools.common_unittest.py +++ b/test/py/ganeti.tools.common_unittest.py @@ -115,7 +115,8 @@ class TestVerifyClusterName(unittest.TestCase): def testNoName(self): self.assertRaises(self.MyException, common.VerifyClusterName, - {}, self.MyException, _verify_fn=NotImplemented) + {}, self.MyException, "cluster_name", + _verify_fn=NotImplemented) @staticmethod def _FailingVerify(name): @@ -128,7 +129,87 @@ class TestVerifyClusterName(unittest.TestCase): } self.assertRaises(errors.GenericError, common.VerifyClusterName, - data, Exception, _verify_fn=self._FailingVerify) + data, self.MyException, "cluster_name", + _verify_fn=self._FailingVerify) + + +class TestVerifyCertificateStrong(testutils.GanetiTestCase): + + class MyException(Exception): + pass + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def testNoCert(self): + self.assertRaises(self.MyException, common.VerifyCertificateStrong, + {}, self.MyException, _verify_fn=NotImplemented) + + def testVerificationSuccessWithCert(self): + common.VerifyCertificateStrong({ + constants.NDS_NODE_DAEMON_CERTIFICATE: "something", + }, self.MyException, _verify_fn=lambda x,y: None) + + def testNoPrivateKey(self): + cert_filename = testutils.TestDataFilename("cert1.pem") + cert_pem = utils.ReadFile(cert_filename) + + self.assertRaises(self.MyException, + common._VerifyCertificateStrong, + cert_pem, self.MyException, _check_fn=NotImplemented) + + def testInvalidCertificate(self): + self.assertRaises(self.MyException, + common._VerifyCertificateStrong, + "Something that's not a certificate", + self.MyException, + _check_fn=NotImplemented) + + @staticmethod + def _Check(cert): + assert cert.get_subject() + + def testSuccessfulCheck(self): + cert_filename = testutils.TestDataFilename("cert2.pem") + cert_pem = utils.ReadFile(cert_filename) + result = \ + common._VerifyCertificateStrong(cert_pem, self.MyException, + _check_fn=self._Check) + + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result) + self.assertTrue(cert) + + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result) + self.assertTrue(key) + + def testMismatchingKey(self): + cert1_path = testutils.TestDataFilename("cert1.pem") + cert2_path = testutils.TestDataFilename("cert2.pem") + + # Extract certificate + cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + utils.ReadFile(cert1_path)) + cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert1) + + # Extract mismatching key + key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, + utils.ReadFile(cert2_path)) + key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, + key2) + + try: + common._VerifyCertificateStrong(cert1_pem + key2_pem, self.MyException, + _check_fn=NotImplemented) + except self.MyException, err: + self.assertTrue("not signed with given key" in str(err)) + else: + self.fail("Exception was not raised") if __name__ == "__main__": diff --git a/test/py/ganeti.tools.node_daemon_setup_unittest.py b/test/py/ganeti.tools.node_daemon_setup_unittest.py index a9fd1a9..9a3abdf 100755 --- a/test/py/ganeti.tools.node_daemon_setup_unittest.py +++ b/test/py/ganeti.tools.node_daemon_setup_unittest.py @@ -31,17 +31,9 @@ """Script for testing ganeti.tools.node_daemon_setup""" import unittest -import shutil -import tempfile -import os.path -import OpenSSL from ganeti import errors from ganeti import constants -from ganeti import serializer -from ganeti import pathutils -from ganeti import compat -from ganeti import utils from ganeti.tools import node_daemon_setup import testutils @@ -50,136 +42,6 @@ import testutils _SetupError = node_daemon_setup.SetupError -class TestLoadData(unittest.TestCase): - def testNoJson(self): - for data in ["", "{", "}"]: - self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, data) - - def testInvalidDataStructure(self): - raw = serializer.DumpJson({ - "some other thing": False, - }) - self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw) - - raw = serializer.DumpJson([]) - self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw) - - def testValidData(self): - raw = serializer.DumpJson({}) - self.assertEqual(node_daemon_setup.LoadData(raw), {}) - - -class TestVerifyCertificate(testutils.GanetiTestCase): - def setUp(self): - testutils.GanetiTestCase.setUp(self) - self.tmpdir = tempfile.mkdtemp() - - def tearDown(self): - testutils.GanetiTestCase.tearDown(self) - shutil.rmtree(self.tmpdir) - - def testNoCert(self): - self.assertRaises(_SetupError, node_daemon_setup.VerifyCertificate, - {}, _verify_fn=NotImplemented) - - def testVerificationSuccessWithCert(self): - node_daemon_setup.VerifyCertificate({ - constants.NDS_NODE_DAEMON_CERTIFICATE: "something", - }, _verify_fn=lambda _: None) - - def testNoPrivateKey(self): - cert_filename = testutils.TestDataFilename("cert1.pem") - cert_pem = utils.ReadFile(cert_filename) - - self.assertRaises(errors.X509CertError, - node_daemon_setup._VerifyCertificate, - cert_pem, _check_fn=NotImplemented) - - def testInvalidCertificate(self): - self.assertRaises(errors.X509CertError, - node_daemon_setup._VerifyCertificate, - "Something that's not a certificate", - _check_fn=NotImplemented) - - @staticmethod - def _Check(cert): - assert cert.get_subject() - - def testSuccessfulCheck(self): - cert_filename = testutils.TestDataFilename("cert2.pem") - cert_pem = utils.ReadFile(cert_filename) - result = \ - node_daemon_setup._VerifyCertificate(cert_pem, _check_fn=self._Check) - - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result) - self.assertTrue(cert) - - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result) - self.assertTrue(key) - - def testMismatchingKey(self): - cert1_path = testutils.TestDataFilename("cert1.pem") - cert2_path = testutils.TestDataFilename("cert2.pem") - - # Extract certificate - cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, - utils.ReadFile(cert1_path)) - cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, - cert1) - - # Extract mismatching key - key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, - utils.ReadFile(cert2_path)) - key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, - key2) - - try: - node_daemon_setup._VerifyCertificate(cert1_pem + key2_pem, - _check_fn=NotImplemented) - except errors.X509CertError, err: - self.assertEqual(err.args, - ("(stdin)", "Certificate is not signed with given key")) - else: - self.fail("Exception was not raised") - - -class TestVerifyClusterName(unittest.TestCase): - def setUp(self): - unittest.TestCase.setUp(self) - self.tmpdir = tempfile.mkdtemp() - - def tearDown(self): - unittest.TestCase.tearDown(self) - shutil.rmtree(self.tmpdir) - - def testNoName(self): - self.assertRaises(_SetupError, node_daemon_setup.VerifyClusterName, - {}, _verify_fn=NotImplemented) - - @staticmethod - def _FailingVerify(name): - assert name == "somecluster.example.com" - raise errors.GenericError() - - def testFailingVerification(self): - data = { - constants.NDS_CLUSTER_NAME: "somecluster.example.com", - } - - self.assertRaises(errors.GenericError, node_daemon_setup.VerifyClusterName, - data, _verify_fn=self._FailingVerify) - - def testSuccess(self): - data = { - constants.NDS_CLUSTER_NAME: "cluster.example.com", - } - - result = \ - node_daemon_setup.VerifyClusterName(data, _verify_fn=lambda _: None) - - self.assertEqual(result, "cluster.example.com") - - class TestVerifySsconf(unittest.TestCase): def testNoSsconf(self): self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, -- 2.4.3.573.g4eafbef
