This patch implements the handling of SSH keys when a node
is removed from the cluster. It covers the implementation
in the backend, the introduction and calling of a new RPC
call for that purpose.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/backend.py                     | 126 +++++++++++++++++++++++
 lib/cmdlib/node.py                 |  51 ++++++++-
 lib/rpc_defs.py                    |  13 +++
 lib/server/noded.py                |  15 +++
 test/py/ganeti.backend_unittest.py | 206 ++++++++++++++++++++++++++++++++++---
 5 files changed, 393 insertions(+), 18 deletions(-)

diff --git a/lib/backend.py b/lib/backend.py
index ef5b77b..5768ac4 100644
--- a/lib/backend.py
+++ b/lib/backend.py
@@ -1390,6 +1390,132 @@ def AddNodeSshKey(node_uuid, node_name,
                ssh_port_map.get(node_name), node_data, ssconf_store)
 
 
+def RemoveNodeSshKey(node_uuid, node_name, from_authorized_keys,
+                     from_public_keys, clear_authorized_keys,
+                     ssh_port_map, master_candidate_uuids,
+                     potential_master_candidates,
+                     pub_key_file=pathutils.SSH_PUB_KEYS,
+                     ssconf_store=None,
+                     noded_cert_file=pathutils.NODED_CERT_FILE,
+                     run_cmd_fn=ssh.RunSshCmdWithStdin):
+  """Removes the node's SSH keys from the key files and distributes those.
+
+  @type node_uuid: str
+  @param node_uuid: UUID of the node whose key is removed
+  @type node_name: str
+  @param node_name: name of the node whose key is remove
+  @type from_authorized_keys: boolean
+  @param from_authorized_keys: whether or not the key should be removed
+    from the C{authorized_keys} file
+  @type from_public_keys: boolean
+  @param from_public_keys: whether or not the key should be remove from
+    the C{ganeti_pub_keys} file
+  @type clear_authorized_keys: boolean
+  @param clear_authorized_keys: whether or not the C{authorized_keys} file
+    should be cleared on the node whose keys are removed
+  @type ssh_port_map: dict of str to int
+  @param ssh_port_map: mapping of node names to their SSH port
+  @type master_candidate_uuids: list of str
+  @param master_candidate_uuids: list of UUIDs of the current master candidates
+  @type potential_master_candidates: list of str
+  @param potential_master_candidates: list of names of potential master
+    candidates
+
+  """
+  if not ssconf_store:
+    ssconf_store = ssconf.SimpleStore()
+
+  if not (from_authorized_keys or from_public_keys or clear_authorized_keys):
+    raise errors.SshUpdateError("No removal from any key file was requested.")
+
+  master_node = ssconf_store.GetMasterNode()
+
+  if from_authorized_keys or from_public_keys:
+    keys = ssh.QueryPubKeyFile([node_uuid], key_file=pub_key_file)
+    if not keys or node_uuid not in keys:
+      raise errors.SshUpdateError("Node '%s' not found in the list of public"
+                                  " SSH keys. It seems someone tries to"
+                                  " remove a key from outside the cluster!"
+                                  % node_uuid)
+
+    if node_name == master_node:
+      raise errors.SshUpdateError("Cannot remove the master node's keys.")
+
+    base_data = {}
+    _InitSshUpdateData(base_data, noded_cert_file, ssconf_store)
+    cluster_name = base_data[constants.SSHS_CLUSTER_NAME]
+
+    if from_authorized_keys:
+      base_data[constants.SSHS_SSH_AUTHORIZED_KEYS] = \
+        (constants.SSHS_REMOVE, keys)
+      (auth_key_file, _) = \
+        ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False,
+                            dircheck=False)
+      ssh.RemoveAuthorizedKeys(auth_key_file, keys[node_uuid])
+
+    pot_mc_data = copy.deepcopy(base_data)
+
+    if from_public_keys:
+      pot_mc_data[constants.SSHS_SSH_PUBLIC_KEYS] = \
+        (constants.SSHS_REMOVE, keys)
+      ssh.RemovePublicKey(node_uuid, key_file=pub_key_file)
+
+    all_nodes = ssconf_store.GetNodeList()
+    for node in all_nodes:
+      if node in [master_node, node_name]:
+        continue
+      ssh_port = ssh_port_map.get(node)
+      if not ssh_port:
+        raise errors.OpExecError("No SSH port information available for"
+                                 " node '%s', map: %s." % (node, ssh_port_map))
+      if node in potential_master_candidates:
+        run_cmd_fn(cluster_name, node, pathutils.SSH_UPDATE,
+                   True, True, False, False, False,
+                   ssh_port, pot_mc_data, ssconf_store)
+      else:
+        if from_authorized_keys:
+          run_cmd_fn(cluster_name, node, pathutils.SSH_UPDATE,
+                     True, True, False, False, False,
+                     ssh_port, base_data, ssconf_store)
+
+  authorized_keys_to_clear = {}
+  if clear_authorized_keys or from_public_keys or from_authorized_keys:
+    data = {}
+    _InitSshUpdateData(data, noded_cert_file, ssconf_store)
+    cluster_name = data[constants.SSHS_CLUSTER_NAME]
+    ssh_port = ssh_port_map.get(node_name)
+    if not ssh_port:
+      raise errors.OpExecError("No SSH port information available for"
+                               " node '%s', which is leaving the cluster.")
+
+    authorized_keys_to_clear = {}
+    if clear_authorized_keys:
+      # We never clear a node's key from its own 'authorized_keys' file
+      other_master_candidate_uuids = [uuid for uuid in master_candidate_uuids
+                                      if uuid != node_uuid]
+      candidate_keys = ssh.QueryPubKeyFile(other_master_candidate_uuids,
+                                           key_file=pub_key_file)
+      authorized_keys_to_clear = candidate_keys
+    if from_authorized_keys:
+      authorized_keys_to_clear[node_uuid] = keys[node_uuid]
+    if authorized_keys_to_clear:
+      data[constants.SSHS_SSH_AUTHORIZED_KEYS] = \
+        (constants.SSHS_REMOVE, authorized_keys_to_clear)
+
+    if from_public_keys:
+      data[constants.SSHS_SSH_PUBLIC_KEYS] = \
+        (constants.SSHS_REMOVE, keys)
+
+    try:
+      run_cmd_fn(cluster_name, node_name, pathutils.SSH_UPDATE,
+                 True, True, False, False, False,
+                 ssh_port, data, ssconf_store)
+    except errors.OpExecError, e:
+      logging.info("Removing SSH keys from node '%s' failed. This can happen"
+                   " when the node is already unreachable. Error: %s",
+                   node_name, e)
+
+
 def GetBlockDevSizes(devices):
   """Return the size of the given block devices
 
diff --git a/lib/cmdlib/node.py b/lib/cmdlib/node.py
index 55483c7..2b95402 100644
--- a/lib/cmdlib/node.py
+++ b/lib/cmdlib/node.py
@@ -828,12 +828,39 @@ class LUNodeSetParams(LogicalUnit):
       if self.old_role == self._ROLE_CANDIDATE:
         RemoveNodeCertFromCandidateCerts(self.cfg, node.uuid)
 
+    EnsureKvmdOnNodes(self, feedback_fn, nodes=[node.uuid])
+
     # this will trigger job queue propagation or cleanup if the mc
     # flag changed
     if [self.old_role, self.new_role].count(self._ROLE_CANDIDATE) == 1:
       self.context.ReaddNode(node)
 
-    EnsureKvmdOnNodes(self, feedback_fn, nodes=[node.uuid])
+      if self.cfg.GetClusterInfo().modify_ssh_setup:
+        potential_master_candidates = self.cfg.GetPotentialMasterCandidates()
+        ssh_port_map = GetSshPortMap(potential_master_candidates, self.cfg)
+        master_node = self.cfg.GetMasterNode()
+        if self.old_role == self._ROLE_CANDIDATE:
+          master_candidate_uuids = self.cfg.GetMasterCandidateUuids()
+          ssh_result = self.rpc.call_node_ssh_key_remove(
+            [master_node],
+            node.uuid, node.name,
+            True, # remove node's key from all nodes' authorized_keys file
+            False, # currently, all nodes are potential master candidates
+            False, # do not clear node's 'authorized_keys'
+            ssh_port_map, master_candidate_uuids, potential_master_candidates)
+          ssh_result[master_node].Raise(
+            "Could not adjust the SSH setup after demoting node '%s'"
+            " (UUID: %s)." % (node.name, node.uuid))
+        if self.new_role == self._ROLE_CANDIDATE:
+          ssh_result = self.rpc.call_node_ssh_key_add(
+            [master_node], node.uuid, node.name,
+            True, # add node's key to all node's 'authorized_keys'
+            True, # all nodes are potential master candidates
+            False, # do not update the node's public keys
+            ssh_port_map, potential_master_candidates)
+          ssh_result[master_node].Raise(
+            "Could not update the SSH setup of node '%s' after promotion"
+            " (UUID: %s)." % (node.name, node.uuid))
 
     return result
 
@@ -1514,6 +1541,28 @@ class LUNodeRemove(LogicalUnit):
     assert locking.BGL in self.owned_locks(locking.LEVEL_CLUSTER), \
       "Not owning BGL"
 
+    if modify_ssh_setup:
+      # retrieve the list of potential master candidates before the node is
+      # removed
+      potential_master_candidates = self.cfg.GetPotentialMasterCandidates()
+      potential_master_candidate = \
+        self.op.node_name in potential_master_candidates
+      ssh_port_map = GetSshPortMap(potential_master_candidates, self.cfg)
+      master_candidate_uuids = [uuid for (uuid, node_info)
+                                in self.cfg.GetAllNodesInfo().items()
+                                if node_info.master_candidate]
+      master_node = self.cfg.GetMasterNode()
+      result = self.rpc.call_node_ssh_key_remove(
+        [master_node],
+        self.node.uuid, self.op.node_name,
+        self.node.master_candidate,
+        potential_master_candidate,
+        True, # clear node's 'authorized_keys'
+        ssh_port_map, master_candidate_uuids, potential_master_candidates)
+      result[master_node].Raise(
+        "Could not remove the SSH key of node '%s' (UUID: %s)." %
+        (self.op.node_name, self.node.uuid))
+
     # Promote nodes to master candidate as needed
     AdjustCandidatePool(self, [self.node.uuid])
     self.context.RemoveNode(self.cfg, self.node)
diff --git a/lib/rpc_defs.py b/lib/rpc_defs.py
index 5e47209..f6bfe6a 100644
--- a/lib/rpc_defs.py
+++ b/lib/rpc_defs.py
@@ -534,6 +534,19 @@ _NODE_CALLS = [
     ("ssh_port_map", None, "Map of nodes' SSH ports to be used for transfers"),
     ("potential_master_candidates", None, "Potential master candidates")],
     None, None, "Distribute a new node's public SSH key on the cluster."),
+  ("node_ssh_key_remove", MULTI, None, constants.RPC_TMO_URGENT, [
+    ("node_uuid", None, "UUID of the node whose key is removed"),
+    ("node_name", None, "Name of the node whose key is removed"),
+    ("from_authorized_keys", None,
+     "If the key should be removed from the 'authorized_keys' file."),
+    ("from_public_keys", None,
+     "If the key should be removed from the public key file."),
+    ("clear_authorized_keys", None,
+     "If the 'authorized_keys' file of the node should be cleared."),
+    ("ssh_port_map", None, "Map of nodes' SSH ports to be used for transfers"),
+    ("master_candidate_uuids", None, "List of UUIDs of master candidates."),
+    ("potential_master_candidates", None, "Potential master candidates")],
+    None, None, "Remove a node's SSH key from the other nodes' key files."),
   ]
 
 _MISC_CALLS = [
diff --git a/lib/server/noded.py b/lib/server/noded.py
index 5071fdf..7bb229e 100644
--- a/lib/server/noded.py
+++ b/lib/server/noded.py
@@ -920,6 +920,21 @@ class NodeRequestHandler(http.server.HttpServerHandler):
                                  to_public_keys, get_public_keys,
                                  ssh_port_map, potential_master_candidates)
 
+  @staticmethod
+  def perspective_node_ssh_key_remove(params):
+    """Removes a node's SSH key from the other nodes' SSH files.
+
+    """
+    (node_uuid, node_name, from_authorized_keys,
+     from_public_keys, clear_authorized_keys,
+     ssh_port_map, master_candidate_uuids,
+     potential_master_candidates) = params
+    return backend.RemoveNodeSshKey(node_uuid, node_name, from_authorized_keys,
+                                    from_public_keys, clear_authorized_keys,
+                                    ssh_port_map,
+                                    master_candidate_uuids,
+                                    potential_master_candidates)
+
   # cluster --------------------------
 
   @staticmethod
diff --git a/test/py/ganeti.backend_unittest.py 
b/test/py/ganeti.backend_unittest.py
index 6e0b294..6e54273 100755
--- a/test/py/ganeti.backend_unittest.py
+++ b/test/py/ganeti.backend_unittest.py
@@ -940,7 +940,7 @@ class TestSpaceReportingConstants(unittest.TestCase):
       self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
 
 
-class TestAddNodeSshKey(testutils.GanetiTestCase):
+class TestAddAndRemoveNodeSshKey(testutils.GanetiTestCase):
 
   _CLUSTER_NAME = "mycluster"
   _SSH_PORT = 22
@@ -949,6 +949,8 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
     testutils.GanetiTestCase.setUp(self)
     self._ssh_add_authorized_patcher = testutils \
       .patch_object(ssh, "AddAuthorizedKeys")
+    self._ssh_remove_authorized_patcher = testutils \
+      .patch_object(ssh, "RemoveAuthorizedKeys")
     self._ssh_add_authorized_mock = self._ssh_add_authorized_patcher.start()
 
     self._ssconf_mock = mock.Mock()
@@ -958,11 +960,14 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
 
     self._run_cmd_mock = mock.Mock()
 
+    self._ssh_remove_authorized_mock = \
+      self._ssh_remove_authorized_patcher.start()
     self.noded_cert_file = testutils.TestDataFilename("cert1.pem")
 
   def tearDown(self):
     super(testutils.GanetiTestCase, self).tearDown()
     self._ssh_add_authorized_patcher.stop()
+    self._ssh_remove_authorized_patcher.stop()
 
   def _SetupTestData(self, number_of_nodes=15, number_of_pot_mcs=5,
                      number_of_mcs=5):
@@ -970,7 +975,9 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
 
     """
     self._pub_key_file = self._CreateTempFile()
