gemini-code-assist[bot] commented on code in PR #37113:
URL: https://github.com/apache/beam/pull/37113#discussion_r2747477968


##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Using `shell=True` with `subprocess.check_output` can be a security risk and 
is generally discouraged. It's also inconsistent with `_detect_hardware`, which 
passes command arguments as a list. For better security and consistency, please 
avoid `shell=True`.
   
   ```suggestion
         cmd = ["nvidia-smi", "--query-gpu=memory.free", 
"--format=csv,noheader,nounits"]
         output = subprocess.check_output(cmd, text=True).strip()
   ```



##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+  It handles:
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+    pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  _lock = threading.Lock()

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The class attribute `_lock = threading.Lock()` appears to be unused. The 
instance synchronization is handled by `self._cv`, which has its own internal 
lock. This unused attribute can be confusing and should be removed for code 
clarity.



##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Catching a broad `Exception` and returning a default value without logging 
can hide important errors and make debugging difficult. It's better to log the 
exception to provide visibility into potential issues with `nvidia-smi` or 
parsing its output.
   
   ```suggestion
       except Exception as e:
         logger.warning('Failed to get GPU memory usage: %s', e)
         return 0.0
   ```



##########
sdks/python/apache_beam/ml/inference/model_manager_test.py:
##########
@@ -0,0 +1,596 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import threading
+import time
+import unittest
+from concurrent.futures import ThreadPoolExecutor
+from unittest.mock import patch
+
+from apache_beam.utils import multi_process_shared
+
+try:
+  from apache_beam.ml.inference.model_manager import GPUMonitor
+  from apache_beam.ml.inference.model_manager import ModelManager
+  from apache_beam.ml.inference.model_manager import ResourceEstimator
+except ImportError as e:
+  raise unittest.SkipTest("Model Manager dependencies are not installed")
+
+
+class MockGPUMonitor:
+  """
+  Simulates GPU hardware with cumulative memory tracking.
+  Allows simulating specific allocation spikes and baseline usage.
+  """
+  def __init__(self, total_memory=12000.0, peak_window: int = 5):
+    self._current = 0.0
+    self._peak = 0.0
+    self._total = total_memory
+    self._lock = threading.Lock()
+    self.running = False
+    self.history = []
+    self.peak_window = peak_window
+
+  def start(self):
+    self.running = True
+
+  def stop(self):
+    self.running = False
+
+  def get_stats(self):
+    with self._lock:
+      return self._current, self._peak, self._total
+
+  def reset_peak(self):
+    with self._lock:
+      self._peak = self._current
+      self.history = [self._current]
+
+  def set_usage(self, current_mb):
+    """Sets absolute usage (legacy helper)."""
+    with self._lock:
+      self._current = current_mb
+      self._peak = max(self._peak, current_mb)
+
+  def allocate(self, amount_mb):
+    """Simulates memory allocation (e.g., tensors loaded to VRAM)."""
+    with self._lock:
+      self._current += amount_mb
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def free(self, amount_mb):
+    """Simulates memory freeing (not used often if pooling is active)."""
+    with self._lock:
+      self._current = max(0.0, self._current - amount_mb)
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def refresh(self):
+    """Simulates a refresh of the monitor stats (no-op for mock)."""
+    pass
+
+
+class MockModel:
+  def __init__(self, name, size, monitor):
+    self.name = name
+    self.size = size
+    self.monitor = monitor
+    self.deleted = False
+    self.monitor.allocate(size)
+
+  def mock_model_unsafe_hard_delete(self):
+    if not self.deleted:
+      self.monitor.free(self.size)
+      self.deleted = True
+
+
+class Counter(object):
+  def __init__(self, start=0):
+    self.running = start
+    self.lock = threading.Lock()
+
+  def get(self):
+    return self.running
+
+  def increment(self, value=1):
+    with self.lock:
+      self.running += value
+      return self.running
+
+
+class TestModelManager(unittest.TestCase):
+  def setUp(self):
+    """Force reset the Singleton ModelManager before every test."""
+    ModelManager._instance = None
+    self.mock_monitor = MockGPUMonitor()
+    self.manager = ModelManager(monitor=self.mock_monitor)
+
+  def tearDown(self):
+    self.manager.shutdown()
+
+  def test_model_manager_deletes_multiprocessshared_instances(self):
+    """Test that MultiProcessShared instances are deleted properly."""
+    model_name = "test_model_shared"
+    tag = f"model_manager_test_{model_name}"
+
+    def loader():
+      multi_process_shared.MultiProcessShared(
+          lambda: Counter, tag=tag, always_proxy=True)
+      return tag
+
+    instance = self.manager.acquire_model(model_name, loader)
+    instance_before = multi_process_shared.MultiProcessShared(
+        Counter, tag=tag, always_proxy=True).acquire()
+    instance_before.increment()
+    self.assertEqual(instance_before.get(), 1)
+    self.manager.release_model(model_name, instance)
+
+    # Force delete all models
+    self.manager._force_reset()
+
+    # Verify that the MultiProcessShared instance is deleted
+    # and the counter is reseted
+    with self.assertRaises(Exception):
+      instance_before.get()
+    instance_after = multi_process_shared.MultiProcessShared(
+        Counter, tag=tag, always_proxy=True).acquire()
+    self.assertEqual(instance_after.get(), 0)
+
+  def test_model_manager_capacity_check(self):
+    """
+    Test that the manager blocks when spawning models exceeds the limit,
+    and unblocks when resources become available (via reuse).
+    """
+    model_name = "known_model"
+    model_cost = 3000.0
+    self.manager._estimator.set_initial_estimate(model_name, model_cost)
+    acquired_refs = []
+
+    def loader():
+      self.mock_monitor.allocate(model_cost)
+      return model_name
+
+    # 1. Saturate GPU with 3 models (9000 MB usage)
+    for _ in range(3):
+      inst = self.manager.acquire_model(model_name, loader)
+      acquired_refs.append(inst)
+
+    # 2. Spawn one more (Should Block because 9000 + 3000 > Limit)
+    def run_inference():
+      return self.manager.acquire_model(model_name, loader)
+
+    with ThreadPoolExecutor(max_workers=1) as executor:
+      future = executor.submit(run_inference)
+      try:
+        future.result(timeout=0.5)
+        self.fail("Should have blocked due to capacity")
+      except TimeoutError:
+        pass
+
+      # 3. Release resources to unblock
+      item_to_release = acquired_refs.pop()
+      self.manager.release_model(model_name, item_to_release)
+
+      result = future.result(timeout=2.0)
+      self.assertIsNotNone(result)
+      self.assertEqual(result, item_to_release)
+
+  def test_model_manager_unknown_model_runs_isolated(self):
+    """Test that a model with no history runs in isolation."""
+    model_name = "unknown_model_v1"
+    self.assertTrue(self.manager._estimator.is_unknown(model_name))
+
+    def dummy_loader():
+      time.sleep(0.05)
+      return "model_instance"
+
+    instance = self.manager.acquire_model(model_name, dummy_loader)
+
+    self.assertTrue(self.manager._isolation_mode)
+    self.assertEqual(self.manager._total_active_jobs, 1)
+
+    self.manager.release_model(model_name, instance)
+    self.assertFalse(self.manager._isolation_mode)
+    self.assertFalse(self.manager._estimator.is_unknown(model_name))
+
+  def test_model_manager_concurrent_execution(self):
+    """Test that multiple small known models can run together."""
+    model_a = "small_model_a"
+    model_b = "small_model_b"
+
+    self.manager._estimator.set_initial_estimate(model_a, 1000.0)
+    self.manager._estimator.set_initial_estimate(model_b, 1000.0)
+    self.mock_monitor.set_usage(1000.0)
+
+    inst_a = self.manager.acquire_model(model_a, lambda: "A")
+    inst_b = self.manager.acquire_model(model_b, lambda: "B")
+
+    self.assertEqual(self.manager._total_active_jobs, 2)
+
+    self.manager.release_model(model_a, inst_a)
+    self.manager.release_model(model_b, inst_b)
+    self.assertEqual(self.manager._total_active_jobs, 0)
+
+  def test_model_manager_concurrent_mixed_workload_convergence(self):
+    """
+    Simulates a production environment with multiple model types running
+    concurrently. Verifies that the estimator converges.
+    """
+    TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0}
+
+    def run_job(model_name):
+      cost = TRUE_COSTS[model_name]
+
+      def loader():
+        model = MockModel(model_name, cost, self.mock_monitor)
+        return model
+
+      instance = self.manager.acquire_model(model_name, loader)
+      time.sleep(random.uniform(0.01, 0.05))
+      self.manager.release_model(model_name, instance)
+
+    # Create a workload stream
+    workload = ["model_small"] * 15 + ["model_medium"] * 15
+    random.shuffle(workload)
+
+    with ThreadPoolExecutor(max_workers=8) as executor:
+      futures = [executor.submit(run_job, name) for name in workload]
+      for f in futures:
+        f.result()
+
+    est_small = self.manager._estimator.get_estimate("model_small")
+    est_med = self.manager._estimator.get_estimate("model_medium")
+
+    self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0)
+    self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0)
+
+  def test_model_manager_oom_recovery(self):
+    """Test that the manager recovers state if a loader crashes."""
+    model_name = "crasher_model"
+    self.manager._estimator.set_initial_estimate(model_name, 1000.0)
+
+    def crashing_loader():
+      raise RuntimeError("CUDA OOM or similar")
+
+    with self.assertRaises(RuntimeError):
+      self.manager.acquire_model(model_name, crashing_loader)
+
+    self.assertEqual(self.manager._total_active_jobs, 0)
+    self.assertEqual(self.manager._pending_reservations, 0.0)
+    self.assertFalse(self.manager._cv._is_owned())
+
+  def test_model_managaer_force_reset_on_exception(self):

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   There is a typo in the test method name. It should be 
`test_model_manager_...` instead of `test_model_managaer_...`.
   
   ```suggestion
     def test_model_manager_force_reset_on_exception(self):
   ```



