This patch makes prepare_node_join use some of the functions that were moved to tools/common.py. The respective unittests are removed, because they are already tested in common_unittest.py.
Signed-off-by: Helga Velroyen <[email protected]> --- lib/tools/prepare_node_join.py | 41 ++--------------- test/py/ganeti.tools.common_unittest.py | 53 ++++++++++++++++++++++ test/py/ganeti.tools.prepare_node_join_unittest.py | 48 -------------------- 3 files changed, 57 insertions(+), 85 deletions(-) diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py index 7eb1e5a..4db335f 100644 --- a/lib/tools/prepare_node_join.py +++ b/lib/tools/prepare_node_join.py @@ -43,10 +43,9 @@ from ganeti import constants from ganeti import errors from ganeti import pathutils from ganeti import utils -from ganeti import serializer from ganeti import ht from ganeti import ssh -from ganeti import ssconf +from ganeti.tools import common _SSH_KEY_LIST_ITEM = \ @@ -89,17 +88,7 @@ def ParseOptions(): (opts, args) = parser.parse_args() - return VerifyOptions(parser, opts, args) - - -def VerifyOptions(parser, opts, args): - """Verifies options and arguments for correctness. - - """ - if args: - parser.error("No arguments are expected") - - return opts + return common.VerifyOptions(parser, opts, args) def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate): @@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate): _verify_fn(cert) -def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName): - """Verifies cluster name. - - @type data: dict - - """ - name = data.get(constants.SSHS_CLUSTER_NAME) - if name: - _verify_fn(name) - else: - raise JoinError("Cluster name must be specified") - - def _UpdateKeyFiles(keys, dry_run, keyfiles): """Updates SSH key files. @@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None): utils.AddAuthorizedKey(auth_keys_file, public_key) -def LoadData(raw): - """Parses and verifies input data. - - @rtype: dict - - """ - return serializer.LoadAndVerifyJson(raw, _DATA_CHECK) - - def Main(): """Main routine. @@ -256,10 +223,10 @@ def Main(): utils.SetupToolLogging(opts.debug, opts.verbose) try: - data = LoadData(sys.stdin.read()) + data = common.LoadData(sys.stdin.read(), _DATA_CHECK) # Check if input data is correct - VerifyClusterName(data) + common.VerifyClusterName(data, JoinError) VerifyCertificate(data) # Update SSH files diff --git a/test/py/ganeti.tools.common_unittest.py b/test/py/ganeti.tools.common_unittest.py index 1652088..427b851 100755 --- a/test/py/ganeti.tools.common_unittest.py +++ b/test/py/ganeti.tools.common_unittest.py @@ -38,6 +38,8 @@ import OpenSSL import time from ganeti import constants +from ganeti import errors +from ganeti import serializer from ganeti import utils from ganeti.tools import common @@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase): self.assertEqual(client_cert.get_subject().CN, my_node_name) +class TestLoadData(unittest.TestCase): + + def testNoJson(self): + self.assertRaises(errors.ParseError, common.LoadData, Exception, "") + self.assertRaises(errors.ParseError, common.LoadData, Exception, "}") + + def testInvalidDataStructure(self): + raw = serializer.DumpJson({ + "some other thing": False, + }) + self.assertRaises(errors.ParseError, common.LoadData, Exception, raw) + + raw = serializer.DumpJson([]) + self.assertRaises(errors.ParseError, common.LoadData, Exception, raw) + + def testValidData(self): + raw = serializer.DumpJson({}) + self.assertEqual(common.LoadData(raw, Exception), {}) + + +class TestVerifyClusterName(unittest.TestCase): + + class MyException(Exception): + pass + + def setUp(self): + unittest.TestCase.setUp(self) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + unittest.TestCase.tearDown(self) + shutil.rmtree(self.tmpdir) + + def testNoName(self): + self.assertRaises(self.MyException, common.VerifyClusterName, + {}, self.MyException, _verify_fn=NotImplemented) + + @staticmethod + def _FailingVerify(name): + assert name == "cluster.example.com" + raise errors.GenericError() + + def testFailingVerification(self): + data = { + constants.SSHS_CLUSTER_NAME: "cluster.example.com", + } + + self.assertRaises(errors.GenericError, common.VerifyClusterName, + data, Exception, _verify_fn=self._FailingVerify) + + if __name__ == "__main__": testutils.GanetiTestProgram() diff --git a/test/py/ganeti.tools.prepare_node_join_unittest.py b/test/py/ganeti.tools.prepare_node_join_unittest.py index e0c60a4..92cb1de 100755 --- a/test/py/ganeti.tools.prepare_node_join_unittest.py +++ b/test/py/ganeti.tools.prepare_node_join_unittest.py @@ -34,11 +34,9 @@ import unittest import shutil import tempfile import os.path -import OpenSSL from ganeti import errors from ganeti import constants -from ganeti import serializer from ganeti import pathutils from ganeti import compat from ganeti import utils @@ -50,25 +48,6 @@ import testutils _JoinError = prepare_node_join.JoinError -class TestLoadData(unittest.TestCase): - def testNoJson(self): - self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "") - self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}") - - def testInvalidDataStructure(self): - raw = serializer.DumpJson({ - "some other thing": False, - }) - self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw) - - raw = serializer.DumpJson([]) - self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw) - - def testValidData(self): - raw = serializer.DumpJson({}) - self.assertEqual(prepare_node_join.LoadData(raw), {}) - - class TestVerifyCertificate(testutils.GanetiTestCase): def setUp(self): testutils.GanetiTestCase.setUp(self) @@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase): prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check) -class TestVerifyClusterName(unittest.TestCase): - def setUp(self): - unittest.TestCase.setUp(self) - self.tmpdir = tempfile.mkdtemp() - - def tearDown(self): - unittest.TestCase.tearDown(self) - shutil.rmtree(self.tmpdir) - - def testNoName(self): - self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName, - {}, _verify_fn=NotImplemented) - - @staticmethod - def _FailingVerify(name): - assert name == "cluster.example.com" - raise errors.GenericError() - - def testFailingVerification(self): - data = { - constants.SSHS_CLUSTER_NAME: "cluster.example.com", - } - - self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName, - data, _verify_fn=self._FailingVerify) - - class TestUpdateSshDaemon(unittest.TestCase): def setUp(self): unittest.TestCase.setUp(self) -- 2.4.3.573.g4eafbef
