Copilot commented on code in PR #64465:
URL: https://github.com/apache/airflow/pull/64465#discussion_r3066489537


##########
providers/sftp/src/airflow/providers/sftp/pools/sftp.py:
##########
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from contextlib import asynccontextmanager, suppress
+from threading import Lock
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import conf
+from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+    import asyncssh
+
+
+class SFTPClientPool(LoggingMixin):
+    """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and 
SFTP clients alive until exit, and limits concurrent usage to pool_size."""
+
+    _instances: dict[str, SFTPClientPool] = {}
+    _lock = Lock()
+
+    def __new__(cls, sftp_conn_id: str, pool_size: int | None = None):
+        with cls._lock:
+            if sftp_conn_id not in cls._instances:
+                instance = super().__new__(cls)
+                instance._pre_init(sftp_conn_id, pool_size)
+                cls._instances[sftp_conn_id] = instance
+            else:
+                # Validate that subsequent constructions for the same 
sftp_conn_id
+                # do not request a different pool_size, which would otherwise 
be
+                # silently ignored due to the singleton behavior.
+                instance = cls._instances[sftp_conn_id]
+                requested_pool_size = pool_size or conf.getint("core", 
"parallelism")
+                if instance.pool_size != requested_pool_size:
+                    raise ValueError(
+                        f"SFTPClientPool for sftp_conn_id '{sftp_conn_id}' has 
already been "
+                        f"initialised with pool_size={instance.pool_size}, but 
a different "
+                        f"pool_size={requested_pool_size} was requested."
+                    )
+            return cls._instances[sftp_conn_id]
+
+    def __init__(self, sftp_conn_id: str, pool_size: int | None = None):
+        # Prevent parent __init__ argument errors
+        pass
+
+    def _pre_init(self, sftp_conn_id: str, pool_size: int | None):
+        """Initialize the Singleton structure synchronously."""
+        LoggingMixin.__init__(self)
+        self.sftp_conn_id = sftp_conn_id
+        self.pool_size = pool_size or conf.getint("core", "parallelism")
+        self._idle: asyncio.LifoQueue[tuple[asyncssh.SSHClientConnection, 
asyncssh.SFTPClient]] = (
+            asyncio.LifoQueue()
+        )
+        self._in_use: set[tuple[asyncssh.SSHClientConnection, 
asyncssh.SFTPClient]] = set()
+        self._semaphore = asyncio.Semaphore(self.pool_size)
+        self._init_lock = asyncio.Lock()
+        self._initialized = False
+        self._closed = False
+        self.log.info("SFTPClientPool initialised...")
+
+    async def _ensure_initialized(self):
+        """Ensure pool is usable (also handles re-opening after close)."""
+        if self._initialized and not self._closed:
+            return
+
+        async with self._init_lock:
+            if not self._initialized or self._closed:
+                self.log.info("Initializing / resetting SFTPClientPool for 
'%s'", self.sftp_conn_id)
+                self._idle = asyncio.LifoQueue()
+                self._in_use.clear()
+                self._semaphore = asyncio.Semaphore(self.pool_size)
+                self._closed = False
+                self._initialized = True
+
+    async def _create_connection(
+        self,
+    ) -> tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]:
+        ssh_conn = await 
SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)._get_conn()
+        sftp = await ssh_conn.start_sftp_client()
+        self.log.info("Created new SFTP connection for sftp_conn_id '%s'", 
self.sftp_conn_id)
+        return ssh_conn, sftp
+
+    async def acquire(self):
+        await self._ensure_initialized()
+
+        if self._closed:
+            raise RuntimeError("Cannot acquire from a closed SFTPClientPool")
+
+        self.log.debug("Acquiring SFTP connection for '%s'", self.sftp_conn_id)
+
+        await self._semaphore.acquire()
+
+        try:
+            try:
+                pair = self._idle.get_nowait()
+            except asyncio.QueueEmpty:
+                pair = await self._create_connection()
+
+            self._in_use.add(pair)
+            return pair
+        except Exception:
+            self._semaphore.release()
+            raise
+
+    async def release(self, pair):
+        if pair not in self._in_use:
+            self.log.warning("Attempted to release unknown or already released 
connection")
+            return
+
+        self._in_use.discard(pair)
+
+        if self._closed:
+            ssh, sftp = pair
+            with suppress(Exception):
+                sftp.exit()
+            with suppress(Exception):
+                ssh.close()
+        else:

Review Comment:
   When closing asyncssh connections, calling `ssh.close()` without awaiting 
`ssh.wait_closed()` can leave the transport not fully shut down (see 
`AsyncSSHTunnel` which awaits `wait_closed()`). Since `close()`/`release()` are 
async here, consider awaiting `wait_closed()` after `ssh.close()` in both the 
`_closed` release path and `close()` to ensure sockets/tasks are cleaned up 
deterministically.



##########
providers/sftp/src/airflow/providers/sftp/hooks/sftp.py:
##########
@@ -789,24 +793,145 @@ def _get_value(self_val, conn_val, default=None):
         ssh_client_conn = await asyncssh.connect(**conn_config)
         return ssh_client_conn
 
-    async def list_directory(self, path: str = "") -> list[str] | None:  # 
type: ignore[return]
-        """Return a list of files on the SFTP server at the provided path."""
+    async def retrieve_file(
+        self,
+        remote_full_path: str,
+        local_full_path: str | os.PathLike[str] | IO[bytes],
+        encoding: str = "utf-8",
+        chunk_size: int = CHUNK_SIZE,
+    ) -> None:
+        """
+        Transfer the remote file to a local location asynchronously.
+
+        If local_full_path is a string or PathLike path, the file will be put 
at that location.
+        If it is a BytesIO or other binary file-like object, the file will be 
streamed into it.
+
+        :param remote_full_path: Full path to the remote file.
+        :param local_full_path: Full path to the local file or a binary 
file-like buffer.
+        :param encoding: Encoding used only as a fallback if backend returns 
text chunks (default: "utf-8").
+        :param chunk_size: Size of chunks to read at a time (default: 64KB).
+        """
+
+        def _to_bytes(chunk: str | bytes) -> bytes:
+            if isinstance(chunk, bytes):
+                return chunk
+            return chunk.encode(encoding)
+
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                async with sftp.open(remote_full_path, "rb") as remote_file:
+                    if isinstance(local_full_path, (str, os.PathLike)):
+                        async with aiofiles.open(local_full_path, "wb") as f:
+                            while True:
+                                chunk = await remote_file.read(chunk_size)
+                                if not chunk:
+                                    break
+                                await f.write(_to_bytes(chunk))
+                    else:
+                        while True:
+                            chunk = await remote_file.read(chunk_size)
+                            if not chunk:
+                                break
+                            local_full_path.write(_to_bytes(chunk))
+                        if hasattr(local_full_path, "seek"):
+                            local_full_path.seek(0)
+
+    async def store_file(self, remote_full_path: str, local_full_path: str | 
bytes | BytesIO) -> None:
+        """
+        Transfer a local file to the remote location.
+
+        If local_full_path_or_buffer is a string path, the file will be read
+        from that location.
+
+        :param remote_full_path: full path to the remote file
+        :param local_full_path: full path to the local file or a file-like 
buffer
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                if isinstance(local_full_path, bytes):
+                    local_full_path = BytesIO(local_full_path)
+
+                if isinstance(local_full_path, BytesIO):
+                    with suppress(asyncssh.SFTPFailure):
+                        remote_path = PurePosixPath(remote_full_path)
+                        await sftp.makedirs(str(remote_path.parent))
+
+                    async with sftp.open(remote_full_path, "wb") as f:
+                        local_full_path.seek(0)
+                        data = local_full_path.read()
+                        await f.write(data)
+                else:
+                    await sftp.put(str(local_full_path), remote_full_path)
+
+    async def mkdir(self, path: str) -> None:
+        """
+        Create a directory on the remote system asynchronously.
+
+        The default permissions are determined by the server. Parent 
directories are created as needed.
+
+        :param path: Full path to the remote directory to create.
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                await sftp.makedirs(path)
+
+    async def list_directory(self, path: str, recursive: bool = False) -> 
list[str] | None:

Review Comment:
   `list_directory` previously accepted a default `path` value; changing the 
signature to require `path` is a backward-incompatible change for async hook 
consumers. Consider keeping `path: str = ""` (and adding `recursive` as an 
optional kwarg) to preserve API compatibility.
   



##########
providers/sftp/src/airflow/providers/sftp/hooks/sftp.py:
##########
@@ -789,24 +793,145 @@ def _get_value(self_val, conn_val, default=None):
         ssh_client_conn = await asyncssh.connect(**conn_config)
         return ssh_client_conn
 
-    async def list_directory(self, path: str = "") -> list[str] | None:  # 
type: ignore[return]
-        """Return a list of files on the SFTP server at the provided path."""
+    async def retrieve_file(
+        self,
+        remote_full_path: str,
+        local_full_path: str | os.PathLike[str] | IO[bytes],
+        encoding: str = "utf-8",
+        chunk_size: int = CHUNK_SIZE,
+    ) -> None:
+        """
+        Transfer the remote file to a local location asynchronously.
+
+        If local_full_path is a string or PathLike path, the file will be put 
at that location.
+        If it is a BytesIO or other binary file-like object, the file will be 
streamed into it.
+
+        :param remote_full_path: Full path to the remote file.
+        :param local_full_path: Full path to the local file or a binary 
file-like buffer.
+        :param encoding: Encoding used only as a fallback if backend returns 
text chunks (default: "utf-8").
+        :param chunk_size: Size of chunks to read at a time (default: 64KB).
+        """
+
+        def _to_bytes(chunk: str | bytes) -> bytes:
+            if isinstance(chunk, bytes):
+                return chunk
+            return chunk.encode(encoding)
+
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                async with sftp.open(remote_full_path, "rb") as remote_file:
+                    if isinstance(local_full_path, (str, os.PathLike)):
+                        async with aiofiles.open(local_full_path, "wb") as f:
+                            while True:
+                                chunk = await remote_file.read(chunk_size)
+                                if not chunk:
+                                    break
+                                await f.write(_to_bytes(chunk))
+                    else:
+                        while True:
+                            chunk = await remote_file.read(chunk_size)
+                            if not chunk:
+                                break
+                            local_full_path.write(_to_bytes(chunk))
+                        if hasattr(local_full_path, "seek"):
+                            local_full_path.seek(0)
+
+    async def store_file(self, remote_full_path: str, local_full_path: str | 
bytes | BytesIO) -> None:
+        """
+        Transfer a local file to the remote location.
+
+        If local_full_path_or_buffer is a string path, the file will be read
+        from that location.
+
+        :param remote_full_path: full path to the remote file
+        :param local_full_path: full path to the local file or a file-like 
buffer
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                if isinstance(local_full_path, bytes):
+                    local_full_path = BytesIO(local_full_path)
+
+                if isinstance(local_full_path, BytesIO):
+                    with suppress(asyncssh.SFTPFailure):
+                        remote_path = PurePosixPath(remote_full_path)
+                        await sftp.makedirs(str(remote_path.parent))
+
+                    async with sftp.open(remote_full_path, "wb") as f:
+                        local_full_path.seek(0)
+                        data = local_full_path.read()
+                        await f.write(data)
+                else:
+                    await sftp.put(str(local_full_path), remote_full_path)
+
+    async def mkdir(self, path: str) -> None:
+        """
+        Create a directory on the remote system asynchronously.
+
+        The default permissions are determined by the server. Parent 
directories are created as needed.
+
+        :param path: Full path to the remote directory to create.
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                await sftp.makedirs(path)
+
+    async def list_directory(self, path: str, recursive: bool = False) -> 
list[str] | None:
+        """
+        List files in a directory on the remote system asynchronously.
+
+        Lists entries under the given directory path.
+
+        If ``recursive=True``, descendants are returned as full paths.
+        If ``recursive=False`` (default), only one-level filenames are 
returned.
+
+        :param path: Full path to the remote directory to list.
+        :param recursive: Whether to recursively list descendants.
+        :return: List of file paths found under the directory, or None if the 
directory does not exist.
+        """
         async with await self._get_conn() as ssh_conn:
-            sftp_client = await ssh_conn.start_sftp_client()
-            try:
-                files = await sftp_client.listdir(path)
-                return sorted(files)
-            except asyncssh.SFTPNoSuchFile:
-                return None
+            async with ssh_conn.start_sftp_client() as sftp:
+                if not recursive:
+                    try:
+                        files = await sftp.readdir(path)
+                    except asyncssh.SFTPNoSuchFile:
+                        return None
+                    return [
+                        os.fsdecode(file.filename)
+                        for file in files
+                        if os.fsdecode(file.filename) not in {".", ".."}
+                    ]

Review Comment:
   The non-recursive branch returns filenames in server order and doesn’t sort, 
while the sync `SFTPHook.list_directory()` returns a sorted list. For 
determinism and parity between sync/async hooks, consider sorting the filtered 
results before returning (and similarly sorting recursive results if ordering 
matters).



##########
providers/sftp/src/airflow/providers/sftp/pools/sftp.py:
##########
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from contextlib import asynccontextmanager, suppress
+from threading import Lock
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import conf
+from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+    import asyncssh
+
+
+class SFTPClientPool(LoggingMixin):
+    """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and 
SFTP clients alive until exit, and limits concurrent usage to pool_size."""
+
+    _instances: dict[str, SFTPClientPool] = {}
+    _lock = Lock()
+
+    def __new__(cls, sftp_conn_id: str, pool_size: int | None = None):
+        with cls._lock:
+            if sftp_conn_id not in cls._instances:
+                instance = super().__new__(cls)
+                instance._pre_init(sftp_conn_id, pool_size)
+                cls._instances[sftp_conn_id] = instance
+            else:
+                # Validate that subsequent constructions for the same 
sftp_conn_id
+                # do not request a different pool_size, which would otherwise 
be
+                # silently ignored due to the singleton behavior.
+                instance = cls._instances[sftp_conn_id]
+                requested_pool_size = pool_size or conf.getint("core", 
"parallelism")
+                if instance.pool_size != requested_pool_size:

Review Comment:
   Using `pool_size or conf.getint(...)` treats `pool_size=0` as “unset” and 
silently falls back to the config value. It would be safer to distinguish 
`None` from an explicit value and validate `pool_size >= 1`, raising a clear 
error for invalid values.



##########
providers/sftp/src/airflow/providers/sftp/pools/sftp.py:
##########
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from contextlib import asynccontextmanager, suppress
+from threading import Lock
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import conf
+from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+    import asyncssh
+
+
+class SFTPClientPool(LoggingMixin):
+    """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and 
SFTP clients alive until exit, and limits concurrent usage to pool_size."""
+
+    _instances: dict[str, SFTPClientPool] = {}
+    _lock = Lock()
+
+    def __new__(cls, sftp_conn_id: str, pool_size: int | None = None):
+        with cls._lock:
+            if sftp_conn_id not in cls._instances:
+                instance = super().__new__(cls)
+                instance._pre_init(sftp_conn_id, pool_size)
+                cls._instances[sftp_conn_id] = instance
+            else:
+                # Validate that subsequent constructions for the same 
sftp_conn_id
+                # do not request a different pool_size, which would otherwise 
be
+                # silently ignored due to the singleton behavior.
+                instance = cls._instances[sftp_conn_id]
+                requested_pool_size = pool_size or conf.getint("core", 
"parallelism")
+                if instance.pool_size != requested_pool_size:
+                    raise ValueError(
+                        f"SFTPClientPool for sftp_conn_id '{sftp_conn_id}' has 
already been "
+                        f"initialised with pool_size={instance.pool_size}, but 
a different "
+                        f"pool_size={requested_pool_size} was requested."
+                    )
+            return cls._instances[sftp_conn_id]
+
+    def __init__(self, sftp_conn_id: str, pool_size: int | None = None):
+        # Prevent parent __init__ argument errors
+        pass
+
+    def _pre_init(self, sftp_conn_id: str, pool_size: int | None):
+        """Initialize the Singleton structure synchronously."""
+        LoggingMixin.__init__(self)
+        self.sftp_conn_id = sftp_conn_id
+        self.pool_size = pool_size or conf.getint("core", "parallelism")
+        self._idle: asyncio.LifoQueue[tuple[asyncssh.SSHClientConnection, 
asyncssh.SFTPClient]] = (
+            asyncio.LifoQueue()
+        )
+        self._in_use: set[tuple[asyncssh.SSHClientConnection, 
asyncssh.SFTPClient]] = set()
+        self._semaphore = asyncio.Semaphore(self.pool_size)
+        self._init_lock = asyncio.Lock()
+        self._initialized = False
+        self._closed = False
+        self.log.info("SFTPClientPool initialised...")
+
+    async def _ensure_initialized(self):
+        """Ensure pool is usable (also handles re-opening after close)."""
+        if self._initialized and not self._closed:
+            return
+
+        async with self._init_lock:
+            if not self._initialized or self._closed:
+                self.log.info("Initializing / resetting SFTPClientPool for 
'%s'", self.sftp_conn_id)
+                self._idle = asyncio.LifoQueue()
+                self._in_use.clear()
+                self._semaphore = asyncio.Semaphore(self.pool_size)
+                self._closed = False
+                self._initialized = True
+
+    async def _create_connection(
+        self,
+    ) -> tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]:
+        ssh_conn = await 
SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)._get_conn()
+        sftp = await ssh_conn.start_sftp_client()