+    self._all_nodes = []
     self._potential_master_candidates = []
+    self._master_candidate_uuids = []
     self._ssh_port_map = {}
 
     self._ssconf_mock.reset_mock()
@@ -981,33 +988,68 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
 
     for i in range(number_of_nodes):
       node_name = "node_name_%s" % i
-      self._potential_master_candidates.append(node_name)
+      node_uuid = "node_uuid_%s" % i
       self._ssh_port_map[node_name] = self._SSH_PORT
+      self._all_nodes.append(node_name)
 
-      self._all_nodes = self._potential_master_candidates[:]
-      for j in range(number_of_pot_mcs, number_of_nodes):
-        node_name = "node_name_%s"
-        self._all_nodes.append(node_name)
-        self._ssh_port_map[node_name] = self._SSH_PORT
+      if i in range(number_of_mcs + number_of_pot_mcs):
+        ssh.AddPublicKey("node_uuid_%s" % i, "key%s" % i,
+                         key_file=self._pub_key_file)
+        self._potential_master_candidates.append(node_name)
 
-      self._master_node = "node_name_%s" % (number_of_pot_mcs / 2)
+      if i in range(number_of_mcs):
+        self._master_candidate_uuids.append(node_uuid)
+
+    self._master_node = "node_name_%s" % (number_of_mcs / 2)
+    self._ssconf_mock.GetNodeList.return_value = self._all_nodes
 
   def _TearDownTestData(self):
     os.remove(self._pub_key_file)
 
