uranusjr commented on a change in pull request #17921:
URL: https://github.com/apache/airflow/pull/17921#discussion_r699103802



##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -417,21 +614,48 @@ def test_arg_checking(self):
             pass
         assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
 
+        # when you work with specific files, then you should use *_files 
arguments
+        task_4 = SFTPOperator(
+            task_id="task_4",
+            ssh_conn_id=TEST_CONN_ID,
+            local_folder="/tmp/dir_for_remote_transfer/from_remote/csv/",
+            remote_filepath=[
+                "/tmp/transfer_file/remote/put_files_file1.txt",
+                "/tmp/transfer_file/remote/put_files_file2.txt",
+            ],
+            operation=SFTPOperation.GET,
+            create_intermediate_dirs=True,
+        )
+        try:
+            task_4.execute(None)
+        except Exception:
+            pass
+        assert task_4.local_filepath is None
+
     def delete_local_resource(self):
         if os.path.exists(self.test_local_filepath):
-            os.remove(self.test_local_filepath)
+            if os.path.isdir(self.test_local_filepath):
+                shutil.rmtree(self.test_local_filepath)
+            else:
+                os.remove(self.test_local_filepath)
         if os.path.exists(self.test_local_filepath_int_dir):
-            os.remove(self.test_local_filepath_int_dir)
+            if os.path.isdir(self.test_local_filepath_int_dir):
+                shutil.rmtree(self.test_local_filepath_int_dir)
+            else:
+                os.remove(self.test_local_filepath_int_dir)
         if os.path.exists(self.test_local_dir):
-            os.rmdir(self.test_local_dir)
+            if os.path.isdir(self.test_local_dir):
+                shutil.rmtree(self.test_local_dir)
+            else:
+                os.remove(self.test_local_dir)
 
     def delete_remote_resource(self):
         if os.path.exists(self.test_remote_filepath):
             # check the remote file content
             remove_file_task = SSHOperator(
                 task_id="test_check_file",
                 ssh_hook=self.hook,
-                command=f"rm {self.test_remote_filepath}",
+                command=f"rm -rf {self.test_remote_filepath}",

Review comment:
       ```suggestion
                   command=f"rm -r {self.test_remote_filepath}",
   ```
   
   But is this (and the if-else things above) needed? I don’t see anywhere a 
`filepath` points a directory, or a `dir` points to a file. The usages seem 
pretty consistent.

##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -441,7 +665,9 @@ def delete_remote_resource(self):
         if os.path.exists(self.test_remote_filepath_int_dir):
             os.remove(self.test_remote_filepath_int_dir)
         if os.path.exists(self.test_remote_dir):
-            os.rmdir(self.test_remote_dir)
+            import shutil

Review comment:
       ```suggestion
   ```
   
   Is this not already imported globally?

##########
File path: airflow/providers/sftp/operators/sftp.py
##########
@@ -134,29 +175,81 @@ def execute(self, context: Any) -> str:
 
             with self.ssh_hook.get_conn() as ssh_client:
                 sftp_client = ssh_client.open_sftp()
-                if self.operation.lower() == SFTPOperation.GET:
-                    local_folder = os.path.dirname(self.local_filepath)
-                    if self.create_intermediate_dirs:
-                        Path(local_folder).mkdir(parents=True, exist_ok=True)
-                    file_msg = f"from {self.remote_filepath} to 
{self.local_filepath}"
-                    self.log.info("Starting to transfer %s", file_msg)
-                    sftp_client.get(self.remote_filepath, self.local_filepath)
-                else:
-                    remote_folder = os.path.dirname(self.remote_filepath)
-                    if self.create_intermediate_dirs:
-                        _make_intermediate_dirs(
-                            sftp_client=sftp_client,
-                            remote_directory=remote_folder,
+                if self.local_filepath and self.remote_filepath:
+                    if isinstance(self.local_filepath, list) and 
isinstance(self.remote_filepath, str):
+                        for file_path in self.local_filepath:
+                            local_folder = os.path.dirname(file_path)
+                            local_file = os.path.basename(file_path)
+                            file_msg = file_path
+                            self._transfer(sftp_client, local_folder, 
local_file, self.remote_filepath)
+                    elif isinstance(self.remote_filepath, list) and 
isinstance(self.local_filepath, str):
+                        for file_path in self.remote_filepath:
+                            remote_folder = os.path.dirname(file_path)
+                            remote_file = os.path.basename(file_path)
+                            file_msg = file_path
+                            self._transfer(sftp_client, self.local_filepath, 
remote_file, remote_folder)
+                    elif isinstance(self.remote_filepath, str) and 
isinstance(self.local_filepath, str):
+                        local_folder = os.path.dirname(self.local_filepath)
+                        file_msg = self.local_filepath
+                        self._transfer(
+                            sftp_client,
+                            local_folder,
+                            self.local_filepath,
+                            self.remote_filepath,
+                            only_file=True,
                         )
-                    file_msg = f"from {self.local_filepath} to 
{self.remote_filepath}"
-                    self.log.info("Starting to transfer file %s", file_msg)
-                    sftp_client.put(self.local_filepath, self.remote_filepath, 
confirm=self.confirm)
+                elif self.local_folder and self.remote_folder:
+                    if self.operation.lower() == SFTPOperation.PUT:
+                        files_list = 
self._search_files(os.listdir(self.local_folder))
+                        for file in files_list:
+                            local_file = os.path.basename(file)
+                            file_msg = file
+                            self._transfer(sftp_client, self.local_folder, 
local_file, self.remote_folder)
+                    elif self.operation.lower() == SFTPOperation.GET:
+                        files_list = 
self._search_files(sftp_client.listdir(self.remote_folder))
+                        for file in files_list:
+                            remote_file = os.path.basename(file)
+                            file_msg = file
+                            self._transfer(sftp_client, self.local_folder, 
remote_file, self.remote_folder)
+                else:
+                    raise AirflowException(f"Argument mismatch, please read 
docs \n {SFTPOperator.__doc__}")

Review comment:
       The argument check should be done in `__init__` (i.e. fail on DAG 
parsing instead of running) and use `TypeError` (the customary exception in 
Python for argument mismatch). It should also describe more clearly what the 
mismatch is (e.g. `local_folder cannot be used with remote_file`) instead of 
including a big block of text.

##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -417,21 +614,48 @@ def test_arg_checking(self):
             pass
         assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
 
+        # when you work with specific files, then you should use *_files 
arguments
+        task_4 = SFTPOperator(
+            task_id="task_4",
+            ssh_conn_id=TEST_CONN_ID,
+            local_folder="/tmp/dir_for_remote_transfer/from_remote/csv/",
+            remote_filepath=[
+                "/tmp/transfer_file/remote/put_files_file1.txt",
+                "/tmp/transfer_file/remote/put_files_file2.txt",
+            ],
+            operation=SFTPOperation.GET,
+            create_intermediate_dirs=True,
+        )
+        try:
+            task_4.execute(None)
+        except Exception:
+            pass

Review comment:
       Use `pytest.raises` to test exceptions.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to