This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 045439ec1fe Add working_directory parameter (#58210)
045439ec1fe is described below
commit 045439ec1fed4d925104a1cc1c17ee4e7701c1da
Author: darkag <[email protected]>
AuthorDate: Thu Nov 13 19:23:34 2025 +0100
Add working_directory parameter (#58210)
* Add working_directory parameter
* fix test
* restore import Connection
* ignore type check
---
.../providers/microsoft/winrm/hooks/winrm.py | 4 +++-
.../providers/microsoft/winrm/operators/winrm.py | 11 +++++++++--
.../tests/unit/microsoft/winrm/hooks/test_winrm.py | 22 ++++++++++++++++++----
.../unit/microsoft/winrm/operators/test_winrm.py | 7 ++++++-
4 files changed, 36 insertions(+), 8 deletions(-)
diff --git
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
index edc08ef5b85..0fb1e75a3f4 100644
---
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
+++
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -226,6 +226,7 @@ class WinRMHook(BaseHook):
ps_path: str | None = None,
output_encoding: str = "utf-8",
return_output: bool = True,
+ working_directory: str | None = None,
) -> tuple[int, list[bytes], list[bytes]]:
"""
Run a command.
@@ -235,12 +236,13 @@ class WinRMHook(BaseHook):
If specified, it will execute the command as powershell script.
:param output_encoding: the encoding used to decode stout and stderr.
:param return_output: Whether to accumulate and return the stdout or
not.
+ :param working_directory: specify working directory.
:return: returns a tuple containing return_code, stdout and stderr in
order.
"""
winrm_client = self.get_conn()
self.log.info("Establishing WinRM connection to host: %s",
self.remote_host)
try:
- shell_id = winrm_client.open_shell()
+ shell_id =
winrm_client.open_shell(working_directory=working_directory)
except Exception as error:
error_msg = f"Error connecting to host: {self.remote_host}, error:
{error}"
self.log.error(error_msg)
diff --git
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
index 917297c898e..c72dd3bf114 100644
---
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
+++
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
@@ -50,10 +50,14 @@ class WinRMOperator(BaseOperator):
:param output_encoding: the encoding used to decode stout and stderr
:param timeout: timeout for executing the command.
:param expected_return_code: expected return code value(s) of command.
+ :param working_directory: specify working directory.
"""
- template_fields: Sequence[str] = ("command",)
- template_fields_renderers = {"command": "powershell"}
+ template_fields: Sequence[str] = (
+ "command",
+ "working_directory",
+ )
+ template_fields_renderers = {"command": "powershell", "working_directory":
"powershell"}
def __init__(
self,
@@ -66,6 +70,7 @@ class WinRMOperator(BaseOperator):
output_encoding: str = "utf-8",
timeout: int = 10,
expected_return_code: int | list[int] | range = 0,
+ working_directory: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -77,6 +82,7 @@ class WinRMOperator(BaseOperator):
self.output_encoding = output_encoding
self.timeout = timeout
self.expected_return_code = expected_return_code
+ self.working_directory = working_directory
def execute(self, context: Context) -> list | str:
if self.ssh_conn_id and not self.winrm_hook:
@@ -97,6 +103,7 @@ class WinRMOperator(BaseOperator):
ps_path=self.ps_path,
output_encoding=self.output_encoding,
return_output=self.do_xcom_push,
+ working_directory=self.working_directory,
)
success = False
diff --git
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
index b4ff975b7eb..2403fd6628e 100644
--- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
+++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
@@ -22,9 +22,13 @@ from unittest.mock import MagicMock, patch
import pytest
from airflow.exceptions import AirflowException
-from airflow.models import Connection
from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
+try:
+ from airflow.sdk import Connection # type: ignore
+except ImportError:
+ from airflow.models import Connection # type: ignore
+
class TestWinRMHook:
def test_get_conn_missing_remote_host(self):
@@ -42,6 +46,8 @@ class TestWinRMHook:
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
+ conn_id="",
+ conn_type="",
login="username",
password="password",
host="remote_host",
@@ -113,6 +119,8 @@ class TestWinRMHook:
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
+ conn_id="",
+ conn_type="",
login="username",
password="password",
host="remote_host",
@@ -154,6 +162,8 @@ class TestWinRMHook:
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
+ conn_id="",
+ conn_type="",
login="username",
password="password",
host="remote_host",
@@ -177,16 +187,20 @@ class TestWinRMHook:
}""",
),
)
- def test_run_without_stdout(self, mock_get_connection, mock_protocol):
+ def test_run_without_stdout_and_working_dir(self, mock_get_connection,
mock_protocol):
winrm_hook = WinRMHook(ssh_conn_id="conn_id")
-
+ working_dir = "c:\\test"
mock_protocol.return_value.run_command =
MagicMock(return_value="command_id")
mock_protocol.return_value.get_command_output_raw = MagicMock(
return_value=(b"stdout", b"stderr", 0, True)
)
+ mock_protocol.return_value.open_shell = MagicMock()
- return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir",
return_output=False)
+ return_code, stdout_buffer, stderr_buffer = winrm_hook.run(
+ "dir", return_output=False, working_directory=working_dir
+ )
+
mock_protocol.return_value.open_shell.assert_called_once_with(working_directory=working_dir)
assert return_code == 0
assert not stdout_buffer
assert stderr_buffer == [b"stderr"]
diff --git
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
index 8997395b0bf..95650799a3b 100644
---
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
+++
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
@@ -44,8 +44,11 @@ class TestWinRMOperator:
def test_default_returning_0_command(self, mock_hook):
stdout = [b"O", b"K"]
command = "not_empty"
+ working_dir = "c:\\temp"
mock_hook.run.return_value = (0, stdout, [])
- op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook,
command=command)
+ op = WinRMOperator(
+ task_id="test_task_id", winrm_hook=mock_hook, command=command,
working_directory=working_dir
+ )
execute_result = op.execute(None)
assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
mock_hook.run.assert_called_once_with(
@@ -53,6 +56,7 @@ class TestWinRMOperator:
ps_path=None,
output_encoding="utf-8",
return_output=True,
+ working_directory=working_dir,
)
@mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
@@ -94,6 +98,7 @@ class TestWinRMOperator:
ps_path=None,
output_encoding="utf-8",
return_output=True,
+ working_directory=None,
)
else:
exception_msg = f"Error running cmd: {command}, return code:
{real_return_code}, error: KO"