This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d404a1438f SSHHook: check if existing connection is still alive
(#41061)
d404a1438f is described below
commit d404a1438fb2680fa3e9f48fb10ae5d68fa738f5
Author: Maxim Martynov <[email protected]>
AuthorDate: Wed Aug 21 14:22:12 2024 +0300
SSHHook: check if existing connection is still alive (#41061)
---
airflow/providers/ssh/hooks/ssh.py | 159 ++++++++++++++++-----------------
airflow/providers/ssh/operators/ssh.py | 1 -
tests/providers/ssh/hooks/test_ssh.py | 23 +++++
3 files changed, 102 insertions(+), 81 deletions(-)
diff --git a/airflow/providers/ssh/hooks/ssh.py
b/airflow/providers/ssh/hooks/ssh.py
index fac93cf262..e5ce9ec27c 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -286,94 +286,93 @@ class SSHHook(BaseHook):
def get_conn(self) -> paramiko.SSHClient:
"""Establish an SSH connection to the remote host."""
- if self.client is None:
- self.log.debug("Creating SSH client for conn_id: %s",
self.ssh_conn_id)
- client = paramiko.SSHClient()
-
- if self.allow_host_key_change:
- self.log.warning(
- "Remote Identification Change is not verified. "
- "This won't protect against Man-In-The-Middle attacks"
- )
- # to avoid BadHostKeyException, skip loading host keys
-
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
+ if self.client:
+ transport = self.client.get_transport()
+ if transport and transport.is_active():
+ # Return the existing connection
+ return self.client
+
+ self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
+ client = paramiko.SSHClient()
+
+ if self.allow_host_key_change:
+ self.log.warning(
+ "Remote Identification Change is not verified. "
+ "This won't protect against Man-In-The-Middle attacks"
+ )
+ # to avoid BadHostKeyException, skip loading host keys
+ client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
+ else:
+ client.load_system_host_keys()
+
+ if self.no_host_key_check:
+ self.log.warning("No Host Key Verification. This won't protect
against Man-In-The-Middle attacks")
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) #
nosec B507
+ # to avoid BadHostKeyException, skip loading and saving host keys
+ known_hosts = os.path.expanduser("~/.ssh/known_hosts")
+ if not self.allow_host_key_change and os.path.isfile(known_hosts):
+ client.load_host_keys(known_hosts)
+
+ elif self.host_key is not None:
+ # Get host key from connection extra if it not set or None then we
fallback to system host keys
+ client_host_keys = client.get_host_keys()
+ if self.port == SSH_PORT:
+ client_host_keys.add(self.remote_host,
self.host_key.get_name(), self.host_key)
else:
- client.load_system_host_keys()
-
- if self.no_host_key_check:
- self.log.warning(
- "No Host Key Verification. This won't protect against
Man-In-The-Middle attacks"
+ client_host_keys.add(
+ f"[{self.remote_host}]:{self.port}",
self.host_key.get_name(), self.host_key
)
- client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# nosec B507
- # to avoid BadHostKeyException, skip loading and saving host
keys
- known_hosts = os.path.expanduser("~/.ssh/known_hosts")
- if not self.allow_host_key_change and
os.path.isfile(known_hosts):
- client.load_host_keys(known_hosts)
-
- elif self.host_key is not None:
- # Get host key from connection extra if it not set or None
then we fallback to system host keys
- client_host_keys = client.get_host_keys()
- if self.port == SSH_PORT:
- client_host_keys.add(self.remote_host,
self.host_key.get_name(), self.host_key)
- else:
- client_host_keys.add(
- f"[{self.remote_host}]:{self.port}",
self.host_key.get_name(), self.host_key
- )
- connect_kwargs: dict[str, Any] = {
- "hostname": self.remote_host,
- "username": self.username,
- "timeout": self.conn_timeout,
- "compress": self.compress,
- "port": self.port,
- "sock": self.host_proxy,
- "look_for_keys": self.look_for_keys,
- "banner_timeout": self.banner_timeout,
- }
-
- if self.password:
- password = self.password.strip()
- connect_kwargs.update(password=password)
-
- if self.pkey:
- connect_kwargs.update(pkey=self.pkey)
-
- if self.key_file:
- connect_kwargs.update(key_filename=self.key_file)
-
- if self.disabled_algorithms:
-
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)
-
- def log_before_sleep(retry_state):
- return self.log.info(
- "Failed to connect. Sleeping before retry attempt %d",
retry_state.attempt_number
- )
+ connect_kwargs: dict[str, Any] = {
+ "hostname": self.remote_host,
+ "username": self.username,
+ "timeout": self.conn_timeout,
+ "compress": self.compress,
+ "port": self.port,
+ "sock": self.host_proxy,
+ "look_for_keys": self.look_for_keys,
+ "banner_timeout": self.banner_timeout,
+ }
- for attempt in Retrying(
- reraise=True,
- wait=wait_fixed(3) + wait_random(0, 2),
- stop=stop_after_attempt(3),
- before_sleep=log_before_sleep,
- ):
- with attempt:
- client.connect(**connect_kwargs)
+ if self.password:
+ password = self.password.strip()
+ connect_kwargs.update(password=password)
- if self.keepalive_interval:
- # MyPy check ignored because "paramiko" isn't well-typed. The
`client.get_transport()` returns
- # type "Transport | None" and item "None" has no attribute
"set_keepalive".
- client.get_transport().set_keepalive(self.keepalive_interval)
# type: ignore[union-attr]
+ if self.pkey:
+ connect_kwargs.update(pkey=self.pkey)
- if self.ciphers:
- # MyPy check ignored because "paramiko" isn't well-typed. The
`client.get_transport()` returns
- # type "Transport | None" and item "None" has no method
`get_security_options`".
- client.get_transport().get_security_options().ciphers =
self.ciphers # type: ignore[union-attr]
+ if self.key_file:
+ connect_kwargs.update(key_filename=self.key_file)
- self.client = client
- return client
+ if self.disabled_algorithms:
+ connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)
- else:
- # Return the existing connection
- return self.client
+ def log_before_sleep(retry_state):
+ return self.log.info(
+ "Failed to connect. Sleeping before retry attempt %d",
retry_state.attempt_number
+ )
+
+ for attempt in Retrying(
+ reraise=True,
+ wait=wait_fixed(3) + wait_random(0, 2),
+ stop=stop_after_attempt(3),
+ before_sleep=log_before_sleep,
+ ):
+ with attempt:
+ client.connect(**connect_kwargs)
+
+ if self.keepalive_interval:
+ # MyPy check ignored because "paramiko" isn't well-typed. The
`client.get_transport()` returns
+ # type "Transport | None" and item "None" has no attribute
"set_keepalive".
+ client.get_transport().set_keepalive(self.keepalive_interval) #
type: ignore[union-attr]
+
+ if self.ciphers:
+ # MyPy check ignored because "paramiko" isn't well-typed. The
`client.get_transport()` returns
+ # type "Transport | None" and item "None" has no method
`get_security_options`".
+ client.get_transport().get_security_options().ciphers =
self.ciphers # type: ignore[union-attr]
+
+ self.client = client
+ return client
@deprecated(
reason=(
diff --git a/airflow/providers/ssh/operators/ssh.py
b/airflow/providers/ssh/operators/ssh.py
index 409078efe7..9847614eaf 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -148,7 +148,6 @@ class SSHOperator(BaseOperator):
def get_ssh_client(self) -> SSHClient:
# Remember to use context manager or call .close() on this when done
- self.log.info("Creating ssh_client")
return self.hook.get_conn()
@deprecated(
diff --git a/tests/providers/ssh/hooks/test_ssh.py
b/tests/providers/ssh/hooks/test_ssh.py
index 24eb275def..71661e5b4e 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -1092,3 +1092,26 @@ class TestSSHHook:
status, msg = hook.test_connection()
assert status is False
assert msg == "Test failure case"
+
+ def test_ssh_connection_client_is_reused_if_open(self):
+ hook = SSHHook(ssh_conn_id="ssh_default")
+ client1 = hook.get_conn()
+ client2 = hook.get_conn()
+ assert client1 is client2
+ assert client2.get_transport().is_active()
+
+ def test_ssh_connection_client_is_recreated_if_closed(self):
+ hook = SSHHook(ssh_conn_id="ssh_default")
+ client1 = hook.get_conn()
+ client1.close()
+ client2 = hook.get_conn()
+ assert client1 is not client2
+ assert client2.get_transport().is_active()
+
+ def test_ssh_connection_client_is_recreated_if_transport_closed(self):
+ hook = SSHHook(ssh_conn_id="ssh_default")
+ client1 = hook.get_conn()
+ client1.get_transport().close()
+ client2 = hook.get_conn()
+ assert client1 is not client2
+ assert client2.get_transport().is_active()