This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a0d27de943f Update queuing logic to avoid lock notify contention 
(#37528)
a0d27de943f is described below

commit a0d27de943fba84d04c82e83482ab03090995c00
Author: RuiLong J. <[email protected]>
AuthorDate: Sun Feb 8 18:37:08 2026 -0800

    Update queuing logic to avoid lock notify contention (#37528)
    
    * Draft of updated lock notify
    
    * Complete queue ticket implementation
    
    * Remove redudant warning log
---
 .../apache_beam/ml/inference/model_manager.py      | 66 +++++++++++++++-------
 .../apache_beam/ml/inference/model_manager_test.py | 14 ++++-
 2 files changed, 56 insertions(+), 24 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py 
b/sdks/python/apache_beam/ml/inference/model_manager.py
index cc9f833c268..bf7c6a43ba6 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -288,6 +288,17 @@ class ResourceEstimator:
       logger.error("Solver failed: %s", e)
 
 
+class QueueTicket:
+  def __init__(self, priority, ticket_num, tag):
+    self.priority = priority
+    self.ticket_num = ticket_num
+    self.tag = tag
+    self.wake_event = threading.Event()
+
+  def __lt__(self, other):
+    return (self.priority, self.ticket_num) < (other.priority, 
other.ticket_num)
+
+
 class ModelManager:
   """Manages model lifecycles, caching, and resource arbitration.
 
@@ -343,6 +354,7 @@ class ModelManager:
     # and also priority for unknown models.
     self._wait_queue = []
     self._ticket_counter = itertools.count()
+    self._cancelled_tickets = set()
     # TODO: Consider making the wait to be smarter, i.e.
     # splitting read/write etc. to avoid potential contention.
     self._cv = threading.Condition()
@@ -417,10 +429,28 @@ class ModelManager:
     self._cv.wait(timeout=self._lock_timeout_seconds)
     return False
 
+  def _wake_next_in_queue(self):
+    if self._wait_queue:
+      # Clean up cancelled tickets at head of queue
+      while self._wait_queue and self._wait_queue[
+          0].ticket_num in self._cancelled_tickets:
+        self._cancelled_tickets.remove(self._wait_queue[0].ticket_num)
+        heapq.heappop(self._wait_queue)
+      next_inline = self._wait_queue[0]
+      next_inline.wake_event.set()
+
+  def _wait_in_queue(self, ticket: QueueTicket):
+    self._cv.release()
+    try:
+      ticket.wake_event.wait(timeout=self._lock_timeout_seconds)
+      ticket.wake_event.clear()
+    finally:
+      self._cv.acquire()
+
   def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
     current_priority = 0 if self._estimator.is_unknown(tag) else 1
     ticket_num = next(self._ticket_counter)
-    my_id = object()
+    my_ticket = QueueTicket(current_priority, ticket_num, tag)
 
     with self._cv:
       # FAST PATH: Grab from idle LRU if available
@@ -439,8 +469,7 @@ class ModelManager:
           current_priority,
           len(self._models[tag]),
           ticket_num)
-      heapq.heappush(
-          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+      heapq.heappush(self._wait_queue, my_ticket)
 
       est_cost = 0.0
       is_unknown = False
@@ -453,10 +482,11 @@ class ModelManager:
             raise RuntimeError(
                 f"Timeout waiting to acquire model: {tag} "
                 f"after {wait_time_elapsed:.1f} seconds.")
-          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+          if not self._wait_queue or self._wait_queue[
+              0].ticket_num != ticket_num:
             logger.info(
                 "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
-            self._cv.wait(timeout=self._lock_timeout_seconds)
+            self._wait_in_queue(my_ticket)
             continue
 
           # Re-evaluate priority in case model became known during wait
@@ -467,9 +497,9 @@ class ModelManager:
           if current_priority != real_priority:
             heapq.heappop(self._wait_queue)
             current_priority = real_priority
-            heapq.heappush(
-                self._wait_queue, (current_priority, ticket_num, my_id, tag))
-            self._cv.notify_all()
+            my_ticket = QueueTicket(current_priority, ticket_num, tag)
+            heapq.heappush(self._wait_queue, my_ticket)
+            self._wake_next_in_queue()
             continue
 
           # Try grab from LRU again in case model was released during wait
@@ -494,7 +524,7 @@ class ModelManager:
                   "Waiting due to isolation in progress: tag=%s ticket num%s",
                   tag,
                   ticket_num)
-              self._cv.wait(timeout=self._lock_timeout_seconds)
+              self._wait_in_queue(my_ticket)
               continue
 
             if self.should_spawn_model(tag, ticket_num):
@@ -508,19 +538,12 @@ class ModelManager:
 
       finally:
         # Remove self from wait queue once done
-        if self._wait_queue and self._wait_queue[0][2] is my_id:
+        if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num:
           heapq.heappop(self._wait_queue)
         else:
-          logger.warning(
-              "Item not at head of wait queue during cleanup"
-              ", this is not expected: tag=%s ticket num=%s",
-              tag,
-              ticket_num)
-          for i, item in enumerate(self._wait_queue):
-            if item[2] is my_id:
-              self._wait_queue.pop(i)
-              heapq.heapify(self._wait_queue)
-        self._cv.notify_all()
+          # Marked as cancelled so that we skip when we reach head later
+          self._cancelled_tickets.add(ticket_num)
+        self._wake_next_in_queue()
 
       return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
 
@@ -553,6 +576,7 @@ class ModelManager:
             self._estimator.add_observation(snapshot, peak_during_job)
 
       finally:
+        self._wake_next_in_queue()
         self._cv.notify_all()
 
   def _try_grab_from_lru(self, tag: str) -> Any:
@@ -596,7 +620,7 @@ class ModelManager:
     # TODO: Also factor in the active counts to avoid thrashing
     demand_map = Counter()
     for item in self._wait_queue:
-      demand_map[item[3]] += 1
+      demand_map[item.tag] += 1
 
     my_demand = demand_map[requesting_tag]
     am_i_starving = len(self._models[requesting_tag]) == 0
diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py 
b/sdks/python/apache_beam/ml/inference/model_manager_test.py
index 7cfb73cb668..270401857e0 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager_test.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py
@@ -174,12 +174,20 @@ class TestModelManager(unittest.TestCase):
     def acquire_model_with_timeout():
       return self.manager.acquire_model(model_name, loader)
 
-    with ThreadPoolExecutor(max_workers=1) as executor:
-      future = executor.submit(acquire_model_with_timeout)
+    with ThreadPoolExecutor(max_workers=1000) as executor:
+      futures = [
+          executor.submit(acquire_model_with_timeout) for i in range(1000)
+      ]
       with self.assertRaises(RuntimeError) as context:
-        future.result(timeout=5.0)
+        for future in futures:
+          future.result()
       self.assertIn("Timeout waiting to acquire model", str(context.exception))
 
+    # Release the initially acquired model and try to acquire again
+    # to make sure the manager is still functional
+    self.manager.release_model(model_name, model_name)
+    _ = self.manager.acquire_model(model_name, loader)
+
   def test_model_manager_capacity_check(self):
     """
     Test that the manager blocks when spawning models exceeds the limit,

Reply via email to