This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 432697d90c allow multiple prefixes in gcs delete/list hooks and
operators (#30815)
432697d90c is described below
commit 432697d90cdcea35607bcaa970c694c88053222c
Author: Shahar Epstein <[email protected]>
AuthorDate: Sun Apr 23 09:31:03 2023 +0300
allow multiple prefixes in gcs delete/list hooks and operators (#30815)
---
airflow/providers/google/cloud/hooks/gcs.py | 54 ++++++++++++++++++++--
airflow/providers/google/cloud/operators/gcs.py | 15 +++---
tests/providers/google/cloud/hooks/test_gcs.py | 30 ++++++++++++
tests/providers/google/cloud/operators/test_gcs.py | 3 +-
4 files changed, 88 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index 742b05e32e..c0600dde8e 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -696,15 +696,63 @@ class GCSHook(GoogleBaseHook):
except NotFound:
self.log.info("Bucket %s not exists", bucket_name)
- def list(self, bucket_name, versions=None, max_results=None, prefix=None,
delimiter=None) -> List:
+ def list(
+ self,
+ bucket_name: str,
+ versions: bool | None = None,
+ max_results: int | None = None,
+ prefix: str | List[str] | None = None,
+ delimiter: str | None = None,
+ ):
+ """
+ List all objects from the bucket with the given a single prefix or
multiple prefixes
+
+ :param bucket_name: bucket name
+ :param versions: if true, list all versions of the objects
+ :param max_results: max count of items to return in a single page of
responses
+ :param prefix: string or list of strings which filter objects whose
name begin with it/them
+ :param delimiter: filters objects based on the delimiter (for e.g
'.csv')
+ :return: a stream of object names matching the filtering criteria
+ """
+ objects = []
+ if isinstance(prefix, list):
+ for prefix_item in prefix:
+ objects.extend(
+ self._list(
+ bucket_name=bucket_name,
+ versions=versions,
+ max_results=max_results,
+ prefix=prefix_item,
+ delimiter=delimiter,
+ )
+ )
+ else:
+ objects.extend(
+ self._list(
+ bucket_name=bucket_name,
+ versions=versions,
+ max_results=max_results,
+ prefix=prefix,
+ delimiter=delimiter,
+ )
+ )
+ return objects
+
+ def _list(
+ self,
+ bucket_name: str,
+ versions: bool | None = None,
+ max_results: int | None = None,
+ prefix: str | None = None,
+ delimiter: str | None = None,
+ ) -> List:
"""
List all objects from the bucket with the give string prefix in name
:param bucket_name: bucket name
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of
responses
- :param prefix: prefix string which filters objects whose name begin
with
- this prefix
+ :param prefix: string which filters objects whose name begin with it
:param delimiter: filters objects based on the delimiter (for e.g
'.csv')
:return: a stream of object names matching the filtering criteria
"""
diff --git a/airflow/providers/google/cloud/operators/gcs.py
b/airflow/providers/google/cloud/operators/gcs.py
index e2936c0933..e2ac68c90d 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -163,8 +163,8 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
XCom in the downstream task.
:param bucket: The Google Cloud Storage bucket to find the objects.
(templated)
- :param prefix: Prefix string which filters objects whose name begin with
- this prefix. (templated)
+ :param prefix: String or list of strings, which filter objects whose name
begin with
+ it/them. (templated)
:param delimiter: The delimiter by which you want to filter the objects.
(templated)
For example, to lists the CSV files from in a directory in GCS you
would use
delimiter='.csv'.
@@ -206,7 +206,7 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
self,
*,
bucket: str,
- prefix: str | None = None,
+ prefix: str | list[str] | None = None,
delimiter: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
@@ -220,14 +220,13 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
self.impersonation_chain = impersonation_chain
def execute(self, context: Context) -> list:
-
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.log.info(
- "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
+ "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix(es):
%s",
self.bucket,
self.delimiter,
self.prefix,
@@ -239,7 +238,6 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
uri=self.bucket,
project_id=hook.project_id,
)
-
return hook.list(bucket_name=self.bucket, prefix=self.prefix,
delimiter=self.delimiter)
@@ -252,8 +250,8 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
:param bucket_name: The GCS bucket to delete from
:param objects: List of objects to delete. These should be the names
of objects in the bucket, not including gs://bucket/
- :param prefix: Prefix of objects to delete. All objects matching this
- prefix in the bucket will be deleted.
+ :param prefix: String or list of strings, which filter objects whose name
begin with
+ it/them. (templated)
:param gcp_conn_id: (Optional) The connection ID used to connect to Google
Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -307,7 +305,6 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
objects = self.objects
else:
objects = hook.list(bucket_name=self.bucket_name,
prefix=self.prefix)
-
self.log.info("Deleting %s objects from %s", len(objects),
self.bucket_name)
for object_name in objects:
hook.delete(bucket_name=self.bucket_name, object_name=object_name)
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py
b/tests/providers/google/cloud/hooks/test_gcs.py
index c70a67a66a..22cedfbd0a 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -758,6 +758,36 @@ class TestGCSHook:
]
)
+ @pytest.mark.parametrize(
+ "prefix, result",
+ (
+ (
+ "prefix",
+ [mock.call(delimiter=",", prefix="prefix", versions=None,
max_results=None, page_token=None)],
+ ),
+ (
+ ["prefix", "prefix_2"],
+ [
+ mock.call(
+ delimiter=",", prefix="prefix", versions=None,
max_results=None, page_token=None
+ ),
+ mock.call(
+ delimiter=",", prefix="prefix_2", versions=None,
max_results=None, page_token=None
+ ),
+ ],
+ ),
+ ),
+ )
+ @mock.patch(GCS_STRING.format("GCSHook.get_conn"))
+ def test_list(self, mock_service, prefix, result):
+
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token
= None
+ self.gcs_hook.list(
+ bucket_name="test_bucket",
+ prefix=prefix,
+ delimiter=",",
+ )
+ assert
mock_service.return_value.bucket.return_value.list_blobs.call_args_list ==
result
+
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_list_by_timespans(self, mock_service):
test_bucket = "test_bucket"
diff --git a/tests/providers/google/cloud/operators/test_gcs.py
b/tests/providers/google/cloud/operators/test_gcs.py
index 4e3ee66d38..bf9a4f5d7e 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -38,6 +38,7 @@ TEST_BUCKET = "test-bucket"
TEST_PROJECT = "test-project"
DELIMITER = ".csv"
PREFIX = "TEST"
+PREFIX_2 = "TEST2"
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv", "OTHERTEST1.csv"]
TEST_OBJECT = "dir1/test-object"
LOCAL_FILE_PATH = "/home/airflow/gcp/test-object"
@@ -160,11 +161,9 @@ class TestGoogleCloudStorageListOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_execute(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
-
operator = GCSListObjectsOperator(
task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX,
delimiter=DELIMITER
)
-
files = operator.execute(context=mock.MagicMock())
mock_hook.return_value.list.assert_called_once_with(
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER