This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 581ec8bb17f Always mark the instruction as cleaned up in the GRPC data
channel when processing an instruction fails. (#36367)
581ec8bb17f is described below
commit 581ec8bb17fa96234cbf9fca552546aeb2e4fd4a
Author: tvalentyn <[email protected]>
AuthorDate: Thu Oct 16 10:09:54 2025 -0700
Always mark the instruction as cleaned up in the GRPC data channel when
processing an instruction fails. (#36367)
* Mark instructions as cleaned up in the GRPC data channel if processing an
instruction fails.
* Invoke cleanup even if BP failed to create.
* Address feedback.
* Add a test
---
.../apache_beam/runners/worker/data_plane.py | 19 +++++++++-
.../apache_beam/runners/worker/sdk_worker.py | 13 ++++---
.../apache_beam/runners/worker/sdk_worker_test.py | 44 ++++++++++++++++++++--
3 files changed, 65 insertions(+), 11 deletions(-)
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py
b/sdks/python/apache_beam/runners/worker/data_plane.py
index e4cf4f185ad..cbd28f8b0a3 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -502,7 +502,11 @@ class _GrpcDataChannel(DataChannel):
instruction_id cannot be reused for new queue.
"""
with self._receive_lock:
- self._received.pop(instruction_id)
+ # Per-instruction read queue may or may not be created yet when
+ # we mark an instruction as 'cleaned up' when creating
+ # a bundle processor failed, e.g. due to a flake in DoFn.setup().
+ # We want to mark an instruction as cleaned up regardless.
+ self._received.pop(instruction_id, None)
self._cleaned_instruction_ids[instruction_id] = True
while len(self._cleaned_instruction_ids) > _MAX_CLEANED_INSTRUCTIONS:
self._cleaned_instruction_ids.popitem(last=False)
@@ -787,6 +791,12 @@ class DataChannelFactory(metaclass=abc.ABCMeta):
"""Close all channels that this factory owns."""
raise NotImplementedError(type(self))
+ def cleanup(self, instruction_id):
+ # type: (str) -> None
+
+ """Clean up resources for a given instruction."""
+ pass
+
class GrpcClientDataChannelFactory(DataChannelFactory):
"""A factory for ``GrpcClientDataChannel``.
@@ -851,10 +861,15 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
def close(self):
# type: () -> None
_LOGGER.info('Closing all cached grpc data channels.')
- for _, channel in self._data_channel_cache.items():
+ for channel in list(self._data_channel_cache.values()):
channel.close()
self._data_channel_cache.clear()
+ def cleanup(self, instruction_id):
+ # type: (str) -> None
+ for channel in list(self._data_channel_cache.values()):
+ channel._clean_receiving_queue(instruction_id)
+
class InMemoryDataChannelFactory(DataChannelFactory):
"""A singleton factory for ``InMemoryDataChannel``."""
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index c520740038e..6060ff8d54a 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -563,15 +563,18 @@ class BundleProcessorCache(object):
"""
Marks the instruction id as failed shutting down the ``BundleProcessor``.
"""
+ processor = None
with self._lock:
self.failed_instruction_ids[instruction_id] = exception
while len(self.failed_instruction_ids) > MAX_FAILED_INSTRUCTIONS:
self.failed_instruction_ids.popitem(last=False)
- processor = self.active_bundle_processors[instruction_id][1]
- del self.active_bundle_processors[instruction_id]
+ if instruction_id in self.active_bundle_processors:
+ processor = self.active_bundle_processors.pop(instruction_id)[1]
# Perform the shutdown while not holding the lock.
- processor.shutdown()
+ if processor:
+ processor.shutdown()
+ self.data_channel_factory.cleanup(instruction_id)
def release(self, instruction_id):
# type: (str) -> None
@@ -694,9 +697,9 @@ class SdkWorker(object):
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
- bundle_processor = self.bundle_processor_cache.get(
- instruction_id, request.process_bundle_descriptor_id)
try:
+ bundle_processor = self.bundle_processor_cache.get(
+ instruction_id, request.process_bundle_descriptor_id)
with bundle_processor.state_handler.process_instruction_id(
instruction_id, request.cache_tokens):
with self.maybe_profile(instruction_id):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 0ab04ff256c..7b53f274cac 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -37,6 +37,7 @@ from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import metrics_pb2
+from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import sdk_worker
from apache_beam.runners.worker import statecache
from apache_beam.runners.worker.sdk_worker import BundleProcessorCache
@@ -126,7 +127,10 @@ class SdkWorkerTest(unittest.TestCase):
def test_inactive_bundle_processor_returns_empty_progress_response(self):
bundle_processor = mock.MagicMock()
- bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+ data_channel_factory = mock.create_autospec(
+ data_plane.GrpcClientDataChannelFactory)
+ bundle_processor_cache = BundleProcessorCache(
+ None, None, data_channel_factory, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
split_request = beam_fn_api_pb2.InstructionRequest(
@@ -153,7 +157,10 @@ class SdkWorkerTest(unittest.TestCase):
def test_failed_bundle_processor_returns_failed_progress_response(self):
bundle_processor = mock.MagicMock()
- bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+ data_channel_factory = mock.create_autospec(
+ data_plane.GrpcClientDataChannelFactory)
+ bundle_processor_cache = BundleProcessorCache(
+ None, None, data_channel_factory, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
@@ -176,7 +183,10 @@ class SdkWorkerTest(unittest.TestCase):
def test_inactive_bundle_processor_returns_empty_split_response(self):
bundle_processor = mock.MagicMock()
- bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+ data_channel_factory = mock.create_autospec(
+ data_plane.GrpcClientDataChannelFactory)
+ bundle_processor_cache = BundleProcessorCache(
+ None, None, data_channel_factory, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
split_request = beam_fn_api_pb2.InstructionRequest(
@@ -262,7 +272,10 @@ class SdkWorkerTest(unittest.TestCase):
def test_failed_bundle_processor_returns_failed_split_response(self):
bundle_processor = mock.MagicMock()
- bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+ data_channel_factory = mock.create_autospec(
+ data_plane.GrpcClientDataChannelFactory)
+ bundle_processor_cache = BundleProcessorCache(
+ None, None, data_channel_factory, {})
bundle_processor_cache.activate('instruction_id')
worker = SdkWorker(bundle_processor_cache)
@@ -338,6 +351,29 @@ class SdkWorkerTest(unittest.TestCase):
self.assertEqual(response, expected_response)
+ def test_bundle_processor_creation_failure_cleans_up_grpc_data_channel(self):
+ data_channel_factory = data_plane.GrpcClientDataChannelFactory()
+ channel = data_channel_factory.create_data_channel_from_url('some_url')
+ state_handler_factory = mock.create_autospec(
+ sdk_worker.GrpcStateHandlerFactory)
+ bundle_processor_cache = BundleProcessorCache(
+ frozenset(), state_handler_factory, data_channel_factory, {})
+ if bundle_processor_cache.periodic_shutdown:
+ bundle_processor_cache.periodic_shutdown.cancel()
+
+ bundle_processor_cache.get = mock.MagicMock(
+ side_effect=RuntimeError('test error'))
+
+ worker = SdkWorker(bundle_processor_cache)
+ instruction_id = 'instruction_id'
+ request = beam_fn_api_pb2.ProcessBundleRequest(
+ process_bundle_descriptor_id='descriptor_id')
+
+ with self.assertRaises(RuntimeError):
+ worker.process_bundle(request, instruction_id)
+
+ self.assertIn(instruction_id, channel._cleaned_instruction_ids)
+
class CachingStateHandlerTest(unittest.TestCase):
def test_caching(self):