This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e83a98603e iterate through blobs before checking prefixes (#36202)
e83a98603e is described below
commit e83a98603ef15c7d57910c482ba75eb76ed79553
Author: Wei Lee <[email protected]>
AuthorDate: Thu Dec 14 20:09:53 2023 +0530
iterate through blobs before checking prefixes (#36202)
* fix(providers/google): iterate through blobs before checking prefixes
According to
https://github.com/googleapis/python-storage/blob/v2.14.0/google/cloud/storage/client.py#L1213-L1217,
the prefixes are not returned until the blobs are consumed
* test(providers/google): add test cases to check gcs.list result
---
airflow/providers/google/cloud/hooks/gcs.py | 18 ++++++-----
tests/providers/google/cloud/hooks/test_gcs.py | 42 +++++++++++++++++++++++---
2 files changed, 47 insertions(+), 13 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index 45a202124d..02055583ce 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -821,12 +821,13 @@ class GCSHook(GoogleBaseHook):
delimiter=delimiter,
versions=versions,
)
- list(blobs)
+
+ blob_names = [blob.name for blob in blobs]
if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
- ids.extend(blob.name for blob in blobs)
+ ids.extend(blob_names)
page_token = blobs.next_page_token
if page_token is None:
@@ -933,16 +934,17 @@ class GCSHook(GoogleBaseHook):
delimiter=delimiter,
versions=versions,
)
- list(blobs)
+
+ blob_names = [
+ blob.name
+ for blob in blobs
+ if timespan_start <= blob.updated.replace(tzinfo=timezone.utc)
< timespan_end
+ ]
if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
- ids.extend(
- blob.name
- for blob in blobs
- if timespan_start <=
blob.updated.replace(tzinfo=timezone.utc) < timespan_end
- )
+ ids.extend(blob_names)
page_token = blobs.next_page_token
if page_token is None:
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py
b/tests/providers/google/cloud/hooks/test_gcs.py
index 33df98e37b..825a357d39 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -21,6 +21,7 @@ import copy
import logging
import os
import re
+from collections import namedtuple
from datetime import datetime, timedelta
from io import BytesIO
from unittest import mock
@@ -799,14 +800,26 @@ class TestGCSHook:
)
@pytest.mark.parametrize(
- "prefix, result",
+ "prefix, blob_names, returned_prefixes, call_args, result",
(
(
"prefix",
+ ["prefix"],
+ None,
+ [mock.call(delimiter=",", prefix="prefix", versions=None,
max_results=None, page_token=None)],
+ ["prefix"],
+ ),
+ (
+ "prefix",
+ ["prefix"],
+ {"prefix,"},
[mock.call(delimiter=",", prefix="prefix", versions=None,
max_results=None, page_token=None)],
+ ["prefix,"],
),
(
["prefix", "prefix_2"],
+ ["prefix", "prefix2"],
+ None,
[
mock.call(
delimiter=",", prefix="prefix", versions=None,
max_results=None, page_token=None
@@ -815,19 +828,38 @@ class TestGCSHook:
delimiter=",", prefix="prefix_2", versions=None,
max_results=None, page_token=None
),
],
+ ["prefix", "prefix2"],
),
),
)
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
- def test_list__delimiter(self, mock_service, prefix, result):
-
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token
= None
+ def test_list__delimiter(self, mock_service, prefix, blob_names,
returned_prefixes, call_args, result):
+ Blob = namedtuple("Blob", ["name"])
+
+ class BlobsIterator:
+ def __init__(self):
+ self._item_iter = (Blob(name=name) for name in blob_names)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ try:
+ return next(self._item_iter)
+ except StopIteration:
+ self.prefixes = returned_prefixes
+ self.next_page_token = None
+ raise
+
+ mock_service.return_value.bucket.return_value.list_blobs.return_value
= BlobsIterator()
with pytest.deprecated_call():
- self.gcs_hook.list(
+ blobs = self.gcs_hook.list(
bucket_name="test_bucket",
prefix=prefix,
delimiter=",",
)
- assert
mock_service.return_value.bucket.return_value.list_blobs.call_args_list ==
result
+ assert
mock_service.return_value.bucket.return_value.list_blobs.call_args_list ==
call_args
+ assert blobs == result
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
@mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")