This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 76628aebdf Consolidate hook management in SSHOperator (#34428)
76628aebdf is described below
commit 76628aebdf5d41daad433565862e523698e86f94
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Sep 18 21:06:12 2023 +0200
Consolidate hook management in SSHOperator (#34428)
* Consolidate hook management in SSHOperator
* use AirflowProviderDeprecationWarning
---
airflow/providers/ssh/operators/ssh.py | 66 ++++++++++++++++---------------
tests/providers/ssh/operators/test_ssh.py | 7 ++--
2 files changed, 38 insertions(+), 35 deletions(-)
diff --git a/airflow/providers/ssh/operators/ssh.py
b/airflow/providers/ssh/operators/ssh.py
index b79202c339..6bb9c8663e 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -19,18 +19,20 @@ from __future__ import annotations
import warnings
from base64 import b64encode
+from functools import cached_property
from typing import TYPE_CHECKING, Sequence
+from deprecated.classic import deprecated
+
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
+from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from paramiko.client import SSHClient
- from airflow.providers.ssh.hooks.ssh import SSHHook
-
class SSHOperator(BaseOperator):
"""
@@ -92,7 +94,10 @@ class SSHOperator(BaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.ssh_hook = ssh_hook
+ if ssh_hook and isinstance(ssh_hook, SSHHook):
+ self.ssh_hook = ssh_hook
+ if remote_host is not None:
+ self.ssh_hook.remote_host = remote_host
self.ssh_conn_id = ssh_conn_id
self.remote_host = remote_host
self.command = command
@@ -102,38 +107,39 @@ class SSHOperator(BaseOperator):
self.get_pty = get_pty
self.banner_timeout = banner_timeout
- def get_hook(self) -> SSHHook:
- from airflow.providers.ssh.hooks.ssh import SSHHook
-
+ @cached_property
+ def ssh_hook(self) -> SSHHook:
+ """Create SSHHook to run commands on remote host."""
if self.ssh_conn_id:
- if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
- self.log.info("ssh_conn_id is ignored when ssh_hook is
provided.")
- else:
- self.log.info("ssh_hook is not provided or invalid. Trying
ssh_conn_id to create SSHHook.")
- self.ssh_hook = SSHHook(
- ssh_conn_id=self.ssh_conn_id,
- conn_timeout=self.conn_timeout,
- cmd_timeout=self.cmd_timeout,
- banner_timeout=self.banner_timeout,
+ self.log.info("ssh_hook is not provided or invalid. Trying
ssh_conn_id to create SSHHook.")
+ hook = SSHHook(
+ ssh_conn_id=self.ssh_conn_id,
+ conn_timeout=self.conn_timeout,
+ cmd_timeout=self.cmd_timeout,
+ banner_timeout=self.banner_timeout,
+ )
+ if self.remote_host is not None:
+ self.log.info(
+ "remote_host is provided explicitly. "
+ "It will replace the remote_host which was defined "
+ "in ssh_hook or predefined in connection of ssh_conn_id."
)
+ hook.remote_host = self.remote_host
+ return hook
+ raise AirflowException("Cannot operate without ssh_hook or
ssh_conn_id.")
- if not self.ssh_hook:
- raise AirflowException("Cannot operate without ssh_hook or
ssh_conn_id.")
-
- if self.remote_host is not None:
- self.log.info(
- "remote_host is provided explicitly. "
- "It will replace the remote_host which was defined "
- "in ssh_hook or predefined in connection of ssh_conn_id."
- )
- self.ssh_hook.remote_host = self.remote_host
+ @property
+ def hook(self) -> SSHHook:
+ return self.ssh_hook
+ @deprecated(reason="use `hook` property instead.",
category=AirflowProviderDeprecationWarning)
+ def get_hook(self) -> SSHHook:
return self.ssh_hook
def get_ssh_client(self) -> SSHClient:
# Remember to use context manager or call .close() on this when done
self.log.info("Creating ssh_client")
- return self.get_hook().get_conn()
+ return self.hook.get_conn()
def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
warnings.warn(
@@ -141,8 +147,7 @@ class SSHOperator(BaseOperator):
"`ssh_hook.exec_ssh_client_command` instead",
AirflowProviderDeprecationWarning,
)
- assert self.ssh_hook
- return self.ssh_hook.exec_ssh_client_command(
+ return self.hook.exec_ssh_client_command(
ssh_client, command, timeout=self.cmd_timeout,
environment=self.environment, get_pty=self.get_pty
)
@@ -154,8 +159,7 @@ class SSHOperator(BaseOperator):
raise AirflowException(f"SSH operator error: exit status =
{exit_status}")
def run_ssh_client_command(self, ssh_client: SSHClient, command: str,
context=None) -> bytes:
- assert self.ssh_hook
- exit_status, agg_stdout, agg_stderr =
self.ssh_hook.exec_ssh_client_command(
+ exit_status, agg_stdout, agg_stderr =
self.hook.exec_ssh_client_command(
ssh_client, command, timeout=self.cmd_timeout,
environment=self.environment, get_pty=self.get_pty
)
self.raise_for_status(exit_status, agg_stderr, context=context)
@@ -178,5 +182,5 @@ class SSHOperator(BaseOperator):
def tunnel(self) -> None:
"""Get ssh tunnel."""
- ssh_client = self.ssh_hook.get_conn() # type: ignore[union-attr]
+ ssh_client = self.hook.get_conn() # type: ignore[union-attr]
ssh_client.get_transport()
diff --git a/tests/providers/ssh/operators/test_ssh.py
b/tests/providers/ssh/operators/test_ssh.py
index 401dc38dd5..3c10d71bdc 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -88,10 +88,9 @@ class TestSSHOperator:
cmd_timeout=cmd_timeout,
ssh_conn_id="ssh_default",
)
- ssh_hook = task.get_hook()
- assert conn_timeout == ssh_hook.conn_timeout
- assert cmd_timeout_expected == ssh_hook.cmd_timeout
- assert "ssh_default" == ssh_hook.ssh_conn_id
+ assert conn_timeout == task.hook.conn_timeout
+ assert cmd_timeout_expected == task.hook.cmd_timeout
+ assert "ssh_default" == task.hook.ssh_conn_id
@pytest.mark.parametrize(
("enable_xcom_pickling", "output", "expected"),