##########
sdks/python/apache_beam/ml/inference/model_manager_test.py:
##########
@@ -0,0 +1,596 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import threading
+import time
+import unittest
+from concurrent.futures import ThreadPoolExecutor
+from unittest.mock import patch
+
+from apache_beam.utils import multi_process_shared
+
+try:
+  from apache_beam.ml.inference.model_manager import GPUMonitor
+  from apache_beam.ml.inference.model_manager import ModelManager
+  from apache_beam.ml.inference.model_manager import ResourceEstimator
+except ImportError as e:
+  raise unittest.SkipTest("Model Manager dependencies are not installed")
+
+
+class MockGPUMonitor:
+  """
+  Simulates GPU hardware with cumulative memory tracking.
+  Allows simulating specific allocation spikes and baseline usage.
+  """
+  def __init__(self, total_memory=12000.0, peak_window: int = 5):
+    self._current = 0.0
+    self._peak = 0.0
+    self._total = total_memory
+    self._lock = threading.Lock()
+    self.running = False
+    self.history = []
+    self.peak_window = peak_window
+
+  def start(self):
+    self.running = True
+
+  def stop(self):
+    self.running = False
+
+  def get_stats(self):
+    with self._lock:
+      return self._current, self._peak, self._total
+
+  def reset_peak(self):
+    with self._lock:
+      self._peak = self._current
+      self.history = [self._current]
+
+  def set_usage(self, current_mb):
+    """Sets absolute usage (legacy helper)."""
+    with self._lock:
+      self._current = current_mb
+      self._peak = max(self._peak, current_mb)
+
+  def allocate(self, amount_mb):
+    """Simulates memory allocation (e.g., tensors loaded to VRAM)."""
+    with self._lock:
+      self._current += amount_mb
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def free(self, amount_mb):
+    """Simulates memory freeing (not used often if pooling is active)."""
+    with self._lock:
+      self._current = max(0.0, self._current - amount_mb)
+      self.history.append(self._current)
+      if len(self.history) > self.peak_window:
+        self.history.pop(0)
+      self._peak = max(self.history)
+
+  def refresh(self):
+    """Simulates a refresh of the monitor stats (no-op for mock)."""
+    pass
+
+
+class MockModel:
+  def __init__(self, name, size, monitor):
+    self.name = name
+    self.size = size
+    self.monitor = monitor
+    self.deleted = False
+    self.monitor.allocate(size)
+
+  def mock_model_unsafe_hard_delete(self):
+    if not self.deleted:
+      self.monitor.free(self.size)
+      self.deleted = True
+
+
+class Counter(object):
+  def __init__(self, start=0):
+    self.running = start
+    self.lock = threading.Lock()
+
+  def get(self):
+    return self.running
+
+  def increment(self, value=1):
+    with self.lock:
+      self.running += value
+      return self.running
+
+
+class TestModelManager(unittest.TestCase):
+  def setUp(self):
+    """Force reset the Singleton ModelManager before every test."""
+    ModelManager._instance = None

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `ModelManager` class as implemented is not a singleton, so setting 
`ModelManager._instance = None` is confusing and seems unnecessary, as each 
test creates its own `ModelManager` instance. This line should probably be 
removed to avoid confusion about the design of `ModelManager`. This also 
applies to `TestModelManagerEviction.setUp` on line 345.



