This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 581ec8bb17f Always mark the instruction as cleaned up in the GRPC data 
channel when processing an instruction fails. (#36367)
581ec8bb17f is described below

commit 581ec8bb17fa96234cbf9fca552546aeb2e4fd4a
Author: tvalentyn <[email protected]>
AuthorDate: Thu Oct 16 10:09:54 2025 -0700

    Always mark the instruction as cleaned up in the GRPC data channel when 
processing an instruction fails. (#36367)
    
    * Mark instructions as cleaned up in the GRPC data channel if processing an 
instruction fails.
    
    * Invoke cleanup even if BP failed to create.
    
    * Address feedback.
    
    * Add a test
---
 .../apache_beam/runners/worker/data_plane.py       | 19 +++++++++-
 .../apache_beam/runners/worker/sdk_worker.py       | 13 ++++---
 .../apache_beam/runners/worker/sdk_worker_test.py  | 44 ++++++++++++++++++++--
 3 files changed, 65 insertions(+), 11 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index e4cf4f185ad..cbd28f8b0a3 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -502,7 +502,11 @@ class _GrpcDataChannel(DataChannel):
     instruction_id cannot be reused for new queue.
     """
     with self._receive_lock:
-      self._received.pop(instruction_id)
+      # Per-instruction read queue may or may not be created yet when
+      # we mark an instruction as 'cleaned up' when creating
+      # a bundle processor failed, e.g. due to a flake in DoFn.setup().
+      # We want to mark an instruction as cleaned up regardless.
+      self._received.pop(instruction_id, None)
       self._cleaned_instruction_ids[instruction_id] = True
       while len(self._cleaned_instruction_ids) > _MAX_CLEANED_INSTRUCTIONS:
         self._cleaned_instruction_ids.popitem(last=False)
@@ -787,6 +791,12 @@ class DataChannelFactory(metaclass=abc.ABCMeta):
     """Close all channels that this factory owns."""
     raise NotImplementedError(type(self))
 
+  def cleanup(self, instruction_id):
+    # type: (str) -> None
+
+    """Clean up resources for a given instruction."""
+    pass
+
 
 class GrpcClientDataChannelFactory(DataChannelFactory):
   """A factory for ``GrpcClientDataChannel``.
@@ -851,10 +861,15 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
   def close(self):
     # type: () -> None
     _LOGGER.info('Closing all cached grpc data channels.')
-    for _, channel in self._data_channel_cache.items():
+    for channel in list(self._data_channel_cache.values()):
       channel.close()
     self._data_channel_cache.clear()
 
+  def cleanup(self, instruction_id):
+    # type: (str) -> None
+    for channel in list(self._data_channel_cache.values()):
+      channel._clean_receiving_queue(instruction_id)
+
 
 class InMemoryDataChannelFactory(DataChannelFactory):
   """A singleton factory for ``InMemoryDataChannel``."""
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index c520740038e..6060ff8d54a 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -563,15 +563,18 @@ class BundleProcessorCache(object):
     """
     Marks the instruction id as failed shutting down the ``BundleProcessor``.
     """
+    processor = None
     with self._lock:
       self.failed_instruction_ids[instruction_id] = exception
       while len(self.failed_instruction_ids) > MAX_FAILED_INSTRUCTIONS:
         self.failed_instruction_ids.popitem(last=False)
-      processor = self.active_bundle_processors[instruction_id][1]
-      del self.active_bundle_processors[instruction_id]
+      if instruction_id in self.active_bundle_processors:
+        processor = self.active_bundle_processors.pop(instruction_id)[1]
 
     # Perform the shutdown while not holding the lock.
-    processor.shutdown()
+    if processor:
+      processor.shutdown()
+    self.data_channel_factory.cleanup(instruction_id)
 
   def release(self, instruction_id):
     # type: (str) -> None
@@ -694,9 +697,9 @@ class SdkWorker(object):
       instruction_id  # type: str
   ):
     # type: (...) -> beam_fn_api_pb2.InstructionResponse
-    bundle_processor = self.bundle_processor_cache.get(
-        instruction_id, request.process_bundle_descriptor_id)
     try:
+      bundle_processor = self.bundle_processor_cache.get(
+          instruction_id, request.process_bundle_descriptor_id)
       with bundle_processor.state_handler.process_instruction_id(
           instruction_id, request.cache_tokens):
         with self.maybe_profile(instruction_id):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 0ab04ff256c..7b53f274cac 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -37,6 +37,7 @@ from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.portability.api import metrics_pb2
+from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import sdk_worker
 from apache_beam.runners.worker import statecache
 from apache_beam.runners.worker.sdk_worker import BundleProcessorCache
@@ -126,7 +127,10 @@ class SdkWorkerTest(unittest.TestCase):
 
   def test_inactive_bundle_processor_returns_empty_progress_response(self):
     bundle_processor = mock.MagicMock()
-    bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+    data_channel_factory = mock.create_autospec(
+        data_plane.GrpcClientDataChannelFactory)
+    bundle_processor_cache = BundleProcessorCache(
+        None, None, data_channel_factory, {})
     bundle_processor_cache.activate('instruction_id')
     worker = SdkWorker(bundle_processor_cache)
     split_request = beam_fn_api_pb2.InstructionRequest(
@@ -153,7 +157,10 @@ class SdkWorkerTest(unittest.TestCase):
 
   def test_failed_bundle_processor_returns_failed_progress_response(self):
     bundle_processor = mock.MagicMock()
-    bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+    data_channel_factory = mock.create_autospec(
+        data_plane.GrpcClientDataChannelFactory)
+    bundle_processor_cache = BundleProcessorCache(
+        None, None, data_channel_factory, {})
     bundle_processor_cache.activate('instruction_id')
     worker = SdkWorker(bundle_processor_cache)
 
@@ -176,7 +183,10 @@ class SdkWorkerTest(unittest.TestCase):
 
   def test_inactive_bundle_processor_returns_empty_split_response(self):
     bundle_processor = mock.MagicMock()
-    bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+    data_channel_factory = mock.create_autospec(
+        data_plane.GrpcClientDataChannelFactory)
+    bundle_processor_cache = BundleProcessorCache(
+        None, None, data_channel_factory, {})
     bundle_processor_cache.activate('instruction_id')
     worker = SdkWorker(bundle_processor_cache)
     split_request = beam_fn_api_pb2.InstructionRequest(
@@ -262,7 +272,10 @@ class SdkWorkerTest(unittest.TestCase):
 
   def test_failed_bundle_processor_returns_failed_split_response(self):
     bundle_processor = mock.MagicMock()
-    bundle_processor_cache = BundleProcessorCache(None, None, None, {})
+    data_channel_factory = mock.create_autospec(
+        data_plane.GrpcClientDataChannelFactory)
+    bundle_processor_cache = BundleProcessorCache(
+        None, None, data_channel_factory, {})
     bundle_processor_cache.activate('instruction_id')
     worker = SdkWorker(bundle_processor_cache)
 
@@ -338,6 +351,29 @@ class SdkWorkerTest(unittest.TestCase):
 
     self.assertEqual(response, expected_response)
 
+  def test_bundle_processor_creation_failure_cleans_up_grpc_data_channel(self):
+    data_channel_factory = data_plane.GrpcClientDataChannelFactory()
+    channel = data_channel_factory.create_data_channel_from_url('some_url')
+    state_handler_factory = mock.create_autospec(
+        sdk_worker.GrpcStateHandlerFactory)
+    bundle_processor_cache = BundleProcessorCache(
+        frozenset(), state_handler_factory, data_channel_factory, {})
+    if bundle_processor_cache.periodic_shutdown:
+      bundle_processor_cache.periodic_shutdown.cancel()
+
+    bundle_processor_cache.get = mock.MagicMock(
+        side_effect=RuntimeError('test error'))
+
+    worker = SdkWorker(bundle_processor_cache)
+    instruction_id = 'instruction_id'
+    request = beam_fn_api_pb2.ProcessBundleRequest(
+        process_bundle_descriptor_id='descriptor_id')
+
+    with self.assertRaises(RuntimeError):
+      worker.process_bundle(request, instruction_id)
+
+    self.assertIn(instruction_id, channel._cleaned_instruction_ids)
+
 
 class CachingStateHandlerTest(unittest.TestCase):
   def test_caching(self):

Reply via email to