This patch prepares the ssh utility library ssh.py and
the ssh update tool with the ability to remove SSH keys
from the 'authorized_keys' and the 'ganeti_pub_keys'
files.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/ssh.py                                  | 24 +++++++---
 lib/tools/ssh_update.py                     | 42 ++++++++++++-----
 src/Ganeti/Constants.hs                     | 15 +++++++
 test/py/ganeti.backend_unittest.py          |  7 ---
 test/py/ganeti.ssh_unittest.py              | 13 ++++++
 test/py/ganeti.tools.ssh_update_unittest.py | 70 ++++++++++++++++++++++-------
 6 files changed, 131 insertions(+), 40 deletions(-)

diff --git a/lib/ssh.py b/lib/ssh.py
index daeece3..e646a0b 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -187,16 +187,16 @@ def AddAuthorizedKey(file_obj, key):
   AddAuthorizedKeys(file_obj, [key])
 
 
-def RemoveAuthorizedKey(file_name, key):
-  """Removes an SSH public key from an authorized_keys file.
+def RemoveAuthorizedKeys(file_name, keys):
+  """Removes public SSH keys from an authorized_keys file.
 
   @type file_name: str
   @param file_name: path to authorized_keys file
-  @type key: str
-  @param key: string containing key
+  @type keys: list of str
+  @param keys: list of strings containing keys
 
   """
-  key_fields = _SplitSshKey(key)
+  key_field_list = [_SplitSshKey(key) for key in keys]
 
   fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
   try:
@@ -206,7 +206,7 @@ def RemoveAuthorizedKey(file_name, key):
       try:
         for line in f:
           # Ignore whitespace changes while comparing lines
-          if _SplitSshKey(line) != key_fields:
+          if _SplitSshKey(line) not in key_field_list:
             out.write(line)
 
         out.flush()
@@ -220,6 +220,18 @@ def RemoveAuthorizedKey(file_name, key):
     raise
 
 
+def RemoveAuthorizedKey(file_name, key):
+  """Removes an SSH public key from an authorized_keys file.
+
+  @type file_name: str
+  @param file_name: path to authorized_keys file
+  @type key: str
+  @param key: string containing key
+
+  """
+  RemoveAuthorizedKeys(file_name, [key])
+
+
 def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, tmp_file,
                              found):
   """Processes one line of the public key file when adding a key.
diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py
index cace9a2..41a9de8 100644
--- a/lib/tools/ssh_update.py
+++ b/lib/tools/ssh_update.py
@@ -46,9 +46,13 @@ _DATA_CHECK = ht.TStrictDict(False, True, {
   constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
   constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
   constants.SSHS_SSH_PUBLIC_KEYS:
-    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+    ht.TItems(
+      [ht.TElemOf(constants.SSHS_ACTIONS),
+       ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]),
   constants.SSHS_SSH_AUTHORIZED_KEYS:
-    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+    ht.TItems(
+      [ht.TElemOf(constants.SSHS_ACTIONS),
+       ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]),
   })
 
 
