This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a3a596906cb Add host_proxy_cmd parameter to SSHHook and SFTPHook
(#44565)
a3a596906cb is described below
commit a3a596906cb28fd7b303b15411801b733b9b49b4
Author: Ajit J Gupta <[email protected]>
AuthorDate: Tue Dec 3 15:52:27 2024 +0530
Add host_proxy_cmd parameter to SSHHook and SFTPHook (#44565)
* Add host_proxy_cmd parameter to SSHHook and SFTPHook
* Fix unit test case by mocking the paramiko.ProxyCommand
* Fixed test cases by adding timeout=None
---
providers/src/airflow/providers/sftp/hooks/sftp.py | 2 ++
providers/src/airflow/providers/ssh/hooks/ssh.py | 5 ++--
providers/tests/sftp/hooks/test_sftp.py | 28 ++++++++++++++++++++
providers/tests/ssh/hooks/test_ssh.py | 30 ++++++++++++++++++++++
4 files changed, 63 insertions(+), 2 deletions(-)
diff --git a/providers/src/airflow/providers/sftp/hooks/sftp.py
b/providers/src/airflow/providers/sftp/hooks/sftp.py
index fec11666dec..1a826cd645c 100644
--- a/providers/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/src/airflow/providers/sftp/hooks/sftp.py
@@ -83,6 +83,7 @@ class SFTPHook(SSHHook):
self,
ssh_conn_id: str | None = "sftp_default",
ssh_hook: SSHHook | None = None,
+ host_proxy_cmd: str | None = None,
*args,
**kwargs,
) -> None:
@@ -115,6 +116,7 @@ class SFTPHook(SSHHook):
ssh_conn_id = ftp_conn_id
kwargs["ssh_conn_id"] = ssh_conn_id
+ kwargs["host_proxy_cmd"] = host_proxy_cmd
self.ssh_conn_id = ssh_conn_id
super().__init__(*args, **kwargs)
diff --git a/providers/src/airflow/providers/ssh/hooks/ssh.py
b/providers/src/airflow/providers/ssh/hooks/ssh.py
index d41eed910db..3502459c644 100644
--- a/providers/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/src/airflow/providers/ssh/hooks/ssh.py
@@ -119,6 +119,7 @@ class SSHHook(BaseHook):
disabled_algorithms: dict | None = None,
ciphers: list[str] | None = None,
auth_timeout: int | None = None,
+ host_proxy_cmd: str | None = None,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
@@ -134,7 +135,7 @@ class SSHHook(BaseHook):
self.banner_timeout = banner_timeout
self.disabled_algorithms = disabled_algorithms
self.ciphers = ciphers
- self.host_proxy_cmd = None
+ self.host_proxy_cmd = host_proxy_cmd
self.auth_timeout = auth_timeout
# Default values, overridable from Connection
@@ -246,7 +247,7 @@ class SSHHook(BaseHook):
with open(user_ssh_config_filename) as config_fd:
ssh_conf.parse(config_fd)
host_info = ssh_conf.lookup(self.remote_host)
- if host_info and host_info.get("proxycommand"):
+ if host_info and host_info.get("proxycommand") and not
self.host_proxy_cmd:
self.host_proxy_cmd = host_info["proxycommand"]
if not (self.password or self.key_file):
diff --git a/providers/tests/sftp/hooks/test_sftp.py
b/providers/tests/sftp/hooks/test_sftp.py
index 7a7a2991a70..5f2c34a8cc0 100644
--- a/providers/tests/sftp/hooks/test_sftp.py
+++ b/providers/tests/sftp/hooks/test_sftp.py
@@ -788,3 +788,31 @@ class TestSFTPHookAsync:
with pytest.raises(AirflowException) as exc:
await hook.get_mod_time("/path/does_not/exist/")
assert str(exc.value) == "No files matching"
+
+ @patch("paramiko.SSHClient")
+ @mock.patch("paramiko.ProxyCommand")
+ def test_sftp_hook_with_proxy_command(self, mock_proxy_command,
mock_ssh_client):
+ mock_transport = mock.MagicMock()
+ mock_ssh_client.return_value.get_transport.return_value =
mock_transport
+ mock_proxy_command.return_value = mock.MagicMock()
+
+ host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy
proxy_host:port %h %p"
+ hook = SFTPHook(
+ remote_host="example.com",
+ username="user",
+ host_proxy_cmd=host_proxy_cmd,
+ )
+ hook.get_conn()
+
+ mock_proxy_command.assert_called_once_with(host_proxy_cmd)
+ mock_ssh_client.return_value.connect.assert_called_once_with(
+ hostname="example.com",
+ username="user",
+ timeout=None,
+ compress=True,
+ port=22,
+ sock=mock_proxy_command.return_value,
+ look_for_keys=True,
+ banner_timeout=30.0,
+ auth_timeout=None,
+ )
diff --git a/providers/tests/ssh/hooks/test_ssh.py
b/providers/tests/ssh/hooks/test_ssh.py
index b6ff22b75ea..e09f2eeee0a 100644
--- a/providers/tests/ssh/hooks/test_ssh.py
+++ b/providers/tests/ssh/hooks/test_ssh.py
@@ -955,3 +955,33 @@ class TestSSHHook:
client2 = hook.get_conn()
assert client1 is not client2
assert client2.get_transport().is_active()
+
+ @mock.patch("paramiko.SSHClient")
+ @mock.patch("paramiko.ProxyCommand")
+ def test_ssh_hook_with_proxy_command(self, mock_proxy_command,
mock_ssh_client):
+ # Mock transport and proxy command behavior
+ mock_transport = mock.MagicMock()
+ mock_ssh_client.return_value.get_transport.return_value =
mock_transport
+ mock_proxy_command.return_value = mock.MagicMock()
+
+ # Create the SSHHook with the proxy command
+ host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy
proxy_host:port %h %p"
+ hook = SSHHook(
+ remote_host="example.com",
+ username="user",
+ host_proxy_cmd=host_proxy_cmd,
+ )
+ hook.get_conn()
+
+ mock_proxy_command.assert_called_once_with(host_proxy_cmd)
+ mock_ssh_client.return_value.connect.assert_called_once_with(
+ hostname="example.com",
+ username="user",
+ timeout=None,
+ compress=True,
+ port=22,
+ sock=mock_proxy_command.return_value,
+ look_for_keys=True,
+ banner_timeout=30.0,
+ auth_timeout=None,
+ )