This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a3a596906cb Add host_proxy_cmd parameter to SSHHook and SFTPHook 
(#44565)
a3a596906cb is described below

commit a3a596906cb28fd7b303b15411801b733b9b49b4
Author: Ajit J Gupta <[email protected]>
AuthorDate: Tue Dec 3 15:52:27 2024 +0530

    Add host_proxy_cmd parameter to SSHHook and SFTPHook (#44565)
    
    * Add host_proxy_cmd parameter to SSHHook and SFTPHook
    
    * Fix unit test case by mocking the paramiko.ProxyCommand
    
    * Fixed test cases by adding timeout=None
---
 providers/src/airflow/providers/sftp/hooks/sftp.py |  2 ++
 providers/src/airflow/providers/ssh/hooks/ssh.py   |  5 ++--
 providers/tests/sftp/hooks/test_sftp.py            | 28 ++++++++++++++++++++
 providers/tests/ssh/hooks/test_ssh.py              | 30 ++++++++++++++++++++++
 4 files changed, 63 insertions(+), 2 deletions(-)

diff --git a/providers/src/airflow/providers/sftp/hooks/sftp.py 
b/providers/src/airflow/providers/sftp/hooks/sftp.py
index fec11666dec..1a826cd645c 100644
--- a/providers/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/src/airflow/providers/sftp/hooks/sftp.py
@@ -83,6 +83,7 @@ class SFTPHook(SSHHook):
         self,
         ssh_conn_id: str | None = "sftp_default",
         ssh_hook: SSHHook | None = None,
+        host_proxy_cmd: str | None = None,
         *args,
         **kwargs,
     ) -> None:
@@ -115,6 +116,7 @@ class SFTPHook(SSHHook):
             ssh_conn_id = ftp_conn_id
 
         kwargs["ssh_conn_id"] = ssh_conn_id
+        kwargs["host_proxy_cmd"] = host_proxy_cmd
         self.ssh_conn_id = ssh_conn_id
 
         super().__init__(*args, **kwargs)
diff --git a/providers/src/airflow/providers/ssh/hooks/ssh.py 
b/providers/src/airflow/providers/ssh/hooks/ssh.py
index d41eed910db..3502459c644 100644
--- a/providers/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/src/airflow/providers/ssh/hooks/ssh.py
@@ -119,6 +119,7 @@ class SSHHook(BaseHook):
         disabled_algorithms: dict | None = None,
         ciphers: list[str] | None = None,
         auth_timeout: int | None = None,
+        host_proxy_cmd: str | None = None,
     ) -> None:
         super().__init__()
         self.ssh_conn_id = ssh_conn_id
@@ -134,7 +135,7 @@ class SSHHook(BaseHook):
         self.banner_timeout = banner_timeout
         self.disabled_algorithms = disabled_algorithms
         self.ciphers = ciphers
-        self.host_proxy_cmd = None
+        self.host_proxy_cmd = host_proxy_cmd
         self.auth_timeout = auth_timeout
 
         # Default values, overridable from Connection
@@ -246,7 +247,7 @@ class SSHHook(BaseHook):
             with open(user_ssh_config_filename) as config_fd:
                 ssh_conf.parse(config_fd)
             host_info = ssh_conf.lookup(self.remote_host)
-            if host_info and host_info.get("proxycommand"):
+            if host_info and host_info.get("proxycommand") and not 
self.host_proxy_cmd:
                 self.host_proxy_cmd = host_info["proxycommand"]
 
             if not (self.password or self.key_file):
diff --git a/providers/tests/sftp/hooks/test_sftp.py 
b/providers/tests/sftp/hooks/test_sftp.py
index 7a7a2991a70..5f2c34a8cc0 100644
--- a/providers/tests/sftp/hooks/test_sftp.py
+++ b/providers/tests/sftp/hooks/test_sftp.py
@@ -788,3 +788,31 @@ class TestSFTPHookAsync:
         with pytest.raises(AirflowException) as exc:
             await hook.get_mod_time("/path/does_not/exist/")
         assert str(exc.value) == "No files matching"
+
+    @patch("paramiko.SSHClient")
+    @mock.patch("paramiko.ProxyCommand")
+    def test_sftp_hook_with_proxy_command(self, mock_proxy_command, 
mock_ssh_client):
+        mock_transport = mock.MagicMock()
+        mock_ssh_client.return_value.get_transport.return_value = 
mock_transport
+        mock_proxy_command.return_value = mock.MagicMock()
+
+        host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy 
proxy_host:port %h %p"
+        hook = SFTPHook(
+            remote_host="example.com",
+            username="user",
+            host_proxy_cmd=host_proxy_cmd,
+        )
+        hook.get_conn()
+
+        mock_proxy_command.assert_called_once_with(host_proxy_cmd)
+        mock_ssh_client.return_value.connect.assert_called_once_with(
+            hostname="example.com",
+            username="user",
+            timeout=None,
+            compress=True,
+            port=22,
+            sock=mock_proxy_command.return_value,
+            look_for_keys=True,
+            banner_timeout=30.0,
+            auth_timeout=None,
+        )
diff --git a/providers/tests/ssh/hooks/test_ssh.py 
b/providers/tests/ssh/hooks/test_ssh.py
index b6ff22b75ea..e09f2eeee0a 100644
--- a/providers/tests/ssh/hooks/test_ssh.py
+++ b/providers/tests/ssh/hooks/test_ssh.py
@@ -955,3 +955,33 @@ class TestSSHHook:
         client2 = hook.get_conn()
         assert client1 is not client2
         assert client2.get_transport().is_active()
+
+    @mock.patch("paramiko.SSHClient")
+    @mock.patch("paramiko.ProxyCommand")
+    def test_ssh_hook_with_proxy_command(self, mock_proxy_command, 
mock_ssh_client):
+        # Mock transport and proxy command behavior
+        mock_transport = mock.MagicMock()
+        mock_ssh_client.return_value.get_transport.return_value = 
mock_transport
+        mock_proxy_command.return_value = mock.MagicMock()
+
+        # Create the SSHHook with the proxy command
+        host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy 
proxy_host:port %h %p"
+        hook = SSHHook(
+            remote_host="example.com",
+            username="user",
+            host_proxy_cmd=host_proxy_cmd,
+        )
+        hook.get_conn()
+
+        mock_proxy_command.assert_called_once_with(host_proxy_cmd)
+        mock_ssh_client.return_value.connect.assert_called_once_with(
+            hostname="example.com",
+            username="user",
+            timeout=None,
+            compress=True,
+            port=22,
+            sock=mock_proxy_command.return_value,
+            look_for_keys=True,
+            banner_timeout=30.0,
+            auth_timeout=None,
+        )

Reply via email to