This patch prepares the ssh utility library ssh.py and the ssh update tool with the ability to remove SSH keys from the 'authorized_keys' and the 'ganeti_pub_keys' files.
Signed-off-by: Helga Velroyen <hel...@google.com> --- lib/ssh.py | 24 +++++++--- lib/tools/ssh_update.py | 42 ++++++++++++----- src/Ganeti/Constants.hs | 15 +++++++ test/py/ganeti.backend_unittest.py | 7 --- test/py/ganeti.ssh_unittest.py | 13 ++++++ test/py/ganeti.tools.ssh_update_unittest.py | 70 ++++++++++++++++++++++------- 6 files changed, 131 insertions(+), 40 deletions(-) diff --git a/lib/ssh.py b/lib/ssh.py index daeece3..e646a0b 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -187,16 +187,16 @@ def AddAuthorizedKey(file_obj, key): AddAuthorizedKeys(file_obj, [key]) -def RemoveAuthorizedKey(file_name, key): - """Removes an SSH public key from an authorized_keys file. +def RemoveAuthorizedKeys(file_name, keys): + """Removes public SSH keys from an authorized_keys file. @type file_name: str @param file_name: path to authorized_keys file - @type key: str - @param key: string containing key + @type keys: list of str + @param keys: list of strings containing keys """ - key_fields = _SplitSshKey(key) + key_field_list = [_SplitSshKey(key) for key in keys] fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name)) try: @@ -206,7 +206,7 @@ def RemoveAuthorizedKey(file_name, key): try: for line in f: # Ignore whitespace changes while comparing lines - if _SplitSshKey(line) != key_fields: + if _SplitSshKey(line) not in key_field_list: out.write(line) out.flush() @@ -220,6 +220,18 @@ def RemoveAuthorizedKey(file_name, key): raise +def RemoveAuthorizedKey(file_name, key): + """Removes an SSH public key from an authorized_keys file. + + @type file_name: str + @param file_name: path to authorized_keys file + @type key: str + @param key: string containing key + + """ + RemoveAuthorizedKeys(file_name, [key]) + + def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, tmp_file, found): """Processes one line of the public key file when adding a key. diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py index cace9a2..41a9de8 100644 --- a/lib/tools/ssh_update.py +++ b/lib/tools/ssh_update.py @@ -46,9 +46,13 @@ _DATA_CHECK = ht.TStrictDict(False, True, { constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString, constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString, constants.SSHS_SSH_PUBLIC_KEYS: - ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)), + ht.TItems( + [ht.TElemOf(constants.SSHS_ACTIONS), + ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]), constants.SSHS_SSH_AUTHORIZED_KEYS: - ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)), + ht.TItems( + [ht.TElemOf(constants.SSHS_ACTIONS), + ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]), }) @@ -86,22 +90,36 @@ def UpdateAuthorizedKeys(data, dry_run, _homedir_fn=None): @param dry_run: Whether to perform a dry run """ - authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS) + instructions = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS) + if not instructions: + logging.info("No change to the authorized_keys file requested.") + return + (action, authorized_keys) = instructions - if authorized_keys: - (auth_keys_file, _) = \ - ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True, - _homedir_fn=_homedir_fn) + (auth_keys_file, _) = \ + ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True, + _homedir_fn=_homedir_fn) + key_values = [] + for key_value in authorized_keys.values(): + key_values += key_value + if action == constants.SSHS_ADD: if dry_run: - logging.info("This is a dry run, not modifying %s", auth_keys_file) + logging.info("This is a dry run, not adding keys to %s", + auth_keys_file) else: - all_authorized_keys = [] - for keys in authorized_keys.values(): - all_authorized_keys += keys if not os.path.exists(auth_keys_file): utils.WriteFile(auth_keys_file, mode=0600, data="") - ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys) + ssh.AddAuthorizedKeys(auth_keys_file, key_values) + elif action == constants.SSHS_REMOVE: + if dry_run: + logging.info("This is a dry run, not removing keys from %s", + auth_keys_file) + else: + ssh.RemoveAuthorizedKeys(auth_keys_file, key_values) + else: + raise SshUpdateError("Action '%s' not implemented for authorized keys." + % action) def UpdatePubKeyFile(data, dry_run, key_file=pathutils.SSH_PUB_KEYS): diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs index df7fb9d..387de31 100644 --- a/src/Ganeti/Constants.hs +++ b/src/Ganeti/Constants.hs @@ -4459,6 +4459,21 @@ sshsSshPublicKeys = "public_keys" sshsNodeDaemonCertificate :: String sshsNodeDaemonCertificate = "node_daemon_certificate" +sshsAdd :: String +sshsAdd = "add" + +sshsRemove :: String +sshsRemove = "remove" + +sshsOverride :: String +sshsOverride = "override" + +sshsClear :: String +sshsClear = "clear" + +sshsActions :: FrozenSet String +sshsActions = ConstantUtils.mkSet [sshsAdd, sshsRemove, sshsOverride, sshsClear] + -- * Key files for SSH daemon sshHostDsaPriv :: String diff --git a/test/py/ganeti.backend_unittest.py b/test/py/ganeti.backend_unittest.py index d403751..6e0b294 100755 --- a/test/py/ganeti.backend_unittest.py +++ b/test/py/ganeti.backend_unittest.py @@ -992,13 +992,6 @@ class TestAddNodeSshKey(testutils.GanetiTestCase): self._master_node = "node_name_%s" % (number_of_pot_mcs / 2) - self._ssconf_store = self._MySsconfStore( - self._CLUSTER_NAME, self._all_nodes, self._master_node) - self._command_runner = self._MyCommandRunner( - self._CLUSTER_NAME, self._master_node, self._all_nodes, - self._potential_master_candidates, - new_node_master_candidate) - def _TearDownTestData(self): os.remove(self._pub_key_file) diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py index cd6ca5c..74d55d3 100755 --- a/test/py/ganeti.ssh_unittest.py +++ b/test/py/ganeti.ssh_unittest.py @@ -175,6 +175,19 @@ class TestSshKeys(testutils.GanetiTestCase): " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n" "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n") + def testAddingDuplicateKeys(self): + ssh.AddAuthorizedKey(self.tmpname, + "ssh-dss AAAAB3NzaC1kc3MAAACB root@test") + ssh.AddAuthorizedKeys(self.tmpname, + ["ssh-dss AAAAB3NzaC1kc3MAAACB root@test", + "ssh-dss AAAAB3NzaC1kc3MAAACB root@test"]) + + self.assertFileContent(self.tmpname, + "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' + " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n" + "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n") + def testAddingAlmostButNotCompletelyTheSameKey(self): ssh.AddAuthorizedKey(self.tmpname, "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test") diff --git a/test/py/ganeti.tools.ssh_update_unittest.py b/test/py/ganeti.tools.ssh_update_unittest.py index af3205a..d4605ad 100755 --- a/test/py/ganeti.tools.ssh_update_unittest.py +++ b/test/py/ganeti.tools.ssh_update_unittest.py @@ -51,10 +51,8 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase): self.assertEqual(user, constants.SSH_LOGIN_USER) return self.tmpdir - def testNoKeys(self): - data_empty_keys = { - constants.SSHS_SSH_AUTHORIZED_KEYS: {}, - } + def testNoop(self): + data_empty_keys = {} for data in [{}, data_empty_keys]: for dry_run in [False, True]: @@ -64,9 +62,9 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase): def testDryRun(self): data = { - constants.SSHS_SSH_AUTHORIZED_KEYS: { + constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, { "node1" : ["key11", "key12", "key13"], - "node2" : ["key21", "key22"]}, + "node2" : ["key21", "key22"]}), } ssh_update.UpdateAuthorizedKeys(data, True, @@ -74,11 +72,11 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase): self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) self.assertEqual(os.listdir(self.sshdir), []) - def testUpdate(self): + def testAddAndRemove(self): data = { - constants.SSHS_SSH_AUTHORIZED_KEYS: { + constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, { "node1": ["key11", "key12"], - "node2": ["key21"]}, + "node2": ["key21"]}), } ssh_update.UpdateAuthorizedKeys(data, False, @@ -89,6 +87,41 @@ class TestUpdateAuthorizedKeys(testutils.GanetiTestCase): self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "authorized_keys")), "key11\nkey12\nkey21\n") + data = { + constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, { + "node1": ["key12"], + "node2": ["key21"]}), + } + ssh_update.UpdateAuthorizedKeys(data, False, + _homedir_fn=self._GetHomeDir) + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, + "authorized_keys")), + "key11\n") + + def testAddAndRemoveDuplicates(self): + data = { + constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_ADD, { + "node1": ["key11", "key12"], + "node2": ["key12"]}), + } + + ssh_update.UpdateAuthorizedKeys(data, False, + _homedir_fn=self._GetHomeDir) + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + self.assertEqual(sorted(os.listdir(self.sshdir)), + sorted(["authorized_keys"])) + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, + "authorized_keys")), + "key11\nkey12\nkey12\n") + data = { + constants.SSHS_SSH_AUTHORIZED_KEYS: (constants.SSHS_REMOVE, { + "node1": ["key12"]}), + } + ssh_update.UpdateAuthorizedKeys(data, False, + _homedir_fn=self._GetHomeDir) + self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, + "authorized_keys")), + "key11\n") class TestUpdatePubKeyFile(testutils.GanetiTestCase): @@ -97,9 +130,7 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase): def testNoKeys(self): pub_key_file = self._CreateTempFile() - data_empty_keys = { - constants.SSHS_SSH_PUBLIC_KEYS: {}, - } + data_empty_keys = {} for data in [{}, data_empty_keys]: for dry_run in [False, True]: @@ -107,16 +138,25 @@ class TestUpdatePubKeyFile(testutils.GanetiTestCase): key_file=pub_key_file) self.assertEqual(utils.ReadFile(pub_key_file), "") - def testValidKeys(self): + def testAddAndRemoveKeys(self): pub_key_file = self._CreateTempFile() data = { - constants.SSHS_SSH_PUBLIC_KEYS: { + constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_OVERRIDE, { "node1": ["key11", "key12"], - "node2": ["key21"]}, + "node2": ["key21"]}), } ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file) self.assertEqual(utils.ReadFile(pub_key_file), "node1 key11\nnode1 key12\nnode2 key21\n") + data = { + constants.SSHS_SSH_PUBLIC_KEYS: (constants.SSHS_REMOVE, { + "node1": ["key12"], + "node3": ["key21"], + "node4": ["key33"]}), + } + ssh_update.UpdatePubKeyFile(data, False, key_file=pub_key_file) + self.assertEqual(utils.ReadFile(pub_key_file), + "node2 key21\n") if __name__ == "__main__": -- 2.0.0.526.g5318336