This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 432697d90c allow multiple prefixes in gcs delete/list hooks and 
operators (#30815)
432697d90c is described below

commit 432697d90cdcea35607bcaa970c694c88053222c
Author: Shahar Epstein <[email protected]>
AuthorDate: Sun Apr 23 09:31:03 2023 +0300

    allow multiple prefixes in gcs delete/list hooks and operators (#30815)
---
 airflow/providers/google/cloud/hooks/gcs.py        | 54 ++++++++++++++++++++--
 airflow/providers/google/cloud/operators/gcs.py    | 15 +++---
 tests/providers/google/cloud/hooks/test_gcs.py     | 30 ++++++++++++
 tests/providers/google/cloud/operators/test_gcs.py |  3 +-
 4 files changed, 88 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index 742b05e32e..c0600dde8e 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -696,15 +696,63 @@ class GCSHook(GoogleBaseHook):
         except NotFound:
             self.log.info("Bucket %s not exists", bucket_name)
 
-    def list(self, bucket_name, versions=None, max_results=None, prefix=None, 
delimiter=None) -> List:
+    def list(
+        self,
+        bucket_name: str,
+        versions: bool | None = None,
+        max_results: int | None = None,
+        prefix: str | List[str] | None = None,
+        delimiter: str | None = None,
+    ):
+        """
+        List all objects from the bucket with the given a single prefix or 
multiple prefixes
+
+        :param bucket_name: bucket name
+        :param versions: if true, list all versions of the objects
+        :param max_results: max count of items to return in a single page of 
responses
+        :param prefix: string or list of strings which filter objects whose 
name begin with it/them
+        :param delimiter: filters objects based on the delimiter (for e.g 
'.csv')
+        :return: a stream of object names matching the filtering criteria
+        """
+        objects = []
+        if isinstance(prefix, list):
+            for prefix_item in prefix:
+                objects.extend(
+                    self._list(
+                        bucket_name=bucket_name,
+                        versions=versions,
+                        max_results=max_results,
+                        prefix=prefix_item,
+                        delimiter=delimiter,
+                    )
+                )
+        else:
+            objects.extend(
+                self._list(
+                    bucket_name=bucket_name,
+                    versions=versions,
+                    max_results=max_results,
+                    prefix=prefix,
+                    delimiter=delimiter,
+                )
+            )
+        return objects
+
+    def _list(
+        self,
+        bucket_name: str,
+        versions: bool | None = None,
+        max_results: int | None = None,
+        prefix: str | None = None,
+        delimiter: str | None = None,
+    ) -> List:
         """
         List all objects from the bucket with the give string prefix in name
 
         :param bucket_name: bucket name
         :param versions: if true, list all versions of the objects
         :param max_results: max count of items to return in a single page of 
responses
-        :param prefix: prefix string which filters objects whose name begin 
with
-            this prefix
+        :param prefix: string which filters objects whose name begin with it
         :param delimiter: filters objects based on the delimiter (for e.g 
'.csv')
         :return: a stream of object names matching the filtering criteria
         """
diff --git a/airflow/providers/google/cloud/operators/gcs.py 
b/airflow/providers/google/cloud/operators/gcs.py
index e2936c0933..e2ac68c90d 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -163,8 +163,8 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
     XCom in the downstream task.
 
     :param bucket: The Google Cloud Storage bucket to find the objects. 
(templated)
-    :param prefix: Prefix string which filters objects whose name begin with
-           this prefix. (templated)
+    :param prefix: String or list of strings, which filter objects whose name 
begin with
+           it/them. (templated)
     :param delimiter: The delimiter by which you want to filter the objects. 
(templated)
         For example, to lists the CSV files from in a directory in GCS you 
would use
         delimiter='.csv'.
@@ -206,7 +206,7 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
         self,
         *,
         bucket: str,
-        prefix: str | None = None,
+        prefix: str | list[str] | None = None,
         delimiter: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
@@ -220,14 +220,13 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
         self.impersonation_chain = impersonation_chain
 
     def execute(self, context: Context) -> list:
-
         hook = GCSHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
 
         self.log.info(
-            "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
+            "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix(es): 
%s",
             self.bucket,
             self.delimiter,
             self.prefix,
@@ -239,7 +238,6 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
             uri=self.bucket,
             project_id=hook.project_id,
         )
-
         return hook.list(bucket_name=self.bucket, prefix=self.prefix, 
delimiter=self.delimiter)
 
 
@@ -252,8 +250,8 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
     :param bucket_name: The GCS bucket to delete from
     :param objects: List of objects to delete. These should be the names
         of objects in the bucket, not including gs://bucket/
-    :param prefix: Prefix of objects to delete. All objects matching this
-        prefix in the bucket will be deleted.
+    :param prefix: String or list of strings, which filter objects whose name 
begin with
+           it/them. (templated)
     :param gcp_conn_id: (Optional) The connection ID used to connect to Google 
Cloud.
     :param impersonation_chain: Optional service account to impersonate using 
short-term
         credentials, or chained list of accounts required to get the 
access_token
@@ -307,7 +305,6 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
             objects = self.objects
         else:
             objects = hook.list(bucket_name=self.bucket_name, 
prefix=self.prefix)
-
         self.log.info("Deleting %s objects from %s", len(objects), 
self.bucket_name)
         for object_name in objects:
             hook.delete(bucket_name=self.bucket_name, object_name=object_name)
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py 
b/tests/providers/google/cloud/hooks/test_gcs.py
index c70a67a66a..22cedfbd0a 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -758,6 +758,36 @@ class TestGCSHook:
             ]
         )
 
