This is an automated email from the ASF dual-hosted git repository.

shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c33bc972f15 Fix race condition in UserPipelineTracker.clear() and 
various problems (#38537)
c33bc972f15 is described below

commit c33bc972f150f89ace5d2f3ef708cdfc1c231095
Author: Shunping Huang <[email protected]>
AuthorDate: Wed May 20 12:16:33 2026 -0400

    Fix race condition in UserPipelineTracker.clear() and various problems 
(#38537)
    
    * Fix race condition in UserPipelineTracker.clear()
    
    * Address comments.
    
    * Fix lints.
    
    * Apply suggestions from code review
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Fix failed test RecordingTest.test_describe
    
    * Fix failed tests test_instrument_example_pipeline_to_write_cache and 
test_instrument_example_pipeline_to_read_cache.
    
    * Formatting.
    
    * Fix InteractiveBeamTest.test_recordings_clear and test_recordings_record.
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 .../runners/interactive/interactive_environment.py |  10 +-
 .../runners/interactive/user_pipeline_tracker.py   | 109 ++++++++++++---------
 .../interactive/user_pipeline_tracker_test.py      |  48 +++++++++
 3 files changed, 119 insertions(+), 48 deletions(-)

diff --git 
a/sdks/python/apache_beam/runners/interactive/interactive_environment.py 
b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index bfb1a7f1190..b243d20ff85 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -365,11 +365,17 @@ class InteractiveEnvironment(object):
     if self.get_cache_manager(pipeline) is cache_manager:
       # NOOP if setting to the same cache_manager.
       return
+    # Check if the pipeline is already tracked as a user pipeline before 
cleanup.
+    is_user_pipeline = self._tracked_user_pipelines.get_user_pipeline(
+        pipeline) is pipeline
     if self.get_cache_manager(pipeline):
       # Invoke cleanup routine when a new cache_manager is forcefully set and
       # current cache_manager is not None.
       self.cleanup(pipeline)
     self._cache_managers[str(id(pipeline))] = cache_manager
+    if is_user_pipeline:
+      # Re-track the user pipeline because the self.cleanup() call above 
evicts it.
+      self.add_user_pipeline(pipeline)
 
   def get_cache_manager(self, pipeline, create_if_absent=False):
     """Gets the cache manager held by current Interactive Environment for the
@@ -468,8 +474,8 @@ class InteractiveEnvironment(object):
   def describe_all_recordings(self):
     """Returns a description of the recording for all watched pipelnes."""
     return {
-        self.pipeline_id_to_pipeline(pid): rm.describe()
-        for pid, rm in self._recording_managers.items()
+        rm.user_pipeline: rm.describe()
+        for rm in self._recording_managers.values()
     }
 
   def set_pipeline_result(self, pipeline, result):
diff --git 
a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py 
b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
index 53ee54ac8a3..4c7871c02be 100644
--- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
+++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
@@ -25,6 +25,7 @@ that derived pipelines can link back to the parent user 
pipeline.
 """
 
 import shutil
+import threading
 from typing import Iterator
 from typing import Optional
 
@@ -39,13 +40,16 @@ class UserPipelineTracker:
   derived pipelines.
   """
   def __init__(self):
+    self._lock = threading.RLock()
     self._user_pipelines: dict[beam.Pipeline, list[beam.Pipeline]] = {}
-    self._derived_pipelines: dict[beam.Pipeline] = {}
-    self._pid_to_pipelines: dict[beam.Pipeline] = {}
+    self._derived_pipelines: dict[beam.Pipeline, beam.Pipeline] = {}
+    self._pid_to_pipelines: dict[str, beam.Pipeline] = {}
 
   def __iter__(self) -> Iterator[beam.Pipeline]:
     """Iterates through all the user pipelines."""
-    for p in self._user_pipelines:
+    with self._lock:
+      pipelines = list(self._user_pipelines.keys())
+    for p in pipelines:
       yield p
 
   def _key(self, pipeline: beam.Pipeline) -> str:
@@ -57,45 +61,57 @@ class UserPipelineTracker:
     Removes the given pipeline and derived pipelines if a user pipeline.
     Otherwise, removes the given derived pipeline.
     """
-    user_pipeline = self.get_user_pipeline(pipeline)
-    if user_pipeline:
-      for d in self._user_pipelines[user_pipeline]:
-        del self._derived_pipelines[d]
-      del self._user_pipelines[user_pipeline]
-    elif pipeline in self._derived_pipelines:
-      del self._derived_pipelines[pipeline]
+    with self._lock:
+      if pipeline in self._user_pipelines:
+        for d in self._user_pipelines[pipeline]:
+          self._derived_pipelines.pop(d, None)
+          self._pid_to_pipelines.pop(self._key(d), None)
+        self._user_pipelines.pop(pipeline, None)
+      elif pipeline in self._derived_pipelines:
+        user_pipeline = self._derived_pipelines.pop(pipeline, None)
+        if user_pipeline in self._user_pipelines:
+          try:
+            self._user_pipelines[user_pipeline].remove(pipeline)
+          except ValueError:
+            pass
+      self._pid_to_pipelines.pop(self._key(pipeline), None)
 
   def clear(self) -> None:
     """Clears the tracker of all user and derived pipelines."""
     # Remove all local_tempdir of created pipelines.
-    for p in self._pid_to_pipelines.values():
-      shutil.rmtree(p.local_tempdir, ignore_errors=True)
+    with self._lock:
+      pipelines = list(self._pid_to_pipelines.values())
+      self._user_pipelines.clear()
+      self._derived_pipelines.clear()
+      self._pid_to_pipelines.clear()
 
-    self._user_pipelines.clear()
-    self._derived_pipelines.clear()
-    self._pid_to_pipelines.clear()
+    for p in pipelines:
+      shutil.rmtree(p.local_tempdir, ignore_errors=True)
 
   def get_pipeline(self, pid: str) -> Optional[beam.Pipeline]:
     """Returns the pipeline corresponding to the given pipeline id."""
-    return self._pid_to_pipelines.get(pid, None)
+    with self._lock:
+      return self._pid_to_pipelines.get(pid, None)
 
   def add_user_pipeline(self, p: beam.Pipeline) -> beam.Pipeline:
     """Adds a user pipeline with an empty set of derived pipelines."""
-    self._memoize_pipieline(p)
+    with self._lock:
+      self._memoize_pipeline(p)
 
-    # Create a new node for the user pipeline if it doesn't exist already.
-    user_pipeline = self.get_user_pipeline(p)
-    if not user_pipeline:
-      user_pipeline = p
-      self._user_pipelines[p] = []
+      # Create a new node for the user pipeline if it doesn't exist already.
+      user_pipeline = self.get_user_pipeline(p)
+      if not user_pipeline:
+        user_pipeline = p
+        self._user_pipelines[p] = []
 
-    return user_pipeline
+      return user_pipeline
 
-  def _memoize_pipieline(self, p: beam.Pipeline) -> None:
+  def _memoize_pipeline(self, p: beam.Pipeline) -> None:
     """Memoizes the pid of the pipeline to the pipeline object."""
     pid = self._key(p)
-    if pid not in self._pid_to_pipelines:
-      self._pid_to_pipelines[pid] = p
+    with self._lock:
+      if pid not in self._pid_to_pipelines:
+        self._pid_to_pipelines[pid] = p
 
   def add_derived_pipeline(
       self, maybe_user_pipeline: beam.Pipeline,
@@ -119,20 +135,21 @@ class UserPipelineTracker:
     # Returns p.
     ut.get_user_pipeline(derived2)
     """
-    self._memoize_pipieline(maybe_user_pipeline)
-    self._memoize_pipieline(derived_pipeline)
+    with self._lock:
+      self._memoize_pipeline(maybe_user_pipeline)
+      self._memoize_pipeline(derived_pipeline)
 
-    # Cannot add a derived pipeline twice.
-    assert derived_pipeline not in self._derived_pipelines
+      # Cannot add a derived pipeline twice.
+      assert derived_pipeline not in self._derived_pipelines
 
-    # Get the "true" user pipeline. This allows for the user to derive a
-    # pipeline from another derived pipeline, use both as arguments, and this
-    # method will still get the correct user pipeline.
-    user = self.add_user_pipeline(maybe_user_pipeline)
+      # Get the "true" user pipeline. This allows for the user to derive a
+      # pipeline from another derived pipeline, use both as arguments, and this
+      # method will still get the correct user pipeline.
+      user = self.add_user_pipeline(maybe_user_pipeline)
 
-    # Map the derived pipeline to the user pipeline.
-    self._derived_pipelines[derived_pipeline] = user
-    self._user_pipelines[user].append(derived_pipeline)
+      # Map the derived pipeline to the user pipeline.
+      self._derived_pipelines[derived_pipeline] = user
+      self._user_pipelines[user].append(derived_pipeline)
 
   def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]:
     """Returns the user pipeline of the given pipeline.
@@ -142,14 +159,14 @@ class UserPipelineTracker:
     returns the same pipeline. If the given pipeline is a derived pipeline then
     this returns the user pipeline.
     """
+    with self._lock:
+      # If `p` is a user pipeline then return it.
+      if p in self._user_pipelines:
+        return p
 
-    # If `p` is a user pipeline then return it.
-    if p in self._user_pipelines:
-      return p
-
-    # If `p` exists then return its user pipeline.
-    if p in self._derived_pipelines:
-      return self._derived_pipelines[p]
+      # If `p` exists then return its user pipeline.
+      if p in self._derived_pipelines:
+        return self._derived_pipelines[p]
 
-    # Otherwise, `p` is not in this tracker.
-    return None
+      # Otherwise, `p` is not in this tracker.
+      return None
diff --git 
a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py 
b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
index f7025b8b75b..6fb8e4dbad9 100644
--- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
+++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 #
 
+import threading
 import unittest
+from unittest.mock import patch
 
 import apache_beam as beam
 from apache_beam.runners.interactive.user_pipeline_tracker import 
UserPipelineTracker
@@ -202,6 +204,52 @@ class UserPipelineTrackerTest(unittest.TestCase):
     self.assertIs(user2, ut.get_user_pipeline(derived21))
     self.assertIs(user2, ut.get_user_pipeline(derived22))
 
+  def test_clear_race_condition(self):
+    ut = UserPipelineTracker()
+    # Add a pipeline so clear() has at least one element to iterate over.
+    p1 = beam.Pipeline()
+    derived1 = beam.Pipeline()
+    ut.add_derived_pipeline(p1, derived1)
+
+    # Set by the mock when clear() enters its loop. Signals the background
+    # worker to mutate.
+    in_loop_event = threading.Event()
+    # Set by the worker when mutation is complete. Signals mock that it can
+    # safely resume clear().
+    mutate_done_event = threading.Event()
+
+    def mock_rmtree(path, ignore_errors=False):
+      # Signal the worker that clear() is iterating.
+      in_loop_event.set()
+      # Pause here to give the worker thread time to perform the mutation.
+      mutate_done_event.wait(timeout=5)
+
+    def worker():
+      # Wait for clear() to start iterating.
+      if in_loop_event.wait(timeout=5):
+        # Concurrently mutate the tracker dictionary.
+        p2 = beam.Pipeline()
+        derived2 = beam.Pipeline()
+        try:
+          ut.add_derived_pipeline(p2, derived2)
+        finally:
+          # Resume the main thread.
+          mutate_done_event.set()
+
+    thread = threading.Thread(target=worker)
+    thread.start()
+
+    try:
+      # Intercept shutil.rmtree inside clear() to orchestrate the concurrent
+      # mutation.
+      with patch('shutil.rmtree', side_effect=mock_rmtree):
+        ut.clear()
+    finally:
+      # Avoid hanging tests if events are missed.
+      in_loop_event.set()
+      mutate_done_event.set()
+      thread.join()
+
 
 if __name__ == '__main__':
   unittest.main()

Reply via email to