kaxil commented on code in PR #60651:
URL: https://github.com/apache/airflow/pull/60651#discussion_r2705582571
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -122,3 +180,28 @@ def execute(self, context: Context) -> list | str:
stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = f"Error running cmd: {self.command}, return code:
{return_code}, error: {stderr_output}"
raise AirflowException(error_msg)
+
+ def execute_complete(
+ self,
+ context: Context,
+ event: dict[Any, Any] | None = None,
+ ) -> Any:
+ """
+ Execute callback when WinRMCommandOutputTrigger finishes execution.
+
+ This method gets executed automatically when WinRMCommandOutputTrigger
completes its execution.
+ """
+ if event:
+ status = event.get("status")
+ return_code = event.get("return_code")
+
+ self.log.info("%s completed with %s", self.task_id, status)
+
+ stdout = base64.standard_b64decode(event.get("stdout", b""))
Review Comment:
nit: The default `b""` is bytes but the trigger sends a string. Works fine
but for consistency maybe `""` instead?
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -122,3 +180,28 @@ def execute(self, context: Context) -> list | str:
stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = f"Error running cmd: {self.command}, return code:
{return_code}, error: {stderr_output}"
raise AirflowException(error_msg)
+
+ def execute_complete(
+ self,
+ context: Context,
+ event: dict[Any, Any] | None = None,
+ ) -> Any:
+ """
+ Execute callback when WinRMCommandOutputTrigger finishes execution.
+
+ This method gets executed automatically when WinRMCommandOutputTrigger
completes its execution.
+ """
+ if event:
Review Comment:
Should we check for error status here? If the trigger yields an error event
(e.g., timeout), the code will proceed to try decoding stdout/stderr which
won't exist in the event dict. Something like:
```python
if event:
status = event.get("status")
if status == "error":
raise AirflowException(f"Trigger failed: {event.get('message')}")
...
```
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py:
##########
@@ -78,37 +86,87 @@ def __init__(
self.command = command
self.ps_path = ps_path
self.output_encoding = output_encoding
- self.timeout = timeout
+ self.timeout = timeout.total_seconds() if isinstance(timeout,
timedelta) else timeout
+ self.poll_interval = (
+ poll_interval.total_seconds()
+ if isinstance(poll_interval, timedelta)
+ else poll_interval
+ if poll_interval is not None
+ else 1.0
+ )
self.expected_return_code = expected_return_code
self.working_directory = working_directory
+ self.deferrable = deferrable
- def execute(self, context: Context) -> list | str:
- if self.ssh_conn_id and not self.winrm_hook:
- self.log.info("Hook not found, creating...")
- self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)
-
+ @property
+ def hook(self) -> WinRMHook:
if not self.winrm_hook:
- raise AirflowException("Cannot operate without winrm_hook or
ssh_conn_id.")
+ if self.ssh_conn_id and not self.winrm_hook:
Review Comment:
nit: `not self.winrm_hook` is redundant here since we're already inside the
outer `if not self.winrm_hook:` block
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py:
##########
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for winrm remote execution."""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import time
+from collections.abc import AsyncIterator
+from contextlib import suppress
+from typing import Any
+
+from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WinRMCommandOutputTrigger(BaseTrigger):
+ """
+ A trigger that polls the command output executed by the WinRMHook.
+
+ This trigger avoids blocking a worker when using the WinRMOperator in
deferred mode.
+
+ The behavior of this trigger is as follows:
+ - poll the command output from the shell launched by WinRM,
+ - if command not done then sleep and retry,
+ - when command done then return the output.
+
+ :param ssh_conn_id: connection id from airflow Connections from where
+ all the required parameters can be fetched like username and password,
+ though priority is given to the params passed during init.
+ :param shell_id: The shell id on the remote machine.
+ :param command_id: The command id executed on the remote machine.
+ :param output_encoding: the encoding used to decode stout and stderr,
defaults to utf-8.
+ :param return_output: Whether to accumulate and return the stdout or not,
defaults to True.
+ :param working_directory: specify working directory.
+ :param poll_interval: How often, in seconds, the trigger should poll the
output command of the launched command,
+ defaults to 1.
+ :param timeout: max time allowed for polling, if it goes beyond it will
raise and fail.
+ """
+
+ def __init__(
+ self,
+ ssh_conn_id: str,
+ shell_id: str,
+ command_id: str,
+ output_encoding: str = "utf-8",
+ return_output: bool = True,
+ working_directory: str | None = None,
+ expected_return_code: int | list[int] | range = 0,
+ poll_interval: float = 1,
+ timeout: float | None = None,
+ deadline: float | None = None,
+ ) -> None:
+ super().__init__()
+ self.ssh_conn_id = ssh_conn_id
+ self.shell_id = shell_id
+ self.command_id = command_id
+ self.output_encoding = output_encoding
+ self.return_output = return_output
+ self.working_directory = working_directory
+ self.expected_return_code = expected_return_code
+ self.poll_interval = poll_interval
+ self.timeout = timeout
+ self.deadline = deadline or (time.monotonic() + self.timeout if
self.timeout is not None else None)
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize WinRMCommandOutputTrigger arguments and classpath."""
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__name__}",
+ {
+ "ssh_conn_id": self.ssh_conn_id,
+ "shell_id": self.shell_id,
+ "command_id": self.command_id,
+ "output_encoding": self.output_encoding,
+ "return_output": self.return_output,
+ "working_directory": self.working_directory,
+ "expected_return_code": self.expected_return_code,
+ "poll_interval": self.poll_interval,
+ "timeout": self.timeout,
+ "deadline": self.deadline,
+ },
+ )
+
+ @property
+ def hook(self) -> WinRMHook:
Review Comment:
This creates a new WinRMHook on every access. In the `run()` method,
`self.hook` is called twice - once at line 110 for `get_conn()` and again at
line 121 for `get_command_output()`. Consider caching the hook instance to
avoid creating a new one on each poll iteration.
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py:
##########
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for winrm remote execution."""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import time
+from collections.abc import AsyncIterator
+from contextlib import suppress
+from typing import Any
+
+from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WinRMCommandOutputTrigger(BaseTrigger):
+ """
+ A trigger that polls the command output executed by the WinRMHook.
+
+ This trigger avoids blocking a worker when using the WinRMOperator in
deferred mode.
+
+ The behavior of this trigger is as follows:
+ - poll the command output from the shell launched by WinRM,
+ - if command not done then sleep and retry,
+ - when command done then return the output.
+
+ :param ssh_conn_id: connection id from airflow Connections from where
+ all the required parameters can be fetched like username and password,
+ though priority is given to the params passed during init.
+ :param shell_id: The shell id on the remote machine.
+ :param command_id: The command id executed on the remote machine.
+ :param output_encoding: the encoding used to decode stout and stderr,
defaults to utf-8.
+ :param return_output: Whether to accumulate and return the stdout or not,
defaults to True.
+ :param working_directory: specify working directory.
+ :param poll_interval: How often, in seconds, the trigger should poll the
output command of the launched command,
+ defaults to 1.
+ :param timeout: max time allowed for polling, if it goes beyond it will
raise and fail.
+ """
+
+ def __init__(
+ self,
+ ssh_conn_id: str,
+ shell_id: str,
+ command_id: str,
+ output_encoding: str = "utf-8",
+ return_output: bool = True,
+ working_directory: str | None = None,
Review Comment:
`working_directory` and `expected_return_code` are stored and serialized but
never used in the trigger. The shell is already opened with the working
directory by the operator before deferring, and return code validation happens
in `execute_complete`. Should these be removed to avoid confusion?
##########
providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py:
##########
@@ -243,55 +250,104 @@ def run(
:param working_directory: specify working directory.
:return: returns a tuple containing return_code, stdout and stderr in
order.
"""
- winrm_client = self.get_conn()
- self.log.info("Establishing WinRM connection to host: %s",
self.remote_host)
- try:
- shell_id =
winrm_client.open_shell(working_directory=working_directory)
- except Exception as error:
- error_msg = f"Error connecting to host: {self.remote_host}, error:
{error}"
- self.log.error(error_msg)
- raise AirflowException(error_msg)
+ conn = self.get_conn()
+ shell_id, command_id = self._run_command(
+ conn=conn,
+ command=command,
+ ps_path=ps_path,
+ working_directory=working_directory,
+ )
try:
- if ps_path is not None:
- self.log.info("Running command as powershell script: '%s'...",
command)
- encoded_ps =
b64encode(command.encode("utf_16_le")).decode("ascii")
- command_id = winrm_client.run_command(shell_id, f"{ps_path}
-encodedcommand {encoded_ps}")
- else:
- self.log.info("Running command: '%s'...", command)
- command_id = winrm_client.run_command(shell_id, command)
-
- # See:
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+ command_done = False
stdout_buffer = []
stderr_buffer = []
- command_done = False
+ return_code: int | None = None
+
+ # See:
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
while not command_done:
- # this is an expected error when waiting for a long-running
process, just silently retry
- with suppress(WinRMOperationTimeoutError):
- (
- stdout,
- stderr,
- return_code,
- command_done,
- ) = winrm_client.get_command_output_raw(shell_id,
command_id)
-
- # Only buffer stdout if we need to so that we minimize
memory usage.
- if return_output:
- stdout_buffer.append(stdout)
- stderr_buffer.append(stderr)
-
- for line in stdout.decode(output_encoding).splitlines():
- self.log.info(line)
- for line in stderr.decode(output_encoding).splitlines():
- self.log.warning(line)
-
- winrm_client.cleanup_command(shell_id, command_id)
+ (
+ stdout,
+ stderr,
+ return_code,
+ command_done,
+ ) = self.get_command_output(conn, shell_id, command_id,
output_encoding)
+
+ # Only buffer stdout if we need to so that we minimize memory
usage.
+ if return_output:
+ stdout_buffer.append(stdout)
+ stderr_buffer.append(stderr)
return return_code, stdout_buffer, stderr_buffer
except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
finally:
- winrm_client.close_shell(shell_id)
+ conn.cleanup_command(shell_id, command_id)
+ conn.close_shell(shell_id)
+
+ def run_command(
+ self,
+ command: str | None,
+ ps_path: str | None = None,
+ working_directory: str | None = None,
+ ) -> tuple[str, str]:
+ return self._run_command(self.get_conn(), command, ps_path,
working_directory)
+
+ def _run_command(
+ self,
+ conn: Protocol,
+ command: str | None,
+ ps_path: str | None = None,
+ working_directory: str | None = None,
+ ) -> tuple[str, str]:
+ if not command:
+ raise AirflowException("No command specified so nothing to execute
here.")
+
+ try:
+ shell_id = conn.open_shell(working_directory=working_directory)
+
+ if ps_path is not None:
+ self.log.info("Running command as powershell script: '%s'...",
command)
+ encoded_ps =
b64encode(command.encode("utf_16_le")).decode("ascii")
+ command_id = conn.run_command(shell_id, f"{ps_path}
-encodedcommand {encoded_ps}")
+ else:
+ self.log.info("Running command: '%s'...", command)
+ command_id = conn.run_command(shell_id, command)
+ except Exception as error:
+ error_msg = f"Error connecting to host: {self.remote_host}, error:
{error}"
+ self.log.error(error_msg)
+ raise AirflowException(error_msg)
+
+ return shell_id, command_id
+
+ def get_command_output(
+ self, conn: Protocol, shell_id: str, command_id: str, output_encoding:
str = "utf-8"
+ ) -> tuple[bytes, bytes, int | None, bool]:
+ with suppress(WinRMOperationTimeoutError):
+ (
+ stdout,
+ stderr,
+ return_code,
+ command_done,
+ ) = conn.get_command_output_raw(shell_id, command_id)
+
+ self.log.debug("return_code: ", return_code)
Review Comment:
These debug statements won't log the actual values. The comma-separated args
don't work like `print()`. Should be:
```python
self.log.debug("return_code: %s", return_code)
self.log.debug("command_done: %s", command_done)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]