This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new eb8639b013e Add throttling counter in gcsio and refactor retrying
(#32428)
eb8639b013e is described below
commit eb8639b013e9d820a253970ed03ac0e7cfc6e13e
Author: Shunping Huang <[email protected]>
AuthorDate: Wed Sep 18 13:14:57 2024 -0400
Add throttling counter in gcsio and refactor retrying (#32428)
* Add retry instance that records throttling metric.
* Use retry with throttling counters by default. Add pipeline option.
* Fix lint
* Fix broken tests.
* Retrieve a more accurate throttling time from the caller frame.
* Apply yapf and linter
* Refactoring copy and delete
- Remove extra retries for copy, delete, _gcs_object.
- Remove the use of client.batch() as the function has no built-in
retry.
* Fix a typo and apply yapf
* Use counter instead of counters in pipeline option.
Additionally, the variable name for the new retry object is changed.
Add a new pipeline option to enable the use of blob generation to
mitigate race conditions (at the expense of more http requests)
* Parameterize existing tests for the new pipeline options.
* Apply yapf
* Fix a typo.
* Revert the change of copy_batch and delete_batch and add warning in their
docstring.
* Fix lint
* Minor change according to code review.
* Restore the previous tox.ini that got accidentally changed.
---
sdks/python/apache_beam/io/gcp/gcsio.py | 73 ++++++++++++-------
.../apache_beam/io/gcp/gcsio_integration_test.py | 39 +++++++++-
sdks/python/apache_beam/io/gcp/gcsio_retry.py | 71 ++++++++++++++++++
sdks/python/apache_beam/io/gcp/gcsio_retry_test.py | 84 ++++++++++++++++++++++
sdks/python/apache_beam/io/gcp/gcsio_test.py | 16 +++--
.../python/apache_beam/options/pipeline_options.py | 12 ++++
6 files changed, 265 insertions(+), 30 deletions(-)
diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py
b/sdks/python/apache_beam/io/gcp/gcsio.py
index 6b0470b8236..22a33fa13c6 100644
--- a/sdks/python/apache_beam/io/gcp/gcsio.py
+++ b/sdks/python/apache_beam/io/gcp/gcsio.py
@@ -43,10 +43,10 @@ from google.cloud.storage.retry import DEFAULT_RETRY
from apache_beam import version as beam_version
from apache_beam.internal.gcp import auth
+from apache_beam.io.gcp import gcsio_retry
from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
-from apache_beam.utils import retry
from apache_beam.utils.annotations import deprecated
__all__ = ['GcsIO', 'create_storage_client']
@@ -155,6 +155,9 @@ class GcsIO(object):
self.client = storage_client
self._rewrite_cb = None
self.bucket_to_project_number = {}
+ self._storage_client_retry = gcsio_retry.get_retry(pipeline_options)
+ self._use_blob_generation = getattr(
+ google_cloud_options, 'enable_gcsio_blob_generation', False)
def get_project_number(self, bucket):
if bucket not in self.bucket_to_project_number:
@@ -167,7 +170,8 @@ class GcsIO(object):
def get_bucket(self, bucket_name, **kwargs):
"""Returns an object bucket from its name, or None if it does not exist."""
try:
- return self.client.lookup_bucket(bucket_name, **kwargs)
+ return self.client.lookup_bucket(
+ bucket_name, retry=self._storage_client_retry, **kwargs)
except NotFound:
return None
@@ -188,7 +192,7 @@ class GcsIO(object):
bucket_or_name=bucket,
project=project,
location=location,
- )
+ retry=self._storage_client_retry)
if kms_key:
bucket.default_kms_key_name(kms_key)
bucket.patch()
@@ -224,18 +228,18 @@ class GcsIO(object):
return BeamBlobReader(
blob,
chunk_size=read_buffer_size,
- enable_read_bucket_metric=self.enable_read_bucket_metric)
+ enable_read_bucket_metric=self.enable_read_bucket_metric,
+ retry=self._storage_client_retry)
elif mode == 'w' or mode == 'wb':
blob = bucket.blob(blob_name)
return BeamBlobWriter(
blob,
mime_type,
- enable_write_bucket_metric=self.enable_write_bucket_metric)
+ enable_write_bucket_metric=self.enable_write_bucket_metric,
+ retry=self._storage_client_retry)
else:
raise ValueError('Invalid file open mode: %s.' % mode)
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def delete(self, path):
"""Deletes the object at the given GCS path.
@@ -243,14 +247,24 @@ class GcsIO(object):
path: GCS file path pattern in the form gs://<bucket>/<name>.
"""
bucket_name, blob_name = parse_gcs_path(path)
+ bucket = self.client.bucket(bucket_name)
+ if self._use_blob_generation:
+ # blob can be None if not found
+ blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
+ generation = getattr(blob, "generation", None)
+ else:
+ generation = None
try:
- bucket = self.client.bucket(bucket_name)
- bucket.delete_blob(blob_name)
+ bucket.delete_blob(
+ blob_name,
+ if_generation_match=generation,
+ retry=self._storage_client_retry)
except NotFound:
return
def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
+ Warning: any exception during batch delete will NOT be retried.
Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
@@ -287,8 +301,6 @@ class GcsIO(object):
return final_results
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def copy(self, src, dest):
"""Copies the given GCS object from src to dest.
@@ -297,19 +309,32 @@ class GcsIO(object):
dest: GCS file path pattern in the form gs://<bucket>/<name>.
Raises:
- TimeoutError: on timeout.
+ Any exceptions during copying
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest,
object_optional=True)
src_bucket = self.client.bucket(src_bucket_name)
- src_blob = src_bucket.blob(src_blob_name)
+ if self._use_blob_generation:
+ src_blob = src_bucket.get_blob(src_blob_name)
+ if src_blob is None:
+ raise NotFound("source blob %s not found during copying" % src)
+ src_generation = src_blob.generation
+ else:
+ src_blob = src_bucket.blob(src_blob_name)
+ src_generation = None
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
- src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
+ src_bucket.copy_blob(
+ src_blob,
+ dest_bucket,
+ new_name=dest_blob_name,
+ source_generation=src_generation,
+ retry=self._storage_client_retry)
def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
+ Warning: any exception during batch copy will NOT be retried.
Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
@@ -450,8 +475,6 @@ class GcsIO(object):
file_status['size'] = gcs_object.size
return file_status
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def _gcs_object(self, path):
"""Returns a gcs object for the given path
@@ -462,7 +485,7 @@ class GcsIO(object):
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
- blob = bucket.get_blob(blob_name)
+ blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
if blob:
return blob
else:
@@ -510,7 +533,8 @@ class GcsIO(object):
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.bucket(bucket_name)
- response = self.client.list_blobs(bucket, prefix=prefix)
+ response = self.client.list_blobs(
+ bucket, prefix=prefix, retry=self._storage_client_retry)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
if file_name not in file_info:
@@ -546,8 +570,7 @@ class GcsIO(object):
def is_soft_delete_enabled(self, gcs_path):
try:
bucket_name, _ = parse_gcs_path(gcs_path)
- # set retry timeout to 5 seconds when checking soft delete policy
- bucket = self.get_bucket(bucket_name,
retry=DEFAULT_RETRY.with_timeout(5))
+ bucket = self.get_bucket(bucket_name)
if (bucket.soft_delete_policy is not None and
bucket.soft_delete_policy.retention_duration_seconds > 0):
return True
@@ -563,8 +586,9 @@ class BeamBlobReader(BlobReader):
self,
blob,
chunk_size=DEFAULT_READ_BUFFER_SIZE,
- enable_read_bucket_metric=False):
- super().__init__(blob, chunk_size=chunk_size)
+ enable_read_bucket_metric=False,
+ retry=DEFAULT_RETRY):
+ super().__init__(blob, chunk_size=chunk_size, retry=retry)
self.enable_read_bucket_metric = enable_read_bucket_metric
self.mode = "r"
@@ -585,13 +609,14 @@ class BeamBlobWriter(BlobWriter):
content_type,
chunk_size=16 * 1024 * 1024,
ignore_flush=True,
- enable_write_bucket_metric=False):
+ enable_write_bucket_metric=False,
+ retry=DEFAULT_RETRY):
super().__init__(
blob,
content_type=content_type,
chunk_size=chunk_size,
ignore_flush=ignore_flush,
- retry=DEFAULT_RETRY)
+ retry=retry)
self.mode = "w"
self.enable_write_bucket_metric = enable_write_bucket_metric
diff --git a/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
b/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
index fad63813680..07a5fb5df55 100644
--- a/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
+++ b/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
@@ -34,6 +34,7 @@ import uuid
import mock
import pytest
+from parameterized import parameterized_class
from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import GoogleCloudOptions
@@ -51,6 +52,9 @@ except ImportError:
@unittest.skipIf(gcsio is None, 'GCP dependencies are not installed')
+@parameterized_class(
+ ('no_gcsio_throttling_counter', 'enable_gcsio_blob_generation'),
+ [(False, False), (False, True), (True, False), (True, True)])
class GcsIOIntegrationTest(unittest.TestCase):
INPUT_FILE = 'gs://dataflow-samples/shakespeare/kinglear.txt'
@@ -67,7 +71,6 @@ class GcsIOIntegrationTest(unittest.TestCase):
self.gcs_tempdir = (
self.test_pipeline.get_option('temp_location') + '/gcs_it-' +
str(uuid.uuid4()))
- self.gcsio = gcsio.GcsIO()
def tearDown(self):
FileSystems.delete([self.gcs_tempdir + '/'])
@@ -92,14 +95,47 @@ class GcsIOIntegrationTest(unittest.TestCase):
@pytest.mark.it_postcommit
def test_copy(self):
+ self.gcsio = gcsio.GcsIO(
+ pipeline_options={
+ "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
+ "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
+ })
+ src = self.INPUT_FILE
+ dest = self.gcs_tempdir + '/test_copy'
+
+ self.gcsio.copy(src, dest)
+ self._verify_copy(src, dest)
+
+ unknown_src = self.test_pipeline.get_option('temp_location') + \
+ '/gcs_it-' + str(uuid.uuid4())
+ with self.assertRaises(NotFound):
+ self.gcsio.copy(unknown_src, dest)
+
+ @pytest.mark.it_postcommit
+ def test_copy_and_delete(self):
+ self.gcsio = gcsio.GcsIO(
+ pipeline_options={
+ "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
+ "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
+ })
src = self.INPUT_FILE
dest = self.gcs_tempdir + '/test_copy'
self.gcsio.copy(src, dest)
self._verify_copy(src, dest)
+ self.gcsio.delete(dest)
+
+ # no exception if we delete an nonexistent file.
+ self.gcsio.delete(dest)
+
@pytest.mark.it_postcommit
def test_batch_copy_and_delete(self):
+ self.gcsio = gcsio.GcsIO(
+ pipeline_options={
+ "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
+ "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
+ })
num_copies = 10
srcs = [self.INPUT_FILE] * num_copies
dests = [
@@ -152,6 +188,7 @@ class GcsIOIntegrationTest(unittest.TestCase):
@mock.patch('apache_beam.io.gcp.gcsio.default_gcs_bucket_name')
@unittest.skipIf(NotFound is None, 'GCP dependencies are not installed')
def test_create_default_bucket(self, mock_default_gcs_bucket_name):
+ self.gcsio = gcsio.GcsIO()
google_cloud_options = self.test_pipeline.options.view_as(
GoogleCloudOptions)
# overwrite kms option here, because get_or_create_default_gcs_bucket()
diff --git a/sdks/python/apache_beam/io/gcp/gcsio_retry.py
b/sdks/python/apache_beam/io/gcp/gcsio_retry.py
new file mode 100644
index 00000000000..29fd71c5195
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/gcsio_retry.py
@@ -0,0 +1,71 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Throttling Handler for GCSIO
+"""
+
+import inspect
+import logging
+import math
+
+from google.api_core import exceptions as api_exceptions
+from google.api_core import retry
+from google.cloud.storage.retry import DEFAULT_RETRY
+from google.cloud.storage.retry import _should_retry # pylint:
disable=protected-access
+
+from apache_beam.metrics.metric import Metrics
+from apache_beam.options.pipeline_options import GoogleCloudOptions
+
+_LOGGER = logging.getLogger(__name__)
+
+__all__ = ['DEFAULT_RETRY_WITH_THROTTLING_COUNTER']
+
+
+class ThrottlingHandler(object):
+ _THROTTLED_SECS = Metrics.counter('gcsio', "cumulativeThrottlingSeconds")
+
+ def __call__(self, exc):
+ if isinstance(exc, api_exceptions.TooManyRequests):
+ _LOGGER.debug('Caught GCS quota error (%s), retrying.', exc.reason)
+ # TODO: revisit the logic here when gcs client library supports error
+ # callbacks
+ frame = inspect.currentframe()
+ if frame is None:
+ _LOGGER.warning('cannot inspect the current stack frame')
+ return
+
+ prev_frame = frame.f_back
+ if prev_frame is None:
+ _LOGGER.warning('cannot inspect the caller stack frame')
+ return
+
+ # next_sleep is one of the arguments in the caller
+ # i.e. _retry_error_helper() in google/api_core/retry/retry_base.py
+ sleep_seconds = prev_frame.f_locals.get("next_sleep", 0)
+ ThrottlingHandler._THROTTLED_SECS.inc(math.ceil(sleep_seconds))
+
+
+DEFAULT_RETRY_WITH_THROTTLING_COUNTER = retry.Retry(
+ predicate=_should_retry, on_error=ThrottlingHandler())
+
+
+def get_retry(pipeline_options):
+ if pipeline_options.view_as(GoogleCloudOptions).no_gcsio_throttling_counter:
+ return DEFAULT_RETRY
+ else:
+ return DEFAULT_RETRY_WITH_THROTTLING_COUNTER
diff --git a/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py
b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py
new file mode 100644
index 00000000000..750879ae028
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for Throttling Handler of GCSIO."""
+
+import unittest
+from unittest.mock import Mock
+
+from apache_beam.metrics.execution import MetricsContainer
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics.metricbase import MetricName
+from apache_beam.runners.worker import statesampler
+from apache_beam.utils import counters
+
+try:
+ from apache_beam.io.gcp import gcsio_retry
+ from google.api_core import exceptions as api_exceptions
+except ImportError:
+ gcsio_retry = None
+ api_exceptions = None
+
+
[email protected]((gcsio_retry is None or api_exceptions is None),
+ 'GCP dependencies are not installed')
+class TestGCSIORetry(unittest.TestCase):
+ def test_retry_on_non_retriable(self):
+ mock = Mock(side_effect=[
+ Exception('Something wrong!'),
+ ])
+ retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER
+ with self.assertRaises(Exception):
+ retry(mock)()
+
+ def test_retry_on_throttling(self):
+ mock = Mock(
+ side_effect=[
+ api_exceptions.TooManyRequests("Slow down!"),
+ api_exceptions.TooManyRequests("Slow down again!"),
+ 12345
+ ])
+ retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER
+
+ sampler = statesampler.StateSampler('', counters.CounterFactory())
+ statesampler.set_current_tracker(sampler)
+ state = sampler.scoped_state(
+ 'my_step', 'my_state', metrics_container=MetricsContainer('my_step'))
+ try:
+ sampler.start()
+ with state:
+ container = MetricsEnvironment.current_container()
+
+ self.assertEqual(
+ container.get_counter(
+ MetricName('gcsio',
+ "cumulativeThrottlingSeconds")).get_cumulative(),
+ 0)
+
+ self.assertEqual(12345, retry(mock)())
+
+ self.assertGreater(
+ container.get_counter(
+ MetricName('gcsio',
+ "cumulativeThrottlingSeconds")).get_cumulative(),
+ 1)
+ finally:
+ sampler.stop()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py
b/sdks/python/apache_beam/io/gcp/gcsio_test.py
index 407295f2fb3..19df15dcf7f 100644
--- a/sdks/python/apache_beam/io/gcp/gcsio_test.py
+++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py
@@ -20,6 +20,7 @@
import logging
import os
+import random
import unittest
from datetime import datetime
@@ -36,6 +37,7 @@ from apache_beam.utils import counters
try:
from apache_beam.io.gcp import gcsio
+ from apache_beam.io.gcp.gcsio_retry import
DEFAULT_RETRY_WITH_THROTTLING_COUNTER
from google.cloud.exceptions import BadRequest, NotFound
except ImportError:
NotFound = None
@@ -85,7 +87,7 @@ class FakeGcsClient(object):
holder = folder.get_blob(blob.name)
return holder
- def list_blobs(self, bucket_or_path, prefix=None):
+ def list_blobs(self, bucket_or_path, prefix=None, **unused_kwargs):
bucket = self.get_bucket(bucket_or_path.name)
if not prefix:
return list(bucket.blobs.values())
@@ -120,7 +122,7 @@ class FakeBucket(object):
def blob(self, name):
return self._create_blob(name)
- def copy_blob(self, blob, dest, new_name=None):
+ def copy_blob(self, blob, dest, new_name=None, **kwargs):
if self.get_blob(blob.name) is None:
raise NotFound("source blob not found")
if not new_name:
@@ -129,7 +131,7 @@ class FakeBucket(object):
dest.add_blob(new_blob)
return new_blob
- def get_blob(self, blob_name):
+ def get_blob(self, blob_name, **unused_kwargs):
bucket = self._get_canonical_bucket()
if blob_name in bucket.blobs:
return bucket.blobs[blob_name]
@@ -146,7 +148,7 @@ class FakeBucket(object):
def set_default_kms_key_name(self, name):
self.default_kms_key_name = name
- def delete_blob(self, name):
+ def delete_blob(self, name, **kwargs):
bucket = self._get_canonical_bucket()
if name in bucket.blobs:
del bucket.blobs[name]
@@ -175,6 +177,7 @@ class FakeBlob(object):
self.updated = updated
self._fail_when_getting_metadata = fail_when_getting_metadata
self._fail_when_reading = fail_when_reading
+ self.generation = random.randint(0, (1 << 63) - 1)
def delete(self):
self.bucket.delete_blob(self.name)
@@ -532,7 +535,10 @@ class TestGCSIO(unittest.TestCase):
with mock.patch('apache_beam.io.gcp.gcsio.BeamBlobReader') as reader:
self.gcs.open(file_name, read_buffer_size=read_buffer_size)
reader.assert_called_with(
- blob, chunk_size=read_buffer_size, enable_read_bucket_metric=False)
+ blob,
+ chunk_size=read_buffer_size,
+ enable_read_bucket_metric=False,
+ retry=DEFAULT_RETRY_WITH_THROTTLING_COUNTER)
def test_file_write_call(self):
file_name = 'gs://gcsio-test/write_file'
diff --git a/sdks/python/apache_beam/options/pipeline_options.py
b/sdks/python/apache_beam/options/pipeline_options.py
index 86bf119c8bf..50021c4610f 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -947,6 +947,18 @@ class GoogleCloudOptions(PipelineOptions):
help=
'Create metrics reporting the approximate number of bytes written per '
'bucket.')
+ parser.add_argument(
+ '--no_gcsio_throttling_counter',
+ default=False,
+ action='store_true',
+ help='Throttling counter in GcsIO is enabled by default. Set '
+ '--no_gcsio_throttling_counter to avoid it.')
+ parser.add_argument(
+ '--enable_gcsio_blob_generation',
+ default=False,
+ action='store_true',
+ help='Use blob generation when mutating blobs in GCSIO to '
+ 'mitigate race conditions at the cost of more HTTP requests.')
def _create_default_gcs_bucket(self):
try: