This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 73fcbb0 Refactor SSHOperator so a subclass can run many commands (#10874) (#17378) 73fcbb0 is described below commit 73fcbb0e4e151c9965fd69ba08de59462bbbe6dc Author: Bjorn Olsen <bjorn.ols...@gmail.com> AuthorDate: Wed Oct 13 22:14:54 2021 +0200 Refactor SSHOperator so a subclass can run many commands (#10874) (#17378) --- airflow/providers/ssh/operators/ssh.py | 204 ++++++++++++++++-------------- tests/providers/ssh/operators/test_ssh.py | 112 ++++++++++++++-- 2 files changed, 211 insertions(+), 105 deletions(-) diff --git a/airflow/providers/ssh/operators/ssh.py b/airflow/providers/ssh/operators/ssh.py index 7f97b03..300e155 100644 --- a/airflow/providers/ssh/operators/ssh.py +++ b/airflow/providers/ssh/operators/ssh.py @@ -19,7 +19,9 @@ import warnings from base64 import b64encode from select import select -from typing import Optional, Union +from typing import Optional, Tuple, Union + +from paramiko.client import SSHClient from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -107,103 +109,115 @@ class SSHOperator(BaseOperator): stacklevel=1, ) - def execute(self, context) -> Union[bytes, str, bool]: + def get_hook(self) -> SSHHook: + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout) + + if not self.ssh_hook: + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") + + if self.remote_host is not None: + self.log.info( + "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id." + ) + self.ssh_hook.remote_host = self.remote_host + + return self.ssh_hook + + def get_ssh_client(self) -> SSHClient: + # Remember to use context manager or call .close() on this when done + self.log.info('Creating ssh_client') + return self.get_hook().get_conn() + + def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> Tuple[int, bytes, bytes]: + self.log.info("Running command: %s", command) + + # set timeout taken as params + stdin, stdout, stderr = ssh_client.exec_command( + command=command, + get_pty=self.get_pty, + timeout=self.timeout, + environment=self.environment, + ) + # get channels + channel = stdout.channel + + # closing stdin + stdin.close() + channel.shutdown_write() + + agg_stdout = b'' + agg_stderr = b'' + + # capture any initial output in case channel is closed already + stdout_buffer_length = len(stdout.channel.in_buffer) + + if stdout_buffer_length > 0: + agg_stdout += stdout.channel.recv(stdout_buffer_length) + + # read from both stdout and stderr + while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): + readq, _, _ = select([channel], [], [], self.cmd_timeout) + for recv in readq: + if recv.recv_ready(): + line = stdout.channel.recv(len(recv.in_buffer)) + agg_stdout += line + self.log.info(line.decode('utf-8', 'replace').strip('\n')) + if recv.recv_stderr_ready(): + line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) + agg_stderr += line + self.log.warning(line.decode('utf-8', 'replace').strip('\n')) + if ( + stdout.channel.exit_status_ready() + and not stderr.channel.recv_stderr_ready() + and not stdout.channel.recv_ready() + ): + stdout.channel.shutdown_read() + try: + stdout.channel.close() + except Exception: + # there is a race that when shutdown_read has been called and when + # you try to close the connection, the socket is already closed + # We should ignore such errors (but we should log them with warning) + self.log.warning("Ignoring exception on close", exc_info=True) + break + + stdout.close() + stderr.close() + + exit_status = stdout.channel.recv_exit_status() + + return exit_status, agg_stdout, agg_stderr + + def raise_for_status(self, exit_status: int, stderr: bytes) -> None: + if exit_status != 0: + error_msg = stderr.decode('utf-8') + raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}") + + def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> bytes: + exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command(ssh_client, command) + self.raise_for_status(exit_status, agg_stderr) + return agg_stdout + + def execute(self, context=None) -> Union[bytes, str]: + result = None + if self.command is None: + raise AirflowException("SSH operator error: SSH command not specified. Aborting.") try: - if self.ssh_conn_id: - if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): - self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") - else: - self.log.info( - "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." - ) - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout) - - if not self.ssh_hook: - raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") - - if self.remote_host is not None: - self.log.info( - "remote_host is provided explicitly. " - "It will replace the remote_host which was defined " - "in ssh_hook or predefined in connection of ssh_conn_id." - ) - self.ssh_hook.remote_host = self.remote_host - - if not self.command: - raise AirflowException("SSH command not specified. Aborting.") - - with self.ssh_hook.get_conn() as ssh_client: - self.log.info("Running command: %s", self.command) - - # set timeout taken as params - stdin, stdout, stderr = ssh_client.exec_command( - command=self.command, - get_pty=self.get_pty, - timeout=self.cmd_timeout, - environment=self.environment, - ) - # get channels - channel = stdout.channel - - # closing stdin - stdin.close() - channel.shutdown_write() - - agg_stdout = b'' - agg_stderr = b'' - - # capture any initial output in case channel is closed already - stdout_buffer_length = len(stdout.channel.in_buffer) - - if stdout_buffer_length > 0: - agg_stdout += stdout.channel.recv(stdout_buffer_length) - - # read from both stdout and stderr - while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): - readq, _, _ = select([channel], [], [], self.cmd_timeout) - for recv in readq: - if recv.recv_ready(): - line = stdout.channel.recv(len(recv.in_buffer)) - agg_stdout += line - self.log.info(line.decode('utf-8', 'replace').strip('\n')) - if recv.recv_stderr_ready(): - line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) - agg_stderr += line - self.log.warning(line.decode('utf-8', 'replace').strip('\n')) - if ( - stdout.channel.exit_status_ready() - and not stderr.channel.recv_stderr_ready() - and not stdout.channel.recv_ready() - ): - stdout.channel.shutdown_read() - try: - stdout.channel.close() - except Exception: - # there is a race that when shutdown_read has been called and when - # you try to close the connection, the socket is already closed - # We should ignore such errors (but we should log them with warning) - self.log.warning("Ignoring exception on close", exc_info=True) - break - - stdout.close() - stderr.close() - - exit_status = stdout.channel.recv_exit_status() - if exit_status == 0: - enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') - if enable_pickling: - return agg_stdout - else: - return b64encode(agg_stdout).decode('utf-8') - - else: - error_msg = agg_stderr.decode('utf-8') - raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}") - + with self.get_ssh_client() as ssh_client: + result = self.run_ssh_client_command(ssh_client, self.command) except Exception as e: raise AirflowException(f"SSH operator error: {str(e)}") - - return True + enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') + if not enable_pickling: + result = b64encode(result).decode('utf-8') + return result def tunnel(self) -> None: """Get ssh tunnel""" diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py index c477416..551b64b 100644 --- a/tests/providers/ssh/operators/test_ssh.py +++ b/tests/providers/ssh/operators/test_ssh.py @@ -38,18 +38,28 @@ COMMAND = "echo -n airflow" COMMAND_WITH_SUDO = "sudo " + COMMAND +class SSHClientSideEffect: + def __init__(self, hook): + self.hook = hook + + def __call__(self): + self.return_value = self.hook.get_conn() + return self.return_value + + class TestSSHOperator: def setup_method(self): from airflow.providers.ssh.hooks.ssh import SSHHook hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True + self.dag = DAG('ssh_test', default_args={'start_date': DEFAULT_DATE}) self.hook = hook def test_hook_created_correctly_with_timeout(self): timeout = 20 ssh_id = "ssh_default" - with DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}): + with self.dag: task = SSHOperator(task_id="test", command=COMMAND, timeout=timeout, ssh_conn_id="ssh_default") task.execute(None) assert timeout == task.ssh_hook.conn_timeout @@ -59,7 +69,7 @@ class TestSSHOperator: conn_timeout = 20 cmd_timeout = 45 ssh_id = 'ssh_default' - with DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}): + with self.dag: task = SSHOperator( task_id="test", command=COMMAND, @@ -130,10 +140,8 @@ class TestSSHOperator: @unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): - dag = DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}) - # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided. - task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=dag) + task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=self.dag) with pytest.raises(AirflowException, match="Cannot operate without ssh_hook or ssh_conn_id."): task_0.execute(None) @@ -144,7 +152,7 @@ class TestSSHOperator: ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=dag, + dag=self.dag, ) try: task_1.execute(None) @@ -157,7 +165,7 @@ class TestSSHOperator: ssh_conn_id=TEST_CONN_ID, # No ssh_hook provided. command=COMMAND, timeout=TIMEOUT, - dag=dag, + dag=self.dag, ) try: task_2.execute(None) @@ -172,10 +180,26 @@ class TestSSHOperator: ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=dag, + dag=self.dag, ) task_3.execute(None) assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id + # If remote_host was specified, ensure it is used + task_4 = SSHOperator( + task_id="test_4", + ssh_hook=self.hook, + ssh_conn_id=TEST_CONN_ID, + command=COMMAND, + timeout=TIMEOUT, + dag=self.dag, + remote_host='operator_remote_host', + ) + try: + task_4.execute(None) + except Exception: + pass + assert task_4.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id + assert task_4.ssh_hook.remote_host == 'operator_remote_host' @pytest.mark.parametrize( "command, get_pty_in, get_pty_out", @@ -188,7 +212,6 @@ class TestSSHOperator: ], ) def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out): - dag = DAG('unit_tests_ssh_test_op_arg_checking', default_args={'start_date': DEFAULT_DATE}) task = SSHOperator( task_id="test", ssh_hook=self.hook, @@ -196,7 +219,7 @@ class TestSSHOperator: conn_timeout=TIMEOUT, cmd_timeout=TIMEOUT, get_pty=get_pty_in, - dag=dag, + dag=self.dag, ) if command is None: with pytest.raises(AirflowException) as ctx: @@ -205,3 +228,72 @@ class TestSSHOperator: else: task.execute(None) assert task.get_pty == get_pty_out + + def test_ssh_client_managed_correctly(self): + # Ensure ssh_client gets created once + # Ensure connection gets closed once + task = SSHOperator( + task_id="test", + ssh_hook=self.hook, + command="ls", + dag=self.dag, + ) + + se = SSHClientSideEffect(self.hook) + with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get, unittest.mock.patch( + 'paramiko.client.SSHClient.close' + ) as mock_close: + mock_get.side_effect = se + task.execute() + mock_get.assert_called_once() + mock_close.assert_called_once() + + def test_one_ssh_client_many_commands(self): + # Ensure we can run multiple commands with one client + many_commands = ['ls', 'date', 'pwd'] + + class CustomSSHOperator(SSHOperator): + def execute(self, context=None): + success = False + with self.get_ssh_client() as ssh_client: + for c in many_commands: + self.run_ssh_client_command(ssh_client, c) + success = True + return success + + task = CustomSSHOperator(task_id="test", ssh_hook=self.hook, dag=self.dag) + se = SSHClientSideEffect(self.hook) + with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get, unittest.mock.patch.object( + task, 'run_ssh_client_command' + ) as mock_run_cmd, unittest.mock.patch('paramiko.client.SSHClient.close') as mock_close: + mock_get.side_effect = se + task.execute() + mock_get.assert_called_once() + mock_close.assert_called_once() + + ssh_client = se.return_value + calls = [unittest.mock.call(ssh_client, c) for c in many_commands] + mock_run_cmd.assert_has_calls(calls) + + def test_fail_with_no_command(self): + # Test that run_ssh_client_command fails on no command + task = SSHOperator( + task_id="test", + ssh_hook=self.hook, + # command="ls", + dag=self.dag, + ) + with pytest.raises(AirflowException, match="SSH command not specified. Aborting."): + task.execute(None) + + def test_command_errored(self): + # Test that run_ssh_client_command works on invalid commands + command = "not_a_real_command" + task = SSHOperator( + task_id="test", + ssh_hook=self.hook, + command=command, + dag=self.dag, + ) + with pytest.raises(AirflowException, match=f"error running cmd: {command}, error: .*"): + task.execute(None)