-  def _KeyReceived(self, key_data, node_name, expected_type,
-                   expected_key):
+  def _KeyOperationExecuted(self, key_data, node_name, expected_type,
+                            expected_key, action_types):
     if not node_name in key_data:
       return False
     for data in key_data[node_name]:
       if expected_type in data:
         (action, key_dict) = data[expected_type]
-        if action in [constants.SSHS_ADD, constants.SSHS_OVERRIDE]:
+        if action in action_types:
           for key_list in key_dict.values():
             if expected_key in key_list:
               return True
     return False
 
+  def _KeyReceived(self, key_data, node_name, expected_type,
+                   expected_key):
+    return self._KeyOperationExecuted(
+      key_data, node_name, expected_type, expected_key,
+      [constants.SSHS_ADD, constants.SSHS_OVERRIDE])
+
+  def _KeyRemoved(self, key_data, node_name, expected_type,
+                  expected_key):
+    if self._KeyOperationExecuted(
+        key_data, node_name, expected_type, expected_key,
+        [constants.SSHS_REMOVE]):
+      return True
+    else:
+      if not node_name in key_data:
+        return False
+      for data in key_data[node_name]:
+        if expected_type in data:
+          (action, key_dict) = data[expected_type]
+          if action == constants.SSHS_CLEAR:
+            return True
+    return False
+
+  def _GetCallsPerNode(self):
+    calls_per_node = {}
+    for (pos, keyword) in self._run_cmd_mock.call_args_list:
+      (cluster_name, node, _, _, _, _, _, _, _, data, _) = pos
+      if not node in calls_per_node:
+        calls_per_node[node] = []
+      calls_per_node[node].append(data)
+    return calls_per_node
+
   def testAddNodeSshKeyValid(self):
     new_node_name = "new_node_name"
     new_node_uuid = "new_node_uuid"
