This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4e218f0183e Add model manager that automatically manage model across
processes (#37113)
4e218f0183e is described below
commit 4e218f0183e5f0c69d98808a019b5f4940b24402
Author: RuiLong J. <[email protected]>
AuthorDate: Wed Feb 4 11:41:49 2026 -0800
Add model manager that automatically manage model across processes (#37113)
* Add model manager that automatically manage model across processes
* Add pydoc and move gpu detection to start
* Add comments and helper function to make it easier to understand the code
and cleanup some code logics
* Add TODO for threading
* Remove tracked model proxy and have model manager store tags instead of
model instance
* Fix import order
* Clean up and logs
* Added timeout for waiting too long on model acquire
* Throw error if timeout
* Add test for timeout and adjust
* Update sdks/python/apache_beam/ml/inference/model_manager.py
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Update sdks/python/apache_beam/ml/inference/model_manager.py
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Gemini clean up
* Update sdks/python/apache_beam/ml/inference/model_manager_test.py
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Update GPU monitor test
* Format
* Cleanup upating is_unkown logic
* Try to fix flake
* Fix import order
* Fix random seed to avoid flake
* Fix identation
* Try fixing doc again
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../apache_beam/ml/inference/model_manager.py | 747 +++++++++++++++++++++
.../apache_beam/ml/inference/model_manager_test.py | 622 +++++++++++++++++
2 files changed, 1369 insertions(+)
diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py
b/sdks/python/apache_beam/ml/inference/model_manager.py
new file mode 100644
index 00000000000..cc9f833c268
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -0,0 +1,747 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+ """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+ This class continuously polls GPU memory statistics to track current usage
+ and peak usage over a sliding time window. It serves as the source of truth
+ for the ModelManager's resource decisions.
+
+ Attributes:
+ fallback_memory_mb: Default total memory if hardware detection fails.
+ poll_interval: Seconds between memory checks.
+ peak_window_seconds: Duration to track peak memory usage.
+ """
+ def __init__(
+ self,
+ fallback_memory_mb: float = 16000.0,
+ poll_interval: float = 0.5,
+ peak_window_seconds: float = 30.0):
+ self._current_usage = 0.0
+ self._peak_usage = 0.0
+ self._total_memory = fallback_memory_mb
+ self._poll_interval = poll_interval
+ self._peak_window_seconds = peak_window_seconds
+ self._memory_history = deque()
+ self._running = False
+ self._thread = None
+ self._lock = threading.Lock()
+
+ def _detect_hardware(self):
+ try:
+ cmd = [
+ "nvidia-smi",
+ "--query-gpu=memory.total",
+ "--format=csv,noheader,nounits"
+ ]
+ output = subprocess.check_output(cmd, text=True).strip()
+ self._total_memory = float(output)
+ return True
+ except (FileNotFoundError, subprocess.CalledProcessError):
+ logger.warning(
+ "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+ self._total_memory)
+ return False
+ except Exception as e:
+ logger.warning(
+ "Error parsing nvidia-smi output: %s. "
+ "Defaulting total memory to %s MB",
+ e,
+ self._total_memory)
+ return False
+
+ def start(self):
+ self._gpu_available = self._detect_hardware()
+ if self._running or not self._gpu_available:
+ return
+ self._running = True
+ self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+ self._thread.start()
+
+ def stop(self):
+ self._running = False
+ if self._thread:
+ self._thread.join()
+
+ def reset_peak(self):
+ with self._lock:
+ now = time.time()
+ self._memory_history.clear()
+ self._memory_history.append((now, self._current_usage))
+ self._peak_usage = self._current_usage
+
+ def get_stats(self) -> Tuple[float, float, float]:
+ with self._lock:
+ return self._current_usage, self._peak_usage, self._total_memory
+
+ def refresh(self):
+ """Forces an immediate poll of the GPU."""
+ usage = self._get_nvidia_smi_used()
+ now = time.time()
+ with self._lock:
+ self._current_usage = usage
+ self._memory_history.append((now, usage))
+ # Recalculate peak immediately
+ while self._memory_history and (now - self._memory_history[0][0]
+ > self._peak_window_seconds):
+ self._memory_history.popleft()
+ self._peak_usage = (
+ max(m for _, m in self._memory_history)
+ if self._memory_history else usage)
+
+ def _get_nvidia_smi_used(self) -> float:
+ try:
+ cmd = [
+ "nvidia-smi",
+ "--query-gpu=memory.free",
+ "--format=csv,noheader,nounits"
+ ]
+ output = subprocess.check_output(cmd, text=True).strip()
+ free_memory = float(output)
+ return self._total_memory - free_memory
+ except Exception as e:
+ logger.warning('Failed to get GPU memory usage: %s', e)
+ return 0.0
+
+ def _poll_loop(self):
+ while self._running:
+ usage = self._get_nvidia_smi_used()
+ now = time.time()
+ with self._lock:
+ self._current_usage = usage
+ self._memory_history.append((now, usage))
+ while self._memory_history and (now - self._memory_history[0][0]
+ > self._peak_window_seconds):
+ self._memory_history.popleft()
+ self._peak_usage = (
+ max(m for _, m in self._memory_history)
+ if self._memory_history else usage)
+ time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+ """Estimates individual model memory usage using statistical observation.
+
+ Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+ individual models based on aggregate system memory readings and the
+ configuration of active models at that time.
+ """
+ def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+ self.smoothing_factor = smoothing_factor
+ self.min_data_points = min_data_points
+ self.estimates: Dict[str, float] = {}
+ self.history = defaultdict(lambda: deque(maxlen=20))
+ self.known_models = set()
+ self._lock = threading.Lock()
+
+ def is_unknown(self, model_tag: str) -> bool:
+ with self._lock:
+ return model_tag not in self.estimates
+
+ def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+ with self._lock:
+ return self.estimates.get(model_tag, default_mb)
+
+ def set_initial_estimate(self, model_tag: str, cost: float):
+ with self._lock:
+ self.estimates[model_tag] = cost
+ self.known_models.add(model_tag)
+ logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+ def add_observation(
+ self, active_snapshot: Dict[str, int], peak_memory: float):
+ if active_snapshot:
+ model_list = "\n".join(
+ f"\t- {model}: {count}"
+ for model, count in sorted(active_snapshot.items()))
+ else:
+ model_list = "\t- None"
+
+ logger.info(
+ "Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s",
+ peak_memory,
+ model_list)
+ if not active_snapshot:
+ return
+ with self._lock:
+ config_key = tuple(sorted(active_snapshot.items()))
+ self.history[config_key].append(peak_memory)
+ for tag in active_snapshot:
+ self.known_models.add(tag)
+ self._solve()
+
+ def _solve(self):
+ """
+ Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+ This creates a 'tall' matrix A where every memory reading is
+ a separate equation.
+ """
+ unique = sorted(list(self.known_models))
+
+ # We need to build the matrix first to know if we have enough data points
+ A, b = [], []
+
+ for config_key, mem_values in self.history.items():
+ if not mem_values:
+ continue
+
+ # 1. Create the feature row for this configuration ONCE
+ # (It represents the model counts + bias)
+ counts = dict(config_key)
+ feature_row = [counts.get(model, 0) for model in unique]
+ feature_row.append(1) # Bias column
+
+ # 2. Add a separate row to the matrix for EVERY individual reading
+ # Instead of averaging, we flatten the history into the matrix
+ for reading in mem_values:
+ A.append(feature_row) # The inputs (models) stay the same
+ b.append(reading) # The output (memory) varies due to noise
+
+ # Convert to numpy for SciPy
+ A = np.array(A)
+ b = np.array(b)
+
+ if len(
+ self.history.keys()) < len(unique) + 1 or len(A) <
self.min_data_points:
+ # Not enough data to solve yet
+ return
+
+ logger.info(
+ "Solving with %s total observations for %s models.",
+ len(A),
+ len(unique))
+
+ try:
+ # Solve using Non-Negative Least Squares
+ # x will be >= 0
+ x, _ = nnls(A, b)
+
+ weights = x[:-1]
+ bias = x[-1]
+
+ for i, model in enumerate(unique):
+ calculated_cost = weights[i]
+
+ if model in self.estimates:
+ old = self.estimates[model]
+ new = (old * (1 - self.smoothing_factor)) + (
+ calculated_cost * self.smoothing_factor)
+ self.estimates[model] = new
+ else:
+ self.estimates[model] = calculated_cost
+
+ logger.info(
+ "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+ logger.info("System Bias: %s MB", bias)
+
+ except Exception as e:
+ logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+ """Manages model lifecycles, caching, and resource arbitration.
+
+ This class acts as the central controller for acquiring model instances.
+
+ 1. LRU Caching of idle models.
+ 2. Resource estimation and admission control (preventing OOM).
+ 3. Dynamic eviction of low-priority models, determined by count of
+ pending requests, when space is needed.
+ 4. 'Isolation Mode' for safely profiling unknown models.
+ """
+ def __init__(
+ self,
+ monitor: Optional['GPUMonitor'] = None,
+ slack_percentage: float = 0.10,
+ poll_interval: float = 0.5,
+ peak_window_seconds: float = 30.0,
+ min_data_points: int = 5,
+ smoothing_factor: float = 0.2,
+ eviction_cooldown_seconds: float = 10.0,
+ min_model_copies: int = 1,
+ wait_timeout_seconds: float = 300.0,
+ lock_timeout_seconds: float = 60.0):
+
+ self._estimator = ResourceEstimator(
+ min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+ self._monitor = monitor if monitor else GPUMonitor(
+ poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+ self._slack_percentage = slack_percentage
+
+ self._eviction_cooldown = eviction_cooldown_seconds
+ self._min_model_copies = min_model_copies
+ self._wait_timeout_seconds = wait_timeout_seconds
+ self._lock_timeout_seconds = lock_timeout_seconds
+
+ # Resource State
+ self._models = defaultdict(list)
+ # Idle LRU used to track released models that
+ # can be freed or reused upon request.
+ self._idle_lru = OrderedDict()
+ self._active_counts = Counter()
+ self._total_active_jobs = 0
+ self._pending_reservations = 0.0
+
+ # Isolation state used to profile unknown models,
+ # ensuring they run alone to get accurate readings.
+ # isolation_baseline represents the GPU usage before
+ # loading the unknown model.
+ self._isolation_mode = False
+ self._isolation_baseline = 0.0
+
+ # Waiting Queue and Ticketing to make sure we have fair ordering
+ # and also priority for unknown models.
+ self._wait_queue = []
+ self._ticket_counter = itertools.count()
+ # TODO: Consider making the wait to be smarter, i.e.
+ # splitting read/write etc. to avoid potential contention.
+ self._cv = threading.Condition()
+
+ self._monitor.start()
+
+ def all_models(self, tag) -> list[Any]:
+ return self._models[tag]
+
+ # Should hold _cv lock when calling
+ def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+ if self._total_active_jobs > 0:
+ logger.info(
+ "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+ self._cv.wait(timeout=self._lock_timeout_seconds)
+ # return False since we have waited and need to re-evaluate
+ # in caller to make sure our priority is still valid.
+ return False
+
+ logger.info("Unknown model %s detected. Flushing GPU.", tag)
+ self._delete_all_models()
+
+ self._isolation_mode = True
+ self._total_active_jobs += 1
+ self._isolation_baseline, _, _ = self._monitor.get_stats()
+ self._monitor.reset_peak()
+ return True
+
+ # Should hold _cv lock when calling
+ def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+ curr, _, total = self._monitor.get_stats()
+ est_cost = self._estimator.get_estimate(tag)
+ limit = total * (1 - self._slack_percentage)
+
+ # Use current usage for capacity check (ignore old spikes)
+ if (curr + self._pending_reservations + est_cost) <= limit:
+ self._pending_reservations += est_cost
+ self._total_active_jobs += 1
+ self._active_counts[tag] += 1
+ return True
+
+ # Evict to make space (passing tag to check demand/existence)
+ if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+ return True
+
+ # Manually log status for debugging if we are going to wait
+ idle_count = 0
+ other_idle_count = 0
+ for item in self._idle_lru.items():
+ if item[1][0] == tag:
+ idle_count += 1
+ else:
+ other_idle_count += 1
+ total_model_count = 0
+ for _, instances in self._models.items():
+ total_model_count += len(instances)
+ curr, _, _ = self._monitor.get_stats()
+ logger.info(
+ "Waiting for resources to free up: "
+ "tag=%s ticket num%s model count=%s "
+ "idle count=%s resource usage=%.1f MB "
+ "total models count=%s other idle=%s",
+ tag,
+ ticket_num,
+ len(self._models[tag]),
+ idle_count,
+ curr,
+ total_model_count,
+ other_idle_count)
+ # Wait since we couldn't make space and
+ # added timeout to avoid missed notify call.
+ self._cv.wait(timeout=self._lock_timeout_seconds)
+ return False
+
+ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+ current_priority = 0 if self._estimator.is_unknown(tag) else 1
+ ticket_num = next(self._ticket_counter)
+ my_id = object()
+
+ with self._cv:
+ # FAST PATH: Grab from idle LRU if available
+ if not self._isolation_mode:
+ cached_instance = self._try_grab_from_lru(tag)
+ if cached_instance:
+ return cached_instance
+
+ # SLOW PATH: Enqueue and wait for turn to acquire model,
+ # with unknown models having priority and order enforced
+ # by ticket number as FIFO.
+ logger.info(
+ "Acquire Queued: tag=%s, priority=%d "
+ "total models count=%s ticket num=%s",
+ tag,
+ current_priority,
+ len(self._models[tag]),
+ ticket_num)
+ heapq.heappush(
+ self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+ est_cost = 0.0
+ is_unknown = False
+ wait_time_start = time.time()
+
+ try:
+ while True:
+ wait_time_elapsed = time.time() - wait_time_start
+ if wait_time_elapsed > self._wait_timeout_seconds:
+ raise RuntimeError(
+ f"Timeout waiting to acquire model: {tag} "
+ f"after {wait_time_elapsed:.1f} seconds.")
+ if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+ logger.info(
+ "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+ self._cv.wait(timeout=self._lock_timeout_seconds)
+ continue
+
+ # Re-evaluate priority in case model became known during wait
+ is_unknown = self._estimator.is_unknown(tag)
+ real_priority = 0 if is_unknown else 1
+
+ # If priority changed, reinsert into queue and wait
+ if current_priority != real_priority:
+ heapq.heappop(self._wait_queue)
+ current_priority = real_priority
+ heapq.heappush(
+ self._wait_queue, (current_priority, ticket_num, my_id, tag))
+ self._cv.notify_all()
+ continue
+
+ # Try grab from LRU again in case model was released during wait
+ cached_instance = self._try_grab_from_lru(tag)
+ if cached_instance:
+ return cached_instance
+
+ # Path A: Isolation
+ if is_unknown:
+ if self.try_enter_isolation_mode(tag, ticket_num):
+ # We got isolation, can proceed to spawn
+ break
+ else:
+ # We waited, need to re-evaluate our turn
+ # because priority may have changed during the wait
+ continue
+
+ # Path B: Concurrent
+ else:
+ if self._isolation_mode:
+ logger.info(
+ "Waiting due to isolation in progress: tag=%s ticket num%s",
+ tag,
+ ticket_num)
+ self._cv.wait(timeout=self._lock_timeout_seconds)
+ continue
+
+ if self.should_spawn_model(tag, ticket_num):
+ est_cost = self._estimator.get_estimate(tag)
+ # We can proceed to spawn since we have resources
+ break
+ else:
+ # We waited, need to re-evaluate our turn
+ # because priority may have changed during the wait
+ continue
+
+ finally:
+ # Remove self from wait queue once done
+ if self._wait_queue and self._wait_queue[0][2] is my_id:
+ heapq.heappop(self._wait_queue)
+ else:
+ logger.warning(
+ "Item not at head of wait queue during cleanup"
+ ", this is not expected: tag=%s ticket num=%s",
+ tag,
+ ticket_num)
+ for i, item in enumerate(self._wait_queue):
+ if item[2] is my_id:
+ self._wait_queue.pop(i)
+ heapq.heapify(self._wait_queue)
+ self._cv.notify_all()
+
+ return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
+
+ def release_model(self, tag: str, instance: Any):
+ with self._cv:
+ try:
+ self._total_active_jobs -= 1
+ if self._active_counts[tag] > 0:
+ self._active_counts[tag] -= 1
+
+ self._idle_lru[id(instance)] = (tag, instance, time.time())
+
+ # Update estimator with latest stats
+ _, peak_during_job, _ = self._monitor.get_stats()
+
+ if self._isolation_mode and self._active_counts[tag] == 0:
+ # For isolation mode, we directly set the initial estimate
+ # so that we can quickly learn the model cost.
+ cost = max(0, peak_during_job - self._isolation_baseline)
+ self._estimator.set_initial_estimate(tag, cost)
+ self._isolation_mode = False
+ self._isolation_baseline = 0.0
+ else:
+ # Regular update for known models
+ snapshot = {
+ t: len(instances)
+ for t, instances in self._models.items() if len(instances) > 0
+ }
+ if snapshot:
+ self._estimator.add_observation(snapshot, peak_during_job)
+
+ finally:
+ self._cv.notify_all()
+
+ def _try_grab_from_lru(self, tag: str) -> Any:
+ target_key = None
+ target_instance = None
+
+ for key, (t, instance, _) in reversed(self._idle_lru.items()):
+ if t == tag:
+ target_key = key
+ target_instance = instance
+ break
+
+ if target_instance:
+ # Found an idle model, remove from LRU and return
+ del self._idle_lru[target_key]
+ self._active_counts[tag] += 1
+ self._total_active_jobs += 1
+ return target_instance
+
+ logger.info("No idle model found for tag: %s", tag)
+ return None
+
+ def _evict_to_make_space(
+ self, limit: float, est_cost: float, requesting_tag: str) -> bool:
+ """
+ Evicts models based on Demand Magnitude + Tiers.
+ Crucially: If we have 0 active copies of 'requesting_tag', we FORCE
eviction
+ of the lowest-demand candidate to avoid starvation.
+ Returns True if space was made, False otherwise.
+ """
+ curr, _, _ = self._monitor.get_stats()
+ projected_usage = curr + self._pending_reservations + est_cost
+
+ if projected_usage <= limit:
+ # Memory usage changed and we are already under limit
+ return True
+
+ now = time.time()
+
+ # Calculate the demand from the wait queue
+ # TODO: Also factor in the active counts to avoid thrashing
+ demand_map = Counter()
+ for item in self._wait_queue:
+ demand_map[item[3]] += 1
+
+ my_demand = demand_map[requesting_tag]
+ am_i_starving = len(self._models[requesting_tag]) == 0
+
+ candidates = []
+ for key, (tag, instance, release_time) in self._idle_lru.items():
+ candidate_demand = demand_map[tag]
+
+ # TODO: Try to avoid churn if demand is similar
+ if not am_i_starving and candidate_demand >= my_demand:
+ continue
+
+ # Attempts to score candidates based on hotness and manually
+ # specified minimum copies. Demand is weighted heavily to
+ # ensure we evict low-demand models first.
+ age = now - release_time
+ is_cold = age >= self._eviction_cooldown
+
+ total_copies = len(self._models[tag])
+ is_surplus = total_copies > self._min_model_copies
+
+ if is_cold and is_surplus: tier = 0
+ elif not is_cold and is_surplus: tier = 1
+ elif is_cold and not is_surplus: tier = 2
+ else: tier = 3
+
+ score = (candidate_demand * 10) + tier
+
+ candidates.append((score, release_time, key, tag, instance))
+
+ candidates.sort(key=lambda x: (x[0], x[1]))
+
+ # Evict candidates until we are under limit
+ for score, _, key, tag, instance in candidates:
+ if projected_usage <= limit:
+ break
+
+ if key not in self._idle_lru: continue
+
+ self._perform_eviction(key, tag, instance, score)
+
+ curr, _, _ = self._monitor.get_stats()
+ projected_usage = curr + self._pending_reservations + est_cost
+
+ return projected_usage <= limit
+
+ def _delete_instance(self, instance: Any):
+ if isinstance(instance, str):
+ # If the instance is a string, it's a uuid used
+ # to retrieve the model from MultiProcessShared
+ multi_process_shared.MultiProcessShared(
+ lambda: "N/A", tag=instance).unsafe_hard_delete()
+ if hasattr(instance, 'mock_model_unsafe_hard_delete'):
+ # Call the mock unsafe hard delete method for testing
+ instance.mock_model_unsafe_hard_delete()
+ del instance
+
+ def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
+ logger.info("Evicting Model: %s (Score %d)", tag, score)
+ curr, _, _ = self._monitor.get_stats()
+ logger.info("Resource Usage Before Eviction: %.1f MB", curr)
+
+ if key in self._idle_lru:
+ del self._idle_lru[key]
+
+ for i, inst in enumerate(self._models[tag]):
+ if instance == inst:
+ del self._models[tag][i]
+ break
+
+ self._delete_instance(instance)
+ gc.collect()
+ torch.cuda.empty_cache()
+ self._monitor.refresh()
+ self._monitor.reset_peak()
+ curr, _, _ = self._monitor.get_stats()
+ logger.info("Resource Usage After Eviction: %.1f MB", curr)
+
+ def _spawn_new_model(
+ self,
+ tag: str,
+ loader_func: Callable[[], Any],
+ is_unknown: bool,
+ est_cost: float) -> Any:
+ try:
+ with self._cv:
+ logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
+ baseline_snap, _, _ = self._monitor.get_stats()
+ instance = loader_func()
+ _, peak_during_load, _ = self._monitor.get_stats()
+
+ snapshot = {tag: 1}
+ self._estimator.add_observation(
+ snapshot, peak_during_load - baseline_snap)
+
+ if not is_unknown:
+ self._pending_reservations = max(
+ 0.0, self._pending_reservations - est_cost)
+ self._models[tag].append(instance)
+ return instance
+
+ except Exception as e:
+ logger.error("Load Failed: %s. Error: %s", tag, e)
+ with self._cv:
+ self._total_active_jobs -= 1
+ if is_unknown:
+ self._isolation_mode = False
+ self._isolation_baseline = 0.0
+ else:
+ self._pending_reservations = max(
+ 0.0, self._pending_reservations - est_cost)
+ self._active_counts[tag] -= 1
+ self._cv.notify_all()
+ raise e
+
+ def _delete_all_models(self):
+ self._idle_lru.clear()
+ for _, instances in self._models.items():
+ for instance in instances:
+ self._delete_instance(instance)
+ self._models.clear()
+ self._active_counts.clear()
+ gc.collect()
+ torch.cuda.empty_cache()
+ self._monitor.refresh()
+ self._monitor.reset_peak()
+
+ def _force_reset(self):
+ logger.warning("Force Reset Triggered")
+ self._delete_all_models()
+ self._models = defaultdict(list)
+ self._idle_lru = OrderedDict()
+ self._active_counts = Counter()
+ self._wait_queue = []
+ self._total_active_jobs = 0
+ self._pending_reservations = 0.0
+ self._isolation_mode = False
+ self._isolation_baseline = 0.0
+
+ def shutdown(self):
+ self._delete_all_models()
+ self._monitor.stop()
+
+ def __del__(self):
+ self.shutdown()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.shutdown()
diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py
b/sdks/python/apache_beam/ml/inference/model_manager_test.py
new file mode 100644
index 00000000000..1bd8edd34d1
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py
@@ -0,0 +1,622 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import threading
+import time
+import unittest
+from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import TimeoutError
+from unittest.mock import patch
+
+from apache_beam.utils import multi_process_shared
+
+try:
+ from apache_beam.ml.inference.model_manager import GPUMonitor
+ from apache_beam.ml.inference.model_manager import ModelManager
+ from apache_beam.ml.inference.model_manager import ResourceEstimator
+except ImportError as e:
+ raise unittest.SkipTest("Model Manager dependencies are not installed")
+
+
+class MockGPUMonitor:
+ """
+ Simulates GPU hardware with cumulative memory tracking.
+ Allows simulating specific allocation spikes and baseline usage.
+ """
+ def __init__(self, total_memory=12000.0, peak_window: int = 5):
+ self._current = 0.0
+ self._peak = 0.0
+ self._total = total_memory
+ self._lock = threading.Lock()
+ self.running = False
+ self.history = []
+ self.peak_window = peak_window
+
+ def start(self):
+ self.running = True
+
+ def stop(self):
+ self.running = False
+
+ def get_stats(self):
+ with self._lock:
+ return self._current, self._peak, self._total
+
+ def reset_peak(self):
+ with self._lock:
+ self._peak = self._current
+ self.history = [self._current]
+
+ def set_usage(self, current_mb):
+ """Sets absolute usage (legacy helper)."""
+ with self._lock:
+ self._current = current_mb
+ self._peak = max(self._peak, current_mb)
+
+ def allocate(self, amount_mb):
+ """Simulates memory allocation (e.g., tensors loaded to VRAM)."""
+ with self._lock:
+ self._current += amount_mb
+ self.history.append(self._current)
+ if len(self.history) > self.peak_window:
+ self.history.pop(0)
+ self._peak = max(self.history)
+
+ def free(self, amount_mb):
+ """Simulates memory freeing (not used often if pooling is active)."""
+ with self._lock:
+ self._current = max(0.0, self._current - amount_mb)
+ self.history.append(self._current)
+ if len(self.history) > self.peak_window:
+ self.history.pop(0)
+ self._peak = max(self.history)
+
+ def refresh(self):
+ """Simulates a refresh of the monitor stats (no-op for mock)."""
+ pass
+
+
+class MockModel:
+ def __init__(self, name, size, monitor):
+ self.name = name
+ self.size = size
+ self.monitor = monitor
+ self.deleted = False
+ self.monitor.allocate(size)
+
+ def mock_model_unsafe_hard_delete(self):
+ if not self.deleted:
+ self.monitor.free(self.size)
+ self.deleted = True
+
+
+class Counter(object):
+ def __init__(self, start=0):
+ self.running = start
+ self.lock = threading.Lock()
+
+ def get(self):
+ return self.running
+
+ def increment(self, value=1):
+ with self.lock:
+ self.running += value
+ return self.running
+
+
+class TestModelManager(unittest.TestCase):
+ def setUp(self):
+ """Force reset the Singleton ModelManager before every test."""
+ self.mock_monitor = MockGPUMonitor()
+ self.manager = ModelManager(monitor=self.mock_monitor)
+
+ def tearDown(self):
+ self.manager.shutdown()
+
+ def test_model_manager_deletes_multiprocessshared_instances(self):
+ """Test that MultiProcessShared instances are deleted properly."""
+ model_name = "test_model_shared"
+ tag = f"model_manager_test_{model_name}"
+
+ def loader():
+ multi_process_shared.MultiProcessShared(
+ lambda: Counter, tag=tag, always_proxy=True)
+ return tag
+
+ instance = self.manager.acquire_model(model_name, loader)
+ instance_before = multi_process_shared.MultiProcessShared(
+ Counter, tag=tag, always_proxy=True).acquire()
+ instance_before.increment()
+ self.assertEqual(instance_before.get(), 1)
+ self.manager.release_model(model_name, instance)
+
+ # Force delete all models
+ self.manager._force_reset()
+
+ # Verify that the MultiProcessShared instance is deleted
+ # and the counter is reseted
+ with self.assertRaises(Exception):
+ instance_before.get()
+ instance_after = multi_process_shared.MultiProcessShared(
+ Counter, tag=tag, always_proxy=True).acquire()
+ self.assertEqual(instance_after.get(), 0)
+
+ def test_model_manager_timeout_on_acquire(self):
+ """Test that acquiring a model times out properly."""
+ model_name = "timeout_model"
+ self.manager = ModelManager(
+ monitor=self.mock_monitor,
+ wait_timeout_seconds=1.0,
+ lock_timeout_seconds=1.0)
+
+ def loader():
+ self.mock_monitor.allocate(self.mock_monitor._total)
+ return model_name
+
+ # Acquire the model in one thread to block others
+ _ = self.manager.acquire_model(model_name, loader)
+
+ def acquire_model_with_timeout():
+ return self.manager.acquire_model(model_name, loader)
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(acquire_model_with_timeout)
+ with self.assertRaises(RuntimeError) as context:
+ future.result(timeout=5.0)
+ self.assertIn("Timeout waiting to acquire model", str(context.exception))
+
+ def test_model_manager_capacity_check(self):
+ """
+ Test that the manager blocks when spawning models exceeds the limit,
+ and unblocks when resources become available (via reuse).
+ """
+ model_name = "known_model"
+ model_cost = 3000.0
+ self.manager._estimator.set_initial_estimate(model_name, model_cost)
+ acquired_refs = []
+
+ def loader():
+ self.mock_monitor.allocate(model_cost)
+ return model_name
+
+ # 1. Saturate GPU with 3 models (9000 MB usage)
+ for _ in range(3):
+ inst = self.manager.acquire_model(model_name, loader)
+ acquired_refs.append(inst)
+
+ # 2. Spawn one more (Should Block because 9000 + 3000 > Limit)
+ def run_inference():
+ return self.manager.acquire_model(model_name, loader)
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(run_inference)
+ try:
+ future.result(timeout=0.5)
+ self.fail("Should have blocked due to capacity")
+ except TimeoutError:
+ pass
+
+ # 3. Release resources to unblock
+ item_to_release = acquired_refs.pop()
+ self.manager.release_model(model_name, item_to_release)
+
+ result = future.result(timeout=2.0)
+ self.assertIsNotNone(result)
+ self.assertEqual(result, item_to_release)
+
+ def test_model_manager_unknown_model_runs_isolated(self):
+ """Test that a model with no history runs in isolation."""
+ model_name = "unknown_model_v1"
+ self.assertTrue(self.manager._estimator.is_unknown(model_name))
+
+ def dummy_loader():
+ time.sleep(0.05)
+ return "model_instance"
+
+ instance = self.manager.acquire_model(model_name, dummy_loader)
+
+ self.assertTrue(self.manager._isolation_mode)
+ self.assertEqual(self.manager._total_active_jobs, 1)
+
+ self.manager.release_model(model_name, instance)
+ self.assertFalse(self.manager._isolation_mode)
+ self.assertFalse(self.manager._estimator.is_unknown(model_name))
+
+ def test_model_manager_concurrent_execution(self):
+ """Test that multiple small known models can run together."""
+ model_a = "small_model_a"
+ model_b = "small_model_b"
+
+ self.manager._estimator.set_initial_estimate(model_a, 1000.0)
+ self.manager._estimator.set_initial_estimate(model_b, 1000.0)
+ self.mock_monitor.set_usage(1000.0)
+
+ inst_a = self.manager.acquire_model(model_a, lambda: "A")
+ inst_b = self.manager.acquire_model(model_b, lambda: "B")
+
+ self.assertEqual(self.manager._total_active_jobs, 2)
+
+ self.manager.release_model(model_a, inst_a)
+ self.manager.release_model(model_b, inst_b)
+ self.assertEqual(self.manager._total_active_jobs, 0)
+
+ def test_model_manager_concurrent_mixed_workload_convergence(self):
+ """
+ Simulates a production environment with multiple model types running
+ concurrently. Verifies that the estimator converges.
+ """
+ TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0}
+
+ def run_job(model_name):
+ cost = TRUE_COSTS[model_name]
+
+ def loader():
+ model = MockModel(model_name, cost, self.mock_monitor)
+ return model
+
+ instance = self.manager.acquire_model(model_name, loader)
+ time.sleep(random.uniform(0.01, 0.05))
+ self.manager.release_model(model_name, instance)
+
+ # Create a workload stream
+ workload = ["model_small"] * 15 + ["model_medium"] * 15
+ random.shuffle(workload)
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = [executor.submit(run_job, name) for name in workload]
+ for f in futures:
+ f.result()
+
+ est_small = self.manager._estimator.get_estimate("model_small")
+ est_med = self.manager._estimator.get_estimate("model_medium")
+
+ self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0)
+ self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0)
+
+ def test_model_manager_oom_recovery(self):
+ """Test that the manager recovers state if a loader crashes."""
+ model_name = "crasher_model"
+ self.manager._estimator.set_initial_estimate(model_name, 1000.0)
+
+ def crashing_loader():
+ raise RuntimeError("CUDA OOM or similar")
+
+ with self.assertRaises(RuntimeError):
+ self.manager.acquire_model(model_name, crashing_loader)
+
+ self.assertEqual(self.manager._total_active_jobs, 0)
+ self.assertEqual(self.manager._pending_reservations, 0.0)
+ self.assertFalse(self.manager._cv._is_owned())
+
+ def test_model_manager_force_reset_on_exception(self):
+ """Test that force_reset clears all models from the manager."""
+ model_name = "test_model"
+
+ def dummy_loader():
+ self.mock_monitor.allocate(1000.0)
+ raise RuntimeError("Simulated loader exception")
+
+ try:
+ instance = self.manager.acquire_model(
+ model_name, lambda: "model_instance")
+ self.manager.release_model(model_name, instance)
+ instance = self.manager.acquire_model(model_name, dummy_loader)
+ except RuntimeError:
+ self.manager._force_reset()
+ self.assertTrue(len(self.manager._models[model_name]) == 0)
+ self.assertEqual(self.manager._total_active_jobs, 0)
+ self.assertEqual(self.manager._pending_reservations, 0.0)
+ self.assertFalse(self.manager._isolation_mode)
+ pass
+
+ instance = self.manager.acquire_model(model_name, lambda: "model_instance")
+ self.manager.release_model(model_name, instance)
+
+ def test_single_model_convergence_with_fluctuations(self):
+ """
+ Tests that the estimator converges to the true usage with fluctuations.
+ """
+ model_name = "fluctuating_model"
+ model_cost = 3000.0
+ load_cost = 2500.0
+ # Fix random seed for reproducibility
+ random.seed(42)
+
+ def loader():
+ self.mock_monitor.allocate(load_cost)
+ return model_name
+
+ model = self.manager.acquire_model(model_name, loader)
+ self.manager.release_model(model_name, model)
+ initial_est = self.manager._estimator.get_estimate(model_name)
+ self.assertEqual(initial_est, load_cost)
+
+ def run_inference():
+ model = self.manager.acquire_model(model_name, loader)
+ noise = model_cost - load_cost + random.uniform(-300.0, 300.0)
+ self.mock_monitor.allocate(noise)
+ time.sleep(0.1)
+ self.mock_monitor.free(noise)
+ self.manager.release_model(model_name, model)
+ return
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = [executor.submit(run_inference) for _ in range(100)]
+
+ for f in futures:
+ f.result()
+
+ est_cost = self.manager._estimator.get_estimate(model_name)
+ self.assertAlmostEqual(est_cost, model_cost, delta=100.0)
+
+
+class TestModelManagerEviction(unittest.TestCase):
+ def setUp(self):
+ self.mock_monitor = MockGPUMonitor(total_memory=12000.0)
+ self.manager = ModelManager(
+ monitor=self.mock_monitor,
+ slack_percentage=0.0,
+ min_data_points=1,
+ eviction_cooldown_seconds=10.0,
+ min_model_copies=1)
+
+ def tearDown(self):
+ self.manager.shutdown()
+
+ def create_loader(self, name, size):
+ return lambda: MockModel(name, size, self.mock_monitor)
+
+ def test_basic_lru_eviction(self):
+ self.manager._estimator.set_initial_estimate("A", 4000)
+ self.manager._estimator.set_initial_estimate("B", 4000)
+ self.manager._estimator.set_initial_estimate("C", 5000)
+
+ model_a = self.manager.acquire_model("A", self.create_loader("A", 4000))
+ self.manager.release_model("A", model_a)
+
+ model_b = self.manager.acquire_model("B", self.create_loader("B", 4000))
+ self.manager.release_model("B", model_b)
+
+ key_a = list(self.manager._idle_lru.keys())[0]
+ self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0)
+
+ key_b = list(self.manager._idle_lru.keys())[1]
+ self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0)
+
+ model_a_again = self.manager.acquire_model(
+ "A", self.create_loader("A", 4000))
+ self.manager.release_model("A", model_a_again)
+
+ self.manager.acquire_model("C", self.create_loader("C", 5000))
+
+ self.assertEqual(len(self.manager.all_models("B")), 0)
+ self.assertEqual(len(self.manager.all_models("A")), 1)
+
+ def test_chained_eviction(self):
+ self.manager._estimator.set_initial_estimate("big_guy", 8000)
+ models = []
+ for i in range(4):
+ name = f"small_{i}"
+ m = self.manager.acquire_model(name, self.create_loader(name, 3000))
+ self.manager.release_model(name, m)
+ models.append(m)
+
+ self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000))
+
+ self.assertTrue(models[0].deleted)
+ self.assertTrue(models[1].deleted)
+ self.assertTrue(models[2].deleted)
+ self.assertFalse(models[3].deleted)
+
+ def test_active_models_are_protected(self):
+ self.manager._estimator.set_initial_estimate("A", 6000)
+ self.manager._estimator.set_initial_estimate("B", 4000)
+ self.manager._estimator.set_initial_estimate("C", 4000)
+
+ model_a = self.manager.acquire_model("A", self.create_loader("A", 6000))
+ model_b = self.manager.acquire_model("B", self.create_loader("B", 4000))
+ self.manager.release_model("B", model_b)
+
+ key_b = list(self.manager._idle_lru.keys())[0]
+ self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0)
+
+ def acquire_c():
+ return self.manager.acquire_model("C", self.create_loader("C", 4000))
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(acquire_c)
+ model_c = future.result(timeout=2.0)
+
+ self.assertTrue(model_b.deleted)
+ self.assertFalse(model_a.deleted)
+
+ self.manager.release_model("A", model_a)
+ self.manager.release_model("C", model_c)
+
+ def test_unknown_model_clears_memory(self):
+ self.manager._estimator.set_initial_estimate("A", 2000)
+ model_a = self.manager.acquire_model("A", self.create_loader("A", 2000))
+ self.manager.release_model("A", model_a)
+ self.assertFalse(model_a.deleted)
+
+ self.assertTrue(self.manager._estimator.is_unknown("X"))
+ model_x = self.manager.acquire_model("X", self.create_loader("X", 10000))
+
+ self.assertTrue(model_a.deleted, "Model A should be deleted for isolation")
+ self.assertEqual(len(self.manager.all_models("A")), 0)
+ self.assertTrue(self.manager._isolation_mode)
+ self.manager.release_model("X", model_x)
+
+ def test_concurrent_eviction_pressure(self):
+ def worker(idx):
+ name = f"model_{idx % 5}"
+ try:
+ m = self.manager.acquire_model(name, self.create_loader(name, 4000))
+ time.sleep(0.001)
+ self.manager.release_model(name, m)
+ except Exception:
+ pass
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = [executor.submit(worker, i) for i in range(50)]
+ for f in futures:
+ f.result()
+
+ curr, _, _ = self.mock_monitor.get_stats()
+ expected_usage = 0
+ for _, instances in self.manager._models.items():
+ expected_usage += len(instances) * 4000
+
+ self.assertAlmostEqual(curr, expected_usage)
+
+ def test_starvation_prevention_overrides_demand(self):
+ self.manager._estimator.set_initial_estimate("A", 12000)
+ m_a = self.manager.acquire_model("A", self.create_loader("A", 12000))
+ self.manager.release_model("A", m_a)
+
+ def cycle_a():
+ try:
+ m = self.manager.acquire_model("A", self.create_loader("A", 12000))
+ time.sleep(0.3)
+ self.manager.release_model("A", m)
+ except Exception:
+ pass
+
+ executor = ThreadPoolExecutor(max_workers=5)
+ for _ in range(5):
+ executor.submit(cycle_a)
+
+ def acquire_b():
+ return self.manager.acquire_model("B", self.create_loader("B", 4000))
+
+ b_future = executor.submit(acquire_b)
+ model_b = b_future.result()
+
+ self.assertTrue(m_a.deleted)
+ self.manager.release_model("B", model_b)
+ executor.shutdown(wait=True)
+
+
+class TestGPUMonitor(unittest.TestCase):
+ def setUp(self):
+ self.subprocess_patcher = patch('subprocess.check_output')
+ self.mock_subprocess = self.subprocess_patcher.start()
+
+ def tearDown(self):
+ self.subprocess_patcher.stop()
+
+ def test_init_hardware_detected(self):
+ """Test that init correctly reads total memory when nvidia-smi exists."""
+ self.mock_subprocess.return_value = "24576"
+ monitor = GPUMonitor()
+ monitor.start()
+ self.assertTrue(monitor._gpu_available)
+ self.assertEqual(monitor._total_memory, 24576.0)
+
+ def test_init_hardware_missing(self):
+ """Test fallback behavior when nvidia-smi is missing."""
+ self.mock_subprocess.side_effect = FileNotFoundError()
+ monitor = GPUMonitor(fallback_memory_mb=12000.0)
+ monitor.start()
+ self.assertFalse(monitor._gpu_available)
+ self.assertEqual(monitor._total_memory, 12000.0)
+
+ @patch('time.sleep')
+ def test_polling_updates_stats(self, mock_sleep):
+ """Test that the polling loop updates current and peak usage."""
+ def subprocess_side_effect(*args, **kwargs):
+ if isinstance(args[0], list) and "memory.total" in args[0][1]:
+ return "16000"
+
+ if isinstance(args[0], list) and any("memory.free" in part
+ for part in args[0]):
+ return "12000"
+
+ raise Exception("Unexpected command")
+
+ self.mock_subprocess.side_effect = subprocess_side_effect
+ self.mock_subprocess.return_value = None
+
+ monitor = GPUMonitor()
+ monitor.start()
+ time.sleep(0.1)
+ curr, peak, total = monitor.get_stats()
+ monitor.stop()
+
+ self.assertEqual(curr, 4000.0)
+ self.assertEqual(peak, 4000.0)
+ self.assertEqual(total, 16000.0)
+
+ def test_reset_peak(self):
+ """Test that resetting peak usage works."""
+ monitor = GPUMonitor()
+ monitor._gpu_available = True
+
+ with monitor._lock:
+ monitor._current_usage = 2000.0
+ monitor._peak_usage = 8000.0
+ monitor._memory_history.append((time.time(), 8000.0))
+ monitor._memory_history.append((time.time(), 2000.0))
+
+ monitor.reset_peak()
+
+ _, peak, _ = monitor.get_stats()
+ self.assertEqual(peak, 2000.0)
+
+
+class TestResourceEstimatorSolver(unittest.TestCase):
+ def setUp(self):
+ self.estimator = ResourceEstimator()
+
+ @patch('apache_beam.ml.inference.model_manager.nnls')
+ def test_solver_respects_min_data_points(self, mock_nnls):
+ mock_nnls.return_value = ([100.0, 50.0], 0.0)
+
+ self.estimator.add_observation({'model_A': 1}, 500)
+ self.estimator.add_observation({'model_B': 1}, 500)
+ self.assertFalse(mock_nnls.called)
+
+ self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000)
+ self.assertFalse(mock_nnls.called)
+
+ self.estimator.add_observation({'model_A': 1}, 500)
+ self.assertFalse(mock_nnls.called)
+
+ self.estimator.add_observation({'model_B': 1}, 500)
+ self.assertTrue(mock_nnls.called)
+
+ @patch('apache_beam.ml.inference.model_manager.nnls')
+ def test_solver_respects_unique_model_constraint(self, mock_nnls):
+ mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0)
+
+ for _ in range(5):
+ self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800)
+
+ for _ in range(5):
+ self.estimator.add_observation({'model_C': 1}, 400)
+
+ self.assertFalse(mock_nnls.called)
+
+ self.estimator.add_observation({'model_A': 1}, 300)
+ self.estimator.add_observation({'model_B': 1}, 300)
+
+ self.assertTrue(mock_nnls.called)
+
+
+if __name__ == "__main__":
+ unittest.main()