This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73fcbb0  Refactor SSHOperator so a subclass can run many commands 
(#10874) (#17378)
73fcbb0 is described below

commit 73fcbb0e4e151c9965fd69ba08de59462bbbe6dc
Author: Bjorn Olsen <bjorn.ols...@gmail.com>
AuthorDate: Wed Oct 13 22:14:54 2021 +0200

    Refactor SSHOperator so a subclass can run many commands (#10874) (#17378)
---
 airflow/providers/ssh/operators/ssh.py    | 204 ++++++++++++++++--------------
 tests/providers/ssh/operators/test_ssh.py | 112 ++++++++++++++--
 2 files changed, 211 insertions(+), 105 deletions(-)

diff --git a/airflow/providers/ssh/operators/ssh.py 
b/airflow/providers/ssh/operators/ssh.py
index 7f97b03..300e155 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -19,7 +19,9 @@
 import warnings
 from base64 import b64encode
 from select import select
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
+
+from paramiko.client import SSHClient
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -107,103 +109,115 @@ class SSHOperator(BaseOperator):
                 stacklevel=1,
             )
 
-    def execute(self, context) -> Union[bytes, str, bool]:
+    def get_hook(self) -> SSHHook:
+        if self.ssh_conn_id:
+            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
+                self.log.info("ssh_conn_id is ignored when ssh_hook is 
provided.")
+            else:
+                self.log.info("ssh_hook is not provided or invalid. Trying 
ssh_conn_id to create SSHHook.")
+                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, 
conn_timeout=self.conn_timeout)
+
+        if not self.ssh_hook:
+            raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
+
+        if self.remote_host is not None:
+            self.log.info(
+                "remote_host is provided explicitly. "
+                "It will replace the remote_host which was defined "
+                "in ssh_hook or predefined in connection of ssh_conn_id."
+            )
+            self.ssh_hook.remote_host = self.remote_host
+
+        return self.ssh_hook
+
+    def get_ssh_client(self) -> SSHClient:
+        # Remember to use context manager or call .close() on this when done
+        self.log.info('Creating ssh_client')
+        return self.get_hook().get_conn()
+
+    def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> 
Tuple[int, bytes, bytes]:
+        self.log.info("Running command: %s", command)
+
+        # set timeout taken as params
+        stdin, stdout, stderr = ssh_client.exec_command(
+            command=command,
+            get_pty=self.get_pty,
+            timeout=self.timeout,
+            environment=self.environment,
+        )
+        # get channels
+        channel = stdout.channel
+
+        # closing stdin
+        stdin.close()
+        channel.shutdown_write()
+
+        agg_stdout = b''
+        agg_stderr = b''
+
+        # capture any initial output in case channel is closed already
+        stdout_buffer_length = len(stdout.channel.in_buffer)
+
+        if stdout_buffer_length > 0:
+            agg_stdout += stdout.channel.recv(stdout_buffer_length)
+
+        # read from both stdout and stderr
+        while not channel.closed or channel.recv_ready() or 
channel.recv_stderr_ready():
+            readq, _, _ = select([channel], [], [], self.cmd_timeout)
+            for recv in readq:
+                if recv.recv_ready():
+                    line = stdout.channel.recv(len(recv.in_buffer))
+                    agg_stdout += line
+                    self.log.info(line.decode('utf-8', 'replace').strip('\n'))
+                if recv.recv_stderr_ready():
+                    line = 
stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
+                    agg_stderr += line
+                    self.log.warning(line.decode('utf-8', 
'replace').strip('\n'))
+            if (
+                stdout.channel.exit_status_ready()
+                and not stderr.channel.recv_stderr_ready()
+                and not stdout.channel.recv_ready()
+            ):
+                stdout.channel.shutdown_read()
+                try:
+                    stdout.channel.close()
+                except Exception:
+                    # there is a race that when shutdown_read has been called 
and when
+                    # you try to close the connection, the socket is already 
closed
+                    # We should ignore such errors (but we should log them 
with warning)
+                    self.log.warning("Ignoring exception on close", 
exc_info=True)
+                break
+
+        stdout.close()
+        stderr.close()
+
+        exit_status = stdout.channel.recv_exit_status()
+
+        return exit_status, agg_stdout, agg_stderr
+
+    def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
+        if exit_status != 0:
+            error_msg = stderr.decode('utf-8')
+            raise AirflowException(f"error running cmd: {self.command}, error: 
{error_msg}")
+
+    def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> 
bytes:
+        exit_status, agg_stdout, agg_stderr = 
self.exec_ssh_client_command(ssh_client, command)
+        self.raise_for_status(exit_status, agg_stderr)
+        return agg_stdout
+
+    def execute(self, context=None) -> Union[bytes, str]:
+        result = None
+        if self.command is None:
+            raise AirflowException("SSH operator error: SSH command not 
specified. Aborting.")
         try:
-            if self.ssh_conn_id:
-                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
-                    self.log.info("ssh_conn_id is ignored when ssh_hook is 
provided.")
-                else:
-                    self.log.info(
-                        "ssh_hook is not provided or invalid. Trying 
ssh_conn_id to create SSHHook."
-                    )
-                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, 
conn_timeout=self.conn_timeout)
-
-            if not self.ssh_hook:
-                raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
-
-            if self.remote_host is not None:
-                self.log.info(
-                    "remote_host is provided explicitly. "
-                    "It will replace the remote_host which was defined "
-                    "in ssh_hook or predefined in connection of ssh_conn_id."
-                )
-                self.ssh_hook.remote_host = self.remote_host
-
-            if not self.command:
-                raise AirflowException("SSH command not specified. Aborting.")
-
-            with self.ssh_hook.get_conn() as ssh_client:
-                self.log.info("Running command: %s", self.command)
-
-                # set timeout taken as params
-                stdin, stdout, stderr = ssh_client.exec_command(
-                    command=self.command,
-                    get_pty=self.get_pty,
-                    timeout=self.cmd_timeout,
-                    environment=self.environment,
-                )
-                # get channels
-                channel = stdout.channel
-
-                # closing stdin
-                stdin.close()
-                channel.shutdown_write()
-
-                agg_stdout = b''
-                agg_stderr = b''
-
-                # capture any initial output in case channel is closed already
-                stdout_buffer_length = len(stdout.channel.in_buffer)
-
-                if stdout_buffer_length > 0:
-                    agg_stdout += stdout.channel.recv(stdout_buffer_length)
-
-                # read from both stdout and stderr
-                while not channel.closed or channel.recv_ready() or 
channel.recv_stderr_ready():
-                    readq, _, _ = select([channel], [], [], self.cmd_timeout)
-                    for recv in readq:
-                        if recv.recv_ready():
-                            line = stdout.channel.recv(len(recv.in_buffer))
-                            agg_stdout += line
-                            self.log.info(line.decode('utf-8', 
'replace').strip('\n'))
-                        if recv.recv_stderr_ready():
-                            line = 
stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
-                            agg_stderr += line
-                            self.log.warning(line.decode('utf-8', 
'replace').strip('\n'))
-                    if (
-                        stdout.channel.exit_status_ready()
-                        and not stderr.channel.recv_stderr_ready()
-                        and not stdout.channel.recv_ready()
-                    ):
-                        stdout.channel.shutdown_read()
-                        try:
-                            stdout.channel.close()
-                        except Exception:
-                            # there is a race that when shutdown_read has been 
called and when
-                            # you try to close the connection, the socket is 
already closed
-                            # We should ignore such errors (but we should log 
them with warning)
-                            self.log.warning("Ignoring exception on close", 
exc_info=True)
-                        break
-
-                stdout.close()
-                stderr.close()
-
-                exit_status = stdout.channel.recv_exit_status()
-                if exit_status == 0:
-                    enable_pickling = conf.getboolean('core', 
'enable_xcom_pickling')
-                    if enable_pickling:
-                        return agg_stdout
-                    else:
-                        return b64encode(agg_stdout).decode('utf-8')
-
-                else:
-                    error_msg = agg_stderr.decode('utf-8')
-                    raise AirflowException(f"error running cmd: 
{self.command}, error: {error_msg}")
-
+            with self.get_ssh_client() as ssh_client:
+                result = self.run_ssh_client_command(ssh_client, self.command)
         except Exception as e:
             raise AirflowException(f"SSH operator error: {str(e)}")
-
-        return True
+        enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
+        if not enable_pickling:
+            result = b64encode(result).decode('utf-8')
+        return result
 
     def tunnel(self) -> None:
         """Get ssh tunnel"""
diff --git a/tests/providers/ssh/operators/test_ssh.py 
b/tests/providers/ssh/operators/test_ssh.py
index c477416..551b64b 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -38,18 +38,28 @@ COMMAND = "echo -n airflow"
 COMMAND_WITH_SUDO = "sudo " + COMMAND
 
 
+class SSHClientSideEffect:
+    def __init__(self, hook):
+        self.hook = hook
+
+    def __call__(self):
+        self.return_value = self.hook.get_conn()
+        return self.return_value
+
+
 class TestSSHOperator:
     def setup_method(self):
         from airflow.providers.ssh.hooks.ssh import SSHHook
 
         hook = SSHHook(ssh_conn_id='ssh_default')
         hook.no_host_key_check = True
+        self.dag = DAG('ssh_test', default_args={'start_date': DEFAULT_DATE})
         self.hook = hook
 
     def test_hook_created_correctly_with_timeout(self):
         timeout = 20
         ssh_id = "ssh_default"
-        with DAG('unit_tests_ssh_test_op_arg_checking', 
default_args={'start_date': DEFAULT_DATE}):
+        with self.dag:
             task = SSHOperator(task_id="test", command=COMMAND, 
timeout=timeout, ssh_conn_id="ssh_default")
         task.execute(None)
         assert timeout == task.ssh_hook.conn_timeout
@@ -59,7 +69,7 @@ class TestSSHOperator:
         conn_timeout = 20
         cmd_timeout = 45
         ssh_id = 'ssh_default'
-        with DAG('unit_tests_ssh_test_op_arg_checking', 
default_args={'start_date': DEFAULT_DATE}):
+        with self.dag:
             task = SSHOperator(
                 task_id="test",
                 command=COMMAND,
@@ -130,10 +140,8 @@ class TestSSHOperator:
 
     @unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' + 
TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
     def test_arg_checking(self):
-        dag = DAG('unit_tests_ssh_test_op_arg_checking', 
default_args={'start_date': DEFAULT_DATE})
-
         # Exception should be raised if neither ssh_hook nor ssh_conn_id is 
provided.
-        task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, 
dag=dag)
+        task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, 
dag=self.dag)
         with pytest.raises(AirflowException, match="Cannot operate without 
ssh_hook or ssh_conn_id."):
             task_0.execute(None)
 
@@ -144,7 +152,7 @@ class TestSSHOperator:
             ssh_conn_id=TEST_CONN_ID,
             command=COMMAND,
             timeout=TIMEOUT,
-            dag=dag,
+            dag=self.dag,
         )
         try:
             task_1.execute(None)
@@ -157,7 +165,7 @@ class TestSSHOperator:
             ssh_conn_id=TEST_CONN_ID,  # No ssh_hook provided.
             command=COMMAND,
             timeout=TIMEOUT,
-            dag=dag,
+            dag=self.dag,
         )
         try:
             task_2.execute(None)
@@ -172,10 +180,26 @@ class TestSSHOperator:
             ssh_conn_id=TEST_CONN_ID,
             command=COMMAND,
             timeout=TIMEOUT,
-            dag=dag,
+            dag=self.dag,
         )
         task_3.execute(None)
         assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+        # If remote_host was specified, ensure it is used
