This patch add a couple of new SSH utility functions to
the ssh module:
- clearing the whole 'ganeti_pub_keys' file
- overriding the whole 'ganeti_pub_keys' file
- retrieving all keys from the file at once

Those functions will be used in later patches. Unit tests
are provided.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/ssh.py                     | 44 ++++++++++++++++++++++++++++++++++++++++--
 src/Ganeti/Constants.hs        |  3 +++
 test/py/ganeti.ssh_unittest.py | 22 +++++++++++++++++++++
 3 files changed, 67 insertions(+), 2 deletions(-)

diff --git a/lib/ssh.py b/lib/ssh.py
index 34d709d..e6f8d24 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -27,6 +27,7 @@
 import logging
 import os
 import tempfile
+import stat
 
 from functools import partial
 
@@ -37,6 +38,7 @@ from ganeti import netutils
 from ganeti import pathutils
 from ganeti import vcluster
 from ganeti import compat
+from ganeti import serializer
 
 
 def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
@@ -513,6 +515,36 @@ def ReplaceNameByUuid(node_uuid, node_name, 
key_file=pathutils.SSH_PUB_KEYS,
                         error_fn=error_fn)
 
 
+def ClearPubKeyFile(key_file=pathutils.SSH_PUB_KEYS, mode=0600):
+  """Resets the content of the public key file.
+
+  """
+  utils.WriteFile(key_file, data="", mode=mode)
+
+
+def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS,
+                       error_fn=errors.ProgrammerError):
+  """Overrides the public key file with a list of given keys.
+
+  @type key_map: dict from str to list of str
+  @param key_map: dictionary mapping uuids to lists of SSH keys
+
+  """
+  try:
+    fd_tmp, tmpname = tempfile.mkstemp(dir=os.path.dirname(key_file))
+    f_tmp = os.fdopen(fd_tmp, "w")
+    for (uuid, keys) in key_map.items():
+      for key in keys:
+        f_tmp.write("%s %s\n" % (uuid, key))
+    f_tmp.flush()
+    os.rename(tmpname, key_file)
+    os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR)
+  except IOError, e:
+    raise error_fn("Cannot override key file due to error '%s'" % e)
+  finally:
+    f_tmp.close()
+
+
 def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS,
                     error_fn=errors.ProgrammerError):
   """Retrieves a map of keys for the requested node UUIDs.
@@ -530,6 +562,9 @@ def QueryPubKeyFile(target_uuids, 
key_file=pathutils.SSH_PUB_KEYS,
   @return: dictionary mapping node uuids to their ssh keys
 
   """
+  all_keys = False
+  if target_uuids is None:
+    all_keys = True
   if isinstance(target_uuids, str):
     target_uuids = [target_uuids]
   result = {}
@@ -544,10 +579,15 @@ def QueryPubKeyFile(target_uuids, 
key_file=pathutils.SSH_PUB_KEYS,
                        % line)
       uuid = chunks[0]
       key = " ".join(chunks[1:]).rstrip()
-      if uuid in target_uuids:
+      if all_keys:
         if uuid not in result:
           result[uuid] = []
         result[uuid].append(key)
+      else:
+        if uuid in target_uuids:
+          if uuid not in result:
+            result[uuid] = []
+          result[uuid].append(key)
   finally:
     f.close()
   return result
@@ -587,7 +627,7 @@ def InitPubKeyFile(master_uuid, 
key_file=pathutils.SSH_PUB_KEYS):
 
   """
   _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER)
-  utils.WriteFile(key_file, data="", mode=0600)
+  ClearPubKeyFile(key_file=key_file)
   key = utils.ReadFile(pub_key)
   AddPublicKey(master_uuid, key, key_file=key_file)
 
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index e29526b..df7fb9d 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -4453,6 +4453,9 @@ sshsSshRootKey = "ssh_root_key"
 sshsSshAuthorizedKeys :: String
 sshsSshAuthorizedKeys = "authorized_keys"
 
+sshsSshPublicKeys :: String
+sshsSshPublicKeys = "public_keys"
+
 sshsNodeDaemonCertificate :: String
 sshsNodeDaemonCertificate = "node_daemon_certificate"
 
diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py
index 09d3169..cd6ca5c 100755
--- a/test/py/ganeti.ssh_unittest.py
+++ b/test/py/ganeti.ssh_unittest.py
@@ -298,6 +298,12 @@ class TestPublicSshKeys(testutils.GanetiTestCase):
     self.assertEquals([self.KEY_B], result[self.UUID_2])
     self.assertEquals(2, len(result))
 
+    # Query all keys
+    target_uuids = None
+    result = ssh.QueryPubKeyFile(target_uuids, key_file=pub_key_file)
+    self.assertEquals([self.KEY_A], result[self.UUID_1])
+    self.assertEquals([self.KEY_B], result[self.UUID_2])
+
   def testReplaceNameByUuid(self):
     pub_key_file = self._CreateTempFile()
     name = "my.precious.node"
@@ -315,6 +321,22 @@ class TestPublicSshKeys(testutils.GanetiTestCase):
       "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
       "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
 
+  def testClearPubKeyFile(self):
+    pub_key_file = self._CreateTempFile()
+    ssh.AddPublicKey(self.UUID_2, self.KEY_A, key_file=pub_key_file)
+    ssh.ClearPubKeyFile(key_file=pub_key_file)
+    self.assertFileContent(pub_key_file, "")
+
+  def testOverridePubKeyFile(self):
+    pub_key_file = self._CreateTempFile()
+    key_map = {self.UUID_1: [self.KEY_A, self.KEY_B],
+               self.UUID_2: [self.KEY_A]}
+    ssh.OverridePubKeyFile(key_map, key_file=pub_key_file, mode=0666)
+    self.assertFileContent(pub_key_file,
+      "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
+      "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n"
+      "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n")
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
2.0.0.526.g5318336

Reply via email to