Review Comment:
   `_create_connection()` can leak an SSH connection if `start_sftp_client()` 
fails after `_get_conn()` succeeds (the exception is propagated and `acquire()` 
only releases the semaphore). Consider wrapping SFTP client creation in 
`try/except` and closing the SSH connection on failure.
   



##########
providers/sftp/src/airflow/providers/sftp/hooks/sftp.py:
##########
@@ -789,24 +793,145 @@ def _get_value(self_val, conn_val, default=None):
         ssh_client_conn = await asyncssh.connect(**conn_config)
         return ssh_client_conn
 
-    async def list_directory(self, path: str = "") -> list[str] | None:  # 
type: ignore[return]
-        """Return a list of files on the SFTP server at the provided path."""
+    async def retrieve_file(
+        self,
+        remote_full_path: str,
+        local_full_path: str | os.PathLike[str] | IO[bytes],
+        encoding: str = "utf-8",
+        chunk_size: int = CHUNK_SIZE,
+    ) -> None:
+        """
+        Transfer the remote file to a local location asynchronously.
+
+        If local_full_path is a string or PathLike path, the file will be put 
at that location.
+        If it is a BytesIO or other binary file-like object, the file will be 
streamed into it.
+
+        :param remote_full_path: Full path to the remote file.
+        :param local_full_path: Full path to the local file or a binary 
file-like buffer.
+        :param encoding: Encoding used only as a fallback if backend returns 
text chunks (default: "utf-8").
+        :param chunk_size: Size of chunks to read at a time (default: 64KB).
+        """
+
+        def _to_bytes(chunk: str | bytes) -> bytes:
+            if isinstance(chunk, bytes):
+                return chunk
+            return chunk.encode(encoding)
+
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                async with sftp.open(remote_full_path, "rb") as remote_file:
+                    if isinstance(local_full_path, (str, os.PathLike)):

Review Comment:
   `ssh_conn.start_sftp_client()` and `sftp.open()` are asyncssh coroutines in 
typical asyncssh usage; using them directly in `async with` without awaiting 
will raise a runtime error (e.g. coroutine is not an async context manager). 
Please `await` these calls and ensure the SFTP client + file handles are closed 
in a `try/finally` (or equivalent) so resources are released on errors. This 
same pattern appears in `store_file`, `mkdir`, `list_directory`, 
`read_directory`, and `get_mod_time` below.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to