+        task_4 = SSHOperator(
+            task_id="test_4",
+            ssh_hook=self.hook,
+            ssh_conn_id=TEST_CONN_ID,
+            command=COMMAND,
+            timeout=TIMEOUT,
+            dag=self.dag,
+            remote_host='operator_remote_host',
+        )
+        try:
+            task_4.execute(None)
+        except Exception:
+            pass
+        assert task_4.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+        assert task_4.ssh_hook.remote_host == 'operator_remote_host'
 
     @pytest.mark.parametrize(
         "command, get_pty_in, get_pty_out",
@@ -188,7 +212,6 @@ class TestSSHOperator:
         ],
     )
     def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out):
-        dag = DAG('unit_tests_ssh_test_op_arg_checking', 
default_args={'start_date': DEFAULT_DATE})
         task = SSHOperator(
             task_id="test",
             ssh_hook=self.hook,
@@ -196,7 +219,7 @@ class TestSSHOperator:
             conn_timeout=TIMEOUT,
             cmd_timeout=TIMEOUT,
             get_pty=get_pty_in,
-            dag=dag,
+            dag=self.dag,
         )
         if command is None:
             with pytest.raises(AirflowException) as ctx:
@@ -205,3 +228,72 @@ class TestSSHOperator:
         else:
             task.execute(None)
         assert task.get_pty == get_pty_out
