This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4e218f0183e Add model manager that automatically manage model across 
processes (#37113)
4e218f0183e is described below

commit 4e218f0183e5f0c69d98808a019b5f4940b24402
Author: RuiLong J. <[email protected]>
AuthorDate: Wed Feb 4 11:41:49 2026 -0800

    Add model manager that automatically manage model across processes (#37113)
    
    * Add model manager that automatically manage model across processes
    
    * Add pydoc and move gpu detection to start
    
    * Add comments and helper function to make it easier to understand the code 
and cleanup some code logics
    
    * Add TODO for threading
    
    * Remove tracked model proxy and have model manager store tags instead of 
model instance
    
    * Fix import order
    
    * Clean up and logs
    
    * Added timeout for waiting too long on model acquire
    
    * Throw error if timeout
    
    * Add test for timeout and adjust
    
    * Update sdks/python/apache_beam/ml/inference/model_manager.py
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Update sdks/python/apache_beam/ml/inference/model_manager.py
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Gemini clean up
    
    * Update sdks/python/apache_beam/ml/inference/model_manager_test.py
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Update GPU monitor test
    
    * Format
    
    * Cleanup upating is_unkown logic
    
    * Try to fix flake
    
    * Fix import order
    
    * Fix random seed to avoid flake
    
    * Fix identation
    
    * Try fixing doc again
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 .../apache_beam/ml/inference/model_manager.py      | 747 +++++++++++++++++++++
 .../apache_beam/ml/inference/model_manager_test.py | 622 +++++++++++++++++
 2 files changed, 1369 insertions(+)

diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py 
b/sdks/python/apache_beam/ml/inference/model_manager.py
new file mode 100644
index 00000000000..cc9f833c268
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -0,0 +1,747 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.free",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception as e:
+      logger.warning('Failed to get GPU memory usage: %s', e)
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+     pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  def __init__(
+      self,
+      monitor: Optional['GPUMonitor'] = None,
+      slack_percentage: float = 0.10,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0,
+      min_data_points: int = 5,
+      smoothing_factor: float = 0.2,
+      eviction_cooldown_seconds: float = 10.0,
+      min_model_copies: int = 1,
+      wait_timeout_seconds: float = 300.0,
+      lock_timeout_seconds: float = 60.0):
+
+    self._estimator = ResourceEstimator(
+        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+    self._monitor = monitor if monitor else GPUMonitor(
+        poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+    self._slack_percentage = slack_percentage
+
+    self._eviction_cooldown = eviction_cooldown_seconds
+    self._min_model_copies = min_model_copies
+    self._wait_timeout_seconds = wait_timeout_seconds
+    self._lock_timeout_seconds = lock_timeout_seconds
+
+    # Resource State
+    self._models = defaultdict(list)
+    # Idle LRU used to track released models that
+    # can be freed or reused upon request.
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+
+    # Isolation state used to profile unknown models,
+    # ensuring they run alone to get accurate readings.
+    # isolation_baseline represents the GPU usage before
+    # loading the unknown model.
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+    # Waiting Queue and Ticketing to make sure we have fair ordering
+    # and also priority for unknown models.
+    self._wait_queue = []
+    self._ticket_counter = itertools.count()
+    # TODO: Consider making the wait to be smarter, i.e.
+    # splitting read/write etc. to avoid potential contention.
+    self._cv = threading.Condition()
+
+    self._monitor.start()
+
+  def all_models(self, tag) -> list[Any]:
+    return self._models[tag]
+
+  # Should hold _cv lock when calling
+  def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+    if self._total_active_jobs > 0:
+      logger.info(
+          "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+      self._cv.wait(timeout=self._lock_timeout_seconds)
+      # return False since we have waited and need to re-evaluate
+      # in caller to make sure our priority is still valid.
+      return False
+
+    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self._delete_all_models()
+
+    self._isolation_mode = True
+    self._total_active_jobs += 1
+    self._isolation_baseline, _, _ = self._monitor.get_stats()
+    self._monitor.reset_peak()
+    return True
+
+  # Should hold _cv lock when calling
+  def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+    curr, _, total = self._monitor.get_stats()
+    est_cost = self._estimator.get_estimate(tag)
+    limit = total * (1 - self._slack_percentage)
+
+    # Use current usage for capacity check (ignore old spikes)
+    if (curr + self._pending_reservations + est_cost) <= limit:
+      self._pending_reservations += est_cost
+      self._total_active_jobs += 1
+      self._active_counts[tag] += 1
+      return True
+
+    # Evict to make space (passing tag to check demand/existence)
+    if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+      return True
+
+    # Manually log status for debugging if we are going to wait
+    idle_count = 0
+    other_idle_count = 0
+    for item in self._idle_lru.items():
+      if item[1][0] == tag:
+        idle_count += 1
+      else:
+        other_idle_count += 1
+    total_model_count = 0
+    for _, instances in self._models.items():
+      total_model_count += len(instances)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info(
+        "Waiting for resources to free up: "
+        "tag=%s ticket num%s model count=%s "
+        "idle count=%s resource usage=%.1f MB "
+        "total models count=%s other idle=%s",
+        tag,
+        ticket_num,
+        len(self._models[tag]),
+        idle_count,
+        curr,
+        total_model_count,
+        other_idle_count)
+    # Wait since we couldn't make space and
+    # added timeout to avoid missed notify call.
+    self._cv.wait(timeout=self._lock_timeout_seconds)
+    return False
+
+  def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+    current_priority = 0 if self._estimator.is_unknown(tag) else 1
+    ticket_num = next(self._ticket_counter)
+    my_id = object()
+
+    with self._cv:
+      # FAST PATH: Grab from idle LRU if available
+      if not self._isolation_mode:
+        cached_instance = self._try_grab_from_lru(tag)
+        if cached_instance:
+          return cached_instance
+
+      # SLOW PATH: Enqueue and wait for turn to acquire model,
+      # with unknown models having priority and order enforced
+      # by ticket number as FIFO.
+      logger.info(
+          "Acquire Queued: tag=%s, priority=%d "
+          "total models count=%s ticket num=%s",
+          tag,
+          current_priority,
+          len(self._models[tag]),
+          ticket_num)
+      heapq.heappush(
+          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+      est_cost = 0.0
+      is_unknown = False
+      wait_time_start = time.time()
+
+      try:
+        while True:
+          wait_time_elapsed = time.time() - wait_time_start
+          if wait_time_elapsed > self._wait_timeout_seconds:
+            raise RuntimeError(
+                f"Timeout waiting to acquire model: {tag} "
+                f"after {wait_time_elapsed:.1f} seconds.")
+          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+            logger.info(
+                "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+            self._cv.wait(timeout=self._lock_timeout_seconds)
+            continue
+
+          # Re-evaluate priority in case model became known during wait
+          is_unknown = self._estimator.is_unknown(tag)
+          real_priority = 0 if is_unknown else 1
+
+          # If priority changed, reinsert into queue and wait
+          if current_priority != real_priority:
+            heapq.heappop(self._wait_queue)
+            current_priority = real_priority
+            heapq.heappush(
+                self._wait_queue, (current_priority, ticket_num, my_id, tag))
+            self._cv.notify_all()
+            continue
+
+          # Try grab from LRU again in case model was released during wait
+          cached_instance = self._try_grab_from_lru(tag)
+          if cached_instance:
+            return cached_instance
+
+          # Path A: Isolation
+          if is_unknown:
+            if self.try_enter_isolation_mode(tag, ticket_num):
+              # We got isolation, can proceed to spawn
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+          # Path B: Concurrent
+          else:
+            if self._isolation_mode:
+              logger.info(
+                  "Waiting due to isolation in progress: tag=%s ticket num%s",
+                  tag,
+                  ticket_num)
+              self._cv.wait(timeout=self._lock_timeout_seconds)
+              continue
+
+            if self.should_spawn_model(tag, ticket_num):
+              est_cost = self._estimator.get_estimate(tag)
+              # We can proceed to spawn since we have resources
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+      finally:
+        # Remove self from wait queue once done
+        if self._wait_queue and self._wait_queue[0][2] is my_id:
+          heapq.heappop(self._wait_queue)
+        else:
+          logger.warning(
+              "Item not at head of wait queue during cleanup"
+              ", this is not expected: tag=%s ticket num=%s",
+              tag,
+              ticket_num)
+          for i, item in enumerate(self._wait_queue):
+            if item[2] is my_id:
+              self._wait_queue.pop(i)
+              heapq.heapify(self._wait_queue)
+        self._cv.notify_all()
+
+      return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
+
+  def release_model(self, tag: str, instance: Any):
+    with self._cv:
+      try:
+        self._total_active_jobs -= 1
+        if self._active_counts[tag] > 0:
+          self._active_counts[tag] -= 1
+
+        self._idle_lru[id(instance)] = (tag, instance, time.time())
+
+        # Update estimator with latest stats
+        _, peak_during_job, _ = self._monitor.get_stats()
+
+        if self._isolation_mode and self._active_counts[tag] == 0:
+          # For isolation mode, we directly set the initial estimate
+          # so that we can quickly learn the model cost.
+          cost = max(0, peak_during_job - self._isolation_baseline)
+          self._estimator.set_initial_estimate(tag, cost)
+          self._isolation_mode = False
+          self._isolation_baseline = 0.0
+        else:
+          # Regular update for known models
+          snapshot = {
+              t: len(instances)
+              for t, instances in self._models.items() if len(instances) > 0
+          }
+          if snapshot:
+            self._estimator.add_observation(snapshot, peak_during_job)
+
+      finally:
+        self._cv.notify_all()
+
+  def _try_grab_from_lru(self, tag: str) -> Any:
+    target_key = None
+    target_instance = None
+
+    for key, (t, instance, _) in reversed(self._idle_lru.items()):
+      if t == tag:
+        target_key = key
+        target_instance = instance
+        break
+
+    if target_instance:
+      # Found an idle model, remove from LRU and return
+      del self._idle_lru[target_key]
+      self._active_counts[tag] += 1
+      self._total_active_jobs += 1
+      return target_instance
+
+    logger.info("No idle model found for tag: %s", tag)
+    return None
+
+  def _evict_to_make_space(
+      self, limit: float, est_cost: float, requesting_tag: str) -> bool:
+    """
+    Evicts models based on Demand Magnitude + Tiers.
+    Crucially: If we have 0 active copies of 'requesting_tag', we FORCE 
eviction
+    of the lowest-demand candidate to avoid starvation.
+    Returns True if space was made, False otherwise.
+    """
+    curr, _, _ = self._monitor.get_stats()
+    projected_usage = curr + self._pending_reservations + est_cost
+
+    if projected_usage <= limit:
+      # Memory usage changed and we are already under limit
+      return True
+
+    now = time.time()
+
+    # Calculate the demand from the wait queue
+    # TODO: Also factor in the active counts to avoid thrashing
+    demand_map = Counter()
+    for item in self._wait_queue:
+      demand_map[item[3]] += 1
+
+    my_demand = demand_map[requesting_tag]
+    am_i_starving = len(self._models[requesting_tag]) == 0
+
+    candidates = []
+    for key, (tag, instance, release_time) in self._idle_lru.items():
+      candidate_demand = demand_map[tag]
+
+      # TODO: Try to avoid churn if demand is similar
+      if not am_i_starving and candidate_demand >= my_demand:
+        continue
+
+      # Attempts to score candidates based on hotness and manually
+      # specified minimum copies. Demand is weighted heavily to
+      # ensure we evict low-demand models first.
+      age = now - release_time
+      is_cold = age >= self._eviction_cooldown
+
+      total_copies = len(self._models[tag])
+      is_surplus = total_copies > self._min_model_copies
+
+      if is_cold and is_surplus: tier = 0
+      elif not is_cold and is_surplus: tier = 1
+      elif is_cold and not is_surplus: tier = 2
+      else: tier = 3
+
+      score = (candidate_demand * 10) + tier
+
+      candidates.append((score, release_time, key, tag, instance))
+
+    candidates.sort(key=lambda x: (x[0], x[1]))
+
+    # Evict candidates until we are under limit
+    for score, _, key, tag, instance in candidates:
+      if projected_usage <= limit:
+        break
+
+      if key not in self._idle_lru: continue
+
+      self._perform_eviction(key, tag, instance, score)
+
+      curr, _, _ = self._monitor.get_stats()
+      projected_usage = curr + self._pending_reservations + est_cost
+
+    return projected_usage <= limit
+
+  def _delete_instance(self, instance: Any):
+    if isinstance(instance, str):
+      # If the instance is a string, it's a uuid used
+      # to retrieve the model from MultiProcessShared
+      multi_process_shared.MultiProcessShared(
+          lambda: "N/A", tag=instance).unsafe_hard_delete()
+    if hasattr(instance, 'mock_model_unsafe_hard_delete'):
+      # Call the mock unsafe hard delete method for testing
+      instance.mock_model_unsafe_hard_delete()
+    del instance
+
+  def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
+    logger.info("Evicting Model: %s (Score %d)", tag, score)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info("Resource Usage Before Eviction: %.1f MB", curr)
+
+    if key in self._idle_lru:
+      del self._idle_lru[key]
+
+    for i, inst in enumerate(self._models[tag]):
+      if instance == inst:
+        del self._models[tag][i]
+        break
+
+    self._delete_instance(instance)
+    gc.collect()
+    torch.cuda.empty_cache()
+    self._monitor.refresh()
+    self._monitor.reset_peak()
+    curr, _, _ = self._monitor.get_stats()
+    logger.info("Resource Usage After Eviction: %.1f MB", curr)
+
+  def _spawn_new_model(
+      self,
+      tag: str,
+      loader_func: Callable[[], Any],
+      is_unknown: bool,
+      est_cost: float) -> Any:
+    try:
+      with self._cv:
+        logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
+        baseline_snap, _, _ = self._monitor.get_stats()
+        instance = loader_func()
+        _, peak_during_load, _ = self._monitor.get_stats()
+
+        snapshot = {tag: 1}
+        self._estimator.add_observation(
+            snapshot, peak_during_load - baseline_snap)
+
+        if not is_unknown:
+          self._pending_reservations = max(
+              0.0, self._pending_reservations - est_cost)
+        self._models[tag].append(instance)
+      return instance
+
+    except Exception as e:
+      logger.error("Load Failed: %s. Error: %s", tag, e)
+      with self._cv:
+        self._total_active_jobs -= 1
+        if is_unknown:
+          self._isolation_mode = False
+          self._isolation_baseline = 0.0
+        else:
+          self._pending_reservations = max(
+              0.0, self._pending_reservations - est_cost)
+          self._active_counts[tag] -= 1
+        self._cv.notify_all()
+      raise e
+
+  def _delete_all_models(self):
+    self._idle_lru.clear()
+    for _, instances in self._models.items():
+      for instance in instances:
+        self._delete_instance(instance)
+    self._models.clear()
+    self._active_counts.clear()
+    gc.collect()
+    torch.cuda.empty_cache()
+    self._monitor.refresh()
+    self._monitor.reset_peak()
+
+  def _force_reset(self):
+    logger.warning("Force Reset Triggered")
+    self._delete_all_models()
+    self._models = defaultdict(list)
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._wait_queue = []
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+  def shutdown(self):
+    self._delete_all_models()
+    self._monitor.stop()
+
+  def __del__(self):
+    self.shutdown()
+
+  def __exit__(self, exc_type, exc_value, traceback):
+    self.shutdown()
diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py 
b/sdks/python/apache_beam/ml/inference/model_manager_test.py
new file mode 100644
index 00000000000..1bd8edd34d1
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py
@@ -0,0 +1,622 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import threading
+import time
+import unittest
+from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import TimeoutError
+from unittest.mock import patch
+
+from apache_beam.utils import multi_process_shared
+
+try:
+  from apache_beam.ml.inference.model_manager import GPUMonitor
+  from apache_beam.ml.inference.model_manager import ModelManager
+  from apache_beam.ml.inference.model_manager import ResourceEstimator
+except ImportError as e:
+  raise unittest.SkipTest("Model Manager dependencies are not installed")
+
+
+class MockGPUMonitor:
+  """
+  Simulates GPU hardware with cumulative memory tracking.
+  Allows simulating specific allocation spikes and baseline usage.
+  """
+  def __init__(self, total_memory=12000.0, peak_window: int = 5):
+    self._current = 0.0
+    self._peak = 0.0
+    self._total = total_memory
+    self._lock = threading.Lock()
+    self.running = False
+    self.history = []
+    self.peak_window = peak_window
+
+  def start(self):
+    self.running = True
+
+  def stop(self):
+    self.running = False
+
+  def get_stats(self):
+    with self._lock:
+      return self._current, self._peak, self._total
+
+  def reset_peak(self):
+    with self._lock:
+      self._peak = self._current
+      self.history = [self._current]
+
+  def set_usage(self, current_mb):
+    """Sets absolute usage (legacy helper)."""
+    with self._lock:
+      self._current = current_mb
+      self._peak = max(self._peak, current_mb)
+
+  def allocate(self, amount_mb):
+    """Simulates memory allocation (e.g., tensors loaded to VRAM)."""
+    with self._lock:
+      self._current += amount_mb
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def free(self, amount_mb):
+    """Simulates memory freeing (not used often if pooling is active)."""
+    with self._lock:
+      self._current = max(0.0, self._current - amount_mb)
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def refresh(self):
+    """Simulates a refresh of the monitor stats (no-op for mock)."""
+    pass
+
+
+class MockModel:
+  def __init__(self, name, size, monitor):
+    self.name = name
+    self.size = size
+    self.monitor = monitor
+    self.deleted = False
+    self.monitor.allocate(size)
+
+  def mock_model_unsafe_hard_delete(self):
+    if not self.deleted:
+      self.monitor.free(self.size)
+      self.deleted = True
+
+
+class Counter(object):
+  def __init__(self, start=0):
+    self.running = start
+    self.lock = threading.Lock()
+
+  def get(self):
+    return self.running
+
+  def increment(self, value=1):
+    with self.lock:
+      self.running += value
+      return self.running
+
+
+class TestModelManager(unittest.TestCase):
+  def setUp(self):
+    """Force reset the Singleton ModelManager before every test."""
+    self.mock_monitor = MockGPUMonitor()
+    self.manager = ModelManager(monitor=self.mock_monitor)
+
+  def tearDown(self):
+    self.manager.shutdown()
+
+  def test_model_manager_deletes_multiprocessshared_instances(self):
+    """Test that MultiProcessShared instances are deleted properly."""
+    model_name = "test_model_shared"
+    tag = f"model_manager_test_{model_name}"
+
+    def loader():
+      multi_process_shared.MultiProcessShared(
+          lambda: Counter, tag=tag, always_proxy=True)
+      return tag
+
+    instance = self.manager.acquire_model(model_name, loader)
+    instance_before = multi_process_shared.MultiProcessShared(
+        Counter, tag=tag, always_proxy=True).acquire()
+    instance_before.increment()
+    self.assertEqual(instance_before.get(), 1)
+    self.manager.release_model(model_name, instance)
+
+    # Force delete all models
+    self.manager._force_reset()
+
+    # Verify that the MultiProcessShared instance is deleted
+    # and the counter is reseted
+    with self.assertRaises(Exception):
+      instance_before.get()
+    instance_after = multi_process_shared.MultiProcessShared(
+        Counter, tag=tag, always_proxy=True).acquire()
+    self.assertEqual(instance_after.get(), 0)
+
+  def test_model_manager_timeout_on_acquire(self):
+    """Test that acquiring a model times out properly."""
+    model_name = "timeout_model"
+    self.manager = ModelManager(
+        monitor=self.mock_monitor,
+        wait_timeout_seconds=1.0,
+        lock_timeout_seconds=1.0)
+
+    def loader():
+      self.mock_monitor.allocate(self.mock_monitor._total)
+      return model_name
+
+    # Acquire the model in one thread to block others
+    _ = self.manager.acquire_model(model_name, loader)
+
+    def acquire_model_with_timeout():
+      return self.manager.acquire_model(model_name, loader)
+
+    with ThreadPoolExecutor(max_workers=1) as executor:
+      future = executor.submit(acquire_model_with_timeout)
+      with self.assertRaises(RuntimeError) as context:
+        future.result(timeout=5.0)
+      self.assertIn("Timeout waiting to acquire model", str(context.exception))
+
+  def test_model_manager_capacity_check(self):
+    """
+    Test that the manager blocks when spawning models exceeds the limit,
+    and unblocks when resources become available (via reuse).
+    """
+    model_name = "known_model"
+    model_cost = 3000.0
+    self.manager._estimator.set_initial_estimate(model_name, model_cost)
+    acquired_refs = []
+
+    def loader():
+      self.mock_monitor.allocate(model_cost)
+      return model_name
+
+    # 1. Saturate GPU with 3 models (9000 MB usage)
+    for _ in range(3):
+      inst = self.manager.acquire_model(model_name, loader)
+      acquired_refs.append(inst)
+
+    # 2. Spawn one more (Should Block because 9000 + 3000 > Limit)
+    def run_inference():
+      return self.manager.acquire_model(model_name, loader)
+
+    with ThreadPoolExecutor(max_workers=1) as executor:
+      future = executor.submit(run_inference)
+      try:
+        future.result(timeout=0.5)
+        self.fail("Should have blocked due to capacity")
+      except TimeoutError:
+        pass
+
+      # 3. Release resources to unblock
+      item_to_release = acquired_refs.pop()
+      self.manager.release_model(model_name, item_to_release)
+
+      result = future.result(timeout=2.0)
+      self.assertIsNotNone(result)
+      self.assertEqual(result, item_to_release)
+
+  def test_model_manager_unknown_model_runs_isolated(self):
+    """Test that a model with no history runs in isolation."""
+    model_name = "unknown_model_v1"
+    self.assertTrue(self.manager._estimator.is_unknown(model_name))
+
+    def dummy_loader():
+      time.sleep(0.05)
+      return "model_instance"
+
+    instance = self.manager.acquire_model(model_name, dummy_loader)
+
+    self.assertTrue(self.manager._isolation_mode)
+    self.assertEqual(self.manager._total_active_jobs, 1)
+
+    self.manager.release_model(model_name, instance)
+    self.assertFalse(self.manager._isolation_mode)
+    self.assertFalse(self.manager._estimator.is_unknown(model_name))
+
+  def test_model_manager_concurrent_execution(self):
+    """Test that multiple small known models can run together."""
+    model_a = "small_model_a"
+    model_b = "small_model_b"
+
+    self.manager._estimator.set_initial_estimate(model_a, 1000.0)
+    self.manager._estimator.set_initial_estimate(model_b, 1000.0)
+    self.mock_monitor.set_usage(1000.0)
+
+    inst_a = self.manager.acquire_model(model_a, lambda: "A")
+    inst_b = self.manager.acquire_model(model_b, lambda: "B")
+
+    self.assertEqual(self.manager._total_active_jobs, 2)
+
+    self.manager.release_model(model_a, inst_a)
+    self.manager.release_model(model_b, inst_b)
+    self.assertEqual(self.manager._total_active_jobs, 0)
+
+  def test_model_manager_concurrent_mixed_workload_convergence(self):
+    """
+    Simulates a production environment with multiple model types running
+    concurrently. Verifies that the estimator converges.
+    """
+    TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0}
+
+    def run_job(model_name):
+      cost = TRUE_COSTS[model_name]
+
+      def loader():
+        model = MockModel(model_name, cost, self.mock_monitor)
+        return model
+
+      instance = self.manager.acquire_model(model_name, loader)
+      time.sleep(random.uniform(0.01, 0.05))
+      self.manager.release_model(model_name, instance)
+
+    # Create a workload stream
+    workload = ["model_small"] * 15 + ["model_medium"] * 15
+    random.shuffle(workload)
+
+    with ThreadPoolExecutor(max_workers=8) as executor:
+      futures = [executor.submit(run_job, name) for name in workload]
+      for f in futures:
+        f.result()
+
+    est_small = self.manager._estimator.get_estimate("model_small")
+    est_med = self.manager._estimator.get_estimate("model_medium")
+
+    self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0)
+    self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0)
+
+  def test_model_manager_oom_recovery(self):
+    """Test that the manager recovers state if a loader crashes."""
+    model_name = "crasher_model"
+    self.manager._estimator.set_initial_estimate(model_name, 1000.0)
+
+    def crashing_loader():
+      raise RuntimeError("CUDA OOM or similar")
+
+    with self.assertRaises(RuntimeError):
+      self.manager.acquire_model(model_name, crashing_loader)
+
+    self.assertEqual(self.manager._total_active_jobs, 0)
+    self.assertEqual(self.manager._pending_reservations, 0.0)
+    self.assertFalse(self.manager._cv._is_owned())
+
+  def test_model_manager_force_reset_on_exception(self):
+    """Test that force_reset clears all models from the manager."""
+    model_name = "test_model"
+
+    def dummy_loader():
+      self.mock_monitor.allocate(1000.0)
+      raise RuntimeError("Simulated loader exception")
+
+    try:
+      instance = self.manager.acquire_model(
+          model_name, lambda: "model_instance")
+      self.manager.release_model(model_name, instance)
+      instance = self.manager.acquire_model(model_name, dummy_loader)
+    except RuntimeError:
+      self.manager._force_reset()
+      self.assertTrue(len(self.manager._models[model_name]) == 0)
+      self.assertEqual(self.manager._total_active_jobs, 0)
+      self.assertEqual(self.manager._pending_reservations, 0.0)
+      self.assertFalse(self.manager._isolation_mode)
+      pass
+
+    instance = self.manager.acquire_model(model_name, lambda: "model_instance")
+    self.manager.release_model(model_name, instance)
+
+  def test_single_model_convergence_with_fluctuations(self):
+    """
+    Tests that the estimator converges to the true usage with fluctuations.
+    """
+    model_name = "fluctuating_model"
+    model_cost = 3000.0
+    load_cost = 2500.0
+    # Fix random seed for reproducibility
+    random.seed(42)
+
+    def loader():
+      self.mock_monitor.allocate(load_cost)
+      return model_name
+
+    model = self.manager.acquire_model(model_name, loader)
+    self.manager.release_model(model_name, model)
+    initial_est = self.manager._estimator.get_estimate(model_name)
+    self.assertEqual(initial_est, load_cost)
+
+    def run_inference():
+      model = self.manager.acquire_model(model_name, loader)
+      noise = model_cost - load_cost + random.uniform(-300.0, 300.0)
+      self.mock_monitor.allocate(noise)
+      time.sleep(0.1)
+      self.mock_monitor.free(noise)
+      self.manager.release_model(model_name, model)
+      return
+
+    with ThreadPoolExecutor(max_workers=8) as executor:
+      futures = [executor.submit(run_inference) for _ in range(100)]
+
+    for f in futures:
+      f.result()
+
+    est_cost = self.manager._estimator.get_estimate(model_name)
+    self.assertAlmostEqual(est_cost, model_cost, delta=100.0)
+
+
+class TestModelManagerEviction(unittest.TestCase):
+  def setUp(self):
+    self.mock_monitor = MockGPUMonitor(total_memory=12000.0)
+    self.manager = ModelManager(
+        monitor=self.mock_monitor,
+        slack_percentage=0.0,
+        min_data_points=1,
+        eviction_cooldown_seconds=10.0,
+        min_model_copies=1)
+
+  def tearDown(self):
+    self.manager.shutdown()
+
+  def create_loader(self, name, size):
+    return lambda: MockModel(name, size, self.mock_monitor)
+
+  def test_basic_lru_eviction(self):
+    self.manager._estimator.set_initial_estimate("A", 4000)
+    self.manager._estimator.set_initial_estimate("B", 4000)
+    self.manager._estimator.set_initial_estimate("C", 5000)
+
+    model_a = self.manager.acquire_model("A", self.create_loader("A", 4000))
+    self.manager.release_model("A", model_a)
+
+    model_b = self.manager.acquire_model("B", self.create_loader("B", 4000))
+    self.manager.release_model("B", model_b)
+
+    key_a = list(self.manager._idle_lru.keys())[0]
+    self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0)
+
+    key_b = list(self.manager._idle_lru.keys())[1]
+    self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0)
+
+    model_a_again = self.manager.acquire_model(
+        "A", self.create_loader("A", 4000))
+    self.manager.release_model("A", model_a_again)
+
+    self.manager.acquire_model("C", self.create_loader("C", 5000))
+
+    self.assertEqual(len(self.manager.all_models("B")), 0)
+    self.assertEqual(len(self.manager.all_models("A")), 1)
+
+  def test_chained_eviction(self):
+    self.manager._estimator.set_initial_estimate("big_guy", 8000)
+    models = []
+    for i in range(4):
+      name = f"small_{i}"
+      m = self.manager.acquire_model(name, self.create_loader(name, 3000))
+      self.manager.release_model(name, m)
+      models.append(m)
+
+    self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000))
+
+    self.assertTrue(models[0].deleted)
+    self.assertTrue(models[1].deleted)
+    self.assertTrue(models[2].deleted)
+    self.assertFalse(models[3].deleted)
+
+  def test_active_models_are_protected(self):
+    self.manager._estimator.set_initial_estimate("A", 6000)
+    self.manager._estimator.set_initial_estimate("B", 4000)
+    self.manager._estimator.set_initial_estimate("C", 4000)
+
+    model_a = self.manager.acquire_model("A", self.create_loader("A", 6000))
+    model_b = self.manager.acquire_model("B", self.create_loader("B", 4000))
+    self.manager.release_model("B", model_b)
+
+    key_b = list(self.manager._idle_lru.keys())[0]
+    self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0)
+
+    def acquire_c():
+      return self.manager.acquire_model("C", self.create_loader("C", 4000))
+
+    with ThreadPoolExecutor(max_workers=1) as executor:
+      future = executor.submit(acquire_c)
+      model_c = future.result(timeout=2.0)
+
+    self.assertTrue(model_b.deleted)
+    self.assertFalse(model_a.deleted)
+
+    self.manager.release_model("A", model_a)
+    self.manager.release_model("C", model_c)
+
+  def test_unknown_model_clears_memory(self):
+    self.manager._estimator.set_initial_estimate("A", 2000)
+    model_a = self.manager.acquire_model("A", self.create_loader("A", 2000))
+    self.manager.release_model("A", model_a)
+    self.assertFalse(model_a.deleted)
+
+    self.assertTrue(self.manager._estimator.is_unknown("X"))
+    model_x = self.manager.acquire_model("X", self.create_loader("X", 10000))
+
+    self.assertTrue(model_a.deleted, "Model A should be deleted for isolation")
+    self.assertEqual(len(self.manager.all_models("A")), 0)
+    self.assertTrue(self.manager._isolation_mode)
+    self.manager.release_model("X", model_x)
+
+  def test_concurrent_eviction_pressure(self):
+    def worker(idx):
+      name = f"model_{idx % 5}"
+      try:
+        m = self.manager.acquire_model(name, self.create_loader(name, 4000))
+        time.sleep(0.001)
+        self.manager.release_model(name, m)
+      except Exception:
+        pass
+
+    with ThreadPoolExecutor(max_workers=8) as executor:
+      futures = [executor.submit(worker, i) for i in range(50)]
+      for f in futures:
+        f.result()
+
+    curr, _, _ = self.mock_monitor.get_stats()
+    expected_usage = 0
+    for _, instances in self.manager._models.items():
+      expected_usage += len(instances) * 4000
+
+    self.assertAlmostEqual(curr, expected_usage)
+
+  def test_starvation_prevention_overrides_demand(self):
+    self.manager._estimator.set_initial_estimate("A", 12000)
+    m_a = self.manager.acquire_model("A", self.create_loader("A", 12000))
+    self.manager.release_model("A", m_a)
+
+    def cycle_a():
+      try:
+        m = self.manager.acquire_model("A", self.create_loader("A", 12000))
+        time.sleep(0.3)
+        self.manager.release_model("A", m)
+      except Exception:
+        pass
+
+    executor = ThreadPoolExecutor(max_workers=5)
+    for _ in range(5):
+      executor.submit(cycle_a)
+
+    def acquire_b():
+      return self.manager.acquire_model("B", self.create_loader("B", 4000))
+
+    b_future = executor.submit(acquire_b)
+    model_b = b_future.result()
+
+    self.assertTrue(m_a.deleted)
+    self.manager.release_model("B", model_b)
+    executor.shutdown(wait=True)
+
+
+class TestGPUMonitor(unittest.TestCase):
+  def setUp(self):
+    self.subprocess_patcher = patch('subprocess.check_output')
+    self.mock_subprocess = self.subprocess_patcher.start()
+
+  def tearDown(self):
+    self.subprocess_patcher.stop()
+
+  def test_init_hardware_detected(self):
+    """Test that init correctly reads total memory when nvidia-smi exists."""
+    self.mock_subprocess.return_value = "24576"
+    monitor = GPUMonitor()
+    monitor.start()
+    self.assertTrue(monitor._gpu_available)
+    self.assertEqual(monitor._total_memory, 24576.0)
+
+  def test_init_hardware_missing(self):
+    """Test fallback behavior when nvidia-smi is missing."""
+    self.mock_subprocess.side_effect = FileNotFoundError()
+    monitor = GPUMonitor(fallback_memory_mb=12000.0)
+    monitor.start()
+    self.assertFalse(monitor._gpu_available)
+    self.assertEqual(monitor._total_memory, 12000.0)
+
+  @patch('time.sleep')
+  def test_polling_updates_stats(self, mock_sleep):
+    """Test that the polling loop updates current and peak usage."""
+    def subprocess_side_effect(*args, **kwargs):
+      if isinstance(args[0], list) and "memory.total" in args[0][1]:
+        return "16000"
+
+      if isinstance(args[0], list) and any("memory.free" in part
+                                           for part in args[0]):
+        return "12000"
+
+      raise Exception("Unexpected command")
+
+    self.mock_subprocess.side_effect = subprocess_side_effect
+    self.mock_subprocess.return_value = None
+
+    monitor = GPUMonitor()
+    monitor.start()
+    time.sleep(0.1)
+    curr, peak, total = monitor.get_stats()
+    monitor.stop()
+
+    self.assertEqual(curr, 4000.0)
+    self.assertEqual(peak, 4000.0)
+    self.assertEqual(total, 16000.0)
+
+  def test_reset_peak(self):
+    """Test that resetting peak usage works."""
+    monitor = GPUMonitor()
+    monitor._gpu_available = True
+
+    with monitor._lock:
+      monitor._current_usage = 2000.0
+      monitor._peak_usage = 8000.0
+      monitor._memory_history.append((time.time(), 8000.0))
+      monitor._memory_history.append((time.time(), 2000.0))
+
+    monitor.reset_peak()
+
+    _, peak, _ = monitor.get_stats()
+    self.assertEqual(peak, 2000.0)
+
+
+class TestResourceEstimatorSolver(unittest.TestCase):
+  def setUp(self):
+    self.estimator = ResourceEstimator()
+
+  @patch('apache_beam.ml.inference.model_manager.nnls')
+  def test_solver_respects_min_data_points(self, mock_nnls):
+    mock_nnls.return_value = ([100.0, 50.0], 0.0)
+
+    self.estimator.add_observation({'model_A': 1}, 500)
+    self.estimator.add_observation({'model_B': 1}, 500)
+    self.assertFalse(mock_nnls.called)
+
+    self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000)
+    self.assertFalse(mock_nnls.called)
+
+    self.estimator.add_observation({'model_A': 1}, 500)
+    self.assertFalse(mock_nnls.called)
+
+    self.estimator.add_observation({'model_B': 1}, 500)
+    self.assertTrue(mock_nnls.called)
+
+  @patch('apache_beam.ml.inference.model_manager.nnls')
+  def test_solver_respects_unique_model_constraint(self, mock_nnls):
+    mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0)
+
+    for _ in range(5):
+      self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800)
+
+    for _ in range(5):
+      self.estimator.add_observation({'model_C': 1}, 400)
+
+    self.assertFalse(mock_nnls.called)
+
+    self.estimator.add_observation({'model_A': 1}, 300)
+    self.estimator.add_observation({'model_B': 1}, 300)
+
+    self.assertTrue(mock_nnls.called)
+
+
+if __name__ == "__main__":
+  unittest.main()

Reply via email to