kaxil commented on code in PR #60651:
URL: https://github.com/apache/airflow/pull/60651#discussion_r2705582571


##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -122,3 +180,28 @@ def execute(self, context: Context) -> list | str:
         stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
         error_msg = f"Error running cmd: {self.command}, return code: 
{return_code}, error: {stderr_output}"
         raise AirflowException(error_msg)
+
+    def execute_complete(
+        self,
+        context: Context,
+        event: dict[Any, Any] | None = None,
+    ) -> Any:
+        """
+        Execute callback when WinRMCommandOutputTrigger finishes execution.
+
+        This method gets executed automatically when WinRMCommandOutputTrigger 
completes its execution.
+        """
+        if event:
+            status = event.get("status")
+            return_code = event.get("return_code")
+
+            self.log.info("%s completed with %s", self.task_id, status)
+
+            stdout = base64.standard_b64decode(event.get("stdout", b""))

Review Comment:
   nit: The default `b""` is bytes but the trigger sends a string. Works fine 
but for consistency maybe `""` instead?



##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -122,3 +180,28 @@ def execute(self, context: Context) -> list | str:
         stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
         error_msg = f"Error running cmd: {self.command}, return code: 
{return_code}, error: {stderr_output}"
         raise AirflowException(error_msg)
+
+    def execute_complete(
+        self,
+        context: Context,
+        event: dict[Any, Any] | None = None,
+    ) -> Any:
+        """
+        Execute callback when WinRMCommandOutputTrigger finishes execution.
+
+        This method gets executed automatically when WinRMCommandOutputTrigger 
completes its execution.
+        """
+        if event:

Review Comment:
   Should we check for error status here? If the trigger yields an error event 
(e.g., timeout), the code will proceed to try decoding stdout/stderr which 
won't exist in the event dict. Something like:
   
   ```python
   if event:
       status = event.get("status")
       if status == "error":
           raise AirflowException(f"Trigger failed: {event.get('message')}")
       ...
   ```



##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -78,37 +86,87 @@ def __init__(
         self.command = command
         self.ps_path = ps_path
         self.output_encoding = output_encoding
-        self.timeout = timeout
+        self.timeout = timeout.total_seconds() if isinstance(timeout, 
timedelta) else timeout
+        self.poll_interval = (
+            poll_interval.total_seconds()
+            if isinstance(poll_interval, timedelta)
+            else poll_interval
+            if poll_interval is not None
+            else 1.0
+        )
         self.expected_return_code = expected_return_code
         self.working_directory = working_directory
+        self.deferrable = deferrable
 
-    def execute(self, context: Context) -> list | str:
-        if self.ssh_conn_id and not self.winrm_hook:
-            self.log.info("Hook not found, creating...")
-            self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)
-
+    @property
+    def hook(self) -> WinRMHook:
         if not self.winrm_hook:
