This patch adapts the 'prepare_node_join' tool so
that instead of copying the cluster SSH key to the new
node, an individual SSH key pair is generated for that
node.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/ssh.py                                         |  5 ++-
 lib/tools/prepare_node_join.py                     | 47 +++++++++++-----------
 test/py/ganeti.tools.prepare_node_join_unittest.py | 25 ++++--------
 3 files changed, 34 insertions(+), 43 deletions(-)

diff --git a/lib/ssh.py b/lib/ssh.py
index e646a0b..be732a8 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -605,14 +605,15 @@ def QueryPubKeyFile(target_uuids, 
key_file=pathutils.SSH_PUB_KEYS,
   return result
 
 
-def InitSSHSetup(error_fn=errors.OpPrereqError):
+def InitSSHSetup(error_fn=errors.OpPrereqError, _homedir_fn=None):
   """Setup the SSH configuration for the node.
 
   This generates a dsa keypair for root, adds the pub key to the
   permitted hosts and adds the hostkey to its own known hosts.
 
   """
-  priv_key, pub_key, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER)
+  priv_key, pub_key, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER,
+                                              _homedir_fn=_homedir_fn)
 
   for name in priv_key, pub_key:
     if os.path.exists(name):
diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
index ed5a227..28d74f6 100644
--- a/lib/tools/prepare_node_join.py
+++ b/lib/tools/prepare_node_join.py
@@ -110,6 +110,13 @@ def _UpdateKeyFiles(keys, dry_run, keyfiles):
                     backup=True, dry_run=dry_run)
 
 
+def _GenerateRootSshKeys(_homedir_fn=None):
+  """Generates root's SSH keys for this node.
+
+  """
+  ssh.InitSSHSetup(error_fn=JoinError, _homedir_fn=_homedir_fn)
+
+
 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
                     _keyfiles=None):
   """Updates SSH daemon's keys.
@@ -154,31 +161,25 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
   @param dry_run: Whether to perform a dry run
 
   """
-  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
   authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
 
-  if keys or authorized_keys:
-    (auth_keys_file, keyfiles) = \
-      ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
-                          _homedir_fn=_homedir_fn)
-
-    if keys:
-      _UpdateKeyFiles(keys, dry_run, keyfiles)
-
-      if dry_run:
-        logging.info("This is a dry run, not modifying %s", auth_keys_file)
-      else:
-        for (_, _, public_key) in keys:
-          ssh.AddAuthorizedKey(auth_keys_file, public_key)
-
-    if authorized_keys:
-      if dry_run:
-        logging.info("This is a dry run, not modifying %s", auth_keys_file)
-      else:
-        all_authorized_keys = []
-        for keys in authorized_keys.values():
-          all_authorized_keys += keys
-        ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
+  (auth_keys_file, _) = \
+    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                        _homedir_fn=_homedir_fn)
+
+  if dry_run:
+    logging.info("This is a dry run, not replacing the SSH keys.")
+  else:
+    _GenerateRootSshKeys(_homedir_fn=_homedir_fn)
+
+  if authorized_keys:
+    if dry_run:
+      logging.info("This is a dry run, not modifying %s", auth_keys_file)
+    else:
+      all_authorized_keys = []
+      for keys in authorized_keys.values():
+        all_authorized_keys += keys
+      ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
 
 
 def Main():
diff --git a/test/py/ganeti.tools.prepare_node_join_unittest.py 
b/test/py/ganeti.tools.prepare_node_join_unittest.py
index ac30f90..20ef1f1 100755
--- a/test/py/ganeti.tools.prepare_node_join_unittest.py
+++ b/test/py/ganeti.tools.prepare_node_join_unittest.py
@@ -245,17 +245,6 @@ class TestUpdateSshRoot(unittest.TestCase):
     self.assertEqual(user, constants.SSH_LOGIN_USER)
     return self.tmpdir
 
-  def testNoKeys(self):
-    data_empty_keys = {
-      constants.SSHS_SSH_ROOT_KEY: [],
-      }
-
-    for data in [{}, data_empty_keys]:
-      for dry_run in [False, True]:
-        prepare_node_join.UpdateSshRoot(data, dry_run,
-                                        _homedir_fn=NotImplemented)
-    self.assertEqual(os.listdir(self.tmpdir), [])
-
   def testDryRun(self):
     data = {
       constants.SSHS_SSH_ROOT_KEY: [
@@ -280,13 +269,13 @@ class TestUpdateSshRoot(unittest.TestCase):
     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
     self.assertEqual(sorted(os.listdir(self.sshdir)),
                      sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
-    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
-                     "privatedsa")
-    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
-                     "ssh-dss pubdsa")
-    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
-                                                   "authorized_keys")),
-                     "ssh-dss pubdsa\n")
+    self.assertTrue(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa"))
+                    is not None)
+    pub_key = utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub"))
+    self.assertTrue(pub_key is not None)
+    self.assertEquals(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                    "authorized_keys")),
+                      pub_key)
 
 
 if __name__ == "__main__":
-- 
2.0.0.526.g5318336

Reply via email to