@@ -86,22 +90,36 @@ def UpdateAuthorizedKeys(data, dry_run, _homedir_fn=None):
   @param dry_run: Whether to perform a dry run
 
   """
-  authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+  instructions = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+  if not instructions:
+    logging.info("No change to the authorized_keys file requested.")
+    return
+  (action, authorized_keys) = instructions
 
-  if authorized_keys:
-    (auth_keys_file, _) = \
-      ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
-                          _homedir_fn=_homedir_fn)
+  (auth_keys_file, _) = \
+    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                        _homedir_fn=_homedir_fn)
 
+  key_values = []
+  for key_value in authorized_keys.values():
+    key_values += key_value
+  if action == constants.SSHS_ADD:
     if dry_run:
-      logging.info("This is a dry run, not modifying %s", auth_keys_file)
+      logging.info("This is a dry run, not adding keys to %s",
+                   auth_keys_file)
     else:
-      all_authorized_keys = []
-      for keys in authorized_keys.values():
-        all_authorized_keys += keys
       if not os.path.exists(auth_keys_file):
         utils.WriteFile(auth_keys_file, mode=0600, data="")
-      ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
+      ssh.AddAuthorizedKeys(auth_keys_file, key_values)
+  elif action == constants.SSHS_REMOVE:
+    if dry_run:
+      logging.info("This is a dry run, not removing keys from %s",
+                   auth_keys_file)
+    else:
+      ssh.RemoveAuthorizedKeys(auth_keys_file, key_values)
+  else:
+    raise SshUpdateError("Action '%s' not implemented for authorized keys."
+                         % action)
 
 
 def UpdatePubKeyFile(data, dry_run, key_file=pathutils.SSH_PUB_KEYS):
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index df7fb9d..387de31 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -4459,6 +4459,21 @@ sshsSshPublicKeys = "public_keys"
 sshsNodeDaemonCertificate :: String
 sshsNodeDaemonCertificate = "node_daemon_certificate"
 
+sshsAdd :: String
+sshsAdd = "add"
+
+sshsRemove :: String
+sshsRemove = "remove"
+
+sshsOverride :: String
+sshsOverride = "override"
+
+sshsClear :: String
+sshsClear = "clear"
+
+sshsActions :: FrozenSet String
+sshsActions = ConstantUtils.mkSet [sshsAdd, sshsRemove, sshsOverride, 
sshsClear]
+
 -- * Key files for SSH daemon
 
 sshHostDsaPriv :: String
diff --git a/test/py/ganeti.backend_unittest.py 
b/test/py/ganeti.backend_unittest.py
index d403751..6e0b294 100755
--- a/test/py/ganeti.backend_unittest.py
+++ b/test/py/ganeti.backend_unittest.py
@@ -992,13 +992,6 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
 
       self._master_node = "node_name_%s" % (number_of_pot_mcs / 2)
 
-      self._ssconf_store = self._MySsconfStore(
-        self._CLUSTER_NAME, self._all_nodes, self._master_node)
-      self._command_runner = self._MyCommandRunner(
-        self._CLUSTER_NAME, self._master_node, self._all_nodes,
-        self._potential_master_candidates,
-        new_node_master_candidate)
-
   def _TearDownTestData(self):
     os.remove(self._pub_key_file)
 
diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py
index cd6ca5c..74d55d3 100755
--- a/test/py/ganeti.ssh_unittest.py
+++ b/test/py/ganeti.ssh_unittest.py
@@ -175,6 +175,19 @@ class TestSshKeys(testutils.GanetiTestCase):
       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
 
+  def testAddingDuplicateKeys(self):
+    ssh.AddAuthorizedKey(self.tmpname,
+                         "ssh-dss AAAAB3NzaC1kc3MAAACB root@test")
+    ssh.AddAuthorizedKeys(self.tmpname,
+                          ["ssh-dss AAAAB3NzaC1kc3MAAACB root@test",
+                           "ssh-dss AAAAB3NzaC1kc3MAAACB root@test"])
+
+    self.assertFileContent(self.tmpname,
+      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
+      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
+      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
+      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
+
   def testAddingAlmostButNotCompletelyTheSameKey(self):
     ssh.AddAuthorizedKey(self.tmpname,
       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test")
diff --git a/test/py/ganeti.tools.ssh_update_unittest.py 
b/test/py/ganeti.tools.ssh_update_unittest.py
index af3205a..d4605ad 100755
--- a/test/py/ganeti.tools.ssh_update_unittest.py
+++ b/test/py/ganeti.tools.ssh_update_unittest.py
@@ -51,10 +51,8 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
     self.assertEqual(user, constants.SSH_LOGIN_USER)
     return self.tmpdir
 
-  def testNoKeys(self):
-    data_empty_keys = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {},
-      }
+  def testNoop(self):
+    data_empty_keys = {}
 
     for data in [{}, data_empty_keys]:
       for dry_run in [False, True]:
@@ -64,9 +62,9 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
 
   def testDryRun(self):
     data = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
         "node1" : ["key11", "key12", "key13"],
-        "node2" : ["key21", "key22"]},
+        "node2" : ["key21", "key22"]}),
       }
 
     ssh_update.UpdateAuthorizedKeys(data, True,
@@ -74,11 +72,11 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
     self.assertEqual(os.listdir(self.sshdir), [])
 
-  def testUpdate(self):
+  def testAddAndRemove(self):
     data = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
         "node1": ["key11", "key12"],
-        "node2": ["key21"]},
+        "node2": ["key21"]}),
       }
 
     ssh_update.UpdateAuthorizedKeys(data, False,
@@ -89,6 +87,41 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
                                                    "authorized_keys")),
                      "key11\nkey12\nkey21\n")
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"],
+        "node2": ["key21"]}),
+      }
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\n")
+
+  def testAddAndRemoveDuplicates(self):
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
+        "node1": ["key11", "key12"],
+        "node2": ["key12"]}),
+      }
+
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(sorted(os.listdir(self.sshdir)),
+                     sorted(["authorized_keys"]))
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\nkey12\nkey12\n")
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"]}),
+      }
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\n")
 
 
 class TestUpdatePubKeyFile(testutils.GanetiTestCase):
@@ -97,9 +130,7 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase):
 
   def testNoKeys(self):
     pub_key_file = self._CreateTempFile()
-    data_empty_keys = {
-      constants.SSHS_SSH_PUBLIC_KEYS: {},
-      }
+    data_empty_keys = {}
 
     for data in [{}, data_empty_keys]:
       for dry_run in [False, True]:
@@ -107,16 +138,25 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase):
                                     key_file=pub_key_file)
     self.assertEqual(utils.ReadFile(pub_key_file), "")
 
-  def testValidKeys(self):
+  def testAddAndRemoveKeys(self):
     pub_key_file = self._CreateTempFile()
     data = {
-      constants.SSHS_SSH_PUBLIC_KEYS: {
+      constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_OVERRIDE, {
         "node1": ["key11", "key12"],
-        "node2": ["key21"]},
+        "node2": ["key21"]}),
       }
     ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file)
     self.assertEqual(utils.ReadFile(pub_key_file),
       "node1 key11\nnode1 key12\nnode2 key21\n")
+    data = {
+      constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"],
+        "node3": ["key21"],
+        "node4": ["key33"]}),
+      }
+    ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file)
+    self.assertEqual(utils.ReadFile(pub_key_file),
+     "node2 key21\n")
 
 
 if __name__ == "__main__":
-- 
2.0.0.526.g5318336

Reply via email to