This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a0d27de943f Update queuing logic to avoid lock notify contention
(#37528)
a0d27de943f is described below
commit a0d27de943fba84d04c82e83482ab03090995c00
Author: RuiLong J. <[email protected]>
AuthorDate: Sun Feb 8 18:37:08 2026 -0800
Update queuing logic to avoid lock notify contention (#37528)
* Draft of updated lock notify
* Complete queue ticket implementation
* Remove redudant warning log
---
.../apache_beam/ml/inference/model_manager.py | 66 +++++++++++++++-------
.../apache_beam/ml/inference/model_manager_test.py | 14 ++++-
2 files changed, 56 insertions(+), 24 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py
b/sdks/python/apache_beam/ml/inference/model_manager.py
index cc9f833c268..bf7c6a43ba6 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -288,6 +288,17 @@ class ResourceEstimator:
logger.error("Solver failed: %s", e)
+class QueueTicket:
+ def __init__(self, priority, ticket_num, tag):
+ self.priority = priority
+ self.ticket_num = ticket_num
+ self.tag = tag
+ self.wake_event = threading.Event()
+
+ def __lt__(self, other):
+ return (self.priority, self.ticket_num) < (other.priority,
other.ticket_num)
+
+
class ModelManager:
"""Manages model lifecycles, caching, and resource arbitration.
@@ -343,6 +354,7 @@ class ModelManager:
# and also priority for unknown models.
self._wait_queue = []
self._ticket_counter = itertools.count()
+ self._cancelled_tickets = set()
# TODO: Consider making the wait to be smarter, i.e.
# splitting read/write etc. to avoid potential contention.
self._cv = threading.Condition()
@@ -417,10 +429,28 @@ class ModelManager:
self._cv.wait(timeout=self._lock_timeout_seconds)
return False
+ def _wake_next_in_queue(self):
+ if self._wait_queue:
+ # Clean up cancelled tickets at head of queue
+ while self._wait_queue and self._wait_queue[
+ 0].ticket_num in self._cancelled_tickets:
+ self._cancelled_tickets.remove(self._wait_queue[0].ticket_num)
+ heapq.heappop(self._wait_queue)
+ next_inline = self._wait_queue[0]
+ next_inline.wake_event.set()
+
+ def _wait_in_queue(self, ticket: QueueTicket):
+ self._cv.release()
+ try:
+ ticket.wake_event.wait(timeout=self._lock_timeout_seconds)
+ ticket.wake_event.clear()
+ finally:
+ self._cv.acquire()
+
def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
current_priority = 0 if self._estimator.is_unknown(tag) else 1
ticket_num = next(self._ticket_counter)
- my_id = object()
+ my_ticket = QueueTicket(current_priority, ticket_num, tag)
with self._cv:
# FAST PATH: Grab from idle LRU if available
@@ -439,8 +469,7 @@ class ModelManager:
current_priority,
len(self._models[tag]),
ticket_num)
- heapq.heappush(
- self._wait_queue, (current_priority, ticket_num, my_id, tag))
+ heapq.heappush(self._wait_queue, my_ticket)
est_cost = 0.0
is_unknown = False
@@ -453,10 +482,11 @@ class ModelManager:
raise RuntimeError(
f"Timeout waiting to acquire model: {tag} "
f"after {wait_time_elapsed:.1f} seconds.")
- if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+ if not self._wait_queue or self._wait_queue[
+ 0].ticket_num != ticket_num:
logger.info(
"Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
- self._cv.wait(timeout=self._lock_timeout_seconds)
+ self._wait_in_queue(my_ticket)
continue
# Re-evaluate priority in case model became known during wait
@@ -467,9 +497,9 @@ class ModelManager:
if current_priority != real_priority:
heapq.heappop(self._wait_queue)
current_priority = real_priority
- heapq.heappush(
- self._wait_queue, (current_priority, ticket_num, my_id, tag))
- self._cv.notify_all()
+ my_ticket = QueueTicket(current_priority, ticket_num, tag)
+ heapq.heappush(self._wait_queue, my_ticket)
+ self._wake_next_in_queue()
continue
# Try grab from LRU again in case model was released during wait
@@ -494,7 +524,7 @@ class ModelManager:
"Waiting due to isolation in progress: tag=%s ticket num%s",
tag,
ticket_num)
- self._cv.wait(timeout=self._lock_timeout_seconds)
+ self._wait_in_queue(my_ticket)
continue
if self.should_spawn_model(tag, ticket_num):
@@ -508,19 +538,12 @@ class ModelManager:
finally:
# Remove self from wait queue once done
- if self._wait_queue and self._wait_queue[0][2] is my_id:
+ if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num:
heapq.heappop(self._wait_queue)
else:
- logger.warning(
- "Item not at head of wait queue during cleanup"
- ", this is not expected: tag=%s ticket num=%s",
- tag,
- ticket_num)
- for i, item in enumerate(self._wait_queue):
- if item[2] is my_id:
- self._wait_queue.pop(i)
- heapq.heapify(self._wait_queue)
- self._cv.notify_all()
+ # Marked as cancelled so that we skip when we reach head later
+ self._cancelled_tickets.add(ticket_num)
+ self._wake_next_in_queue()
return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
@@ -553,6 +576,7 @@ class ModelManager:
self._estimator.add_observation(snapshot, peak_during_job)
finally:
+ self._wake_next_in_queue()
self._cv.notify_all()
def _try_grab_from_lru(self, tag: str) -> Any:
@@ -596,7 +620,7 @@ class ModelManager:
# TODO: Also factor in the active counts to avoid thrashing
demand_map = Counter()
for item in self._wait_queue:
- demand_map[item[3]] += 1
+ demand_map[item.tag] += 1
my_demand = demand_map[requesting_tag]
am_i_starving = len(self._models[requesting_tag]) == 0
diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py
b/sdks/python/apache_beam/ml/inference/model_manager_test.py
index 7cfb73cb668..270401857e0 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager_test.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py
@@ -174,12 +174,20 @@ class TestModelManager(unittest.TestCase):
def acquire_model_with_timeout():
return self.manager.acquire_model(model_name, loader)
- with ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(acquire_model_with_timeout)
+ with ThreadPoolExecutor(max_workers=1000) as executor:
+ futures = [
+ executor.submit(acquire_model_with_timeout) for i in range(1000)
+ ]
with self.assertRaises(RuntimeError) as context:
- future.result(timeout=5.0)
+ for future in futures:
+ future.result()
self.assertIn("Timeout waiting to acquire model", str(context.exception))
+ # Release the initially acquired model and try to acquire again
+ # to make sure the manager is still functional
+ self.manager.release_model(model_name, model_name)
+ _ = self.manager.acquire_model(model_name, loader)
+
def test_model_manager_capacity_check(self):
"""
Test that the manager blocks when spawning models exceeds the limit,