This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 464f7c45603 Move responsibility to run a command from WinRMOperator to 
WinRMHook (#43646)
464f7c45603 is described below

commit 464f7c456038683992cea2b665a968fd23e4b2a6
Author: David Blain <[email protected]>
AuthorDate: Wed Nov 6 09:52:47 2024 +0100

    Move responsibility to run a command from WinRMOperator to WinRMHook 
(#43646)
    
    * refactor: Moved responsibility to run a command away from WinRmOperator 
to WinRMHook and also made WinRMHook closable
    
    * refactor: Reformatted exception message in WinRMOperator
    
    * refactor: command parameter of run method in WinRMHook must be specified
    
    * refactor: Changed return type of run method in WinRMHook
    
    * refactor: WinRMHook cannot be closable as it doesn't have the 
winrm_client instance
    
    * refactor: Reorganized imports in WinRMHook
    
    * refactor: Added unit tests for new run method in WinRMHook
    
    * refactor: Reorganized imports in TestWinRMHook
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 .../providers/microsoft/winrm/hooks/winrm.py       | 72 +++++++++++++++++++
 .../providers/microsoft/winrm/operators/winrm.py   | 73 ++++---------------
 .../tests/microsoft/winrm/hooks/test_winrm.py      | 84 +++++++++++++++++++++-
 3 files changed, 167 insertions(+), 62 deletions(-)

diff --git a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py 
b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
index 96abdf8e9bf..961e37ba3fe 100644
--- a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
+++ b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -19,6 +19,10 @@
 
 from __future__ import annotations
 
+from base64 import b64encode
+from contextlib import suppress
+
+from winrm.exceptions import WinRMOperationTimeoutError
 from winrm.protocol import Protocol
 
 from airflow.exceptions import AirflowException
@@ -218,3 +222,71 @@ class WinRMHook(BaseHook):
             raise AirflowException(error_msg)
 
         return self.client
+
+    def run(
+        self,
+        command: str,
+        ps_path: str | None = None,
+        output_encoding: str = "utf-8",
+        return_output: bool = True,
+    ) -> tuple[int, list[bytes], list[bytes]]:
+        """
+        Run a command.
+
+        :param command: command to execute on remote host.
+        :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` 
for v6+.
+            If specified, it will execute the command as powershell script.
+        :param output_encoding: the encoding used to decode stout and stderr.
+        :param return_output: Whether to accumulate and return the stdout or 
not.
+        :return: returns a tuple containing return_code, stdout and stderr in 
order.
+        """
+        winrm_client = self.get_conn()
+
+        try:
+            if ps_path is not None:
+                self.log.info("Running command as powershell script: '%s'...", 
command)
+                encoded_ps = 
b64encode(command.encode("utf_16_le")).decode("ascii")
+                command_id = self.winrm_protocol.run_command(  # type: 
ignore[attr-defined]
+                    winrm_client, f"{ps_path} -encodedcommand {encoded_ps}"
+                )
+            else:
+                self.log.info("Running command: '%s'...", command)
+                command_id = self.winrm_protocol.run_command(  # type: 
ignore[attr-defined]
+                    winrm_client, command
+                )
+
+                # See: 
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+            stdout_buffer = []
+            stderr_buffer = []
+            command_done = False
+            while not command_done:
+                # this is an expected error when waiting for a long-running 
process, just silently retry
+                with suppress(WinRMOperationTimeoutError):
+                    (
+                        stdout,
+                        stderr,
+                        return_code,
+                        command_done,
+                    ) = self.winrm_protocol._raw_get_command_output(  # type: 
ignore[attr-defined]
+                        winrm_client, command_id
+                    )
+
+                    # Only buffer stdout if we need to so that we minimize 
memory usage.
+                    if return_output:
+                        stdout_buffer.append(stdout)
+                    stderr_buffer.append(stderr)
+
+                    for line in stdout.decode(output_encoding).splitlines():
+                        self.log.info(line)
+                    for line in stderr.decode(output_encoding).splitlines():
+                        self.log.warning(line)
+
+            self.winrm_protocol.cleanup_command(  # type: ignore[attr-defined]
+                winrm_client, command_id
+            )
+
+            return return_code, stdout_buffer, stderr_buffer
+        except Exception as e:
+            raise AirflowException(f"WinRM operator error: {e}")
+        finally:
+            self.winrm_protocol.close_shell(winrm_client)  # type: 
ignore[attr-defined]
diff --git a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py 
b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
index 3b61afb195b..0662333c788 100644
--- a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
+++ b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
@@ -21,8 +21,6 @@ import logging
 from base64 import b64encode
 from typing import TYPE_CHECKING, Sequence
 
-from winrm.exceptions import WinRMOperationTimeoutError
-
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -90,68 +88,21 @@ class WinRMOperator(BaseOperator):
         if not self.command:
             raise AirflowException("No command specified so nothing to execute 
here.")
 
-        winrm_client = self.winrm_hook.get_conn()
-
-        try:
-            if self.ps_path is not None:
-                self.log.info("Running command as powershell script: '%s'...", 
self.command)
-                encoded_ps = 
b64encode(self.command.encode("utf_16_le")).decode("ascii")
-                command_id = self.winrm_hook.winrm_protocol.run_command(  # 
type: ignore[attr-defined]
-                    winrm_client, f"{self.ps_path} -encodedcommand 
{encoded_ps}"
-                )
-            else:
-                self.log.info("Running command: '%s'...", self.command)
-                command_id = self.winrm_hook.winrm_protocol.run_command(  # 
type: ignore[attr-defined]
-                    winrm_client, self.command
-                )
-
-            # See: 
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
-            stdout_buffer = []
-            stderr_buffer = []
-            command_done = False
-            while not command_done:
-                try:
-                    (
-                        stdout,
-                        stderr,
-                        return_code,
-                        command_done,
-                    ) = 
self.winrm_hook.winrm_protocol._raw_get_command_output(  # type: 
ignore[attr-defined]
-                        winrm_client, command_id
-                    )
-
-                    # Only buffer stdout if we need to so that we minimize 
memory usage.
-                    if self.do_xcom_push:
-                        stdout_buffer.append(stdout)
-                    stderr_buffer.append(stderr)
-
-                    for line in 
stdout.decode(self.output_encoding).splitlines():
-                        self.log.info(line)
-                    for line in 
stderr.decode(self.output_encoding).splitlines():
-                        self.log.warning(line)
-                except WinRMOperationTimeoutError:
-                    # this is an expected error when waiting for a
-                    # long-running process, just silently retry
-                    pass
-
-            self.winrm_hook.winrm_protocol.cleanup_command(  # type: 
ignore[attr-defined]
-                winrm_client, command_id
-            )
-            self.winrm_hook.winrm_protocol.close_shell(winrm_client)  # type: 
ignore[attr-defined]
-
-        except Exception as e:
-            raise AirflowException(f"WinRM operator error: {e}")
+        return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run(
+            command=self.command,
+            ps_path=self.ps_path,
+            output_encoding=self.output_encoding,
+            return_output=self.do_xcom_push,
+        )
 
         if return_code == 0:
             # returning output if do_xcom_push is set
             enable_pickling = conf.getboolean("core", "enable_xcom_pickling")
+
             if enable_pickling:
                 return stdout_buffer
-            else:
-                return 
b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
-        else:
-            stderr_output = 
b"".join(stderr_buffer).decode(self.output_encoding)
-            error_msg = (
-                f"Error running cmd: {self.command}, return code: 
{return_code}, error: {stderr_output}"
-            )
-            raise AirflowException(error_msg)
+            return 
b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
+
+        stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
+        error_msg = f"Error running cmd: {self.command}, return code: 
{return_code}, error: {stderr_output}"
+        raise AirflowException(error_msg)
diff --git a/providers/tests/microsoft/winrm/hooks/test_winrm.py 
b/providers/tests/microsoft/winrm/hooks/test_winrm.py
index 83411ccf9cc..7c2223cadc3 100644
--- a/providers/tests/microsoft/winrm/hooks/test_winrm.py
+++ b/providers/tests/microsoft/winrm/hooks/test_winrm.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
 
@@ -119,3 +119,85 @@ class TestWinRMHook:
         winrm_hook.get_conn()
 
         assert 
f"http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman"; == 
winrm_hook.endpoint
+
+    @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", 
autospec=True)
+    @patch(
+        
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
+        return_value=Connection(
+            login="username",
+            password="password",
+            host="remote_host",
+            extra="""{
+                      "endpoint": "endpoint",
+                      "remote_port": 123,
+                      "transport": "plaintext",
+                      "service": "service",
+                      "keytab": "keytab",
+                      "ca_trust_path": "ca_trust_path",
+                      "cert_pem": "cert_pem",
+                      "cert_key_pem": "cert_key_pem",
+                      "server_cert_validation": "validate",
+                      "kerberos_delegation": "true",
+                      "read_timeout_sec": 124,
+                      "operation_timeout_sec": 123,
+                      "kerberos_hostname_override": 
"kerberos_hostname_override",
+                      "message_encryption": "auto",
+                      "credssp_disable_tlsv1_2": "true",
+                      "send_cbt": "false"
+                  }""",
+        ),
+    )
+    def test_run_with_stdout(self, mock_get_connection, mock_protocol):
+        winrm_hook = WinRMHook(ssh_conn_id="conn_id")
+
+        mock_protocol.return_value.run_command = 
MagicMock(return_value="command_id")
+        mock_protocol.return_value._raw_get_command_output = MagicMock(
+            return_value=(b"stdout", b"stderr", 0, True)
+        )
+
+        return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir")
+
+        assert return_code == 0
+        assert stdout_buffer == [b"stdout"]
+        assert stderr_buffer == [b"stderr"]
+
+    @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", 
autospec=True)
+    @patch(
+        
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
+        return_value=Connection(
+            login="username",
+            password="password",
+            host="remote_host",
+            extra="""{
+                      "endpoint": "endpoint",
+                      "remote_port": 123,
+                      "transport": "plaintext",
+                      "service": "service",
+                      "keytab": "keytab",
+                      "ca_trust_path": "ca_trust_path",
+                      "cert_pem": "cert_pem",
+                      "cert_key_pem": "cert_key_pem",
+                      "server_cert_validation": "validate",
+                      "kerberos_delegation": "true",
+                      "read_timeout_sec": 124,
+                      "operation_timeout_sec": 123,
+                      "kerberos_hostname_override": 
"kerberos_hostname_override",
+                      "message_encryption": "auto",
+                      "credssp_disable_tlsv1_2": "true",
+                      "send_cbt": "false"
+                  }""",
+        ),
+    )
+    def test_run_without_stdout(self, mock_get_connection, mock_protocol):
+        winrm_hook = WinRMHook(ssh_conn_id="conn_id")
+
+        mock_protocol.return_value.run_command = 
MagicMock(return_value="command_id")
+        mock_protocol.return_value._raw_get_command_output = MagicMock(
+            return_value=(b"stdout", b"stderr", 0, True)
+        )
+
+        return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", 
return_output=False)
+
+        assert return_code == 0
+        assert not stdout_buffer
+        assert stderr_buffer == [b"stderr"]

Reply via email to