This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 045439ec1fe Add working_directory parameter (#58210)
045439ec1fe is described below

commit 045439ec1fed4d925104a1cc1c17ee4e7701c1da
Author: darkag <[email protected]>
AuthorDate: Thu Nov 13 19:23:34 2025 +0100

    Add working_directory parameter (#58210)
    
    * Add working_directory parameter
    
    * fix test
    
    * restore import Connection
    
    * ignore type check
---
 .../providers/microsoft/winrm/hooks/winrm.py       |  4 +++-
 .../providers/microsoft/winrm/operators/winrm.py   | 11 +++++++++--
 .../tests/unit/microsoft/winrm/hooks/test_winrm.py | 22 ++++++++++++++++++----
 .../unit/microsoft/winrm/operators/test_winrm.py   |  7 ++++++-
 4 files changed, 36 insertions(+), 8 deletions(-)

diff --git 
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
 
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
index edc08ef5b85..0fb1e75a3f4 100644
--- 
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
+++ 
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -226,6 +226,7 @@ class WinRMHook(BaseHook):
         ps_path: str | None = None,
         output_encoding: str = "utf-8",
         return_output: bool = True,
+        working_directory: str | None = None,
     ) -> tuple[int, list[bytes], list[bytes]]:
         """
         Run a command.
@@ -235,12 +236,13 @@ class WinRMHook(BaseHook):
             If specified, it will execute the command as powershell script.
         :param output_encoding: the encoding used to decode stout and stderr.
         :param return_output: Whether to accumulate and return the stdout or 
not.
+        :param working_directory: specify working directory.
         :return: returns a tuple containing return_code, stdout and stderr in 
order.
         """
         winrm_client = self.get_conn()
         self.log.info("Establishing WinRM connection to host: %s", 
self.remote_host)
         try:
-            shell_id = winrm_client.open_shell()
+            shell_id = 
winrm_client.open_shell(working_directory=working_directory)
         except Exception as error:
             error_msg = f"Error connecting to host: {self.remote_host}, error: 
{error}"
             self.log.error(error_msg)
diff --git 
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
 
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
index 917297c898e..c72dd3bf114 100644
--- 
a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
+++ 
b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py
@@ -50,10 +50,14 @@ class WinRMOperator(BaseOperator):
     :param output_encoding: the encoding used to decode stout and stderr
     :param timeout: timeout for executing the command.
     :param expected_return_code: expected return code value(s) of command.
+    :param working_directory: specify working directory.
     """
 
-    template_fields: Sequence[str] = ("command",)
-    template_fields_renderers = {"command": "powershell"}
+    template_fields: Sequence[str] = (
+        "command",
+        "working_directory",
+    )
+    template_fields_renderers = {"command": "powershell", "working_directory": 
"powershell"}
 
     def __init__(
         self,
@@ -66,6 +70,7 @@ class WinRMOperator(BaseOperator):
         output_encoding: str = "utf-8",
         timeout: int = 10,
         expected_return_code: int | list[int] | range = 0,
+        working_directory: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -77,6 +82,7 @@ class WinRMOperator(BaseOperator):
         self.output_encoding = output_encoding
         self.timeout = timeout
         self.expected_return_code = expected_return_code
+        self.working_directory = working_directory
 
     def execute(self, context: Context) -> list | str:
         if self.ssh_conn_id and not self.winrm_hook:
@@ -97,6 +103,7 @@ class WinRMOperator(BaseOperator):
             ps_path=self.ps_path,
             output_encoding=self.output_encoding,
             return_output=self.do_xcom_push,
+            working_directory=self.working_directory,
         )
 
         success = False
diff --git 
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py 
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
index b4ff975b7eb..2403fd6628e 100644
--- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
+++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py
@@ -22,9 +22,13 @@ from unittest.mock import MagicMock, patch
 import pytest
 
 from airflow.exceptions import AirflowException
-from airflow.models import Connection
 from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
 
+try:
+    from airflow.sdk import Connection  # type: ignore
+except ImportError:
+    from airflow.models import Connection  # type: ignore
+
 
 class TestWinRMHook:
     def test_get_conn_missing_remote_host(self):
@@ -42,6 +46,8 @@ class TestWinRMHook:
     @patch(
         
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
         return_value=Connection(
+            conn_id="",
+            conn_type="",
             login="username",
             password="password",
             host="remote_host",
@@ -113,6 +119,8 @@ class TestWinRMHook:
     @patch(
         
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
         return_value=Connection(
+            conn_id="",
+            conn_type="",
             login="username",
             password="password",
             host="remote_host",
@@ -154,6 +162,8 @@ class TestWinRMHook:
     @patch(
         
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
         return_value=Connection(
+            conn_id="",
+            conn_type="",
             login="username",
             password="password",
             host="remote_host",
@@ -177,16 +187,20 @@ class TestWinRMHook:
                   }""",
         ),
     )
-    def test_run_without_stdout(self, mock_get_connection, mock_protocol):
+    def test_run_without_stdout_and_working_dir(self, mock_get_connection, 
mock_protocol):
         winrm_hook = WinRMHook(ssh_conn_id="conn_id")
-
+        working_dir = "c:\\test"
         mock_protocol.return_value.run_command = 
MagicMock(return_value="command_id")
         mock_protocol.return_value.get_command_output_raw = MagicMock(
             return_value=(b"stdout", b"stderr", 0, True)
         )
+        mock_protocol.return_value.open_shell = MagicMock()
 
-        return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", 
return_output=False)
+        return_code, stdout_buffer, stderr_buffer = winrm_hook.run(
+            "dir", return_output=False, working_directory=working_dir
+        )
 
+        
mock_protocol.return_value.open_shell.assert_called_once_with(working_directory=working_dir)
         assert return_code == 0
         assert not stdout_buffer
         assert stderr_buffer == [b"stderr"]
diff --git 
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py 
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
index 8997395b0bf..95650799a3b 100644
--- 
a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
+++ 
b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py
@@ -44,8 +44,11 @@ class TestWinRMOperator:
     def test_default_returning_0_command(self, mock_hook):
         stdout = [b"O", b"K"]
         command = "not_empty"
+        working_dir = "c:\\temp"
         mock_hook.run.return_value = (0, stdout, [])
-        op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, 
command=command)
+        op = WinRMOperator(
+            task_id="test_task_id", winrm_hook=mock_hook, command=command, 
working_directory=working_dir
+        )
         execute_result = op.execute(None)
         assert execute_result == b64encode(b"".join(stdout)).decode("utf-8")
         mock_hook.run.assert_called_once_with(
@@ -53,6 +56,7 @@ class TestWinRMOperator:
             ps_path=None,
             output_encoding="utf-8",
             return_output=True,
+            working_directory=working_dir,
         )
 
     @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook")
@@ -94,6 +98,7 @@ class TestWinRMOperator:
                 ps_path=None,
                 output_encoding="utf-8",
                 return_output=True,
+                working_directory=None,
             )
         else:
             exception_msg = f"Error running cmd: {command}, return code: 
{real_return_code}, error: KO"

Reply via email to