This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fed7f31c3 Fix S3KeySensor. See #24321 (#24378)
7fed7f31c3 is described below

commit 7fed7f31c3a895c0df08228541f955efb16fbf79
Author: Vincent <[email protected]>
AuthorDate: Sat Jun 11 15:31:17 2022 -0400

    Fix S3KeySensor. See #24321 (#24378)
---
 airflow/providers/amazon/aws/sensors/s3.py        | 12 ++++++------
 tests/providers/amazon/aws/sensors/test_s3_key.py | 18 ++++++++++++------
 2 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/s3.py 
b/airflow/providers/amazon/aws/sensors/s3.py
index 182b05864c..21b06432f2 100644
--- a/airflow/providers/amazon/aws/sensors/s3.py
+++ b/airflow/providers/amazon/aws/sensors/s3.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
+import fnmatch
 import os
 import re
 import sys
@@ -112,12 +112,13 @@ class S3KeySensor(BaseSensorOperator):
         """
         if self.wildcard_match:
             prefix = re.split(r'[\[\*\?]', key, 1)[0]
-            files = self.get_hook().get_file_metadata(prefix, bucket_name)
-            if len(files) == 0:
+            keys = self.get_hook().get_file_metadata(prefix, bucket_name)
+            key_matches = [k for k in keys if fnmatch.fnmatch(k['Key'], key)]
+            if len(key_matches) == 0:
                 return False
 
             # Reduce the set of metadata to size only
-            files = list(map(lambda f: {'Size': f['Size']}, files))
+            files = list(map(lambda f: {'Size': f['Size']}, key_matches))
         else:
             obj = self.get_hook().head_object(key, bucket_name)
             if obj is None:
@@ -341,8 +342,7 @@ class S3PrefixSensor(S3KeySensor):
             stacklevel=2,
         )
 
-        self.prefix = prefix
-        prefixes = [self.prefix] if isinstance(self.prefix, str) else 
self.prefix
+        prefixes = [prefix] if isinstance(prefix, str) else prefix
         keys = [pref if pref.endswith(delimiter) else pref + delimiter for 
pref in prefixes]
 
         super().__init__(bucket_key=keys, **kwargs)
diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py 
b/tests/providers/amazon/aws/sensors/test_s3_key.py
index 5ec8e242f8..6dca0df4a1 100644
--- a/tests/providers/amazon/aws/sensors/test_s3_key.py
+++ b/tests/providers/amazon/aws/sensors/test_s3_key.py
@@ -160,25 +160,31 @@ class TestS3KeySensor(unittest.TestCase):
         assert op.poke(None) is False
         mock_get_file_metadata.assert_called_once_with("file", "test_bucket")
 
-        mock_get_file_metadata.return_value = [{'Size': 0}]
+        mock_get_file_metadata.return_value = [{'Key': 'dummyFile', 'Size': 0}]
+        assert op.poke(None) is False
+
+        mock_get_file_metadata.return_value = [{'Key': 'file1', 'Size': 0}]
         assert op.poke(None) is True
 
     
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata')
     def test_poke_wildcard_multiple_files(self, mock_get_file_metadata):
         op = S3KeySensor(
             task_id='s3_key_sensor',
-            bucket_key=['s3://test_bucket/file1*', 's3://test_bucket/file2*'],
+            bucket_key=['s3://test_bucket/file*', 's3://test_bucket/*.zip'],
             wildcard_match=True,
         )
 
-        mock_get_file_metadata.side_effect = [[{'Size': 0}], []]
+        mock_get_file_metadata.side_effect = [[{'Key': 'file1', 'Size': 0}], 
[]]
+        assert op.poke(None) is False
+
+        mock_get_file_metadata.side_effect = [[{'Key': 'file1', 'Size': 0}], 
[{'Key': 'file2', 'Size': 0}]]
         assert op.poke(None) is False
 
-        mock_get_file_metadata.side_effect = [[{'Size': 0}], [{'Size': 0}]]
+        mock_get_file_metadata.side_effect = [[{'Key': 'file1', 'Size': 0}], 
[{'Key': 'test.zip', 'Size': 0}]]
         assert op.poke(None) is True
 
-        mock_get_file_metadata.assert_any_call("file1", "test_bucket")
-        mock_get_file_metadata.assert_any_call("file2", "test_bucket")
+        mock_get_file_metadata.assert_any_call("file", "test_bucket")
+        mock_get_file_metadata.assert_any_call("", "test_bucket")
 
     @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object')
     def test_poke_with_check_function(self, mock_head_object):

Reply via email to