AMOOOMA commented on code in PR #37113:
URL: https://github.com/apache/beam/pull/37113#discussion_r2747629755


##########
sdks/python/apache_beam/ml/inference/model_manager.py:
##########
@@ -0,0 +1,730 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for managing ML models in Apache Beam pipelines.
+
+This module provides classes and functions to efficiently manage multiple
+machine learning models within Apache Beam pipelines. It includes functionality
+for loading, caching, and updating models using multi-process shared memory,
+ensuring that models are reused across different workers to optimize resource
+usage and performance.
+"""
+
+import gc
+import heapq
+import itertools
+import logging
+import subprocess
+import threading
+import time
+from collections import Counter
+from collections import OrderedDict
+from collections import defaultdict
+from collections import deque
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from scipy.optimize import nnls
+
+from apache_beam.utils import multi_process_shared
+
+logger = logging.getLogger(__name__)
+
+
+class GPUMonitor:
+  """Monitors GPU memory usage in a separate thread using nvidia-smi.
+
+  This class continuously polls GPU memory statistics to track current usage
+  and peak usage over a sliding time window. It serves as the source of truth
+  for the ModelManager's resource decisions.
+
+  Attributes:
+    fallback_memory_mb: Default total memory if hardware detection fails.
+    poll_interval: Seconds between memory checks.
+    peak_window_seconds: Duration to track peak memory usage.
+  """
+  def __init__(
+      self,
+      fallback_memory_mb: float = 16000.0,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0):
+    self._current_usage = 0.0
+    self._peak_usage = 0.0
+    self._total_memory = fallback_memory_mb
+    self._poll_interval = poll_interval
+    self._peak_window_seconds = peak_window_seconds
+    self._memory_history = deque()
+    self._running = False
+    self._thread = None
+    self._lock = threading.Lock()
+
+  def _detect_hardware(self):
+    try:
+      cmd = [
+          "nvidia-smi",
+          "--query-gpu=memory.total",
+          "--format=csv,noheader,nounits"
+      ]
+      output = subprocess.check_output(cmd, text=True).strip()
+      self._total_memory = float(output)
+      return True
+    except (FileNotFoundError, subprocess.CalledProcessError):
+      logger.warning(
+          "nvidia-smi not found or failed. Defaulting total memory to %s MB",
+          self._total_memory)
+      return False
+    except Exception as e:
+      logger.warning(
+          "Error parsing nvidia-smi output: %s. "
+          "Defaulting total memory to %s MB",
+          e,
+          self._total_memory)
+      return False
+
+  def start(self):
+    self._gpu_available = self._detect_hardware()
+    if self._running or not self._gpu_available:
+      return
+    self._running = True
+    self._thread = threading.Thread(target=self._poll_loop, daemon=True)
+    self._thread.start()
+
+  def stop(self):
+    self._running = False
+    if self._thread:
+      self._thread.join()
+
+  def reset_peak(self):
+    with self._lock:
+      now = time.time()
+      self._memory_history.clear()
+      self._memory_history.append((now, self._current_usage))
+      self._peak_usage = self._current_usage
+
+  def get_stats(self) -> Tuple[float, float, float]:
+    with self._lock:
+      return self._current_usage, self._peak_usage, self._total_memory
+
+  def refresh(self):
+    """Forces an immediate poll of the GPU."""
+    usage = self._get_nvidia_smi_used()
+    now = time.time()
+    with self._lock:
+      self._current_usage = usage
+      self._memory_history.append((now, usage))
+      # Recalculate peak immediately
+      while self._memory_history and (now - self._memory_history[0][0]
+                                      > self._peak_window_seconds):
+        self._memory_history.popleft()
+      self._peak_usage = (
+          max(m for _, m in self._memory_history)
+          if self._memory_history else usage)
+
+  def _get_nvidia_smi_used(self) -> float:
+    try:
+      cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
+      output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip()
+      free_memory = float(output)
+      return self._total_memory - free_memory
+    except Exception:
+      return 0.0
+
+  def _poll_loop(self):
+    while self._running:
+      usage = self._get_nvidia_smi_used()
+      now = time.time()
+      with self._lock:
+        self._current_usage = usage
+        self._memory_history.append((now, usage))
+        while self._memory_history and (now - self._memory_history[0][0]
+                                        > self._peak_window_seconds):
+          self._memory_history.popleft()
+        self._peak_usage = (
+            max(m for _, m in self._memory_history)
+            if self._memory_history else usage)
+      time.sleep(self._poll_interval)
+
+
+class ResourceEstimator:
+  """Estimates individual model memory usage using statistical observation.
+
+  Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of
+  individual models based on aggregate system memory readings and the
+  configuration of active models at that time.
+  """
+  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+    self.smoothing_factor = smoothing_factor
+    self.min_data_points = min_data_points
+    self.estimates: Dict[str, float] = {}
+    self.history = defaultdict(lambda: deque(maxlen=20))
+    self.known_models = set()
+    self._lock = threading.Lock()
+
+  def is_unknown(self, model_tag: str) -> bool:
+    with self._lock:
+      return model_tag not in self.estimates
+
+  def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float:
+    with self._lock:
+      return self.estimates.get(model_tag, default_mb)
+
+  def set_initial_estimate(self, model_tag: str, cost: float):
+    with self._lock:
+      self.estimates[model_tag] = cost
+      self.known_models.add(model_tag)
+      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+
+  def add_observation(
+      self, active_snapshot: Dict[str, int], peak_memory: float):
+    if active_snapshot:
+      model_list = "\n".join(
+          f"\t- {model}: {count}"
+          for model, count in sorted(active_snapshot.items()))
+    else:
+      model_list = "\t- None"
+
+    logger.info(
+        "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
+        peak_memory,
+        model_list)
+    if not active_snapshot:
+      return
+    with self._lock:
+      config_key = tuple(sorted(active_snapshot.items()))
+      self.history[config_key].append(peak_memory)
+      for tag in active_snapshot:
+        self.known_models.add(tag)
+      self._solve()
+
+  def _solve(self):
+    """
+    Solves Ax=b using raw readings (no pre-averaging) and NNLS.
+    This creates a 'tall' matrix A where every memory reading is
+    a separate equation.
+    """
+    unique = sorted(list(self.known_models))
+
+    # We need to build the matrix first to know if we have enough data points
+    A, b = [], []
+
+    for config_key, mem_values in self.history.items():
+      if not mem_values:
+        continue
+
+      # 1. Create the feature row for this configuration ONCE
+      # (It represents the model counts + bias)
+      counts = dict(config_key)
+      feature_row = [counts.get(model, 0) for model in unique]
+      feature_row.append(1)  # Bias column
+
+      # 2. Add a separate row to the matrix for EVERY individual reading
+      # Instead of averaging, we flatten the history into the matrix
+      for reading in mem_values:
+        A.append(feature_row)  # The inputs (models) stay the same
+        b.append(reading)  # The output (memory) varies due to noise
+
+    # Convert to numpy for SciPy
+    A = np.array(A)
+    b = np.array(b)
+
+    if len(
+        self.history.keys()) < len(unique) + 1 or len(A) < 
self.min_data_points:
+      # Not enough data to solve yet
+      return
+
+    logger.info(
+        "Solving with %s total observations for %s models.",
+        len(A),
+        len(unique))
+
+    try:
+      # Solve using Non-Negative Least Squares
+      # x will be >= 0
+      x, _ = nnls(A, b)
+
+      weights = x[:-1]
+      bias = x[-1]
+
+      for i, model in enumerate(unique):
+        calculated_cost = weights[i]
+
+        if model in self.estimates:
+          old = self.estimates[model]
+          new = (old * (1 - self.smoothing_factor)) + (
+              calculated_cost * self.smoothing_factor)
+          self.estimates[model] = new
+        else:
+          self.estimates[model] = calculated_cost
+
+        logger.info(
+            "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
+      logger.info("System Bias: %s MB", bias)
+
+    except Exception as e:
+      logger.error("Solver failed: %s", e)
+
+
+class ModelManager:
+  """Manages model lifecycles, caching, and resource arbitration.
+
+  This class acts as the central controller for acquiring model instances.
+  It handles:
+  1. LRU Caching of idle models.
+  2. Resource estimation and admission control (preventing OOM).
+  3. Dynamic eviction of low-priority models, determined by count of 
+    pending requests, when space is needed.
+  4. 'Isolation Mode' for safely profiling unknown models.
+  """
+  _lock = threading.Lock()
+
+  def __init__(
+      self,
+      monitor: Optional['GPUMonitor'] = None,
+      slack_percentage: float = 0.10,
+      poll_interval: float = 0.5,
+      peak_window_seconds: float = 30.0,
+      min_data_points: int = 5,
+      smoothing_factor: float = 0.2,
+      eviction_cooldown_seconds: float = 10.0,
+      min_model_copies: int = 1):
+
+    self._estimator = ResourceEstimator(
+        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+    self._monitor = monitor if monitor else GPUMonitor(
+        poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
+    self._slack_percentage = slack_percentage
+
+    self._eviction_cooldown = eviction_cooldown_seconds
+    self._min_model_copies = min_model_copies
+
+    # Resource State
+    self._models = defaultdict(list)
+    # Idle LRU used to track released models that
+    # can be freed or reused upon request.
+    self._idle_lru = OrderedDict()
+    self._active_counts = Counter()
+    self._total_active_jobs = 0
+    self._pending_reservations = 0.0
+
+    # Isolation state used to profile unknown models,
+    # ensuring they run alone to get accurate readings.
+    # isolation_baseline represents the GPU usage before
+    # loading the unknown model.
+    self._isolation_mode = False
+    self._isolation_baseline = 0.0
+
+    # Waiting Queue and Ticketing to make sure we have fair ordering
+    # and also priority for unknown models.
+    self._wait_queue = []
+    self._ticket_counter = itertools.count()
+    # TODO: Consider making the wait to be smarter, i.e.
+    # splitting read/write etc. to avoid potential contention.
+    self._cv = threading.Condition()
+
+    self._monitor.start()
+
+  def all_models(self, tag) -> list[Any]:
+    return self._models[tag]
+
+  def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
+    if self._total_active_jobs > 0:
+      logger.info(
+          "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
+      self._cv.wait()
+      # return False since we have waited and need to re-evaluate
+      # in caller to make sure our priority is still valid.
+      return False
+
+    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self._delete_all_models()
+
+    self._isolation_mode = True
+    self._total_active_jobs += 1
+    self._isolation_baseline, _, _ = self._monitor.get_stats()
+    self._monitor.reset_peak()
+    return True
+
+  def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
+    curr, _, total = self._monitor.get_stats()
+    est_cost = self._estimator.get_estimate(tag)
+    limit = total * (1 - self._slack_percentage)
+
+    # Use current usage for capacity check (ignore old spikes)
+    if (curr + self._pending_reservations + est_cost) <= limit:
+      self._pending_reservations += est_cost
+      self._total_active_jobs += 1
+      self._active_counts[tag] += 1
+      return True
+
+    # Evict to make space (passing tag to check demand/existence)
+    if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
+      return True
+
+    # Manually log status for debugging if we are going to wait
+    idle_count = 0
+    other_idle_count = 0
+    for item in self._idle_lru.items():
+      if item[1][0] == tag:
+        idle_count += 1
+      else:
+        other_idle_count += 1
+    total_model_count = 0
+    for _, instances in self._models.items():
+      total_model_count += len(instances)
+    curr, _, _ = self._monitor.get_stats()
+    logger.info(
+        "Waiting for resources to free up: "
+        "tag=%s ticket num%s model count=%s "
+        "idle count=%s resource usage=%.1f MB "
+        "total models count=%s other idle=%s",
+        tag,
+        ticket_num,
+        len(self._models[tag]),
+        idle_count,
+        curr,
+        total_model_count,
+        other_idle_count)
+    self._cv.wait(timeout=10.0)
+    return False
+
+  def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
+    current_priority = 0 if self._estimator.is_unknown(tag) else 1
+    ticket_num = next(self._ticket_counter)
+    my_id = object()
+
+    with self._cv:
+      # FAST PATH: Grab from idle LRU if available
+      if not self._isolation_mode:
+        cached_instance = self._try_grab_from_lru(tag)
+        if cached_instance:
+          return cached_instance
+
+      # SLOW PATH: Enqueue and wait for turn to acquire model,
+      # with unknown models having priority and order enforced
+      # by ticket number as FIFO.
+      logger.info(
+          "Acquire Queued: tag=%s, priority=%d "
+          "total models count=%s ticket num=%s",
+          tag,
+          current_priority,
+          len(self._models[tag]),
+          ticket_num)
+      heapq.heappush(
+          self._wait_queue, (current_priority, ticket_num, my_id, tag))
+
+      should_spawn = False
+      est_cost = 0.0
+      is_unknown = False
+
+      try:
+        while True:
+          if not self._wait_queue or self._wait_queue[0][2] is not my_id:
+            logger.info(
+                "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
+            self._cv.wait()
+            continue
+
+          # Re-evaluate priority in case model became known during wait
+          real_is_unknown = self._estimator.is_unknown(tag)
+          real_priority = 0 if real_is_unknown else 1
+
+          # If priority changed, reinsert into queue and wait
+          if current_priority != real_priority:
+            heapq.heappop(self._wait_queue)
+            current_priority = real_priority
+            heapq.heappush(
+                self._wait_queue, (current_priority, ticket_num, my_id, tag))
+            self._cv.notify_all()
+            continue
+
+          # Try grab from LRU again in case model was released during wait
+          cached_instance = self._try_grab_from_lru(tag)
+          if cached_instance:
+            return cached_instance
+
+          is_unknown = real_is_unknown
+
+          # Path A: Isolation
+          if is_unknown:
+            if self.enter_isolation_mode(tag, ticket_num):
+              should_spawn = True
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+          # Path B: Concurrent
+          else:
+            if self._isolation_mode:
+              logger.info(
+                  "Waiting due to isolation in progress: tag=%s ticket num%s",
+                  tag,
+                  ticket_num)
+              self._cv.wait()
+              continue
+
+            if self.should_spawn_model(tag, ticket_num):
+              should_spawn = True
+              est_cost = self._estimator.get_estimate(tag)
+              break
+            else:
+              # We waited, need to re-evaluate our turn
+              # because priority may have changed during the wait
+              continue
+
+      finally:
+        # Remove self from wait queue once done
+        if self._wait_queue and self._wait_queue[0][2] is my_id:
+          heapq.heappop(self._wait_queue)
+        else:
+          for i, item in enumerate(self._wait_queue):
+            if item[2] is my_id:
+              self._wait_queue.pop(i)
+              heapq.heapify(self._wait_queue)

Review Comment:
   This only happens when things are not working so leaving it as is.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to