This should ask the user to confirm the host key, and if confirmed
insert it into the .ssh/known_hosts file so that rsync will pick it up.

Not entirely sure how the user will know if the host key is correct...

Signed-off-by: Ben Lipton <[email protected]>
---
 p2v-transfer/p2v_transfer.py           |   23 +++++++++++++++++++++--
 p2v-transfer/test/p2v_transfer_test.py |    9 +++++++--
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/p2v-transfer/p2v_transfer.py b/p2v-transfer/p2v_transfer.py
index e232103..b03f92f 100755
--- a/p2v-transfer/p2v_transfer.py
+++ b/p2v-transfer/p2v_transfer.py
@@ -28,6 +28,7 @@ necessary to gain access to the bootstrap OS.
 """
 
 
+import binascii
 import re
 import stat
 import sys
@@ -47,6 +48,18 @@ class P2VError(Exception):
   pass
 
 
+class AskAddPolicy(paramiko.AutoAddPolicy):
+  """Policy that asks the user to confirm a key before adding it."""
+  def missing_host_key(self, client, hostname, key):
+    print "Target has ssh host key fingerprint ",
+    print binascii.hexlify(key.get_fingerprint())
+    response = raw_input("Is this correct? y/N: ")
+    if response.lower() == "y":
+      super(AskAddPolicy, self).missing_host_key(client, hostname, key)
+    else:
+      raise paramiko.SSHException("Incorrect host key for %s" % hostname)
+
+
 def ParseOptions(argv):
   usage = "Usage: %prog [options] root_dev target_host private_key"
 
@@ -123,8 +136,14 @@ def EstablishConnection(user, host, key):
   DisplayCommandStart("Connecting to instance...")
 
   client = paramiko.SSHClient()
-  client.set_missing_host_key_policy(paramiko.WarningPolicy())
-  client.load_system_host_keys()
+  client.set_missing_host_key_policy(AskAddPolicy())
+  known_hosts_filename = os.path.expanduser("~/.ssh/known_hosts")
+  try:
+    # Load from the known_hosts file. Additional keys will be saved back there.
+    client.load_host_keys(known_hosts_filename)
+  except IOError:
+    pass
+
   try:
     client.connect(host, username=user, pkey=key,
                    allow_agent=False, look_for_keys=False)
diff --git a/p2v-transfer/test/p2v_transfer_test.py 
b/p2v-transfer/test/p2v_transfer_test.py
index 4de7983..1b372a8 100755
--- a/p2v-transfer/test/p2v_transfer_test.py
+++ b/p2v-transfer/test/p2v_transfer_test.py
@@ -414,9 +414,14 @@ EOF
   def testEstablishConnectionCreatesClient(self):
     self.mox.StubOutWithMock(self.module.paramiko, "SSHClient",
                              use_mock_anything=True)
+    self.mox.StubOutWithMock(self.module.os.path, 'expanduser')
+
+    known_hosts = "/home/%s/.ssh/known_hosts" % self.user
+
     self.module.paramiko.SSHClient().AndReturn(self.client)
-    self.client.set_missing_host_key_policy(mox.IsA(paramiko.WarningPolicy))
-    self.client.load_system_host_keys()
+    self.module.os.path.expanduser("~/.ssh/known_hosts").AndReturn(known_hosts)
+    self.client.set_missing_host_key_policy(mox.IsA(self.module.AskAddPolicy))
+    self.client.load_host_keys(known_hosts)
     self.client.connect(self.host, username=self.user, pkey=self.pkey,
                         allow_agent=False, look_for_keys=False)
 
-- 
1.7.3.1

Reply via email to