On Tue, Sep 02, 2014 at 04:19:35PM +0200, 'Helga Velroyen' via ganeti-devel 
wrote:
This patch prepares the ssh utility library ssh.py and
the ssh update tool with the ability to remove SSH keys
from the 'authorized_keys' and the 'ganeti_pub_keys'
files.

Signed-off-by: Helga Velroyen <[email protected]>
---
lib/ssh.py                                  | 24 +++++++---
lib/tools/ssh_update.py                     | 42 ++++++++++++-----
src/Ganeti/Constants.hs                     | 15 +++++++
test/py/ganeti.backend_unittest.py          |  7 ---
test/py/ganeti.tools.ssh_update_unittest.py | 70 ++++++++++++++++++++++-------
5 files changed, 118 insertions(+), 40 deletions(-)

diff --git a/lib/ssh.py b/lib/ssh.py
index 6848244..52a7484 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -185,16 +185,16 @@ def AddAuthorizedKey(file_obj, key):
  AddAuthorizedKeys(file_obj, [key])


-def RemoveAuthorizedKey(file_name, key):
-  """Removes an SSH public key from an authorized_keys file.
+def RemoveAuthorizedKeys(file_name, keys):
+  """Removes public SSH keys from an authorized_keys file.

  @type file_name: str
  @param file_name: path to authorized_keys file
-  @type key: str
-  @param key: string containing key
+  @type keys: list of str
+  @param keys: list of strings containing keys

  """
-  key_fields = _SplitSshKey(key)
+  key_field_list = [_SplitSshKey(key) for key in keys]

  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
  try:
@@ -204,7 +204,7 @@ def RemoveAuthorizedKey(file_name, key):
      try:
        for line in f:
          # Ignore whitespace changes while comparing lines
-          if _SplitSshKey(line) != key_fields:
+          if _SplitSshKey(line) not in key_field_list:
            out.write(line)

        out.flush()
@@ -218,6 +218,18 @@ def RemoveAuthorizedKey(file_name, key):
    raise


+def RemoveAuthorizedKey(file_name, key):
+  """Removes an SSH public key from an authorized_keys file.
+
+  @type file_name: str
+  @param file_name: path to authorized_keys file
+  @type key: str
+  @param key: string containing key
+
+  """
+  RemoveAuthorizedKeys(file_name, [key])
+
+
def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, tmp_file,
                             found):
  """Processes one line of the public key file when adding a key.
diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py
index cace9a2..41a9de8 100644
--- a/lib/tools/ssh_update.py
+++ b/lib/tools/ssh_update.py
@@ -46,9 +46,13 @@ _DATA_CHECK = ht.TStrictDict(False, True, {
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
  constants.SSHS_SSH_PUBLIC_KEYS:
-    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+    ht.TItems(
+      [ht.TElemOf(constants.SSHS_ACTIONS),
+       ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]),

I'd say this part should go to the previous patch, as the instructions for public keys are already used there.

  constants.SSHS_SSH_AUTHORIZED_KEYS:
-    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+    ht.TItems(
+      [ht.TElemOf(constants.SSHS_ACTIONS),
+       ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]),
  })


