This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 76628aebdf Consolidate hook management in SSHOperator (#34428)
76628aebdf is described below

commit 76628aebdf5d41daad433565862e523698e86f94
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Sep 18 21:06:12 2023 +0200

    Consolidate hook management in SSHOperator (#34428)
    
    * Consolidate hook management in SSHOperator
    
    * use AirflowProviderDeprecationWarning
---
 airflow/providers/ssh/operators/ssh.py    | 66 ++++++++++++++++---------------
 tests/providers/ssh/operators/test_ssh.py |  7 ++--
 2 files changed, 38 insertions(+), 35 deletions(-)

diff --git a/airflow/providers/ssh/operators/ssh.py 
b/airflow/providers/ssh/operators/ssh.py
index b79202c339..6bb9c8663e 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -19,18 +19,20 @@ from __future__ import annotations
 
 import warnings
 from base64 import b64encode
+from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
+from deprecated.classic import deprecated
+
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
+from airflow.providers.ssh.hooks.ssh import SSHHook
 from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from paramiko.client import SSHClient
 
-    from airflow.providers.ssh.hooks.ssh import SSHHook
-
 
 class SSHOperator(BaseOperator):
     """
@@ -92,7 +94,10 @@ class SSHOperator(BaseOperator):
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        self.ssh_hook = ssh_hook
+        if ssh_hook and isinstance(ssh_hook, SSHHook):
+            self.ssh_hook = ssh_hook
+            if remote_host is not None:
+                self.ssh_hook.remote_host = remote_host
         self.ssh_conn_id = ssh_conn_id
         self.remote_host = remote_host
         self.command = command
@@ -102,38 +107,39 @@ class SSHOperator(BaseOperator):
         self.get_pty = get_pty
         self.banner_timeout = banner_timeout
 
-    def get_hook(self) -> SSHHook:
-        from airflow.providers.ssh.hooks.ssh import SSHHook
-
+    @cached_property
+    def ssh_hook(self) -> SSHHook:
+        """Create SSHHook to run commands on remote host."""
         if self.ssh_conn_id:
-            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
-                self.log.info("ssh_conn_id is ignored when ssh_hook is 
provided.")
-            else:
-                self.log.info("ssh_hook is not provided or invalid. Trying 
ssh_conn_id to create SSHHook.")
-                self.ssh_hook = SSHHook(
-                    ssh_conn_id=self.ssh_conn_id,
-                    conn_timeout=self.conn_timeout,
-                    cmd_timeout=self.cmd_timeout,
-                    banner_timeout=self.banner_timeout,
+            self.log.info("ssh_hook is not provided or invalid. Trying 
ssh_conn_id to create SSHHook.")
+            hook = SSHHook(
+                ssh_conn_id=self.ssh_conn_id,
+                conn_timeout=self.conn_timeout,
+                cmd_timeout=self.cmd_timeout,
+                banner_timeout=self.banner_timeout,
+            )
+            if self.remote_host is not None:
+                self.log.info(
+                    "remote_host is provided explicitly. "
+                    "It will replace the remote_host which was defined "
+                    "in ssh_hook or predefined in connection of ssh_conn_id."
                 )
+                hook.remote_host = self.remote_host
+            return hook
+        raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
 
-        if not self.ssh_hook:
-            raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
-
-        if self.remote_host is not None:
-            self.log.info(
-                "remote_host is provided explicitly. "
-                "It will replace the remote_host which was defined "
-                "in ssh_hook or predefined in connection of ssh_conn_id."
-            )
-            self.ssh_hook.remote_host = self.remote_host
+    @property
+    def hook(self) -> SSHHook:
+        return self.ssh_hook
 
+    @deprecated(reason="use `hook` property instead.", 
category=AirflowProviderDeprecationWarning)
+    def get_hook(self) -> SSHHook:
         return self.ssh_hook
 
     def get_ssh_client(self) -> SSHClient:
         # Remember to use context manager or call .close() on this when done
         self.log.info("Creating ssh_client")
-        return self.get_hook().get_conn()
+        return self.hook.get_conn()
 
     def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
         warnings.warn(
@@ -141,8 +147,7 @@ class SSHOperator(BaseOperator):
             "`ssh_hook.exec_ssh_client_command` instead",
             AirflowProviderDeprecationWarning,
         )
-        assert self.ssh_hook
-        return self.ssh_hook.exec_ssh_client_command(
+        return self.hook.exec_ssh_client_command(
             ssh_client, command, timeout=self.cmd_timeout, 
environment=self.environment, get_pty=self.get_pty
         )
 
@@ -154,8 +159,7 @@ class SSHOperator(BaseOperator):
             raise AirflowException(f"SSH operator error: exit status = 
{exit_status}")
 
     def run_ssh_client_command(self, ssh_client: SSHClient, command: str, 
context=None) -> bytes:
-        assert self.ssh_hook
-        exit_status, agg_stdout, agg_stderr = 
self.ssh_hook.exec_ssh_client_command(
+        exit_status, agg_stdout, agg_stderr = 
self.hook.exec_ssh_client_command(
             ssh_client, command, timeout=self.cmd_timeout, 
environment=self.environment, get_pty=self.get_pty
         )
         self.raise_for_status(exit_status, agg_stderr, context=context)
@@ -178,5 +182,5 @@ class SSHOperator(BaseOperator):
 
     def tunnel(self) -> None:
         """Get ssh tunnel."""
-        ssh_client = self.ssh_hook.get_conn()  # type: ignore[union-attr]
+        ssh_client = self.hook.get_conn()  # type: ignore[union-attr]
         ssh_client.get_transport()
diff --git a/tests/providers/ssh/operators/test_ssh.py 
b/tests/providers/ssh/operators/test_ssh.py
index 401dc38dd5..3c10d71bdc 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -88,10 +88,9 @@ class TestSSHOperator:
                 cmd_timeout=cmd_timeout,
                 ssh_conn_id="ssh_default",
             )
-        ssh_hook = task.get_hook()
-        assert conn_timeout == ssh_hook.conn_timeout
-        assert cmd_timeout_expected == ssh_hook.cmd_timeout
-        assert "ssh_default" == ssh_hook.ssh_conn_id
+        assert conn_timeout == task.hook.conn_timeout
+        assert cmd_timeout_expected == task.hook.cmd_timeout
+        assert "ssh_default" == task.hook.ssh_conn_id
 
     @pytest.mark.parametrize(
         ("enable_xcom_pickling", "output", "expected"),

Reply via email to