dabla commented on code in PR #64465:
URL: https://github.com/apache/airflow/pull/64465#discussion_r3189291907


##########
providers/sftp/src/airflow/providers/sftp/hooks/sftp.py:
##########
@@ -789,24 +793,147 @@ def _get_value(self_val, conn_val, default=None):
         ssh_client_conn = await asyncssh.connect(**conn_config)
         return ssh_client_conn
 
-    async def list_directory(self, path: str = "") -> list[str] | None:  # 
type: ignore[return]
-        """Return a list of files on the SFTP server at the provided path."""
+    async def retrieve_file(
+        self,
+        remote_full_path: str,
+        local_full_path: str | os.PathLike[str] | IO[bytes],
+        encoding: str = "utf-8",
+        chunk_size: int = CHUNK_SIZE,
+    ) -> None:
+        """
+        Transfer the remote file to a local location asynchronously.
+
+        If local_full_path is a string or PathLike path, the file will be put 
at that location.
+        If it is a BytesIO or other binary file-like object, the file will be 
streamed into it.
+
+        :param remote_full_path: Full path to the remote file.
+        :param local_full_path: Full path to the local file or a binary 
file-like buffer.
+        :param encoding: Encoding used only as a fallback if backend returns 
text chunks (default: "utf-8").
+        :param chunk_size: Size of chunks to read at a time (default: 64KB).
+        """
+
+        def _to_bytes(chunk: str | bytes) -> bytes:
+            if isinstance(chunk, bytes):
+                return chunk
+            return chunk.encode(encoding)
+
         async with await self._get_conn() as ssh_conn:
-            sftp_client = await ssh_conn.start_sftp_client()
-            try:
-                files = await sftp_client.listdir(path)
-                return sorted(files)
-            except asyncssh.SFTPNoSuchFile:
-                return None
+            async with ssh_conn.start_sftp_client() as sftp:
+                async with sftp.open(remote_full_path, "rb") as remote_file:
+                    if isinstance(local_full_path, (str, os.PathLike)):
+                        async with aiofiles.open(local_full_path, "wb") as f:
+                            while True:
+                                chunk = await remote_file.read(chunk_size)
+                                if not chunk:
+                                    break
+                                await f.write(_to_bytes(chunk))
+                    else:
+                        while True:
+                            chunk = await remote_file.read(chunk_size)
+                            if not chunk:
+                                break
+                            local_full_path.write(_to_bytes(chunk))
+                        if hasattr(local_full_path, "seek"):
+                            local_full_path.seek(0)
+
+    async def store_file(self, remote_full_path: str, local_full_path: str | 
bytes | BytesIO) -> None:
+        """
+        Transfer a local file to the remote location.
+
+        If local_full_path_or_buffer is a string path, the file will be read
+        from that location.
+
+        :param remote_full_path: full path to the remote file
+        :param local_full_path: full path to the local file or a file-like 
buffer
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                if isinstance(local_full_path, bytes):
+                    local_full_path = BytesIO(local_full_path)
+
+                if isinstance(local_full_path, BytesIO):
+                    with suppress(asyncssh.SFTPFailure):
+                        remote_path = PurePosixPath(remote_full_path)
+                        await sftp.makedirs(str(remote_path.parent))
+
+                    async with sftp.open(remote_full_path, "wb") as f:
+                        local_full_path.seek(0)
+                        data = local_full_path.read()
+                        await f.write(data)
+                else:
+                    await sftp.put(str(local_full_path), remote_full_path)
+
+    async def mkdir(self, path: str) -> None:
+        """
+        Create a directory on the remote system asynchronously.
+
+        The default permissions are determined by the server. Parent 
directories are created as needed.
+
+        :param path: Full path to the remote directory to create.
+        """
+        async with await self._get_conn() as ssh_conn:
+            async with ssh_conn.start_sftp_client() as sftp:
+                await sftp.makedirs(path)
+
+    async def list_directory(self, path: str = "", recursive: bool = False) -> 
list[str] | None:
+        """
+        List files in a directory on the remote system asynchronously.
+
+        Lists entries under the given directory path.
+
+        If ``recursive=True``, descendants are returned as full paths.
+        If ``recursive=False`` (default), only one-level filenames are 
returned.

Review Comment:
   I've splitted into list_directory which doesn't doe recursion and respect 
same contract as the sync version and added async walktree which does the 
recursion like the sync version.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to