@@ -1039,12 +1081,7 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
                             noded_cert_file=self.noded_cert_file,
                             run_cmd_fn=self._run_cmd_mock)
 
-      calls_per_node = {}
-      for (pos, keyword) in self._run_cmd_mock.call_args_list:
-        (cluster_name, node, _, _, _, _, _, _, _, data, _) = pos
-        if not node in calls_per_node:
-          calls_per_node[node] = []
-        calls_per_node[node].append(data)
+      calls_per_node = self._GetCallsPerNode()
 
       # one sample node per type (master candidate, potential master candidate,
       # normal node)
@@ -1105,6 +1142,141 @@ class TestAddNodeSshKey(testutils.GanetiTestCase):
 
       self._TearDownTestData()
 
+  def testRemoveNodeSshKeyValid(self):
+    node_name = "node_name"
+    node_uuid = "node_uuid"
+    node_key1 = "node_key1"
+    node_key2 = "node_key2"
+
+    for (from_authorized_keys, from_public_keys,
+         clear_authorized_keys) in \
+       [(True, True, False),
+        (True, False, False),
+        (False, True, False),
+        (False, True, True),
+        (False, False, True),
+        (True, True, True),
+       ]:
+
+      self._SetupTestData()
+
+      # set up public key file, ssconf store, and node lists
+      if from_public_keys or from_authorized_keys:
+        for key in [node_key1, node_key2]:
+          ssh.AddPublicKey(node_uuid, key, key_file=self._pub_key_file)
+        self._potential_master_candidates.append(node_name)
+      if from_authorized_keys:
+        ssh.AddAuthorizedKeys(self._pub_key_file, [node_key1, node_key2])
+
+      self._ssh_port_map[node_name] = self._SSH_PORT
+
+      if from_authorized_keys:
+        self._master_candidate_uuids.append(node_uuid)
+
+      backend.RemoveNodeSshKey(node_uuid, node_name,
+                               from_authorized_keys,
+                               from_public_keys,
+                               clear_authorized_keys,
+                               self._ssh_port_map,
+                               self._master_candidate_uuids,
+                               self._potential_master_candidates,
+                               pub_key_file=self._pub_key_file,
+                               ssconf_store=self._ssconf_mock,
+                               noded_cert_file=self.noded_cert_file,
+                               run_cmd_fn=self._run_cmd_mock)
+
+      calls_per_node = self._GetCallsPerNode()
+
+      # one sample node per type (master candidate, potential master candidate,
+      # normal node)
+      mc_idx = 3
+      pot_mc_idx = 7
+      normal_idx = 12
+      sample_nodes = [mc_idx, pot_mc_idx, normal_idx]
+      pot_sample_nodes = [mc_idx, pot_mc_idx]
+
+      if from_authorized_keys:
+        for node_idx in sample_nodes:
+          self.assertTrue(self._KeyRemoved(
+            calls_per_node, "node_name_%i" % node_idx,
+            constants.SSHS_SSH_AUTHORIZED_KEYS, node_key1),
+            "Node %i did not get request to remove authorized key '%s'"
+            " although it should have." % (node_idx, node_key1))
+      else:
+        for node_idx in sample_nodes:
+          self.assertFalse(self._KeyRemoved(
+            calls_per_node, "node_name_%i" % node_idx,
+            constants.SSHS_SSH_AUTHORIZED_KEYS, node_key1),
+            "Node %i got requested to remove authorized key '%s', although it"
+            " should not have." % (node_idx, node_key1))
+
+      if from_public_keys:
+        for node_idx in pot_sample_nodes:
+          self.assertTrue(self._KeyRemoved(
+            calls_per_node, "node_name_%i" % node_idx,
+            constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
+            "Node %i did not receive request to remove public key '%s',"
+            " although it should have." % (node_idx, node_key1))
+        self.assertTrue(self._KeyRemoved(
+          calls_per_node, node_name,
+          constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
+          "Node %s did not receive request to remove its own public key '%s',"
+          " although it should have." % (node_name, node_key1))
+        for node_idx in list(set(sample_nodes) - set(pot_sample_nodes)):
+          self.assertFalse(self._KeyRemoved(
+            calls_per_node, "node_name_%i" % node_idx,
+            constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
+            "Node %i received a request to remove public key '%s',"
+            " although it should not have." % (node_idx, node_key1))
+      else:
+        for node_idx in sample_nodes:
+          self.assertFalse(self._KeyRemoved(
+            calls_per_node, "node_name_%i" % node_idx,
+            constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
+            "Node %i received a request to remove public key '%s',"
+            " although it should not have." % (node_idx, node_key1))
+
+      if clear_authorized_keys:
+        for node_idx in list(set(sample_nodes) - set([mc_idx])):
+          key = "key%s" % node_idx
+          self.assertFalse(self._KeyRemoved(
+            calls_per_node, node_name,
+            constants.SSHS_SSH_AUTHORIZED_KEYS, key),
+            "Node %s did receive request to remove authorized key '%s',"
+            " although it should not have." % (node_name, key))
+        mc_key = "key%s" % mc_idx
+        self.assertTrue(self._KeyRemoved(
+          calls_per_node, node_name,
+          constants.SSHS_SSH_AUTHORIZED_KEYS, mc_key),
+          "Node %s did not receive request to remove authorized key '%s',"
+          " although it should have." % (node_name, mc_key))
+        if from_authorized_keys:
+          self.assertTrue(self._KeyRemoved(
+            calls_per_node, node_name,
+            constants.SSHS_SSH_AUTHORIZED_KEYS, node_key1),
+            "Node %s did receive request to remove its own authorized key 
'%s',"
+            " although it should not have." % (node_name, node_key1))
+      else:
+        for node_idx in sample_nodes:
+          key = "key%s" % node_idx
+          self.assertFalse(self._KeyRemoved(
+            calls_per_node, node_name,
+            constants.SSHS_SSH_AUTHORIZED_KEYS, key),
+            "Node %s did receive request to remove authorized key '%s',"
+            " although it should not have." % (node_name, key))
+
+
+class TestVerifySshSetup(testutils.GanetiTestCase):
+
+  _NODE1_UUID = "uuid1"
+  _NODE2_UUID = "uuid2"
+  _NODE3_UUID = "uuid3"
+  _NODE1_NAME = "name1"
+  _NODE2_NAME = "name2"
+  _NODE3_NAME = "name3"
+  _NODE1_KEYS = ["key11", "key12"]
+
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
2.0.0.526.g5318336

Reply via email to