patchviasocket improve error handling more detailed error if host file not found or cannot be opened using mkstemp and mkdtemp for improved security improve resource cleanup in error conditions in unit test
Project: http://git-wip-us.apache.org/repos/asf/cloudstack/repo Commit: http://git-wip-us.apache.org/repos/asf/cloudstack/commit/751d3552 Tree: http://git-wip-us.apache.org/repos/asf/cloudstack/tree/751d3552 Diff: http://git-wip-us.apache.org/repos/asf/cloudstack/diff/751d3552 Branch: refs/heads/master Commit: 751d3552dc3e3514c51fb9038ab91a625470212f Parents: 0acd3c1 Author: Sverrir Berg <[email protected]> Authored: Tue May 17 15:06:35 2016 +0000 Committer: Sverrir Berg <[email protected]> Committed: Fri May 20 15:42:34 2016 +0000 ---------------------------------------------------------------------- scripts/vm/hypervisor/kvm/patchviasocket.py | 20 ++++----- .../vm/hypervisor/kvm/test_patchviasocket.py | 46 ++++++++++---------- 2 files changed, 32 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cloudstack/blob/751d3552/scripts/vm/hypervisor/kvm/patchviasocket.py ---------------------------------------------------------------------- diff --git a/scripts/vm/hypervisor/kvm/patchviasocket.py b/scripts/vm/hypervisor/kvm/patchviasocket.py index d9616c9..c971d5d 100755 --- a/scripts/vm/hypervisor/kvm/patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/patchviasocket.py @@ -31,22 +31,18 @@ PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud" MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n" -def read_pub_key(key_file): - try: - if os.path.isfile(key_file): - with open(key_file, "r") as f: - return f.read() - except IOError: - return None - - def send_to_socket(sock_file, key_file, cmdline): - pub_key = read_pub_key(key_file) - - if not pub_key: + if not os.path.exists(key_file): print("ERROR: ssh public key not found on host at {0}".format(key_file)) return 1 + try: + with open(key_file, "r") as f: + pub_key = f.read() + except IOError as e: + print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror)) + return 1 + # Keep old substitution from perl code: cmdline = cmdline.replace("%", " ") http://git-wip-us.apache.org/repos/asf/cloudstack/blob/751d3552/scripts/vm/hypervisor/kvm/test_patchviasocket.py ---------------------------------------------------------------------- diff --git a/scripts/vm/hypervisor/kvm/test_patchviasocket.py b/scripts/vm/hypervisor/kvm/test_patchviasocket.py index 074b159..6b411d3 100755 --- a/scripts/vm/hypervisor/kvm/test_patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/test_patchviasocket.py @@ -32,7 +32,7 @@ NON_EXISTING_FILE = "must-not-exist" def write_key_file(): - tmpfile = tempfile.mktemp(".sck") + _, tmpfile = tempfile.mkstemp(".sck") with open(tmpfile, "w") as f: f.write(KEY_DATA) return tmpfile @@ -42,7 +42,8 @@ class SocketThread(threading.Thread): def __init__(self): super(SocketThread, self).__init__() self._data = "" - self._file = tempfile.mktemp(".sck") + self._folder = tempfile.mkdtemp(".sck") + self._file = os.path.join(self._folder, "socket") self._ready = False def data(self): @@ -60,18 +61,21 @@ class SocketThread(threading.Thread): MAX_SIZE = 10 * 1024 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(self._file) - s.listen(1) - s.settimeout(TIMEOUT) try: - self._ready = True - client, address = s.accept() - self._data = client.recv(MAX_SIZE) - client.close() - except socket.timeout: - pass - s.close() - os.remove(self._file) + s.bind(self._file) + s.listen(1) + s.settimeout(TIMEOUT) + try: + self._ready = True + client, address = s.accept() + self._data = client.recv(MAX_SIZE) + client.close() + except socket.timeout: + pass + finally: + s.close() + os.remove(self._file) + os.rmdir(self._folder) class TestPatchViaSocket(unittest.TestCase): @@ -88,15 +92,6 @@ class TestPatchViaSocket(unittest.TestCase): os.remove(self._key_file) os.remove(self._unreadable) - def test_read_file(self): - pub_key = patchviasocket.read_pub_key(self._key_file) - self.assertEqual(KEY_DATA, pub_key) - - def test_read_file_error(self): - self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE)) - self.assertIsNone(patchviasocket.read_pub_key(self._unreadable)) - self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file - def test_write_to_socket(self): reader = SocketThread() reader.start() @@ -116,6 +111,13 @@ class TestPatchViaSocket(unittest.TestCase): self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA)) reader.join() # timeout + def test_host_key_access_denied(self): + reader = SocketThread() + reader.start() + reader.wait_until_ready() + self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA)) + reader.join() # timeout + def test_nonexistant_socket_error(self): reader = SocketThread() reader.start()