-            raise AirflowException("Cannot operate without winrm_hook or 
ssh_conn_id.")
+            if self.ssh_conn_id and not self.winrm_hook:

Review Comment:
   nit: `not self.winrm_hook` is redundant here since we're already inside the 
outer `if not self.winrm_hook:` block



##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py:
##########
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for winrm remote execution."""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import time
+from collections.abc import AsyncIterator
+from contextlib import suppress
+from typing import Any
+
+from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WinRMCommandOutputTrigger(BaseTrigger):
+    """
+    A trigger that polls the command output executed by the WinRMHook.
+
+    This trigger avoids blocking a worker when using the WinRMOperator in 
deferred mode.
+
+    The behavior of this trigger is as follows:
+    - poll the command output from the shell launched by WinRM,
+    - if command not done then sleep and retry,
+    - when command done then return the output.
+
+    :param ssh_conn_id: connection id from airflow Connections from where
+        all the required parameters can be fetched like username and password,
+        though priority is given to the params passed during init.
+    :param shell_id: The shell id on the remote machine.
+    :param command_id: The command id executed on the remote machine.
+    :param output_encoding: the encoding used to decode stout and stderr, 
defaults to utf-8.
+    :param return_output: Whether to accumulate and return the stdout or not, 
defaults to True.
+    :param working_directory: specify working directory.
+    :param poll_interval: How often, in seconds, the trigger should poll the 
output command of the launched command,
+        defaults to 1.
+    :param timeout: max time allowed for polling, if it goes beyond it will 
raise and fail.
+    """
+
+    def __init__(
+        self,
+        ssh_conn_id: str,
+        shell_id: str,
+        command_id: str,
+        output_encoding: str = "utf-8",
+        return_output: bool = True,
+        working_directory: str | None = None,
+        expected_return_code: int | list[int] | range = 0,
+        poll_interval: float = 1,
+        timeout: float | None = None,
+        deadline: float | None = None,
+    ) -> None:
+        super().__init__()
+        self.ssh_conn_id = ssh_conn_id
+        self.shell_id = shell_id
+        self.command_id = command_id
+        self.output_encoding = output_encoding
+        self.return_output = return_output
+        self.working_directory = working_directory
+        self.expected_return_code = expected_return_code
+        self.poll_interval = poll_interval
+        self.timeout = timeout
+        self.deadline = deadline or (time.monotonic() + self.timeout if 
self.timeout is not None else None)
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize WinRMCommandOutputTrigger arguments and classpath."""
+        return (
+            f"{self.__class__.__module__}.{self.__class__.__name__}",
+            {
+                "ssh_conn_id": self.ssh_conn_id,
+                "shell_id": self.shell_id,
+                "command_id": self.command_id,
+                "output_encoding": self.output_encoding,
+                "return_output": self.return_output,
+                "working_directory": self.working_directory,
+                "expected_return_code": self.expected_return_code,
+                "poll_interval": self.poll_interval,
+                "timeout": self.timeout,
+                "deadline": self.deadline,
+            },
+        )
+
+    @property
+    def hook(self) -> WinRMHook:

Review Comment:
   This creates a new WinRMHook on every access. In the `run()` method, 
`self.hook` is called twice - once at line 110 for `get_conn()` and again at 
line 121 for `get_command_output()`. Consider caching the hook instance to 
avoid creating a new one on each poll iteration.



##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py:
##########
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for winrm remote execution."""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import time
+from collections.abc import AsyncIterator
+from contextlib import suppress
+from typing import Any
+
+from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WinRMCommandOutputTrigger(BaseTrigger):
+    """
+    A trigger that polls the command output executed by the WinRMHook.
+
+    This trigger avoids blocking a worker when using the WinRMOperator in 
deferred mode.
+
+    The behavior of this trigger is as follows:
+    - poll the command output from the shell launched by WinRM,
+    - if command not done then sleep and retry,
+    - when command done then return the output.
+
+    :param ssh_conn_id: connection id from airflow Connections from where
+        all the required parameters can be fetched like username and password,
+        though priority is given to the params passed during init.
+    :param shell_id: The shell id on the remote machine.
+    :param command_id: The command id executed on the remote machine.
+    :param output_encoding: the encoding used to decode stout and stderr, 
defaults to utf-8.
+    :param return_output: Whether to accumulate and return the stdout or not, 
defaults to True.
+    :param working_directory: specify working directory.
+    :param poll_interval: How often, in seconds, the trigger should poll the 
output command of the launched command,
+        defaults to 1.
+    :param timeout: max time allowed for polling, if it goes beyond it will 
raise and fail.
+    """
+
+    def __init__(
+        self,
+        ssh_conn_id: str,
+        shell_id: str,
+        command_id: str,
+        output_encoding: str = "utf-8",
+        return_output: bool = True,
+        working_directory: str | None = None,

Review Comment:
   `working_directory` and `expected_return_code` are stored and serialized but 
never used in the trigger. The shell is already opened with the working 
directory by the operator before deferring, and return code validation happens 
in `execute_complete`. Should these be removed to avoid confusion?



##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py:
##########
@@ -243,55 +250,104 @@ def run(
         :param working_directory: specify working directory.
         :return: returns a tuple containing return_code, stdout and stderr in 
order.
         """
-        winrm_client = self.get_conn()
-        self.log.info("Establishing WinRM connection to host: %s", 
self.remote_host)
-        try:
-            shell_id = 
winrm_client.open_shell(working_directory=working_directory)
-        except Exception as error:
-            error_msg = f"Error connecting to host: {self.remote_host}, error: 
{error}"
-            self.log.error(error_msg)
-            raise AirflowException(error_msg)
+        conn = self.get_conn()
+        shell_id, command_id = self._run_command(
+            conn=conn,
+            command=command,
+            ps_path=ps_path,
+            working_directory=working_directory,
+        )
 
         try:
-            if ps_path is not None:
-                self.log.info("Running command as powershell script: '%s'...", 
command)
-                encoded_ps = 
b64encode(command.encode("utf_16_le")).decode("ascii")
-                command_id = winrm_client.run_command(shell_id, f"{ps_path} 
-encodedcommand {encoded_ps}")
-            else:
-                self.log.info("Running command: '%s'...", command)
-                command_id = winrm_client.run_command(shell_id, command)
-
-                # See: 
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+            command_done = False
             stdout_buffer = []
             stderr_buffer = []
-            command_done = False
+            return_code: int | None = None
+
+            # See: 
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
             while not command_done:
-                # this is an expected error when waiting for a long-running 
process, just silently retry
-                with suppress(WinRMOperationTimeoutError):
-                    (
-                        stdout,
-                        stderr,
-                        return_code,
-                        command_done,
-                    ) = winrm_client.get_command_output_raw(shell_id, 
command_id)
-
-                    # Only buffer stdout if we need to so that we minimize 
memory usage.
-                    if return_output:
-                        stdout_buffer.append(stdout)
-                    stderr_buffer.append(stderr)
-
-                    for line in stdout.decode(output_encoding).splitlines():
-                        self.log.info(line)
-                    for line in stderr.decode(output_encoding).splitlines():
-                        self.log.warning(line)
-
-            winrm_client.cleanup_command(shell_id, command_id)
+                (
+                    stdout,
+                    stderr,
+                    return_code,
+                    command_done,
+                ) = self.get_command_output(conn, shell_id, command_id, 
output_encoding)
+
+                # Only buffer stdout if we need to so that we minimize memory 
usage.
+                if return_output:
+                    stdout_buffer.append(stdout)
+                stderr_buffer.append(stderr)
 
             return return_code, stdout_buffer, stderr_buffer
         except Exception as e:
             raise AirflowException(f"WinRM operator error: {e}")
         finally:
-            winrm_client.close_shell(shell_id)
+            conn.cleanup_command(shell_id, command_id)
+            conn.close_shell(shell_id)
+
+    def run_command(
+        self,
+        command: str | None,
+        ps_path: str | None = None,
+        working_directory: str | None = None,
+    ) -> tuple[str, str]:
+        return self._run_command(self.get_conn(), command, ps_path, 
working_directory)
+
+    def _run_command(
+        self,
+        conn: Protocol,
+        command: str | None,
+        ps_path: str | None = None,
+        working_directory: str | None = None,
+    ) -> tuple[str, str]:
+        if not command:
+            raise AirflowException("No command specified so nothing to execute 
here.")
+
+        try:
+            shell_id = conn.open_shell(working_directory=working_directory)
+
+            if ps_path is not None:
+                self.log.info("Running command as powershell script: '%s'...", 
command)
+                encoded_ps = 
b64encode(command.encode("utf_16_le")).decode("ascii")
+                command_id = conn.run_command(shell_id, f"{ps_path} 
-encodedcommand {encoded_ps}")
+            else:
+                self.log.info("Running command: '%s'...", command)
+                command_id = conn.run_command(shell_id, command)
+        except Exception as error:
+            error_msg = f"Error connecting to host: {self.remote_host}, error: 
{error}"
+            self.log.error(error_msg)
+            raise AirflowException(error_msg)
+
+        return shell_id, command_id
+
+    def get_command_output(
+        self, conn: Protocol, shell_id: str, command_id: str, output_encoding: 
str = "utf-8"
+    ) -> tuple[bytes, bytes, int | None, bool]:
+        with suppress(WinRMOperationTimeoutError):
+            (
+                stdout,
+                stderr,
+                return_code,
+                command_done,
+            ) = conn.get_command_output_raw(shell_id, command_id)
+
+            self.log.debug("return_code: ", return_code)

Review Comment:
   These debug statements won't log the actual values. The comma-separated args 
don't work like `print()`. Should be:
   
   ```python
   self.log.debug("return_code: %s", return_code)
   self.log.debug("command_done: %s", command_done)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to