This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 12fd5fbdc94 Replace `sshtunnel` with native paramiko/asyncssh
tunneling (#64299)
12fd5fbdc94 is described below
commit 12fd5fbdc94e3325574043aee2ee437914924c15
Author: Dev-iL <[email protected]>
AuthorDate: Sat Apr 4 15:17:03 2026 +0300
Replace `sshtunnel` with native paramiko/asyncssh tunneling (#64299)
---
devel-common/src/docs/utils/conf_constants.py | 1 -
docker-tests/tests/docker_tests/test_prod_image.py | 2 +-
docs/spelling_wordlist.txt | 4 +-
providers/ssh/docs/index.rst | 1 -
providers/ssh/pyproject.toml | 2 +-
.../ssh/src/airflow/providers/ssh/hooks/ssh.py | 89 +++--
providers/ssh/src/airflow/providers/ssh/tunnel.py | 427 +++++++++++++++++++++
providers/ssh/tests/unit/ssh/hooks/test_ssh.py | 233 ++++++++---
.../ssh/tests/unit/ssh/hooks/test_ssh_async.py | 62 +++
providers/ssh/tests/unit/ssh/test_tunnel.py | 171 +++++++++
uv.lock | 14 -
11 files changed, 891 insertions(+), 115 deletions(-)
diff --git a/devel-common/src/docs/utils/conf_constants.py
b/devel-common/src/docs/utils/conf_constants.py
index 1bd89b3e8d6..b885ff0941e 100644
--- a/devel-common/src/docs/utils/conf_constants.py
+++ b/devel-common/src/docs/utils/conf_constants.py
@@ -265,7 +265,6 @@ def get_autodoc_mock_imports() -> list[str]:
"smbclient",
"snowflake",
"sqlalchemy-drill",
- "sshtunnel",
"telegram",
"tenacity",
"vertica_python",
diff --git a/docker-tests/tests/docker_tests/test_prod_image.py
b/docker-tests/tests/docker_tests/test_prod_image.py
index 585f6a39d60..808013538fe 100644
--- a/docker-tests/tests/docker_tests/test_prod_image.py
+++ b/docker-tests/tests/docker_tests/test_prod_image.py
@@ -217,7 +217,7 @@ class TestPythonPackages:
"pyodbc": ["pyodbc"],
"redis": ["redis"],
"sendgrid": ["sendgrid"],
- "sftp/ssh": ["paramiko", "sshtunnel"],
+ "sftp/ssh": ["paramiko"],
"slack": ["slack_sdk"],
"statsd": ["statsd"],
"providers": [provider[len("apache-") :].replace("-", ".") for
provider in REGULAR_IMAGE_PROVIDERS],
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 1008b37f12e..4465ada3dfe 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -104,6 +104,7 @@ astroid
Async
async
asyncio
+asyncssh
athena
Atlassian
atlassian
@@ -1503,8 +1504,7 @@ srv
ssc
ssd
SSHClient
-sshtunnel
-SSHTunnelForwarder
+SSHTunnel
ssl
sslcert
sslkey
diff --git a/providers/ssh/docs/index.rst b/providers/ssh/docs/index.rst
index 2359510402b..c27343398ca 100644
--- a/providers/ssh/docs/index.rst
+++ b/providers/ssh/docs/index.rst
@@ -96,7 +96,6 @@ PIP package Version required
``apache-airflow-providers-common-compat`` ``>=1.12.0``
``asyncssh`` ``>=2.12.0``
``paramiko`` ``>=3.4.0,<4.0.0``
-``sshtunnel`` ``>=0.3.2``
========================================== ==================
Cross provider package dependencies
diff --git a/providers/ssh/pyproject.toml b/providers/ssh/pyproject.toml
index 1f36149da19..9fbce1d87cd 100644
--- a/providers/ssh/pyproject.toml
+++ b/providers/ssh/pyproject.toml
@@ -64,7 +64,7 @@ dependencies = [
"asyncssh>=2.12.0",
# TODO: Bump to >= 4.0.0 once
https://github.com/apache/airflow/issues/54079 is handled
"paramiko>=3.4.0,<4.0.0",
- "sshtunnel>=0.3.2",
+
]
[dependency-groups]
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index 7d5f29c4278..d029a50772b 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -19,7 +19,6 @@
from __future__ import annotations
-import logging
import os
from base64 import decodebytes
from collections.abc import Sequence
@@ -30,11 +29,11 @@ from typing import Any
import paramiko
from paramiko.config import SSH_PORT
-from sshtunnel import SSHTunnelForwarder
from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random
from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+from airflow.providers.ssh.tunnel import AsyncSSHTunnel, SSHTunnel
from airflow.utils.platform import getuser
try:
@@ -366,51 +365,37 @@ class SSHHook(BaseHook):
def get_tunnel(
self, remote_port: int, remote_host: str = "localhost", local_port:
int | None = None
- ) -> SSHTunnelForwarder:
+ ) -> SSHTunnel:
"""
- Create a tunnel between two hosts.
+ Create a local port-forwarding tunnel through the SSH connection.
- This is conceptually similar to ``ssh -L
<LOCAL_PORT>:host:<REMOTE_PORT>``.
+ This is conceptually similar to ``ssh -L
<LOCAL_PORT>:<remote_host>:<REMOTE_PORT>``.
- :param remote_port: The remote port to create a tunnel to
- :param remote_host: The remote host to create a tunnel to (default
localhost)
- :param local_port: The local port to attach the tunnel to
+ The returned ``SSHTunnel`` should be used as a context manager::
- :return: sshtunnel.SSHTunnelForwarder object
- """
- if local_port:
- local_bind_address: tuple[str, int] | tuple[str] = ("localhost",
local_port)
- else:
- local_bind_address = ("localhost",)
-
- tunnel_kwargs = {
- "ssh_port": self.port,
- "ssh_username": self.username,
- "ssh_pkey": self.key_file or self.pkey,
- "ssh_proxy": self.host_proxy,
- "local_bind_address": local_bind_address,
- "remote_bind_address": (remote_host, remote_port),
- "logger": self.log,
- }
+ with hook.get_tunnel(remote_port=5432) as tunnel:
+ connect_to("localhost", tunnel.local_bind_port)
- if self.password:
- password = self.password.strip()
- tunnel_kwargs.update(
- ssh_password=password,
- )
- else:
- tunnel_kwargs.update(
- host_pkey_directories=None,
- )
+ The ``.start()`` / ``.stop()`` methods still work but are deprecated.
- if not hasattr(self.log, "handlers"):
- # We need to not hit this
https://github.com/pahaz/sshtunnel/blob/dc0732884379a19a21bf7a49650d0708519ec54f/sshtunnel.py#L238-L239
- paramkio_log = logging.getLogger("paramiko.transport")
- paramkio_log.addHandler(logging.NullHandler())
- paramkio_log.propagate = True
- client = SSHTunnelForwarder(self.remote_host, **tunnel_kwargs)
+ .. versionchanged:: 4.4.0
+ Returns ``SSHTunnel`` instead of ``sshtunnel.SSHTunnelForwarder``.
+ The tunnel now reuses the hook's SSH connection (``get_conn()``)
+ instead of establishing a separate one.
- return client
+ :param remote_port: The remote port to create a tunnel to
+ :param remote_host: The remote host to create a tunnel to (default
localhost)
+ :param local_port: The local port to attach the tunnel to (None for
ephemeral)
+ :return: SSHTunnel instance
+ """
+ ssh_client = self.get_conn()
+ return SSHTunnel(
+ ssh_client=ssh_client,
+ remote_host=remote_host,
+ remote_port=remote_port,
+ local_port=local_port,
+ logger=self.log,
+ )
def _pkey_from_private_key(self, private_key: str, passphrase: str | None
= None) -> paramiko.PKey:
"""
@@ -662,6 +647,30 @@ class SSHHookAsync(BaseHook):
result = await ssh_conn.run(command, timeout=timeout, check=False)
return result.exit_status or 0, result.stdout or "", result.stderr
or ""
+ async def get_tunnel(
+ self, remote_port: int, remote_host: str = "localhost", local_port:
int | None = None
+ ) -> AsyncSSHTunnel:
+ """
+ Create an async local port-forwarding tunnel through the SSH
connection.
+
+ Usage::
+
+ async with await hook.get_tunnel(remote_port=5432) as tunnel:
+ connect_to("localhost", tunnel.local_bind_port)
+
+ :param remote_port: The remote port to create a tunnel to
+ :param remote_host: The remote host to create a tunnel to (default
localhost)
+ :param local_port: The local port to attach the tunnel to (None for
ephemeral)
+ :return: AsyncSSHTunnel instance
+ """
+ ssh_conn = await self._get_conn()
+ return AsyncSSHTunnel(
+ ssh_conn=ssh_conn,
+ remote_host=remote_host,
+ remote_port=remote_port,
+ local_port=local_port,
+ )
+
async def run_command_output(self, command: str, timeout: float | None =
None) -> str:
"""
Execute a command and return stdout.
diff --git a/providers/ssh/src/airflow/providers/ssh/tunnel.py
b/providers/ssh/src/airflow/providers/ssh/tunnel.py
new file mode 100644
index 00000000000..69bb00743e6
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/tunnel.py
@@ -0,0 +1,427 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+SSH tunnel implementations for the Airflow SSH provider.
+
+This module provides ``SSHTunnel`` (sync, paramiko-based) and
``AsyncSSHTunnel``
+(async, asyncssh-based) as replacements for the removed
``sshtunnel.SSHTunnelForwarder``.
+
+**SSHTunnel** reuses an already-connected ``paramiko.SSHClient`` from
+``SSHHook.get_conn()``, so all authentication and proxy configuration is
inherited
+automatically. It binds a local TCP socket and forwards connections to a remote
+host/port through the SSH transport using ``open_channel('direct-tcpip',
...)``.
+
+**AsyncSSHTunnel** wraps ``asyncssh.forward_local_port()`` and is intended for
use
+with ``SSHHookAsync``.
+
+Migration from ``sshtunnel.SSHTunnelForwarder``
+-----------------------------------------------
+
+Before::
+
+ from sshtunnel import SSHTunnelForwarder
+
+ tunnel = hook.get_tunnel(remote_port=5432)
+ tunnel.start()
+ # use tunnel.local_bind_port
+ tunnel.stop()
+
+After::
+
+ with hook.get_tunnel(remote_port=5432) as tunnel:
+ # use tunnel.local_bind_port
+
+The ``.start()`` / ``.stop()`` methods still exist but emit deprecation
warnings.
+Use the context manager interface instead.
+"""
+
+from __future__ import annotations
+
+import logging
+import socket
+import threading
+import warnings
+from select import select
+from typing import TYPE_CHECKING
+
+from airflow.exceptions import AirflowProviderDeprecationWarning
+
+if TYPE_CHECKING:
+ import asyncssh
+ import paramiko
+
+ from airflow.sdk.types import Logger
+
+
+log = logging.getLogger(__name__)
+
+# Attributes that existed on SSHTunnelForwarder but not on SSHTunnel.
+# Used by __getattr__ to provide a helpful migration message.
+_SSHTUNNELFORWARDER_ATTRS = frozenset(
+ {
+ "tunnel_is_up",
+ "skip_tunnel_checkup",
+ "ssh_host",
+ "ssh_port",
+ "ssh_username",
+ "ssh_password",
+ "ssh_pkey",
+ "ssh_proxy",
+ "local_bind_address",
+ "local_bind_addresses",
+ "local_bind_host",
+ "local_bind_hosts",
+ "remote_bind_address",
+ "remote_bind_addresses",
+ "tunnel_bindings",
+ "is_alive",
+ "raise_fwd_exc",
+ "daemon_forward_servers",
+ "daemon_transport",
+ }
+)
+
+
+class SSHTunnel:
+ """
+ Local port-forwarding tunnel over an existing paramiko SSH connection.
+
+ This replaces ``sshtunnel.SSHTunnelForwarder`` by using the SSH client's
+ transport directly via ``open_channel('direct-tcpip', ...)``.
+
+ The recommended usage is as a context manager::
+
+ client = hook.get_conn()
+ with SSHTunnel(client, "dbhost", 5432) as tunnel:
+ connect_to_db("localhost", tunnel.local_bind_port)
+
+ :param ssh_client: An already-connected ``paramiko.SSHClient``.
+ :param remote_host: The destination host to forward to (from the SSH
server's perspective).
+ :param remote_port: The destination port to forward to.
+ :param local_port: Local port to bind. ``None`` means an OS-assigned
ephemeral port.
+ :param logger: Optional logger instance. Falls back to the module logger.
+ """
+
+ def __init__(
+ self,
+ ssh_client: paramiko.SSHClient,
+ remote_host: str,
+ remote_port: int,
+ local_port: int | None = None,
+ logger: logging.Logger | Logger | None = None,
+ ) -> None:
+ self._ssh_client = ssh_client
+ self._remote_host = remote_host
+ self._remote_port = remote_port
+ self._logger: logging.Logger | Logger = logger or log
+ self._server_socket: socket.socket | None = None
+ self._thread: threading.Thread | None = None
+ # Self-pipe for waking the select loop on shutdown
+ self._shutdown_r: socket.socket | None = None
+ self._shutdown_w: socket.socket | None = None
+ self._running = False
+
+ # Bind the listening socket eagerly so local_bind_port is available
+ # before entering the context manager or calling start().
+ self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
1)
+ try:
+ self._server_socket.bind(("localhost", local_port or 0))
+ self._server_socket.listen(5)
+ except OSError:
+ self._server_socket.close()
+ raise
+ self._server_socket.setblocking(False)
+
+ # -- Public properties ---------------------------------------------------
+
+ @property
+ def local_bind_port(self) -> int:
+ """Return the local port the tunnel is listening on."""
+ if self._server_socket is None:
+ raise RuntimeError("Tunnel socket is not bound")
+ return self._server_socket.getsockname()[1]
+
+ @property
+ def local_bind_address(self) -> tuple[str, int]:
+ """Return ``('localhost', <port>)`` — the local address the tunnel is
listening on."""
+ return ("localhost", self.local_bind_port)
+
+ # -- Context manager -----------------------------------------------------
+
+ def __enter__(self) -> SSHTunnel:
+ self._start_forwarding()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ self._stop_forwarding()
+
+ # -- Deprecated start/stop -----------------------------------------------
+
+ def start(self) -> None:
+ """Start the tunnel. **Deprecated** — use the context manager
interface instead."""
+ warnings.warn(
+ "SSHTunnel.start() is deprecated. Use the context manager
interface: "
+ "`with hook.get_tunnel(...) as tunnel:`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ self._start_forwarding()
+
+ def stop(self) -> None:
+ """Stop the tunnel. **Deprecated** — use the context manager interface
instead."""
+ warnings.warn(
+ "SSHTunnel.stop() is deprecated. Use the context manager
interface: "
+ "`with hook.get_tunnel(...) as tunnel:`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ self._stop_forwarding()
+
+ # -- Migration helper ----------------------------------------------------
+
+ def __getattr__(self, name: str):
+ if name in _SSHTUNNELFORWARDER_ATTRS:
+ raise AttributeError(
+ f"'{type(self).__name__}' has no attribute '{name}'. "
+ f"sshtunnel.SSHTunnelForwarder has been replaced by SSHTunnel.
"
+ f"Use the context manager interface and .local_bind_port
instead."
+ )
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute
'{name}'")
+
+ # -- Internal ------------------------------------------------------------
+
+ def _start_forwarding(self) -> None:
+ """Start the forwarding thread."""
+ if self._running:
+ return
+ self._shutdown_r, self._shutdown_w = socket.socketpair()
+ self._running = True
+ self._thread = threading.Thread(target=self._serve_forever,
daemon=True)
+ self._thread.start()
+
+ def _stop_forwarding(self) -> None:
+ """Signal the forwarding thread to stop and wait for it."""
+ if not self._running:
+ return
+ self._running = False
+ # Wake the select loop
+ if self._shutdown_w is not None:
+ try:
+ self._shutdown_w.send(b"\x00")
+ except OSError:
+ pass
+ if self._thread is not None:
+ self._thread.join(timeout=5)
+ self._thread = None
+ # Close the shutdown pair
+ for sock in (self._shutdown_r, self._shutdown_w):
+ if sock is not None:
+ try:
+ sock.close()
+ except OSError:
+ pass
+ self._shutdown_r = None
+ self._shutdown_w = None
+ # Close the server socket
+ if self._server_socket is not None:
+ try:
+ self._server_socket.close()
+ except OSError:
+ pass
+ self._server_socket = None
+
+ def _serve_forever(self) -> None:
+ """Accept connections on the local socket and forward them through
SSH."""
+ if self._server_socket is None or self._shutdown_r is None:
+ return
+ server_socket = self._server_socket
+ shutdown_r = self._shutdown_r
+ active_channels: list[tuple[socket.socket, paramiko.Channel]] = []
+ try:
+ while self._running:
+ read_fds: list[socket.socket | paramiko.Channel] =
[server_socket, shutdown_r]
+ for local_sock, chan in active_channels:
+ read_fds.append(local_sock)
+ read_fds.append(chan)
+
+ try:
+ readable, _, _ = select(read_fds, [], [], 1.0)
+ except (OSError, ValueError):
+ break
+
+ for fd in readable:
+ if fd is shutdown_r:
+ return
+ if fd is server_socket:
+ self._accept_connection(active_channels)
+ else:
+ self._forward_data(fd, active_channels)
+
+ # Clean up closed channels
+ active_channels = [(s, c) for s, c in active_channels if not
(s.fileno() == -1 or c.closed)]
+ finally:
+ for local_sock, chan in active_channels:
+ self._close_pair(local_sock, chan)
+
+ def _accept_connection(self, active_channels: list[tuple[socket.socket,
paramiko.Channel]]) -> None:
+ """Accept a new local connection and open an SSH channel for it."""
+ if self._server_socket is None:
+ return
+ try:
+ client_sock, addr = self._server_socket.accept()
+ except OSError:
+ return
+
+ transport = self._ssh_client.get_transport()
+ if transport is None or not transport.is_active():
+ self._logger.warning("SSH transport is not active, rejecting
connection from %s", addr)
+ client_sock.close()
+ return
+
+ try:
+ channel = transport.open_channel(
+ "direct-tcpip",
+ (self._remote_host, self._remote_port),
+ addr,
+ )
+ except Exception:
+ self._logger.warning(
+ "Failed to open SSH channel to %s:%s",
+ self._remote_host,
+ self._remote_port,
+ exc_info=True,
+ )
+ client_sock.close()
+ return
+
+ if channel is None:
+ self._logger.warning("SSH channel request was rejected")
+ client_sock.close()
+ return
+
+ active_channels.append((client_sock, channel))
+
+ def _forward_data(
+ self,
+ fd: socket.socket | paramiko.Channel,
+ active_channels: list[tuple[socket.socket, paramiko.Channel]],
+ ) -> None:
+ """Forward data between a local socket and its paired SSH channel."""
+ for local_sock, chan in active_channels:
+ if fd is local_sock:
+ try:
+ data = local_sock.recv(16384)
+ except OSError:
+ data = b""
+ if not data:
+ self._close_pair(local_sock, chan)
+ return
+ try:
+ chan.sendall(data)
+ except OSError:
+ self._logger.warning("Error sending data to SSH channel,
closing connection")
+ self._close_pair(local_sock, chan)
+ return
+ if fd is chan:
+ try:
+ data = chan.recv(16384)
+ except OSError:
+ data = b""
+ if not data:
+ self._close_pair(local_sock, chan)
+ return
+ try:
+ local_sock.sendall(data)
+ except OSError:
+ self._logger.warning("Error sending data to local socket,
closing connection")
+ self._close_pair(local_sock, chan)
+ return
+
+ @staticmethod
+ def _close_pair(local_sock: socket.socket, chan: paramiko.Channel) -> None:
+ """Close both ends of a forwarded connection."""
+ for closeable in (chan, local_sock):
+ try:
+ closeable.close()
+ except OSError:
+ pass
+
+
+class AsyncSSHTunnel:
+ """
+ Async local port-forwarding tunnel over an asyncssh SSH connection.
+
+ This wraps ``asyncssh.SSHClientConnection.forward_local_port()`` and is
+ intended for use with ``SSHHookAsync``.
+
+ Usage::
+
+ async with await hook.get_tunnel(remote_port=5432) as tunnel:
+ connect_to_db("localhost", tunnel.local_bind_port)
+
+ On exit, both the forwarding listener and the underlying SSH connection
+ are closed.
+
+ :param ssh_conn: An ``asyncssh.SSHClientConnection``.
+ :param remote_host: The destination host to forward to.
+ :param remote_port: The destination port to forward to.
+ :param local_port: Local port to bind. ``None`` means an OS-assigned
ephemeral port.
+ """
+
+ def __init__(
+ self,
+ ssh_conn: asyncssh.SSHClientConnection,
+ remote_host: str,
+ remote_port: int,
+ local_port: int | None = None,
+ ) -> None:
+ self._ssh_conn = ssh_conn
+ self._remote_host = remote_host
+ self._remote_port = remote_port
+ self._local_port = local_port or 0
+ self._listener: asyncssh.SSHListener | None = None
+
+ @property
+ def local_bind_port(self) -> int:
+ """Return the local port the tunnel is listening on."""
+ if self._listener is None:
+ raise RuntimeError("Tunnel is not started. Use `async with` to
start it.")
+ return self._listener.get_port()
+
+ async def __aenter__(self) -> AsyncSSHTunnel:
+ try:
+ self._listener = await self._ssh_conn.forward_local_port(
+ "localhost",
+ self._local_port,
+ self._remote_host,
+ self._remote_port,
+ )
+ except BaseException:
+ self._ssh_conn.close()
+ await self._ssh_conn.wait_closed()
+ raise
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
+ if self._listener is not None:
+ self._listener.close()
+ await self._listener.wait_closed()
+ self._listener = None
+ self._ssh_conn.close()
+ await self._ssh_conn.wait_closed()
diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
index 62a91e21afd..74652cd3258 100644
--- a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
+++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py
@@ -354,8 +354,9 @@ class TestSSHHook:
auth_timeout=None,
)
- @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder")
- def test_tunnel_with_password(self, ssh_mock):
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_password(self, ssh_client_mock, tunnel_mock):
hook = SSHHook(
remote_host="remote_host",
port="port",
@@ -366,34 +367,27 @@ class TestSSHHook:
)
with hook.get_tunnel(1234):
- ssh_mock.assert_called_once_with(
- "remote_host",
- ssh_port="port",
- ssh_username="username",
- ssh_password="password",
- ssh_pkey="fake.file",
- ssh_proxy=None,
- local_bind_address=("localhost",),
- remote_bind_address=("localhost", 1234),
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=1234,
+ local_port=None,
logger=hook.log,
)
- @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder")
- def test_tunnel_without_password(self, ssh_mock):
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_without_password(self, ssh_client_mock, tunnel_mock):
hook = SSHHook(
remote_host="remote_host", port="port", username="username",
conn_timeout=10, key_file="fake.file"
)
with hook.get_tunnel(1234):
- ssh_mock.assert_called_once_with(
- "remote_host",
- ssh_port="port",
- ssh_username="username",
- ssh_pkey="fake.file",
- ssh_proxy=None,
- local_bind_address=("localhost",),
- remote_bind_address=("localhost", 1234),
- host_pkey_directories=None,
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=1234,
+ local_port=None,
logger=hook.log,
)
@@ -408,8 +402,9 @@ class TestSSHHook:
ssh_hook =
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS)
assert ssh_hook.look_for_keys is False
- @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder")
- def test_tunnel_with_private_key(self, ssh_mock):
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_private_key(self, ssh_client_mock, tunnel_mock):
hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
remote_host="remote_host",
@@ -419,20 +414,17 @@ class TestSSHHook:
)
with hook.get_tunnel(1234):
- ssh_mock.assert_called_once_with(
- "remote_host",
- ssh_port="port",
- ssh_username="username",
- ssh_pkey=TEST_PKEY,
- ssh_proxy=None,
- local_bind_address=("localhost",),
- remote_bind_address=("localhost", 1234),
- host_pkey_directories=None,
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=1234,
+ local_port=None,
logger=hook.log,
)
- @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder")
- def test_tunnel_with_private_key_passphrase(self, ssh_mock):
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_private_key_passphrase(self, ssh_client_mock,
tunnel_mock):
hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA,
remote_host="remote_host",
@@ -442,20 +434,17 @@ class TestSSHHook:
)
with hook.get_tunnel(1234):
- ssh_mock.assert_called_once_with(
- "remote_host",
- ssh_port="port",
- ssh_username="username",
- ssh_pkey=TEST_PKEY,
- ssh_proxy=None,
- local_bind_address=("localhost",),
- remote_bind_address=("localhost", 1234),
- host_pkey_directories=None,
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=1234,
+ local_port=None,
logger=hook.log,
)
- @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder")
- def test_tunnel_with_private_key_ecdsa(self, ssh_mock):
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_private_key_ecdsa(self, ssh_client_mock, tunnel_mock):
hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_ECDSA_EXTRA,
remote_host="remote_host",
@@ -465,15 +454,11 @@ class TestSSHHook:
)
with hook.get_tunnel(1234):
- ssh_mock.assert_called_once_with(
- "remote_host",
- ssh_port="port",
- ssh_username="username",
- ssh_pkey=TEST_PKEY_ECDSA,
- ssh_proxy=None,
- local_bind_address=("localhost",),
- remote_bind_address=("localhost", 1234),
- host_pkey_directories=None,
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=1234,
+ local_port=None,
logger=hook.log,
)
@@ -945,3 +930,141 @@ class TestSSHHook:
banner_timeout=30.0,
auth_timeout=None,
)
+
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_local_port(self, ssh_client_mock, tunnel_mock):
+ hook = SSHHook(
+ remote_host="remote_host",
+ port="port",
+ username="username",
+ conn_timeout=10,
+ key_file="fake.file",
+ )
+
+ with hook.get_tunnel(remote_port=5432, local_port=15432):
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="localhost",
+ remote_port=5432,
+ local_port=15432,
+ logger=hook.log,
+ )
+
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHTunnel", autospec=True)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
+ def test_tunnel_with_remote_host(self, ssh_client_mock, tunnel_mock):
+ hook = SSHHook(
+ remote_host="remote_host",
+ port="port",
+ username="username",
+ conn_timeout=10,
+ key_file="fake.file",
+ )
+
+ with hook.get_tunnel(remote_port=5432, remote_host="dbhost"):
+ tunnel_mock.assert_called_once_with(
+ ssh_client=ssh_client_mock.return_value,
+ remote_host="dbhost",
+ remote_port=5432,
+ local_port=None,
+ logger=hook.log,
+ )
+
+
+class TestSSHTunnel:
+ """Tests for the SSHTunnel class."""
+
+ def test_deprecation_warning_on_start(self):
+ """Calling .start() should emit a DeprecationWarning."""
+ import warnings
+
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ tunnel.start()
+ assert len(w) == 1
+ assert issubclass(w[0].category, DeprecationWarning)
+ assert "deprecated" in str(w[0].message).lower()
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_deprecation_warning_on_stop(self):
+ """Calling .stop() should emit a DeprecationWarning."""
+ import warnings
+
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432)
+ try:
+ tunnel._start_forwarding()
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ tunnel.stop()
+ assert len(w) == 1
+ assert issubclass(w[0].category, DeprecationWarning)
+ assert "deprecated" in str(w[0].message).lower()
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_getattr_migration_hint(self):
+ """Accessing SSHTunnelForwarder-specific attrs raises AttributeError
with migration hint."""
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432)
+ try:
+ with pytest.raises(AttributeError, match="SSHTunnelForwarder has
been replaced"):
+ tunnel.tunnel_is_up
+ with pytest.raises(AttributeError, match="SSHTunnelForwarder has
been replaced"):
+ tunnel.skip_tunnel_checkup
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_getattr_unknown_attr(self):
+ """Accessing a truly unknown attr raises a normal AttributeError."""
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432)
+ try:
+ with pytest.raises(AttributeError, match="has no attribute
'nonexistent'"):
+ tunnel.nonexistent
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_ephemeral_port(self):
+ """When local_port is None, an ephemeral port is assigned."""
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432)
+ try:
+ assert tunnel.local_bind_port > 0
+ assert tunnel.local_bind_address == ("localhost",
tunnel.local_bind_port)
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_explicit_port(self):
+ """When local_port is given, it binds to that port."""
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ tunnel = SSHTunnel(ssh_client, "remote", 5432, local_port=19999)
+ try:
+ assert tunnel.local_bind_port == 19999
+ finally:
+ tunnel._stop_forwarding()
+
+ def test_context_manager(self):
+ """Context manager enters and exits cleanly."""
+ from airflow.providers.ssh.tunnel import SSHTunnel
+
+ ssh_client = mock.MagicMock(spec=paramiko.SSHClient)
+ with SSHTunnel(ssh_client, "remote", 5432) as tunnel:
+ assert tunnel.local_bind_port > 0
diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
b/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
index ee92c809be8..1e56fb4c372 100644
--- a/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
+++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
@@ -170,3 +170,65 @@ class TestSSHHookAsync:
with mock.patch.object(hook, "run_command", return_value=(0, "test
output", "")):
output = await hook.run_command_output("echo test")
assert output == "test output"
+
+ @pytest.mark.asyncio
+ async def test_get_tunnel(self):
+ """Test that get_tunnel returns an AsyncSSHTunnel."""
+ from airflow.providers.ssh.tunnel import AsyncSSHTunnel
+
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+
+ mock_ssh_conn = mock.MagicMock()
+
+ with mock.patch.object(hook, "_get_conn", new_callable=mock.AsyncMock,
return_value=mock_ssh_conn):
+ tunnel = await hook.get_tunnel(remote_port=5432)
+ assert isinstance(tunnel, AsyncSSHTunnel)
+
+ @pytest.mark.asyncio
+ async def test_get_tunnel_async_context_manager(self):
+ """Test AsyncSSHTunnel as async context manager with
local_bind_port."""
+
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+
+ mock_listener = mock.MagicMock()
+ mock_listener.get_port.return_value = 15432
+ mock_listener.close = mock.MagicMock()
+ mock_listener.wait_closed = mock.AsyncMock()
+
+ mock_ssh_conn = mock.MagicMock()
+ mock_ssh_conn.forward_local_port =
mock.AsyncMock(return_value=mock_listener)
+ mock_ssh_conn.close = mock.MagicMock()
+ mock_ssh_conn.wait_closed = mock.AsyncMock()
+
+ with mock.patch.object(hook, "_get_conn", new_callable=mock.AsyncMock,
return_value=mock_ssh_conn):
+ async with await hook.get_tunnel(remote_port=5432) as tunnel:
+ assert tunnel.local_bind_port == 15432
+
mock_ssh_conn.forward_local_port.assert_called_once_with("localhost", 0,
"localhost", 5432)
+
+ # Verify cleanup
+ mock_listener.close.assert_called_once()
+ mock_listener.wait_closed.assert_called_once()
+ mock_ssh_conn.close.assert_called_once()
+ mock_ssh_conn.wait_closed.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_get_tunnel_with_local_port(self):
+ """Test AsyncSSHTunnel with explicit local port."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+
+ mock_listener = mock.MagicMock()
+ mock_listener.get_port.return_value = 19999
+ mock_listener.close = mock.MagicMock()
+ mock_listener.wait_closed = mock.AsyncMock()
+
+ mock_ssh_conn = mock.MagicMock()
+ mock_ssh_conn.forward_local_port =
mock.AsyncMock(return_value=mock_listener)
+ mock_ssh_conn.close = mock.MagicMock()
+ mock_ssh_conn.wait_closed = mock.AsyncMock()
+
+ with mock.patch.object(hook, "_get_conn", new_callable=mock.AsyncMock,
return_value=mock_ssh_conn):
+ async with await hook.get_tunnel(remote_port=5432,
local_port=19999) as tunnel:
+ assert tunnel.local_bind_port == 19999
+ mock_ssh_conn.forward_local_port.assert_called_once_with(
+ "localhost", 19999, "localhost", 5432
+ )
diff --git a/providers/ssh/tests/unit/ssh/test_tunnel.py
b/providers/ssh/tests/unit/ssh/test_tunnel.py
new file mode 100644
index 00000000000..12cf24e7dc5
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/test_tunnel.py
@@ -0,0 +1,171 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import socket
+import threading
+from unittest import mock
+
+import paramiko
+import pytest
+
+from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.ssh.tunnel import SSHTunnel
+
+
[email protected]
+def mock_ssh_client():
+ client = mock.MagicMock(spec=paramiko.SSHClient)
+ transport = mock.MagicMock(spec=paramiko.Transport)
+ transport.is_active.return_value = True
+ client.get_transport.return_value = transport
+ return client
+
+
+class TestSSHTunnel:
+ def test_local_bind_port_is_available_after_init(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ port = tunnel.local_bind_port
+ assert isinstance(port, int)
+ assert port > 0
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_local_bind_address(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ host, port = tunnel.local_bind_address
+ assert host == "localhost"
+ assert port == tunnel.local_bind_port
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_explicit_local_port(self, mock_ssh_client):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(("localhost", 0))
+ free_port = sock.getsockname()[1]
+ sock.close()
+
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432,
local_port=free_port)
+ try:
+ assert tunnel.local_bind_port == free_port
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_context_manager_starts_and_stops(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ with tunnel as t:
+ assert t is tunnel
+ assert tunnel._running is True
+ assert tunnel._thread is not None
+ assert tunnel._thread.is_alive()
+ assert tunnel._running is False
+ assert tunnel._thread is None
+
+ def test_start_emits_deprecation_warning(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ with pytest.warns(AirflowProviderDeprecationWarning,
match="SSHTunnel.start"):
+ tunnel.start()
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_stop_emits_deprecation_warning(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ tunnel._start_forwarding()
+ with pytest.warns(AirflowProviderDeprecationWarning,
match="SSHTunnel.stop"):
+ tunnel.stop()
+
+ def test_getattr_migration_error(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ with pytest.raises(AttributeError, match="SSHTunnelForwarder has
been replaced"):
+ tunnel.tunnel_is_up
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_getattr_unknown_attribute(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ with pytest.raises(AttributeError, match="has no attribute
'nonexistent'"):
+ tunnel.nonexistent
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_double_start_is_noop(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ tunnel._start_forwarding()
+ thread1 = tunnel._thread
+ tunnel._start_forwarding()
+ assert tunnel._thread is thread1
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_stop_without_start_is_noop(self, mock_ssh_client):
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ try:
+ tunnel._stop_forwarding()
+ finally:
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
+
+ def test_forwarding_thread_accepts_and_forwards(self, mock_ssh_client):
+ """Test that data is forwarded between local socket and SSH channel."""
+ channel = mock.MagicMock(spec=paramiko.Channel)
+ channel.closed = False
+ channel.recv.return_value = b"response"
+
+ transport = mock_ssh_client.get_transport.return_value
+ transport.open_channel.return_value = channel
+
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432)
+ with tunnel:
+ client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ client.connect(("localhost", tunnel.local_bind_port))
+ client.sendall(b"hello")
+ # Give the forwarding thread time to process
+ threading.Event().wait(0.2)
+ finally:
+ client.close()
+
+ def test_custom_logger(self, mock_ssh_client):
+ custom_logger = mock.MagicMock()
+ tunnel = SSHTunnel(mock_ssh_client, "remotehost", 5432,
logger=custom_logger)
+ try:
+ assert tunnel._logger is custom_logger
+ finally:
+ tunnel._stop_forwarding()
+ if tunnel._server_socket is not None:
+ tunnel._server_socket.close()
diff --git a/uv.lock b/uv.lock
index 77805548b68..83a4cab1bc0 100644
--- a/uv.lock
+++ b/uv.lock
@@ -7052,7 +7052,6 @@ dependencies = [
{ name = "apache-airflow-providers-common-compat" },
{ name = "asyncssh" },
{ name = "paramiko" },
- { name = "sshtunnel" },
]
[package.dev-dependencies]
@@ -7072,7 +7071,6 @@ requires-dist = [
{ name = "apache-airflow-providers-common-compat", editable =
"providers/common/compat" },
{ name = "asyncssh", specifier = ">=2.12.0" },
{ name = "paramiko", specifier = ">=3.4.0,<4.0.0" },
- { name = "sshtunnel", specifier = ">=0.3.2" },
]
[package.metadata.requires-dev]
@@ -20615,18 +20613,6 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/cc/a7/29935d7b8572ba378514864bd3d99dfb01518a5c20366b794b6cdb283356/sshfs-2025.11.0-py3-none-any.whl",
hash =
"sha256:90a3a2e815d28a0e8475f10fe8ef0127a507b60f77df17c3473c1a28e78c7f4b", size
= 16420, upload-time = "2025-12-05T05:07:21.935Z" },
]
-[[package]]
-name = "sshtunnel"
-version = "0.4.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "paramiko" },
-]
-sdist = { url =
"https://files.pythonhosted.org/packages/8d/ad/4c587adf79865be268ee0b6bd52cfaa7a75d827a23ced072dc5ab554b4af/sshtunnel-0.4.0.tar.gz",
hash =
"sha256:e7cb0ea774db81bf91844db22de72a40aae8f7b0f9bb9ba0f666d474ef6bf9fc", size
= 62716, upload-time = "2021-01-11T13:26:32.975Z" }
-wheels = [
- { url =
"https://files.pythonhosted.org/packages/58/13/8476c4328dcadfe26f8bd7f3a1a03bf9ddb890a7e7b692f54a179bc525bf/sshtunnel-0.4.0-py2.py3-none-any.whl",
hash =
"sha256:98e54c26f726ab8bd42b47a3a21fca5c3e60f58956f0f70de2fb8ab0046d0606", size
= 24729, upload-time = "2021-01-11T13:26:29.969Z" },
-]
-
[[package]]
name = "sspilib"
version = "0.5.0"