uranusjr commented on a change in pull request #17921:
URL: https://github.com/apache/airflow/pull/17921#discussion_r699103802
##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -417,21 +614,48 @@ def test_arg_checking(self):
pass
assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+ # when you work with specific files, then you should use *_files
arguments
+ task_4 = SFTPOperator(
+ task_id="task_4",
+ ssh_conn_id=TEST_CONN_ID,
+ local_folder="/tmp/dir_for_remote_transfer/from_remote/csv/",
+ remote_filepath=[
+ "/tmp/transfer_file/remote/put_files_file1.txt",
+ "/tmp/transfer_file/remote/put_files_file2.txt",
+ ],
+ operation=SFTPOperation.GET,
+ create_intermediate_dirs=True,
+ )
+ try:
+ task_4.execute(None)
+ except Exception:
+ pass
+ assert task_4.local_filepath is None
+
def delete_local_resource(self):
if os.path.exists(self.test_local_filepath):
- os.remove(self.test_local_filepath)
+ if os.path.isdir(self.test_local_filepath):
+ shutil.rmtree(self.test_local_filepath)
+ else:
+ os.remove(self.test_local_filepath)
if os.path.exists(self.test_local_filepath_int_dir):
- os.remove(self.test_local_filepath_int_dir)
+ if os.path.isdir(self.test_local_filepath_int_dir):
+ shutil.rmtree(self.test_local_filepath_int_dir)
+ else:
+ os.remove(self.test_local_filepath_int_dir)
if os.path.exists(self.test_local_dir):
- os.rmdir(self.test_local_dir)
+ if os.path.isdir(self.test_local_dir):
+ shutil.rmtree(self.test_local_dir)
+ else:
+ os.remove(self.test_local_dir)
def delete_remote_resource(self):
if os.path.exists(self.test_remote_filepath):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
ssh_hook=self.hook,
- command=f"rm {self.test_remote_filepath}",
+ command=f"rm -rf {self.test_remote_filepath}",
Review comment:
```suggestion
command=f"rm -r {self.test_remote_filepath}",
```
But is this (and the if-else things above) needed? I don’t see anywhere a
`filepath` points a directory, or a `dir` points to a file. The usages seem
pretty consistent.
##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -441,7 +665,9 @@ def delete_remote_resource(self):
if os.path.exists(self.test_remote_filepath_int_dir):
os.remove(self.test_remote_filepath_int_dir)
if os.path.exists(self.test_remote_dir):
- os.rmdir(self.test_remote_dir)
+ import shutil
Review comment:
```suggestion
```
Is this not already imported globally?
##########
File path: airflow/providers/sftp/operators/sftp.py
##########
@@ -134,29 +175,81 @@ def execute(self, context: Any) -> str:
with self.ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
- if self.operation.lower() == SFTPOperation.GET:
- local_folder = os.path.dirname(self.local_filepath)
- if self.create_intermediate_dirs:
- Path(local_folder).mkdir(parents=True, exist_ok=True)
- file_msg = f"from {self.remote_filepath} to
{self.local_filepath}"
- self.log.info("Starting to transfer %s", file_msg)
- sftp_client.get(self.remote_filepath, self.local_filepath)
- else:
- remote_folder = os.path.dirname(self.remote_filepath)
- if self.create_intermediate_dirs:
- _make_intermediate_dirs(
- sftp_client=sftp_client,
- remote_directory=remote_folder,
+ if self.local_filepath and self.remote_filepath:
+ if isinstance(self.local_filepath, list) and
isinstance(self.remote_filepath, str):
+ for file_path in self.local_filepath:
+ local_folder = os.path.dirname(file_path)
+ local_file = os.path.basename(file_path)
+ file_msg = file_path
+ self._transfer(sftp_client, local_folder,
local_file, self.remote_filepath)
+ elif isinstance(self.remote_filepath, list) and
isinstance(self.local_filepath, str):
+ for file_path in self.remote_filepath:
+ remote_folder = os.path.dirname(file_path)
+ remote_file = os.path.basename(file_path)
+ file_msg = file_path
+ self._transfer(sftp_client, self.local_filepath,
remote_file, remote_folder)
+ elif isinstance(self.remote_filepath, str) and
isinstance(self.local_filepath, str):
+ local_folder = os.path.dirname(self.local_filepath)
+ file_msg = self.local_filepath
+ self._transfer(
+ sftp_client,
+ local_folder,
+ self.local_filepath,
+ self.remote_filepath,
+ only_file=True,
)
- file_msg = f"from {self.local_filepath} to
{self.remote_filepath}"
- self.log.info("Starting to transfer file %s", file_msg)
- sftp_client.put(self.local_filepath, self.remote_filepath,
confirm=self.confirm)
+ elif self.local_folder and self.remote_folder:
+ if self.operation.lower() == SFTPOperation.PUT:
+ files_list =
self._search_files(os.listdir(self.local_folder))
+ for file in files_list:
+ local_file = os.path.basename(file)
+ file_msg = file
+ self._transfer(sftp_client, self.local_folder,
local_file, self.remote_folder)
+ elif self.operation.lower() == SFTPOperation.GET:
+ files_list =
self._search_files(sftp_client.listdir(self.remote_folder))
+ for file in files_list:
+ remote_file = os.path.basename(file)
+ file_msg = file
+ self._transfer(sftp_client, self.local_folder,
remote_file, self.remote_folder)
+ else:
+ raise AirflowException(f"Argument mismatch, please read
docs \n {SFTPOperator.__doc__}")
Review comment:
The argument check should be done in `__init__` (i.e. fail on DAG
parsing instead of running) and use `TypeError` (the customary exception in
Python for argument mismatch). It should also describe more clearly what the
mismatch is (e.g. `local_folder cannot be used with remote_file`) instead of
including a big block of text.
##########
File path: tests/providers/sftp/operators/test_sftp.py
##########
@@ -417,21 +614,48 @@ def test_arg_checking(self):
pass
assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+ # when you work with specific files, then you should use *_files
arguments
+ task_4 = SFTPOperator(
+ task_id="task_4",
+ ssh_conn_id=TEST_CONN_ID,
+ local_folder="/tmp/dir_for_remote_transfer/from_remote/csv/",
+ remote_filepath=[
+ "/tmp/transfer_file/remote/put_files_file1.txt",
+ "/tmp/transfer_file/remote/put_files_file2.txt",
+ ],
+ operation=SFTPOperation.GET,
+ create_intermediate_dirs=True,
+ )
+ try:
+ task_4.execute(None)
+ except Exception:
+ pass
Review comment:
Use `pytest.raises` to test exceptions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]