@@ -86,22 +90,36 @@ def UpdateAuthorizedKeys(data, dry_run, _homedir_fn=None):
  @param dry_run: Whether to perform a dry run

  """
-  authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+  instructions = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+  if not instructions:
+    logging.info("No change to the authorized_keys file requested.")
+    return
+  (action, authorized_keys) = instructions

-  if authorized_keys:
-    (auth_keys_file, _) = \
-      ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
-                          _homedir_fn=_homedir_fn)
+  (auth_keys_file, _) = \
+    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                        _homedir_fn=_homedir_fn)

+  key_values = []
+  for key_value in authorized_keys.values():
+    key_values += key_value

Following the discussion with Klaus, this could be simlified to

key_values = list(itertools.chain.from_iterable(authorized_keys.values()))

if desired.

+  if action == constants.SSHS_ADD:
    if dry_run:
-      logging.info("This is a dry run, not modifying %s", auth_keys_file)
+      logging.info("This is a dry run, not adding keys to %s",
+                   auth_keys_file)
    else:
-      all_authorized_keys = []
-      for keys in authorized_keys.values():
-        all_authorized_keys += keys
      if not os.path.exists(auth_keys_file):
        utils.WriteFile(auth_keys_file, mode=0600, data="")
-      ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
+      ssh.AddAuthorizedKeys(auth_keys_file, key_values)
+  elif action == constants.SSHS_REMOVE:
+    if dry_run:
+      logging.info("This is a dry run, not removing keys from %s",
+                   auth_keys_file)
+    else:
+      ssh.RemoveAuthorizedKeys(auth_keys_file, key_values)
+  else:
+    raise SshUpdateError("Action '%s' not implemented for authorized keys."
+                         % action)


def UpdatePubKeyFile(data, dry_run, key_file=pathutils.SSH_PUB_KEYS):
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index d018b20..36227e6 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -4481,6 +4481,21 @@ sshsSshPublicKeys = "public_keys"
sshsNodeDaemonCertificate :: String
sshsNodeDaemonCertificate = "node_daemon_certificate"

+sshsAdd :: String
+sshsAdd = "add"
+
+sshsRemove :: String
+sshsRemove = "remove"
+
+sshsOverride :: String
+sshsOverride = "override"
+
+sshsClear :: String
+sshsClear = "clear"
+
+sshsActions :: FrozenSet String
+sshsActions = ConstantUtils.mkSet [sshsAdd, sshsRemove, sshsOverride, 
sshsClear]
+
-- * Key files for SSH daemon

sshHostDsaPriv :: String
diff --git a/test/py/ganeti.backend_unittest.py 
b/test/py/ganeti.backend_unittest.py
index d403751..6e0b294 100755
--- a/test/py/ganeti.backend_unittest.py
+++ b/test/py/ganeti.backend_unittest.py
@@ -992,13 +992,6 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):

      self._master_node = "node_name_%s" % (number_of_pot_mcs / 2)

-      self._ssconf_store = self._MySsconfStore(
-        self._CLUSTER_NAME, self._all_nodes, self._master_node)
-      self._command_runner = self._MyCommandRunner(
-        self._CLUSTER_NAME, self._master_node, self._all_nodes,
-        self._potential_master_candidates,
-        new_node_master_candidate)
-
  def _TearDownTestData(self):
    os.remove(self._pub_key_file)

diff --git a/test/py/ganeti.tools.ssh_update_unittest.py 
b/test/py/ganeti.tools.ssh_update_unittest.py
index af3205a..d4605ad 100755
--- a/test/py/ganeti.tools.ssh_update_unittest.py
+++ b/test/py/ganeti.tools.ssh_update_unittest.py
@@ -51,10 +51,8 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
    self.assertEqual(user, constants.SSH_LOGIN_USER)
    return self.tmpdir

-  def testNoKeys(self):
-    data_empty_keys = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {},
-      }
+  def testNoop(self):
+    data_empty_keys = {}

    for data in [{}, data_empty_keys]:
      for dry_run in [False, True]:
@@ -64,9 +62,9 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):

  def testDryRun(self):
    data = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
        "node1" : ["key11", "key12", "key13"],
-        "node2" : ["key21", "key22"]},
+        "node2" : ["key21", "key22"]}),
      }

    ssh_update.UpdateAuthorizedKeys(data, True,
@@ -74,11 +72,11 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
    self.assertEqual(os.listdir(self.sshdir), [])

-  def testUpdate(self):
+  def testAddAndRemove(self):
    data = {
-      constants.SSHS_SSH_AUTHORIZED_KEYS: {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
        "node1": ["key11", "key12"],
-        "node2": ["key21"]},
+        "node2": ["key21"]}),
      }

    ssh_update.UpdateAuthorizedKeys(data, False,
@@ -89,6 +87,41 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase):
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
                                                   "authorized_keys")),
                     "key11\nkey12\nkey21\n")
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"],
+        "node2": ["key21"]}),
+      }
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\n")
+
+  def testAddAndRemoveDuplicates(self):
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, {
+        "node1": ["key11", "key12"],
+        "node2": ["key12"]}),
+      }
+
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+    self.assertEqual(sorted(os.listdir(self.sshdir)),
+                     sorted(["authorized_keys"]))

For the test 'sorted' is redundant here, as we should have just 1 element. But if it's used everywhere else, we can retain it for consistency.

+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\nkey12\nkey12\n")
+    data = {
+      constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"]}),
+      }
+    ssh_update.UpdateAuthorizedKeys(data, False,
+                                    _homedir_fn=self._GetHomeDir)
+    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
+                                                   "authorized_keys")),
+                     "key11\n")


class TestUpdatePubKeyFile(testutils.GanetiTestCase):
@@ -97,9 +130,7 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase):

  def testNoKeys(self):
    pub_key_file = self._CreateTempFile()
-    data_empty_keys = {
-      constants.SSHS_SSH_PUBLIC_KEYS: {},
-      }
+    data_empty_keys = {}

    for data in [{}, data_empty_keys]:
      for dry_run in [False, True]:
@@ -107,16 +138,25 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase):
                                    key_file=pub_key_file)
    self.assertEqual(utils.ReadFile(pub_key_file), "")

-  def testValidKeys(self):
+  def testAddAndRemoveKeys(self):
    pub_key_file = self._CreateTempFile()
    data = {
-      constants.SSHS_SSH_PUBLIC_KEYS: {
+      constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_OVERRIDE, {
        "node1": ["key11", "key12"],
-        "node2": ["key21"]},
+        "node2": ["key21"]}),
      }
    ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file)
    self.assertEqual(utils.ReadFile(pub_key_file),
      "node1 key11\nnode1 key12\nnode2 key21\n")
+    data = {
+      constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_REMOVE, {
+        "node1": ["key12"],
+        "node3": ["key21"],
+        "node4": ["key33"]}),
+      }
+    ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file)
+    self.assertEqual(utils.ReadFile(pub_key_file),
+     "node2 key21\n")


if __name__ == "__main__":
--
2.1.0.rc2.206.gedb03e5

Rest LGTM

Reply via email to