This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e656e1de55 Adding fnmatch type regex to SFTPSensor (#24084)
e656e1de55 is described below
commit e656e1de55094e8369cab80b9b1669b1d1225f54
Author: Alex Kruchkov <[email protected]>
AuthorDate: Mon Jun 6 15:54:27 2022 +0300
Adding fnmatch type regex to SFTPSensor (#24084)
---
airflow/providers/sftp/hooks/sftp.py | 19 ++++++++++++++++
airflow/providers/sftp/sensors/sftp.py | 18 ++++++++++++---
tests/providers/sftp/hooks/test_sftp.py | 38 ++++++++++++++++++++++++++++---
tests/providers/sftp/sensors/test_sftp.py | 28 +++++++++++++++++++++++
4 files changed, 97 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/sftp/hooks/sftp.py
b/airflow/providers/sftp/hooks/sftp.py
index 58c820b838..d436d091b5 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -19,6 +19,7 @@
import datetime
import stat
import warnings
+from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Tuple
import pysftp
@@ -329,3 +330,21 @@ class SFTPHook(SSHHook):
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
+
+ def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
+ """
+ Returning the first matching file based on the given fnmatch type
pattern
+
+ :param path: path to be checked
+ :param fnmatch_pattern: The pattern that will be matched with `fnmatch`
+ :return: string containing the first found file, or an empty string if
none matched
+ """
+ files_list = self.list_directory(path)
+
+ for file in files_list:
+ if not fnmatch(file, fnmatch_pattern):
+ pass
+ else:
+ return file
+
+ return ""
diff --git a/airflow/providers/sftp/sensors/sftp.py
b/airflow/providers/sftp/sensors/sftp.py
index 904321e9b8..757a23b1d8 100644
--- a/airflow/providers/sftp/sensors/sftp.py
+++ b/airflow/providers/sftp/sensors/sftp.py
@@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator):
Waits for a file or directory to be present on SFTP.
:param path: Remote file or directory path
+ :param file_pattern: The pattern that will be used to match the file
(fnmatch format)
:param sftp_conn_id: The connection to run the sensor against
:param newer_than: DateTime for which the file or file path should be
newer than, comparison is inclusive
"""
@@ -47,22 +48,33 @@ class SFTPSensor(BaseSensorOperator):
self,
*,
path: str,
+ file_pattern: str = "",
newer_than: Optional[datetime] = None,
sftp_conn_id: str = 'sftp_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.path = path
+ self.file_pattern = file_pattern
self.hook: Optional[SFTPHook] = None
self.sftp_conn_id = sftp_conn_id
self.newer_than: Optional[datetime] = newer_than
+ self.actual_file_to_check = self.path
def poke(self, context: 'Context') -> bool:
self.hook = SFTPHook(self.sftp_conn_id)
- self.log.info('Poking for %s', self.path)
+ self.log.info(f"Poking for {self.path}, with pattern
{self.file_pattern}")
+
+ if self.file_pattern:
+ file_from_pattern = self.hook.get_file_by_pattern(self.path,
self.file_pattern)
+ if file_from_pattern:
+ self.actual_file_to_check = file_from_pattern
+ else:
+ return False
+
try:
- mod_time = self.hook.get_mod_time(self.path)
- self.log.info('Found File %s last modified: %s', str(self.path),
str(mod_time))
+ mod_time = self.hook.get_mod_time(self.actual_file_to_check)
+ self.log.info('Found File %s last modified: %s',
str(self.actual_file_to_check), str(mod_time))
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise e
diff --git a/tests/providers/sftp/hooks/test_sftp.py
b/tests/providers/sftp/hooks/test_sftp.py
index 9c63402054..95bb971bdf 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -43,6 +43,8 @@ TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
+ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt'
+LOG_FILE_FOR_TESTS = 'test_log.log'
SFTP_CONNECTION_USER = "root"
@@ -60,13 +62,18 @@ class TestSFTPHook(unittest.TestCase):
session.commit()
return old_login
+ def _create_additional_test_file(self, file_name):
+ with open(os.path.join(TMP_PATH, file_name), 'a') as file:
+ file.write('Test file')
+
def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
- with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
- file.write('Test file')
+ for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS,
LOG_FILE_FOR_TESTS]:
+ with open(os.path.join(TMP_PATH, file_name), 'a') as file:
+ file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR,
TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
@@ -353,7 +360,32 @@ class TestSFTPHook(unittest.TestCase):
# Default is 'sftp_default
assert SFTPHook().ssh_conn_id == 'sftp_default'
+ def test_get_suffix_pattern_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt")
+ self.assertTrue(output, TMP_FILE_FOR_TESTS)
+
+ def test_get_prefix_pattern_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "test*")
+ self.assertTrue(output, TMP_FILE_FOR_TESTS)
+
+ def test_get_pattern_not_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "*.text")
+ self.assertFalse(output)
+
+ def test_get_several_pattern_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "*.log")
+ self.assertEqual(LOG_FILE_FOR_TESTS, output)
+
+ def test_get_first_pattern_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt")
+ self.assertEqual(TMP_FILE_FOR_TESTS, output)
+
+ def test_get_middle_pattern_match(self):
+ output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt")
+ self.assertEqual(ANOTHER_FILE_FOR_TESTS, output)
+
def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
+ for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS,
LOG_FILE_FOR_TESTS]:
+ os.remove(os.path.join(TMP_PATH, file_name))
self.update_connection(self.old_login)
diff --git a/tests/providers/sftp/sensors/test_sftp.py
b/tests/providers/sftp/sensors/test_sftp.py
index f7c26495bf..1bb6c71068 100644
--- a/tests/providers/sftp/sensors/test_sftp.py
+++ b/tests/providers/sftp/sensors/test_sftp.py
@@ -97,3 +97,31 @@ class TestSFTPSensor(unittest.TestCase):
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
assert not output
+
+ @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+ def test_file_with_pattern_parameter_call(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.return_value =
'19700101000000'
+ sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/',
file_pattern="*.txt")
+ context = {'ds': '1970-01-01'}
+ output = sftp_sensor.poke(context)
+
sftp_hook_mock.return_value.get_file_by_pattern.assert_called_once_with('/path/to/file/',
'*.txt')
+ assert output
+
+ @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+ def test_file_present_with_pattern(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.return_value =
'19700101000000'
+ sftp_hook_mock.return_value.get_file_by_pattern.return_value =
'/path/to/file/text_file.txt'
+ sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/',
file_pattern="*.txt")
+ context = {'ds': '1970-01-01'}
+ output = sftp_sensor.poke(context)
+
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/text_file.txt')
+ assert output
+
+ @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+ def test_file_not_present_with_pattern(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.return_value =
'19700101000000'
+ sftp_hook_mock.return_value.get_file_by_pattern.return_value = ""
+ sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/',
file_pattern="*.txt")
+ context = {'ds': '1970-01-01'}
+ output = sftp_sensor.poke(context)
+ assert not output