This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d404a1438f SSHHook: check if existing connection is still alive 
(#41061)
d404a1438f is described below

commit d404a1438fb2680fa3e9f48fb10ae5d68fa738f5
Author: Maxim Martynov <[email protected]>
AuthorDate: Wed Aug 21 14:22:12 2024 +0300

    SSHHook: check if existing connection is still alive (#41061)
---
 airflow/providers/ssh/hooks/ssh.py     | 159 ++++++++++++++++-----------------
 airflow/providers/ssh/operators/ssh.py |   1 -
 tests/providers/ssh/hooks/test_ssh.py  |  23 +++++
 3 files changed, 102 insertions(+), 81 deletions(-)

diff --git a/airflow/providers/ssh/hooks/ssh.py 
b/airflow/providers/ssh/hooks/ssh.py
index fac93cf262..e5ce9ec27c 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -286,94 +286,93 @@ class SSHHook(BaseHook):
 
     def get_conn(self) -> paramiko.SSHClient:
         """Establish an SSH connection to the remote host."""
-        if self.client is None:
-            self.log.debug("Creating SSH client for conn_id: %s", 
self.ssh_conn_id)
-            client = paramiko.SSHClient()
-
-            if self.allow_host_key_change:
-                self.log.warning(
-                    "Remote Identification Change is not verified. "
-                    "This won't protect against Man-In-The-Middle attacks"
-                )
-                # to avoid BadHostKeyException, skip loading host keys
-                
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
+        if self.client:
+            transport = self.client.get_transport()
+            if transport and transport.is_active():
+                # Return the existing connection
+                return self.client
+
+        self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
+        client = paramiko.SSHClient()
+
+        if self.allow_host_key_change:
+            self.log.warning(
+                "Remote Identification Change is not verified. "
+                "This won't protect against Man-In-The-Middle attacks"
+            )
+            # to avoid BadHostKeyException, skip loading host keys
+            client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
+        else:
+            client.load_system_host_keys()
+
+        if self.no_host_key_check:
+            self.log.warning("No Host Key Verification. This won't protect 
against Man-In-The-Middle attacks")
+            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())  # 
nosec B507
+            # to avoid BadHostKeyException, skip loading and saving host keys
+            known_hosts = os.path.expanduser("~/.ssh/known_hosts")
+            if not self.allow_host_key_change and os.path.isfile(known_hosts):
+                client.load_host_keys(known_hosts)
+
+        elif self.host_key is not None:
+            # Get host key from connection extra if it not set or None then we 
fallback to system host keys
+            client_host_keys = client.get_host_keys()
+            if self.port == SSH_PORT:
+                client_host_keys.add(self.remote_host, 
self.host_key.get_name(), self.host_key)
             else:
-                client.load_system_host_keys()
-
-            if self.no_host_key_check:
-                self.log.warning(
-                    "No Host Key Verification. This won't protect against 
Man-In-The-Middle attacks"
+                client_host_keys.add(
+                    f"[{self.remote_host}]:{self.port}", 
self.host_key.get_name(), self.host_key
                 )
-                client.set_missing_host_key_policy(paramiko.AutoAddPolicy())  
# nosec B507
-                # to avoid BadHostKeyException, skip loading and saving host 
keys
-                known_hosts = os.path.expanduser("~/.ssh/known_hosts")
-                if not self.allow_host_key_change and 
os.path.isfile(known_hosts):
-                    client.load_host_keys(known_hosts)
-
-            elif self.host_key is not None:
-                # Get host key from connection extra if it not set or None 
then we fallback to system host keys
-                client_host_keys = client.get_host_keys()
-                if self.port == SSH_PORT:
-                    client_host_keys.add(self.remote_host, 
self.host_key.get_name(), self.host_key)
-                else:
-                    client_host_keys.add(
-                        f"[{self.remote_host}]:{self.port}", 
self.host_key.get_name(), self.host_key
-                    )
 
-            connect_kwargs: dict[str, Any] = {
-                "hostname": self.remote_host,
-                "username": self.username,
-                "timeout": self.conn_timeout,
-                "compress": self.compress,
-                "port": self.port,
-                "sock": self.host_proxy,
-                "look_for_keys": self.look_for_keys,
-                "banner_timeout": self.banner_timeout,
-            }
-
-            if self.password:
-                password = self.password.strip()
-                connect_kwargs.update(password=password)
-
-            if self.pkey:
-                connect_kwargs.update(pkey=self.pkey)
-
-            if self.key_file:
-                connect_kwargs.update(key_filename=self.key_file)
-
-            if self.disabled_algorithms:
-                
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)
-
-            def log_before_sleep(retry_state):
-                return self.log.info(
-                    "Failed to connect. Sleeping before retry attempt %d", 
retry_state.attempt_number
-                )
+        connect_kwargs: dict[str, Any] = {
+            "hostname": self.remote_host,
+            "username": self.username,
+            "timeout": self.conn_timeout,
+            "compress": self.compress,
+            "port": self.port,
+            "sock": self.host_proxy,
+            "look_for_keys": self.look_for_keys,
+            "banner_timeout": self.banner_timeout,
+        }
 
-            for attempt in Retrying(
-                reraise=True,
-                wait=wait_fixed(3) + wait_random(0, 2),
-                stop=stop_after_attempt(3),
-                before_sleep=log_before_sleep,
-            ):
-                with attempt:
-                    client.connect(**connect_kwargs)
+        if self.password:
+            password = self.password.strip()
+            connect_kwargs.update(password=password)
 
-            if self.keepalive_interval:
-                # MyPy check ignored because "paramiko" isn't well-typed. The 
`client.get_transport()` returns
-                # type "Transport | None" and item "None" has no attribute 
"set_keepalive".
-                client.get_transport().set_keepalive(self.keepalive_interval)  
# type: ignore[union-attr]
+        if self.pkey:
+            connect_kwargs.update(pkey=self.pkey)
 
-            if self.ciphers:
-                # MyPy check ignored because "paramiko" isn't well-typed. The 
`client.get_transport()` returns
-                # type "Transport | None" and item "None" has no method 
`get_security_options`".
-                client.get_transport().get_security_options().ciphers = 
self.ciphers  # type: ignore[union-attr]
+        if self.key_file:
+            connect_kwargs.update(key_filename=self.key_file)
 
-            self.client = client
-            return client
+        if self.disabled_algorithms:
+            connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)
 
-        else:
-            # Return the existing connection
-            return self.client
+        def log_before_sleep(retry_state):
+            return self.log.info(
+                "Failed to connect. Sleeping before retry attempt %d", 
retry_state.attempt_number
+            )
+
+        for attempt in Retrying(
+            reraise=True,
+            wait=wait_fixed(3) + wait_random(0, 2),
+            stop=stop_after_attempt(3),
+            before_sleep=log_before_sleep,
+        ):
+            with attempt:
+                client.connect(**connect_kwargs)
+
+        if self.keepalive_interval:
+            # MyPy check ignored because "paramiko" isn't well-typed. The 
`client.get_transport()` returns
+            # type "Transport | None" and item "None" has no attribute 
"set_keepalive".
+            client.get_transport().set_keepalive(self.keepalive_interval)  # 
type: ignore[union-attr]
+
+        if self.ciphers:
+            # MyPy check ignored because "paramiko" isn't well-typed. The 
`client.get_transport()` returns
+            # type "Transport | None" and item "None" has no method 
`get_security_options`".
+            client.get_transport().get_security_options().ciphers = 
self.ciphers  # type: ignore[union-attr]
+
+        self.client = client
+        return client
 
     @deprecated(
         reason=(
diff --git a/airflow/providers/ssh/operators/ssh.py 
b/airflow/providers/ssh/operators/ssh.py
index 409078efe7..9847614eaf 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -148,7 +148,6 @@ class SSHOperator(BaseOperator):
 
     def get_ssh_client(self) -> SSHClient:
         # Remember to use context manager or call .close() on this when done
-        self.log.info("Creating ssh_client")
         return self.hook.get_conn()
 
     @deprecated(
diff --git a/tests/providers/ssh/hooks/test_ssh.py 
b/tests/providers/ssh/hooks/test_ssh.py
index 24eb275def..71661e5b4e 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -1092,3 +1092,26 @@ class TestSSHHook:
         status, msg = hook.test_connection()
         assert status is False
         assert msg == "Test failure case"
+
+    def test_ssh_connection_client_is_reused_if_open(self):
+        hook = SSHHook(ssh_conn_id="ssh_default")
+        client1 = hook.get_conn()
+        client2 = hook.get_conn()
+        assert client1 is client2
+        assert client2.get_transport().is_active()
+
+    def test_ssh_connection_client_is_recreated_if_closed(self):
+        hook = SSHHook(ssh_conn_id="ssh_default")
+        client1 = hook.get_conn()
+        client1.close()
+        client2 = hook.get_conn()
+        assert client1 is not client2
+        assert client2.get_transport().is_active()
+
+    def test_ssh_connection_client_is_recreated_if_transport_closed(self):
+        hook = SSHHook(ssh_conn_id="ssh_default")
+        client1 = hook.get_conn()
+        client1.get_transport().close()
+        client2 = hook.get_conn()
+        assert client1 is not client2
+        assert client2.get_transport().is_active()

Reply via email to