##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+  It handles:
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+    pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  _lock = threading.Lock()
+
+  def __init__(
+      self,
+      monitor: Optional['GPUMonitor'] = None,
+      slack_percentage: float = 0.10,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0,
+      min_data_points: int = 5,
+      smoothing_factor: float = 0.2,
+      eviction_cooldown_seconds: float = 10.0,
+      min_model_copies: int = 1):
+
+    self._estimator = ResourceEstimator(
+        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+    self._monitor = monitor if monitor else GPUMonitor(
+        poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+    self._slack_percentage = slack_percentage
+
+    self._eviction_cooldown = eviction_cooldown_seconds
+    self._min_model_copies = min_model_copies
+
+    # Resource State
+    self._models = defaultdict(list)
+    # Idle LRU used to track released models that
+    # can be freed or reused upon request.
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+
+    # Isolation state used to profile unknown models,
+    # ensuring they run alone to get accurate readings.
+    # isolation_baseline represents the GPU usage before
+    # loading the unknown model.
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+    # Waiting Queue and Ticketing to make sure we have fair ordering
+    # and also priority for unknown models.
+    self._wait_queue = []
+    self._ticket_counter = itertools.count()
+    # TODO: Consider making the wait to be smarter, i.e.
+    # splitting read/write etc. to avoid potential contention.
+    self._cv = threading.Condition()
+
+    self._monitor.start()
+
+  def all_models(self, tag) -> list[Any]:
+    return self._models[tag]
+
+  def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+    if self._total_active_jobs > 0:
+      logger.info(
+          "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+      self._cv.wait()
+      # return False since we have waited and need to re-evaluate
+      # in caller to make sure our priority is still valid.
+      return False
+
+    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self._delete_all_models()
+
+    self._isolation_mode = True
+    self._total_active_jobs += 1
+    self._isolation_baseline, _, _ = self._monitor.get_stats()
+    self._monitor.reset_peak()
+    return True
+
+  def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+    curr, _, total = self._monitor.get_stats()
+    est_cost = self._estimator.get_estimate(tag)
+    limit = total * (1 - self._slack_percentage)
+
+    # Use current usage for capacity check (ignore old spikes)
+    if (curr + self._pending_reservations + est_cost) <= limit:
+      self._pending_reservations += est_cost
+      self._total_active_jobs += 1
+      self._active_counts[tag] += 1
+      return True
+
+    # Evict to make space (passing tag to check demand/existence)
+    if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+      return True
+
+    # Manually log status for debugging if we are going to wait
+    idle_count = 0
+    other_idle_count = 0
+    for item in self._idle_lru.items():
+      if item[1][0] == tag:
+        idle_count += 1
+      else:
+        other_idle_count += 1
+    total_model_count = 0
+    for _, instances in self._models.items():
+      total_model_count += len(instances)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info(
+        "Waiting for resources to free up: "
+        "tag=%s ticket num%s model count=%s "
+        "idle count=%s resource usage=%.1f MB "
+        "total models count=%s other idle=%s",
+        tag,
+        ticket_num,
+        len(self._models[tag]),
+        idle_count,
+        curr,
+        total_model_count,
+        other_idle_count)
+    self._cv.wait(timeout=10.0)
+    return False
+
+  def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+    current_priority = 0 if self._estimator.is_unknown(tag) else 1
+    ticket_num = next(self._ticket_counter)
+    my_id = object()
+
+    with self._cv:
+      # FAST PATH: Grab from idle LRU if available
+      if not self._isolation_mode:
+        cached_instance = self._try_grab_from_lru(tag)
+        if cached_instance:
+          return cached_instance
+
+      # SLOW PATH: Enqueue and wait for turn to acquire model,
+      # with unknown models having priority and order enforced
+      # by ticket number as FIFO.
+      logger.info(
+          "Acquire Queued: tag=%s, priority=%d "
+          "total models count=%s ticket num=%s",
+          tag,
+          current_priority,
+          len(self._models[tag]),
+          ticket_num)
+      heapq.heappush(
+          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+      should_spawn = False
+      est_cost = 0.0
+      is_unknown = False
+
+      try:
+        while True:
+          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+            logger.info(
+                "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+            self._cv.wait()
+            continue
+
+          # Re-evaluate priority in case model became known during wait
+          real_is_unknown = self._estimator.is_unknown(tag)
+          real_priority = 0 if real_is_unknown else 1
+
+          # If priority changed, reinsert into queue and wait
+          if current_priority != real_priority:
+            heapq.heappop(self._wait_queue)
+            current_priority = real_priority
+            heapq.heappush(
+                self._wait_queue, (current_priority, ticket_num, my_id, tag))
+            self._cv.notify_all()
+            continue
+
+          # Try grab from LRU again in case model was released during wait
+          cached_instance = self._try_grab_from_lru(tag)
+          if cached_instance:
+            return cached_instance
+
+          is_unknown = real_is_unknown
+
+          # Path A: Isolation
+          if is_unknown:
+            if self.enter_isolation_mode(tag, ticket_num):
+              should_spawn = True
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+          # Path B: Concurrent
+          else:
+            if self._isolation_mode:
+              logger.info(
+                  "Waiting due to isolation in progress: tag=%s ticket num%s",
+                  tag,
+                  ticket_num)
+              self._cv.wait()
+              continue
+
+            if self.should_spawn_model(tag, ticket_num):
+              should_spawn = True
+              est_cost = self._estimator.get_estimate(tag)
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+      finally:
+        # Remove self from wait queue once done
+        if self._wait_queue and self._wait_queue[0][2] is my_id:
+          heapq.heappop(self._wait_queue)
+        else:
+          for i, item in enumerate(self._wait_queue):
+            if item[2] is my_id:
+              self._wait_queue.pop(i)
+              heapq.heapify(self._wait_queue)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The linear scan to find and remove an item from `_wait_queue` followed by 
`heapq.heapify()` is an O(N) operation, which can be inefficient if the queue 
becomes large. A more performant approach for removing items from a priority 
queue is to mark them as invalid and simply ignore them when they are popped 
from the heap. This avoids the linear scan. You could maintain a separate set 
of completed or cancelled request IDs to track this.



##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+  It handles:
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+    pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  _lock = threading.Lock()
+
+  def __init__(
+      self,
+      monitor: Optional['GPUMonitor'] = None,
+      slack_percentage: float = 0.10,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0,
+      min_data_points: int = 5,
+      smoothing_factor: float = 0.2,
+      eviction_cooldown_seconds: float = 10.0,
+      min_model_copies: int = 1):
+
+    self._estimator = ResourceEstimator(
+        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+    self._monitor = monitor if monitor else GPUMonitor(
+        poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+    self._slack_percentage = slack_percentage
+
+    self._eviction_cooldown = eviction_cooldown_seconds
+    self._min_model_copies = min_model_copies
+
+    # Resource State
+    self._models = defaultdict(list)
+    # Idle LRU used to track released models that
+    # can be freed or reused upon request.
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+
+    # Isolation state used to profile unknown models,
+    # ensuring they run alone to get accurate readings.
+    # isolation_baseline represents the GPU usage before
+    # loading the unknown model.
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+    # Waiting Queue and Ticketing to make sure we have fair ordering
+    # and also priority for unknown models.
+    self._wait_queue = []
+    self._ticket_counter = itertools.count()
+    # TODO: Consider making the wait to be smarter, i.e.
+    # splitting read/write etc. to avoid potential contention.
+    self._cv = threading.Condition()
+
+    self._monitor.start()
+
+  def all_models(self, tag) -> list[Any]:
+    return self._models[tag]
+
+  def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+    if self._total_active_jobs > 0:
+      logger.info(
+          "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+      self._cv.wait()
+      # return False since we have waited and need to re-evaluate
+      # in caller to make sure our priority is still valid.
+      return False
+
+    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self._delete_all_models()
+
+    self._isolation_mode = True
+    self._total_active_jobs += 1
+    self._isolation_baseline, _, _ = self._monitor.get_stats()
+    self._monitor.reset_peak()
+    return True
+
+  def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+    curr, _, total = self._monitor.get_stats()
+    est_cost = self._estimator.get_estimate(tag)
+    limit = total * (1 - self._slack_percentage)
+
+    # Use current usage for capacity check (ignore old spikes)
+    if (curr + self._pending_reservations + est_cost) <= limit:
+      self._pending_reservations += est_cost
+      self._total_active_jobs += 1
+      self._active_counts[tag] += 1
+      return True
+
+    # Evict to make space (passing tag to check demand/existence)
+    if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+      return True
+
+    # Manually log status for debugging if we are going to wait
+    idle_count = 0
+    other_idle_count = 0
+    for item in self._idle_lru.items():
+      if item[1][0] == tag:
+        idle_count += 1
+      else:
+        other_idle_count += 1
+    total_model_count = 0
+    for _, instances in self._models.items():
+      total_model_count += len(instances)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info(
+        "Waiting for resources to free up: "
+        "tag=%s ticket num%s model count=%s "
+        "idle count=%s resource usage=%.1f MB "
+        "total models count=%s other idle=%s",
+        tag,
+        ticket_num,
+        len(self._models[tag]),
+        idle_count,
+        curr,
+        total_model_count,
+        other_idle_count)
+    self._cv.wait(timeout=10.0)
+    return False
+
+  def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+    current_priority = 0 if self._estimator.is_unknown(tag) else 1
+    ticket_num = next(self._ticket_counter)
+    my_id = object()
+
+    with self._cv:
+      # FAST PATH: Grab from idle LRU if available
+      if not self._isolation_mode:
+        cached_instance = self._try_grab_from_lru(tag)
+        if cached_instance:
+          return cached_instance
+
+      # SLOW PATH: Enqueue and wait for turn to acquire model,
+      # with unknown models having priority and order enforced
+      # by ticket number as FIFO.
+      logger.info(
+          "Acquire Queued: tag=%s, priority=%d "
+          "total models count=%s ticket num=%s",
+          tag,
+          current_priority,
+          len(self._models[tag]),
+          ticket_num)
+      heapq.heappush(
+          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+      should_spawn = False
+      est_cost = 0.0
+      is_unknown = False
+
+      try:
+        while True:
+          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+            logger.info(
+                "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+            self._cv.wait()
+            continue
+
+          # Re-evaluate priority in case model became known during wait
+          real_is_unknown = self._estimator.is_unknown(tag)
+          real_priority = 0 if real_is_unknown else 1
+
+          # If priority changed, reinsert into queue and wait
+          if current_priority != real_priority:
+            heapq.heappop(self._wait_queue)
+            current_priority = real_priority
+            heapq.heappush(
+                self._wait_queue, (current_priority, ticket_num, my_id, tag))
+            self._cv.notify_all()
+            continue
+
+          # Try grab from LRU again in case model was released during wait
+          cached_instance = self._try_grab_from_lru(tag)
+          if cached_instance:
+            return cached_instance
+
+          is_unknown = real_is_unknown
+
+          # Path A: Isolation
+          if is_unknown:
+            if self.enter_isolation_mode(tag, ticket_num):
+              should_spawn = True
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+          # Path B: Concurrent
+          else:
+            if self._isolation_mode:
+              logger.info(
+                  "Waiting due to isolation in progress: tag=%s ticket num%s",
+                  tag,
+                  ticket_num)
+              self._cv.wait()
+              continue
+
+            if self.should_spawn_model(tag, ticket_num):
+              should_spawn = True
+              est_cost = self._estimator.get_estimate(tag)
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+      finally:
+        # Remove self from wait queue once done
+        if self._wait_queue and self._wait_queue[0][2] is my_id:
+          heapq.heappop(self._wait_queue)
+        else:
+          for i, item in enumerate(self._wait_queue):
+            if item[2] is my_id:
+              self._wait_queue.pop(i)
+              heapq.heapify(self._wait_queue)
+        self._cv.notify_all()
+
+      if should_spawn:
+        return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
+
+  def release_model(self, tag: str, instance: Any):
+    with self._cv:
+      try:
+        self._total_active_jobs -= 1
+        if self._active_counts[tag] > 0:
+          self._active_counts[tag] -= 1
+
+        self._idle_lru[id(instance)] = (tag, instance, time.time())
+
+        # Update estimator with latest stats
+        _, peak_during_job, _ = self._monitor.get_stats()
+
+        if self._isolation_mode and self._active_counts[tag] == 0:
+          # For isolation mode, we directly set the initial estimate
+          # so that we can quickly learn the model cost.
+          cost = max(0, peak_during_job - self._isolation_baseline)
+          self._estimator.set_initial_estimate(tag, cost)
+          self._isolation_mode = False
+          self._isolation_baseline = 0.0
+        else:
+          # Regular update for known models
+          snapshot = {
+              t: len(instances)
+              for t, instances in self._models.items() if len(instances) > 0
+          }
+          if snapshot:
+            self._estimator.add_observation(snapshot, peak_during_job)
+
+      finally:
+        self._cv.notify_all()
+
+  def _try_grab_from_lru(self, tag: str) -> Any:
+    target_key = None
+    target_instance = None
+
+    for key, (t, instance, _) in reversed(self._idle_lru.items()):
+      if t == tag:
+        target_key = key
+        target_instance = instance
+        break
+
+    if target_instance:
+      # Found an idle model, remove from LRU and return
+      del self._idle_lru[target_key]
+      self._active_counts[tag] += 1
+      self._total_active_jobs += 1
+      return target_instance
+
+    logger.info("No idle model found for tag: %s", tag)
+    return None
+
+  def _evict_to_make_space(
+      self, limit: float, est_cost: float, requesting_tag: str) -> bool:
+    """
+    Evicts models based on Demand Magnitude + Tiers.
+    Crucially: If we have 0 active copies of 'requesting_tag', we FORCE 
eviction
+    of the lowest-demand candidate to avoid starvation.
+    Returns True if space was made, False otherwise.
+    """
+    curr, _, _ = self._monitor.get_stats()
+    projected_usage = curr + self._pending_reservations + est_cost
+
+    if projected_usage <= limit:
+      # Memory usage changed and we are already under limit
+      return True
+
+    now = time.time()
+
+    # Calculate the demand from the wait queue
+    demand_map = Counter()
+    for item in self._wait_queue:
+      demand_map[item[3]] += 1
+
+    my_demand = demand_map[requesting_tag]
+    am_i_starving = len(self._models[requesting_tag]) == 0
+
+    candidates = []
+    for key, (tag, instance, release_time) in self._idle_lru.items():
+      candidate_demand = demand_map[tag]
+
+      # TODO: Try to avoid churn if demand is similar
+      if not am_i_starving and candidate_demand >= my_demand:
+        continue
+
+      # Attempts to score candidates based on hotness and manually
+      # specified minimum copies. Demand is weighted heavily to
+      # ensure we evict low-demand models first.
+      age = now - release_time
+      is_cold = age >= self._eviction_cooldown
+
+      total_copies = len(self._models[tag])
+      is_surplus = total_copies > self._min_model_copies
+
+      if is_cold and is_surplus: tier = 0
+      elif not is_cold and is_surplus: tier = 1
+      elif is_cold and not is_surplus: tier = 2
+      else: tier = 3
+
+      score = (candidate_demand * 10) + tier
+
+      candidates.append((score, release_time, key, tag, instance))
+
+    candidates.sort(key=lambda x: (x[0], x[1]))
+
+    # Evict candidates until we are under limit
+    for score, _, key, tag, instance in candidates:
+      if projected_usage <= limit:
+        break
+
+      if key not in self._idle_lru: continue
+
+      self._perform_eviction(key, tag, instance, score)
+
+      curr, _, _ = self._monitor.get_stats()
+      projected_usage = curr + self._pending_reservations + est_cost
+
+    return projected_usage <= limit
+
+  def _delete_instance(self, instance: Any):
+    if isinstance(instance, str):
+      # If the instance is a string, it's a uuid used
+      # to retrieve the model from MultiProcessShared
+      multi_process_shared.MultiProcessShared(
+          lambda: "N/A", tag=instance).unsafe_hard_delete()
+    if hasattr(instance, 'mock_model_unsafe_hard_delete'):
+      # Call the mock unsafe hard delete method for testing
+      instance.mock_model_unsafe_hard_delete()
+    del instance
+
+  def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
+    logger.info("Evicting Model: %s (Score %d)", tag, score)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info("Resource Usage Before Eviction: %.1f MB", curr)
+
+    if key in self._idle_lru:
+      del self._idle_lru[key]
+
+    for i, inst in enumerate(self._models[tag]):
+      if instance == inst:
+        del self._models[tag][i]
+        break
+
+    self._delete_instance(instance)
+    gc.collect()
+    torch.cuda.empty_cache()
+    self._monitor.refresh()
+    self._monitor.reset_peak()
+    curr, _, _ = self._monitor.get_stats()
+    logger.info("Resource Usage After Eviction: %.1f MB", curr)
+
+  def _spawn_new_model(
+      self,
+      tag: str,
+      loader_func: Callable[[], Any],
+      is_unknown: bool,
+      est_cost: float) -> Any:
+    try:
+      with self._cv:
+        logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
+        baseline_snap, _, _ = self._monitor.get_stats()
+        instance = loader_func()
+        _, peak_during_load, _ = self._monitor.get_stats()
+
+        snapshot = {tag: 1}
+        self._estimator.add_observation(
+            snapshot, peak_during_load - baseline_snap)
+
+        if not is_unknown:
+          self._pending_reservations = max(
+              0.0, self._pending_reservations - est_cost)
+        self._models[tag].append(instance)
+      return instance
+
+    except Exception as e:
+      logger.error("Load Failed: %s. Error: %s", tag, e)
+      with self._cv:
+        self._total_active_jobs -= 1
+        if is_unknown:
+          self._isolation_mode = False
+          self._isolation_baseline = 0.0
+        else:
+          self._pending_reservations = max(
+              0.0, self._pending_reservations - est_cost)
+          self._active_counts[tag] -= 1
+        self._cv.notify_all()
+      raise e
+
+  def _delete_all_models(self):
+    self._idle_lru.clear()
+    for _, instances in self._models.items():
+      for instance in instances:
+        self._delete_instance(instance)
+    self._models.clear()
+    self._active_counts.clear()
+    gc.collect()
+    torch.cuda.empty_cache()
+    self._monitor.refresh()
+    self._monitor.reset_peak()
+
+  def _force_reset(self):
+    logger.warning("Force Reset Triggered")
+    self._delete_all_models()
+    self._models = defaultdict(list)
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._wait_queue = []
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+  def shutdown(self):
+    self._delete_all_models()
+    gc.collect()
+    torch.cuda.empty_cache()

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `_delete_all_models()` method already calls `gc.collect()` and 
`torch.cuda.empty_cache()`. Calling them again here is redundant and can be 
removed.



##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+  It handles:
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+    pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  _lock = threading.Lock()
+
+  def __init__(
+      self,
+      monitor: Optional['GPUMonitor'] = None,
+      slack_percentage: float = 0.10,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0,
+      min_data_points: int = 5,
+      smoothing_factor: float = 0.2,
+      eviction_cooldown_seconds: float = 10.0,
+      min_model_copies: int = 1):
+
+    self._estimator = ResourceEstimator(
+        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+    self._monitor = monitor if monitor else GPUMonitor(
+        poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+    self._slack_percentage = slack_percentage
+
+    self._eviction_cooldown = eviction_cooldown_seconds
+    self._min_model_copies = min_model_copies
+
+    # Resource State
+    self._models = defaultdict(list)
+    # Idle LRU used to track released models that
+    # can be freed or reused upon request.
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+
+    # Isolation state used to profile unknown models,
+    # ensuring they run alone to get accurate readings.
+    # isolation_baseline represents the GPU usage before
+    # loading the unknown model.
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+    # Waiting Queue and Ticketing to make sure we have fair ordering
+    # and also priority for unknown models.
+    self._wait_queue = []
+    self._ticket_counter = itertools.count()
+    # TODO: Consider making the wait to be smarter, i.e.
+    # splitting read/write etc. to avoid potential contention.
+    self._cv = threading.Condition()
+
+    self._monitor.start()
+
+  def all_models(self, tag) -> list[Any]:
+    return self._models[tag]
+
+  def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+    if self._total_active_jobs > 0:
+      logger.info(
+          "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+      self._cv.wait()
+      # return False since we have waited and need to re-evaluate
+      # in caller to make sure our priority is still valid.
+      return False
+
+    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self._delete_all_models()
+
+    self._isolation_mode = True
+    self._total_active_jobs += 1
+    self._isolation_baseline, _, _ = self._monitor.get_stats()
+    self._monitor.reset_peak()
+    return True
+
+  def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+    curr, _, total = self._monitor.get_stats()
+    est_cost = self._estimator.get_estimate(tag)
+    limit = total * (1 - self._slack_percentage)
+
+    # Use current usage for capacity check (ignore old spikes)
+    if (curr + self._pending_reservations + est_cost) <= limit:
+      self._pending_reservations += est_cost
+      self._total_active_jobs += 1
+      self._active_counts[tag] += 1
+      return True
+
+    # Evict to make space (passing tag to check demand/existence)
+    if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+      return True
+
+    # Manually log status for debugging if we are going to wait
+    idle_count = 0
+    other_idle_count = 0
+    for item in self._idle_lru.items():
+      if item[1][0] == tag:
+        idle_count += 1
+      else:
+        other_idle_count += 1
+    total_model_count = 0
+    for _, instances in self._models.items():
+      total_model_count += len(instances)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info(
+        "Waiting for resources to free up: "
+        "tag=%s ticket num%s model count=%s "
+        "idle count=%s resource usage=%.1f MB "
+        "total models count=%s other idle=%s",
+        tag,
+        ticket_num,
+        len(self._models[tag]),
+        idle_count,
+        curr,
+        total_model_count,
+        other_idle_count)
+    self._cv.wait(timeout=10.0)
+    return False
+
+  def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+    current_priority = 0 if self._estimator.is_unknown(tag) else 1
+    ticket_num = next(self._ticket_counter)
+    my_id = object()
+
+    with self._cv:
+      # FAST PATH: Grab from idle LRU if available
+      if not self._isolation_mode:
+        cached_instance = self._try_grab_from_lru(tag)
+        if cached_instance:
+          return cached_instance
+
+      # SLOW PATH: Enqueue and wait for turn to acquire model,
+      # with unknown models having priority and order enforced
+      # by ticket number as FIFO.
+      logger.info(
+          "Acquire Queued: tag=%s, priority=%d "
+          "total models count=%s ticket num=%s",
+          tag,
+          current_priority,
+          len(self._models[tag]),
+          ticket_num)
+      heapq.heappush(
+          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+      should_spawn = False
+      est_cost = 0.0
+      is_unknown = False
+
+      try:
+        while True:
+          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+            logger.info(
+                "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+            self._cv.wait()
+            continue
+
+          # Re-evaluate priority in case model became known during wait
+          real_is_unknown = self._estimator.is_unknown(tag)
+          real_priority = 0 if real_is_unknown else 1
+
+          # If priority changed, reinsert into queue and wait
+          if current_priority != real_priority:
+            heapq.heappop(self._wait_queue)
+            current_priority = real_priority
+            heapq.heappush(
+                self._wait_queue, (current_priority, ticket_num, my_id, tag))
+            self._cv.notify_all()
+            continue
+
+          # Try grab from LRU again in case model was released during wait
+          cached_instance = self._try_grab_from_lru(tag)
+          if cached_instance:
+            return cached_instance
+
+          is_unknown = real_is_unknown
+
+          # Path A: Isolation
+          if is_unknown:
+            if self.enter_isolation_mode(tag, ticket_num):
+              should_spawn = True
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+          # Path B: Concurrent
+          else:
+            if self._isolation_mode:
+              logger.info(
+                  "Waiting due to isolation in progress: tag=%s ticket num%s",
+                  tag,
+                  ticket_num)
+              self._cv.wait()
+              continue
+
+            if self.should_spawn_model(tag, ticket_num):
+              should_spawn = True
+              est_cost = self._estimator.get_estimate(tag)
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+      finally:
+        # Remove self from wait queue once done
+        if self._wait_queue and self._wait_queue[0][2] is my_id:
+          heapq.heappop(self._wait_queue)
+        else:
+          for i, item in enumerate(self._wait_queue):
+            if item[2] is my_id:
+              self._wait_queue.pop(i)
+              heapq.heapify(self._wait_queue)
+        self._cv.notify_all()
+
+      if should_spawn:
+        return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)
+
+  def release_model(self, tag: str, instance: Any):
+    with self._cv:
+      try:
+        self._total_active_jobs -= 1
+        if self._active_counts[tag] > 0:
+          self._active_counts[tag] -= 1
+
+        self._idle_lru[id(instance)] = (tag, instance, time.time())
+
+        # Update estimator with latest stats
+        _, peak_during_job, _ = self._monitor.get_stats()
+
+        if self._isolation_mode and self._active_counts[tag] == 0:
+          # For isolation mode, we directly set the initial estimate
+          # so that we can quickly learn the model cost.
+          cost = max(0, peak_during_job - self._isolation_baseline)
+          self._estimator.set_initial_estimate(tag, cost)
+          self._isolation_mode = False
+          self._isolation_baseline = 0.0
+        else:
+          # Regular update for known models
+          snapshot = {
+              t: len(instances)
+              for t, instances in self._models.items() if len(instances) > 0
+          }
+          if snapshot:
+            self._estimator.add_observation(snapshot, peak_during_job)
+
+      finally:
+        self._cv.notify_all()
+
+  def _try_grab_from_lru(self, tag: str) -> Any:
+    target_key = None
+    target_instance = None
+
+    for key, (t, instance, _) in reversed(self._idle_lru.items()):
+      if t == tag:
+        target_key = key
+        target_instance = instance
+        break
+
+    if target_instance:
+      # Found an idle model, remove from LRU and return
+      del self._idle_lru[target_key]
+      self._active_counts[tag] += 1
+      self._total_active_jobs += 1
+      return target_instance
+
+    logger.info("No idle model found for tag: %s", tag)
+    return None
+
+  def _evict_to_make_space(
+      self, limit: float, est_cost: float, requesting_tag: str) -> bool:
+    """
+    Evicts models based on Demand Magnitude + Tiers.
+    Crucially: If we have 0 active copies of 'requesting_tag', we FORCE 
eviction
+    of the lowest-demand candidate to avoid starvation.
+    Returns True if space was made, False otherwise.
+    """
+    curr, _, _ = self._monitor.get_stats()
+    projected_usage = curr + self._pending_reservations + est_cost
+
+    if projected_usage <= limit:
+      # Memory usage changed and we are already under limit
+      return True
+
+    now = time.time()
+
+    # Calculate the demand from the wait queue
+    demand_map = Counter()
+    for item in self._wait_queue:
+      demand_map[item[3]] += 1
+
+    my_demand = demand_map[requesting_tag]
+    am_i_starving = len(self._models[requesting_tag]) == 0
+
+    candidates = []
+    for key, (tag, instance, release_time) in self._idle_lru.items():
+      candidate_demand = demand_map[tag]
+
+      # TODO: Try to avoid churn if demand is similar

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `TODO` here is valid. To avoid churn when demand is similar, you could 
introduce a demand threshold or a hysteresis mechanism. For example, only evict 
if `my_demand > candidate_demand + demand_margin`. This would prevent evicting 
a model with slightly lower demand just to make space for another, which might 
be evicted soon after.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to