+
+    def test_ssh_client_managed_correctly(self):
+        # Ensure ssh_client gets created once
+        # Ensure connection gets closed once
+        task = SSHOperator(
+            task_id="test",
+            ssh_hook=self.hook,
+            command="ls",
+            dag=self.dag,
+        )
+
+        se = SSHClientSideEffect(self.hook)
+        with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get, 
unittest.mock.patch(
+            'paramiko.client.SSHClient.close'
+        ) as mock_close:
+            mock_get.side_effect = se
+            task.execute()
+            mock_get.assert_called_once()
+            mock_close.assert_called_once()
+
+    def test_one_ssh_client_many_commands(self):
+        # Ensure we can run multiple commands with one client
+        many_commands = ['ls', 'date', 'pwd']
+
+        class CustomSSHOperator(SSHOperator):
+            def execute(self, context=None):
+                success = False
+                with self.get_ssh_client() as ssh_client:
+                    for c in many_commands:
+                        self.run_ssh_client_command(ssh_client, c)
+                    success = True
+                return success
+
+        task = CustomSSHOperator(task_id="test", ssh_hook=self.hook, 
dag=self.dag)
+        se = SSHClientSideEffect(self.hook)
+        with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get, 
unittest.mock.patch.object(
+            task, 'run_ssh_client_command'
+        ) as mock_run_cmd, 
unittest.mock.patch('paramiko.client.SSHClient.close') as mock_close:
+            mock_get.side_effect = se
+            task.execute()
+            mock_get.assert_called_once()
+            mock_close.assert_called_once()
+
+            ssh_client = se.return_value
+            calls = [unittest.mock.call(ssh_client, c) for c in many_commands]
+            mock_run_cmd.assert_has_calls(calls)
+
+    def test_fail_with_no_command(self):
+        # Test that run_ssh_client_command fails on no command
+        task = SSHOperator(
+            task_id="test",
+            ssh_hook=self.hook,
+            # command="ls",
+            dag=self.dag,
+        )
+        with pytest.raises(AirflowException, match="SSH command not specified. 
Aborting."):
+            task.execute(None)
+
+    def test_command_errored(self):
+        # Test that run_ssh_client_command works on invalid commands
+        command = "not_a_real_command"
+        task = SSHOperator(
+            task_id="test",
+            ssh_hook=self.hook,
+            command=command,
+            dag=self.dag,
+        )
+        with pytest.raises(AirflowException, match=f"error running cmd: 
{command}, error: .*"):
+            task.execute(None)

Reply via email to