+    @pytest.mark.parametrize(
+        "prefix, result",
+        (
+            (
+                "prefix",
+                [mock.call(delimiter=",", prefix="prefix", versions=None, 
max_results=None, page_token=None)],
+            ),
+            (
+                ["prefix", "prefix_2"],
+                [
+                    mock.call(
+                        delimiter=",", prefix="prefix", versions=None, 
max_results=None, page_token=None
+                    ),
+                    mock.call(
+                        delimiter=",", prefix="prefix_2", versions=None, 
max_results=None, page_token=None
+                    ),
+                ],
+            ),
+        ),
+    )
+    @mock.patch(GCS_STRING.format("GCSHook.get_conn"))
+    def test_list(self, mock_service, prefix, result):
+        
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token
 = None
+        self.gcs_hook.list(
+            bucket_name="test_bucket",
+            prefix=prefix,
+            delimiter=",",
+        )
+        assert 
mock_service.return_value.bucket.return_value.list_blobs.call_args_list == 
result
+
     @mock.patch(GCS_STRING.format("GCSHook.get_conn"))
     def test_list_by_timespans(self, mock_service):
         test_bucket = "test_bucket"
diff --git a/tests/providers/google/cloud/operators/test_gcs.py 
b/tests/providers/google/cloud/operators/test_gcs.py
index 4e3ee66d38..bf9a4f5d7e 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -38,6 +38,7 @@ TEST_BUCKET = "test-bucket"
 TEST_PROJECT = "test-project"
 DELIMITER = ".csv"
 PREFIX = "TEST"
+PREFIX_2 = "TEST2"
 MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv", "OTHERTEST1.csv"]
 TEST_OBJECT = "dir1/test-object"
 LOCAL_FILE_PATH = "/home/airflow/gcp/test-object"
@@ -160,11 +161,9 @@ class TestGoogleCloudStorageListOperator:
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
     def test_execute(self, mock_hook):
         mock_hook.return_value.list.return_value = MOCK_FILES
-
         operator = GCSListObjectsOperator(
             task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, 
delimiter=DELIMITER
         )
-
         files = operator.execute(context=mock.MagicMock())
         mock_hook.return_value.list.assert_called_once_with(
             bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER

Reply via email to