This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e656e1de55 Adding fnmatch type regex to SFTPSensor (#24084)
e656e1de55 is described below

commit e656e1de55094e8369cab80b9b1669b1d1225f54
Author: Alex Kruchkov <[email protected]>
AuthorDate: Mon Jun 6 15:54:27 2022 +0300

    Adding fnmatch type regex to SFTPSensor (#24084)
---
 airflow/providers/sftp/hooks/sftp.py      | 19 ++++++++++++++++
 airflow/providers/sftp/sensors/sftp.py    | 18 ++++++++++++---
 tests/providers/sftp/hooks/test_sftp.py   | 38 ++++++++++++++++++++++++++++---
 tests/providers/sftp/sensors/test_sftp.py | 28 +++++++++++++++++++++++
 4 files changed, 97 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/sftp/hooks/sftp.py 
b/airflow/providers/sftp/hooks/sftp.py
index 58c820b838..d436d091b5 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -19,6 +19,7 @@
 import datetime
 import stat
 import warnings
+from fnmatch import fnmatch
 from typing import Any, Dict, List, Optional, Tuple
 
 import pysftp
@@ -329,3 +330,21 @@ class SFTPHook(SSHHook):
             return True, "Connection successfully tested"
         except Exception as e:
             return False, str(e)
+
+    def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
+        """
+        Returning the first matching file based on the given fnmatch type 
pattern
+
+        :param path: path to be checked
+        :param fnmatch_pattern: The pattern that will be matched with `fnmatch`
+        :return: string containing the first found file, or an empty string if 
none matched
+        """
+        files_list = self.list_directory(path)
+
+        for file in files_list:
+            if not fnmatch(file, fnmatch_pattern):
+                pass
+            else:
+                return file
+
+        return ""
diff --git a/airflow/providers/sftp/sensors/sftp.py 
b/airflow/providers/sftp/sensors/sftp.py
index 904321e9b8..757a23b1d8 100644
--- a/airflow/providers/sftp/sensors/sftp.py
+++ b/airflow/providers/sftp/sensors/sftp.py
@@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator):
     Waits for a file or directory to be present on SFTP.
 
     :param path: Remote file or directory path
+    :param file_pattern: The pattern that will be used to match the file 
(fnmatch format)
     :param sftp_conn_id: The connection to run the sensor against
     :param newer_than: DateTime for which the file or file path should be 
newer than, comparison is inclusive
     """
@@ -47,22 +48,33 @@ class SFTPSensor(BaseSensorOperator):
         self,
         *,
         path: str,
+        file_pattern: str = "",
         newer_than: Optional[datetime] = None,
         sftp_conn_id: str = 'sftp_default',
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         self.path = path
+        self.file_pattern = file_pattern
         self.hook: Optional[SFTPHook] = None
         self.sftp_conn_id = sftp_conn_id
         self.newer_than: Optional[datetime] = newer_than
+        self.actual_file_to_check = self.path
 
     def poke(self, context: 'Context') -> bool:
         self.hook = SFTPHook(self.sftp_conn_id)
-        self.log.info('Poking for %s', self.path)
+        self.log.info(f"Poking for {self.path}, with pattern 
{self.file_pattern}")
+
+        if self.file_pattern:
+            file_from_pattern = self.hook.get_file_by_pattern(self.path, 
self.file_pattern)
+            if file_from_pattern:
+                self.actual_file_to_check = file_from_pattern
+            else:
+                return False
+
         try:
-            mod_time = self.hook.get_mod_time(self.path)
-            self.log.info('Found File %s last modified: %s', str(self.path), 
str(mod_time))
+            mod_time = self.hook.get_mod_time(self.actual_file_to_check)
+            self.log.info('Found File %s last modified: %s', 
str(self.actual_file_to_check), str(mod_time))
         except OSError as e:
             if e.errno != SFTP_NO_SUCH_FILE:
                 raise e
diff --git a/tests/providers/sftp/hooks/test_sftp.py 
b/tests/providers/sftp/hooks/test_sftp.py
index 9c63402054..95bb971bdf 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -43,6 +43,8 @@ TMP_PATH = '/tmp'
 TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
 SUB_DIR = "sub_dir"
 TMP_FILE_FOR_TESTS = 'test_file.txt'
+ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt'
+LOG_FILE_FOR_TESTS = 'test_log.log'
 
 SFTP_CONNECTION_USER = "root"
 
@@ -60,13 +62,18 @@ class TestSFTPHook(unittest.TestCase):
         session.commit()
         return old_login
 
+    def _create_additional_test_file(self, file_name):
+        with open(os.path.join(TMP_PATH, file_name), 'a') as file:
+            file.write('Test file')
+
     def setUp(self):
         self.old_login = self.update_connection(SFTP_CONNECTION_USER)
         self.hook = SFTPHook()
         os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
 
-        with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
-            file.write('Test file')
+        for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, 
LOG_FILE_FOR_TESTS]:
+            with open(os.path.join(TMP_PATH, file_name), 'a') as file:
+                file.write('Test file')
         with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, 
TMP_FILE_FOR_TESTS), 'a') as file:
             file.write('Test file')
 
@@ -353,7 +360,32 @@ class TestSFTPHook(unittest.TestCase):
         # Default is 'sftp_default
         assert SFTPHook().ssh_conn_id == 'sftp_default'
 
+    def test_get_suffix_pattern_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt")
+        self.assertTrue(output, TMP_FILE_FOR_TESTS)
+
+    def test_get_prefix_pattern_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "test*")
+        self.assertTrue(output, TMP_FILE_FOR_TESTS)
+
+    def test_get_pattern_not_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "*.text")
+        self.assertFalse(output)
+
+    def test_get_several_pattern_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "*.log")
+        self.assertEqual(LOG_FILE_FOR_TESTS, output)
+
+    def test_get_first_pattern_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt")
+        self.assertEqual(TMP_FILE_FOR_TESTS, output)
+
+    def test_get_middle_pattern_match(self):
+        output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt")
+        self.assertEqual(ANOTHER_FILE_FOR_TESTS, output)
+
     def tearDown(self):
         shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
-        os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
+        for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, 
LOG_FILE_FOR_TESTS]:
+            os.remove(os.path.join(TMP_PATH, file_name))
         self.update_connection(self.old_login)
diff --git a/tests/providers/sftp/sensors/test_sftp.py 
b/tests/providers/sftp/sensors/test_sftp.py
index f7c26495bf..1bb6c71068 100644
--- a/tests/providers/sftp/sensors/test_sftp.py
+++ b/tests/providers/sftp/sensors/test_sftp.py
@@ -97,3 +97,31 @@ class TestSFTPSensor(unittest.TestCase):
         output = sftp_sensor.poke(context)
         
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
         assert not output
+
+    @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+    def test_file_with_pattern_parameter_call(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_mod_time.return_value = 
'19700101000000'
+        sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', 
file_pattern="*.txt")
+        context = {'ds': '1970-01-01'}
+        output = sftp_sensor.poke(context)
+        
sftp_hook_mock.return_value.get_file_by_pattern.assert_called_once_with('/path/to/file/',
 '*.txt')
+        assert output
+
+    @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+    def test_file_present_with_pattern(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_mod_time.return_value = 
'19700101000000'
+        sftp_hook_mock.return_value.get_file_by_pattern.return_value = 
'/path/to/file/text_file.txt'
+        sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', 
file_pattern="*.txt")
+        context = {'ds': '1970-01-01'}
+        output = sftp_sensor.poke(context)
+        
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/text_file.txt')
+        assert output
+
+    @patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
+    def test_file_not_present_with_pattern(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_mod_time.return_value = 
'19700101000000'
+        sftp_hook_mock.return_value.get_file_by_pattern.return_value = ""
+        sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', 
file_pattern="*.txt")
+        context = {'ds': '1970-01-01'}
+        output = sftp_sensor.poke(context)
+        assert not output

Reply via email to