This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c255ac411b Support of wildcard in S3ListOperator and S3ToGCSOperator 
(#31640)
c255ac411b is described below

commit c255ac411b93d222bc9a0dbd4139a15687d2c981
Author: max <[email protected]>
AuthorDate: Mon Jun 5 21:22:56 2023 +0200

    Support of wildcard in S3ListOperator and S3ToGCSOperator (#31640)
---
 airflow/providers/amazon/aws/hooks/s3.py             | 13 ++++++++++---
 airflow/providers/amazon/aws/operators/s3.py         | 10 +++++++++-
 tests/providers/amazon/aws/hooks/test_s3.py          | 15 ++++++++++++---
 tests/providers/amazon/aws/operators/test_s3_list.py |  5 ++++-
 4 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 1cfa6b7b01..c5dc46349f 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -384,6 +384,7 @@ class S3Hook(AwsBaseHook):
         from_datetime: datetime | None = None,
         to_datetime: datetime | None = None,
         object_filter: Callable[..., list] | None = None,
+        apply_wildcard: bool = False,
     ) -> list:
         """
         Lists keys in a bucket under prefix and not containing delimiter
@@ -402,6 +403,7 @@ class S3Hook(AwsBaseHook):
         :param to_datetime: should return only keys with LastModified attr 
less than this to_datetime
         :param object_filter: Function that receives the list of the S3 
objects, from_datetime and
             to_datetime and returns the List of matched key.
+        :param apply_wildcard: whether to treat '*' as a wildcard or a plain 
symbol in the prefix.
 
         **Example**: Returns the list of S3 object with LastModified attr 
greater than from_datetime
              and less than to_datetime:
@@ -425,7 +427,9 @@ class S3Hook(AwsBaseHook):
 
         :return: a list of matched keys
         """
-        prefix = prefix or ""
+        _original_prefix = prefix or ""
+        _apply_wildcard = bool(apply_wildcard and "*" in _original_prefix)
+        _prefix = _original_prefix.split("*", 1)[0] if _apply_wildcard else 
_original_prefix
         delimiter = delimiter or ""
         start_after_key = start_after_key or ""
         self.object_filter_usr = object_filter
@@ -437,7 +441,7 @@ class S3Hook(AwsBaseHook):
         paginator = self.get_conn().get_paginator("list_objects_v2")
         response = paginator.paginate(
             Bucket=bucket_name,
-            Prefix=prefix,
+            Prefix=_prefix,
             Delimiter=delimiter,
             PaginationConfig=config,
             StartAfter=start_after_key,
@@ -446,7 +450,10 @@ class S3Hook(AwsBaseHook):
         keys: list[str] = []
         for page in response:
             if "Contents" in page:
-                keys.extend(iter(page["Contents"]))
+                new_keys = page["Contents"]
+                if _apply_wildcard:
+                    new_keys = (k for k in new_keys if 
fnmatch.fnmatch(k["Key"], _original_prefix))
+                keys.extend(new_keys)
         if self.object_filter_usr is not None:
             return self.object_filter_usr(keys, from_datetime, to_datetime)
 
diff --git a/airflow/providers/amazon/aws/operators/s3.py 
b/airflow/providers/amazon/aws/operators/s3.py
index d9ab1ab75c..f7854791d2 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -628,6 +628,7 @@ class S3ListOperator(BaseOperator):
     :param delimiter: the delimiter marks key hierarchy. (templated)
     :param aws_conn_id: The connection ID to use when connecting to S3 storage.
     :param verify: Whether or not to verify SSL certificates for S3 connection.
+    :param apply_wildcard: whether to treat '*' as a wildcard or a plain 
symbol in the prefix.
         By default SSL certificates are verified.
         You can provide the following values:
 
@@ -664,6 +665,7 @@ class S3ListOperator(BaseOperator):
         delimiter: str = "",
         aws_conn_id: str = "aws_default",
         verify: str | bool | None = None,
+        apply_wildcard: bool = False,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -672,6 +674,7 @@ class S3ListOperator(BaseOperator):
         self.delimiter = delimiter
         self.aws_conn_id = aws_conn_id
         self.verify = verify
+        self.apply_wildcard = apply_wildcard
 
     def execute(self, context: Context):
         hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
@@ -683,7 +686,12 @@ class S3ListOperator(BaseOperator):
             self.delimiter,
         )
 
-        return hook.list_keys(bucket_name=self.bucket, prefix=self.prefix, 
delimiter=self.delimiter)
+        return hook.list_keys(
+            bucket_name=self.bucket,
+            prefix=self.prefix,
+            delimiter=self.delimiter,
+            apply_wildcard=self.apply_wildcard,
+        )
 
 
 class S3ListPrefixesOperator(BaseOperator):
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index aa47d7e038..cd23823dcb 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -221,6 +221,9 @@ class TestAwsS3Hook:
         hook = S3Hook()
         bucket = hook.get_bucket(s3_bucket)
         bucket.put_object(Key="a", Body=b"a")
+        bucket.put_object(Key="ba", Body=b"ab")
+        bucket.put_object(Key="bxa", Body=b"axa")
+        bucket.put_object(Key="bxb", Body=b"axb")
         bucket.put_object(Key="dir/b", Body=b"b")
 
         from_datetime = datetime(1992, 3, 8, 18, 52, 51)
@@ -230,14 +233,20 @@ class TestAwsS3Hook:
             return []
 
         assert [] == hook.list_keys(s3_bucket, prefix="non-existent/")
-        assert ["a", "dir/b"] == hook.list_keys(s3_bucket)
-        assert ["a"] == hook.list_keys(s3_bucket, delimiter="/")
+        assert ["a", "ba", "bxa", "bxb", "dir/b"] == hook.list_keys(s3_bucket)
+        assert ["a", "ba", "bxa", "bxb"] == hook.list_keys(s3_bucket, 
delimiter="/")
         assert ["dir/b"] == hook.list_keys(s3_bucket, prefix="dir/")
-        assert ["dir/b"] == hook.list_keys(s3_bucket, start_after_key="a")
+        assert ["ba", "bxa", "bxb", "dir/b"] == hook.list_keys(s3_bucket, 
start_after_key="a")
         assert [] == hook.list_keys(s3_bucket, from_datetime=from_datetime, 
to_datetime=to_datetime)
         assert [] == hook.list_keys(
             s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime, 
object_filter=dummy_object_filter
         )
+        assert [] == hook.list_keys(s3_bucket, prefix="*a")
+        assert ["a", "ba", "bxa"] == hook.list_keys(s3_bucket, prefix="*a", 
apply_wildcard=True)
+        assert [] == hook.list_keys(s3_bucket, prefix="b*a")
+        assert ["ba", "bxa"] == hook.list_keys(s3_bucket, prefix="b*a", 
apply_wildcard=True)
+        assert [] == hook.list_keys(s3_bucket, prefix="b*")
+        assert ["ba", "bxa", "bxb"] == hook.list_keys(s3_bucket, prefix="b*", 
apply_wildcard=True)
 
     def test_list_keys_paged(self, s3_bucket):
         hook = S3Hook()
diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py 
b/tests/providers/amazon/aws/operators/test_s3_list.py
index 2823bb2b75..0773044573 100644
--- a/tests/providers/amazon/aws/operators/test_s3_list.py
+++ b/tests/providers/amazon/aws/operators/test_s3_list.py
@@ -39,6 +39,9 @@ class TestS3ListOperator:
         files = operator.execute(None)
 
         mock_hook.return_value.list_keys.assert_called_once_with(
-            bucket_name=BUCKET, prefix=PREFIX, delimiter=DELIMITER
+            bucket_name=BUCKET,
+            prefix=PREFIX,
+            delimiter=DELIMITER,
+            apply_wildcard=False,
         )
         assert sorted(files) == sorted(MOCK_FILES)

Reply via email to