This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c255ac411b Support of wildcard in S3ListOperator and S3ToGCSOperator
(#31640)
c255ac411b is described below
commit c255ac411b93d222bc9a0dbd4139a15687d2c981
Author: max <[email protected]>
AuthorDate: Mon Jun 5 21:22:56 2023 +0200
Support of wildcard in S3ListOperator and S3ToGCSOperator (#31640)
---
airflow/providers/amazon/aws/hooks/s3.py | 13 ++++++++++---
airflow/providers/amazon/aws/operators/s3.py | 10 +++++++++-
tests/providers/amazon/aws/hooks/test_s3.py | 15 ++++++++++++---
tests/providers/amazon/aws/operators/test_s3_list.py | 5 ++++-
4 files changed, 35 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py
b/airflow/providers/amazon/aws/hooks/s3.py
index 1cfa6b7b01..c5dc46349f 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -384,6 +384,7 @@ class S3Hook(AwsBaseHook):
from_datetime: datetime | None = None,
to_datetime: datetime | None = None,
object_filter: Callable[..., list] | None = None,
+ apply_wildcard: bool = False,
) -> list:
"""
Lists keys in a bucket under prefix and not containing delimiter
@@ -402,6 +403,7 @@ class S3Hook(AwsBaseHook):
:param to_datetime: should return only keys with LastModified attr
less than this to_datetime
:param object_filter: Function that receives the list of the S3
objects, from_datetime and
to_datetime and returns the List of matched key.
+ :param apply_wildcard: whether to treat '*' as a wildcard or a plain
symbol in the prefix.
**Example**: Returns the list of S3 object with LastModified attr
greater than from_datetime
and less than to_datetime:
@@ -425,7 +427,9 @@ class S3Hook(AwsBaseHook):
:return: a list of matched keys
"""
- prefix = prefix or ""
+ _original_prefix = prefix or ""
+ _apply_wildcard = bool(apply_wildcard and "*" in _original_prefix)
+ _prefix = _original_prefix.split("*", 1)[0] if _apply_wildcard else
_original_prefix
delimiter = delimiter or ""
start_after_key = start_after_key or ""
self.object_filter_usr = object_filter
@@ -437,7 +441,7 @@ class S3Hook(AwsBaseHook):
paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
- Prefix=prefix,
+ Prefix=_prefix,
Delimiter=delimiter,
PaginationConfig=config,
StartAfter=start_after_key,
@@ -446,7 +450,10 @@ class S3Hook(AwsBaseHook):
keys: list[str] = []
for page in response:
if "Contents" in page:
- keys.extend(iter(page["Contents"]))
+ new_keys = page["Contents"]
+ if _apply_wildcard:
+ new_keys = (k for k in new_keys if
fnmatch.fnmatch(k["Key"], _original_prefix))
+ keys.extend(new_keys)
if self.object_filter_usr is not None:
return self.object_filter_usr(keys, from_datetime, to_datetime)
diff --git a/airflow/providers/amazon/aws/operators/s3.py
b/airflow/providers/amazon/aws/operators/s3.py
index d9ab1ab75c..f7854791d2 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -628,6 +628,7 @@ class S3ListOperator(BaseOperator):
:param delimiter: the delimiter marks key hierarchy. (templated)
:param aws_conn_id: The connection ID to use when connecting to S3 storage.
:param verify: Whether or not to verify SSL certificates for S3 connection.
+ :param apply_wildcard: whether to treat '*' as a wildcard or a plain
symbol in the prefix.
By default SSL certificates are verified.
You can provide the following values:
@@ -664,6 +665,7 @@ class S3ListOperator(BaseOperator):
delimiter: str = "",
aws_conn_id: str = "aws_default",
verify: str | bool | None = None,
+ apply_wildcard: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -672,6 +674,7 @@ class S3ListOperator(BaseOperator):
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
+ self.apply_wildcard = apply_wildcard
def execute(self, context: Context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
@@ -683,7 +686,12 @@ class S3ListOperator(BaseOperator):
self.delimiter,
)
- return hook.list_keys(bucket_name=self.bucket, prefix=self.prefix,
delimiter=self.delimiter)
+ return hook.list_keys(
+ bucket_name=self.bucket,
+ prefix=self.prefix,
+ delimiter=self.delimiter,
+ apply_wildcard=self.apply_wildcard,
+ )
class S3ListPrefixesOperator(BaseOperator):
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py
b/tests/providers/amazon/aws/hooks/test_s3.py
index aa47d7e038..cd23823dcb 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -221,6 +221,9 @@ class TestAwsS3Hook:
hook = S3Hook()
bucket = hook.get_bucket(s3_bucket)
bucket.put_object(Key="a", Body=b"a")
+ bucket.put_object(Key="ba", Body=b"ab")
+ bucket.put_object(Key="bxa", Body=b"axa")
+ bucket.put_object(Key="bxb", Body=b"axb")
bucket.put_object(Key="dir/b", Body=b"b")
from_datetime = datetime(1992, 3, 8, 18, 52, 51)
@@ -230,14 +233,20 @@ class TestAwsS3Hook:
return []
assert [] == hook.list_keys(s3_bucket, prefix="non-existent/")
- assert ["a", "dir/b"] == hook.list_keys(s3_bucket)
- assert ["a"] == hook.list_keys(s3_bucket, delimiter="/")
+ assert ["a", "ba", "bxa", "bxb", "dir/b"] == hook.list_keys(s3_bucket)
+ assert ["a", "ba", "bxa", "bxb"] == hook.list_keys(s3_bucket,
delimiter="/")
assert ["dir/b"] == hook.list_keys(s3_bucket, prefix="dir/")
- assert ["dir/b"] == hook.list_keys(s3_bucket, start_after_key="a")
+ assert ["ba", "bxa", "bxb", "dir/b"] == hook.list_keys(s3_bucket,
start_after_key="a")
assert [] == hook.list_keys(s3_bucket, from_datetime=from_datetime,
to_datetime=to_datetime)
assert [] == hook.list_keys(
s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime,
object_filter=dummy_object_filter
)
+ assert [] == hook.list_keys(s3_bucket, prefix="*a")
+ assert ["a", "ba", "bxa"] == hook.list_keys(s3_bucket, prefix="*a",
apply_wildcard=True)
+ assert [] == hook.list_keys(s3_bucket, prefix="b*a")
+ assert ["ba", "bxa"] == hook.list_keys(s3_bucket, prefix="b*a",
apply_wildcard=True)
+ assert [] == hook.list_keys(s3_bucket, prefix="b*")
+ assert ["ba", "bxa", "bxb"] == hook.list_keys(s3_bucket, prefix="b*",
apply_wildcard=True)
def test_list_keys_paged(self, s3_bucket):
hook = S3Hook()
diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py
b/tests/providers/amazon/aws/operators/test_s3_list.py
index 2823bb2b75..0773044573 100644
--- a/tests/providers/amazon/aws/operators/test_s3_list.py
+++ b/tests/providers/amazon/aws/operators/test_s3_list.py
@@ -39,6 +39,9 @@ class TestS3ListOperator:
files = operator.execute(None)
mock_hook.return_value.list_keys.assert_called_once_with(
- bucket_name=BUCKET, prefix=PREFIX, delimiter=DELIMITER
+ bucket_name=BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ apply_wildcard=False,
)
assert sorted(files) == sorted(MOCK_FILES)