This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 464f7c45603 Move responsibility to run a command from WinRMOperator to
WinRMHook (#43646)
464f7c45603 is described below
commit 464f7c456038683992cea2b665a968fd23e4b2a6
Author: David Blain <[email protected]>
AuthorDate: Wed Nov 6 09:52:47 2024 +0100
Move responsibility to run a command from WinRMOperator to WinRMHook
(#43646)
* refactor: Moved responsibility to run a command away from WinRmOperator
to WinRMHook and also made WinRMHook closable
* refactor: Reformatted exception message in WinRMOperator
* refactor: command parameter of run method in WinRMHook must be specified
* refactor: Changed return type of run method in WinRMHook
* refactor: WinRMHook cannot be closable as it doesn't have the
winrm_client instance
* refactor: Reorganized imports in WinRMHook
* refactor: Added unit tests for new run method in WinRMHook
* refactor: Reorganized imports in TestWinRMHook
---------
Co-authored-by: David Blain <[email protected]>
---
.../providers/microsoft/winrm/hooks/winrm.py | 72 +++++++++++++++++++
.../providers/microsoft/winrm/operators/winrm.py | 73 ++++---------------
.../tests/microsoft/winrm/hooks/test_winrm.py | 84 +++++++++++++++++++++-
3 files changed, 167 insertions(+), 62 deletions(-)
diff --git a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
index 96abdf8e9bf..961e37ba3fe 100644
--- a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
+++ b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -19,6 +19,10 @@
from __future__ import annotations
+from base64 import b64encode
+from contextlib import suppress
+
+from winrm.exceptions import WinRMOperationTimeoutError
from winrm.protocol import Protocol
from airflow.exceptions import AirflowException
@@ -218,3 +222,71 @@ class WinRMHook(BaseHook):
raise AirflowException(error_msg)
return self.client
+
+ def run(
+ self,
+ command: str,
+ ps_path: str | None = None,
+ output_encoding: str = "utf-8",
+ return_output: bool = True,
+ ) -> tuple[int, list[bytes], list[bytes]]:
+ """
+ Run a command.
+
+ :param command: command to execute on remote host.
+ :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh`
for v6+.
+ If specified, it will execute the command as powershell script.
+ :param output_encoding: the encoding used to decode stout and stderr.
+ :param return_output: Whether to accumulate and return the stdout or
not.
+ :return: returns a tuple containing return_code, stdout and stderr in
order.
+ """
+ winrm_client = self.get_conn()
+
+ try:
+ if ps_path is not None:
+ self.log.info("Running command as powershell script: '%s'...",
command)
+ encoded_ps =
b64encode(command.encode("utf_16_le")).decode("ascii")
+ command_id = self.winrm_protocol.run_command( # type:
ignore[attr-defined]
+ winrm_client, f"{ps_path} -encodedcommand {encoded_ps}"
+ )
+ else:
+ self.log.info("Running command: '%s'...", command)
+ command_id = self.winrm_protocol.run_command( # type:
ignore[attr-defined]
+ winrm_client, command
+ )
+
+ # See:
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+ stdout_buffer = []
+ stderr_buffer = []
+ command_done = False
+ while not command_done:
+ # this is an expected error when waiting for a long-running
process, just silently retry
+ with suppress(WinRMOperationTimeoutError):
+ (
+ stdout,
+ stderr,
+ return_code,
+ command_done,
+ ) = self.winrm_protocol._raw_get_command_output( # type:
ignore[attr-defined]
+ winrm_client, command_id
+ )
+
+ # Only buffer stdout if we need to so that we minimize
memory usage.
+ if return_output:
+ stdout_buffer.append(stdout)
+ stderr_buffer.append(stderr)
+
+ for line in stdout.decode(output_encoding).splitlines():
+ self.log.info(line)
+ for line in stderr.decode(output_encoding).splitlines():
+ self.log.warning(line)
+
+ self.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
+ winrm_client, command_id
+ )
+
+ return return_code, stdout_buffer, stderr_buffer
+ except Exception as e:
+ raise AirflowException(f"WinRM operator error: {e}")
+ finally:
+ self.winrm_protocol.close_shell(winrm_client) # type:
ignore[attr-defined]
diff --git a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
index 3b61afb195b..0662333c788 100644
--- a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
+++ b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
@@ -21,8 +21,6 @@ import logging
from base64 import b64encode
from typing import TYPE_CHECKING, Sequence
-from winrm.exceptions import WinRMOperationTimeoutError
-
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -90,68 +88,21 @@ class WinRMOperator(BaseOperator):
if not self.command:
raise AirflowException("No command specified so nothing to execute
here.")
- winrm_client = self.winrm_hook.get_conn()
-
- try:
- if self.ps_path is not None:
- self.log.info("Running command as powershell script: '%s'...",
self.command)
- encoded_ps =
b64encode(self.command.encode("utf_16_le")).decode("ascii")
- command_id = self.winrm_hook.winrm_protocol.run_command( #
type: ignore[attr-defined]
- winrm_client, f"{self.ps_path} -encodedcommand
{encoded_ps}"
- )
- else:
- self.log.info("Running command: '%s'...", self.command)
- command_id = self.winrm_hook.winrm_protocol.run_command( #
type: ignore[attr-defined]
- winrm_client, self.command
- )
-
- # See:
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
- stdout_buffer = []
- stderr_buffer = []
- command_done = False
- while not command_done:
- try:
- (
- stdout,
- stderr,
- return_code,
- command_done,
- ) =
self.winrm_hook.winrm_protocol._raw_get_command_output( # type:
ignore[attr-defined]
- winrm_client, command_id
- )
-
- # Only buffer stdout if we need to so that we minimize
memory usage.
- if self.do_xcom_push:
- stdout_buffer.append(stdout)
- stderr_buffer.append(stderr)
-
- for line in
stdout.decode(self.output_encoding).splitlines():
- self.log.info(line)
- for line in
stderr.decode(self.output_encoding).splitlines():
- self.log.warning(line)
- except WinRMOperationTimeoutError:
- # this is an expected error when waiting for a
- # long-running process, just silently retry
- pass
-
- self.winrm_hook.winrm_protocol.cleanup_command( # type:
ignore[attr-defined]
- winrm_client, command_id
- )
- self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type:
ignore[attr-defined]
-
- except Exception as e:
- raise AirflowException(f"WinRM operator error: {e}")
+ return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run(
+ command=self.command,
+ ps_path=self.ps_path,
+ output_encoding=self.output_encoding,
+ return_output=self.do_xcom_push,
+ )
if return_code == 0:
# returning output if do_xcom_push is set
enable_pickling = conf.getboolean("core", "enable_xcom_pickling")
+
if enable_pickling:
return stdout_buffer
- else:
- return
b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
- else:
- stderr_output =
b"".join(stderr_buffer).decode(self.output_encoding)
- error_msg = (
- f"Error running cmd: {self.command}, return code:
{return_code}, error: {stderr_output}"
- )
- raise AirflowException(error_msg)
+ return
b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
+
+ stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
+ error_msg = f"Error running cmd: {self.command}, return code:
{return_code}, error: {stderr_output}"
+ raise AirflowException(error_msg)
diff --git a/providers/tests/microsoft/winrm/hooks/test_winrm.py
b/providers/tests/microsoft/winrm/hooks/test_winrm.py
index 83411ccf9cc..7c2223cadc3 100644
--- a/providers/tests/microsoft/winrm/hooks/test_winrm.py
+++ b/providers/tests/microsoft/winrm/hooks/test_winrm.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
@@ -119,3 +119,85 @@ class TestWinRMHook:
winrm_hook.get_conn()
assert
f"http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman" ==
winrm_hook.endpoint
+
+ @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol",
autospec=True)
+ @patch(
+
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
+ return_value=Connection(
+ login="username",
+ password="password",
+ host="remote_host",
+ extra="""{
+ "endpoint": "endpoint",
+ "remote_port": 123,
+ "transport": "plaintext",
+ "service": "service",
+ "keytab": "keytab",
+ "ca_trust_path": "ca_trust_path",
+ "cert_pem": "cert_pem",
+ "cert_key_pem": "cert_key_pem",
+ "server_cert_validation": "validate",
+ "kerberos_delegation": "true",
+ "read_timeout_sec": 124,
+ "operation_timeout_sec": 123,
+ "kerberos_hostname_override":
"kerberos_hostname_override",
+ "message_encryption": "auto",
+ "credssp_disable_tlsv1_2": "true",
+ "send_cbt": "false"
+ }""",
+ ),
+ )
+ def test_run_with_stdout(self, mock_get_connection, mock_protocol):
+ winrm_hook = WinRMHook(ssh_conn_id="conn_id")
+
+ mock_protocol.return_value.run_command =
MagicMock(return_value="command_id")
+ mock_protocol.return_value._raw_get_command_output = MagicMock(
+ return_value=(b"stdout", b"stderr", 0, True)
+ )
+
+ return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir")
+
+ assert return_code == 0
+ assert stdout_buffer == [b"stdout"]
+ assert stderr_buffer == [b"stderr"]
+
+ @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol",
autospec=True)
+ @patch(
+
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
+ return_value=Connection(
+ login="username",
+ password="password",
+ host="remote_host",
+ extra="""{
+ "endpoint": "endpoint",
+ "remote_port": 123,
+ "transport": "plaintext",
+ "service": "service",
+ "keytab": "keytab",
+ "ca_trust_path": "ca_trust_path",
+ "cert_pem": "cert_pem",
+ "cert_key_pem": "cert_key_pem",
+ "server_cert_validation": "validate",
+ "kerberos_delegation": "true",
+ "read_timeout_sec": 124,
+ "operation_timeout_sec": 123,
+ "kerberos_hostname_override":
"kerberos_hostname_override",
+ "message_encryption": "auto",
+ "credssp_disable_tlsv1_2": "true",
+ "send_cbt": "false"
+ }""",
+ ),
+ )
+ def test_run_without_stdout(self, mock_get_connection, mock_protocol):
+ winrm_hook = WinRMHook(ssh_conn_id="conn_id")
+
+ mock_protocol.return_value.run_command =
MagicMock(return_value="command_id")
+ mock_protocol.return_value._raw_get_command_output = MagicMock(
+ return_value=(b"stdout", b"stderr", 0, True)
+ )
+
+ return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir",
return_output=False)
+
+ assert return_code == 0
+ assert not stdout_buffer
+ assert stderr_buffer == [b"stderr"]