This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 22d1c576272 Add OOM handling for RunInference with model manager 
(#37557)
22d1c576272 is described below

commit 22d1c576272af62d691ca3124c99b9c563fbbb8c
Author: RuiLong J. <[email protected]>
AuthorDate: Thu Feb 12 06:51:11 2026 -0800

    Add OOM handling for RunInference with model manager (#37557)
    
    * Add OOM protection handling for RunInference
    
    * Make sure we release the model regardless
    
    * Add testing coverage
    
    * Lint and make logging optional
    
    * Pass verbose logging setting from model manager to estimator
    
    * Lint
    
    * Prevent flakes on EOFError
    
    * Enforce batch size
---
 sdks/python/apache_beam/ml/inference/base.py       | 39 ++++++++++++--
 sdks/python/apache_beam/ml/inference/base_test.py  | 26 +++++++++
 .../apache_beam/ml/inference/model_manager.py      | 63 ++++++++++++++--------
 3 files changed, 102 insertions(+), 26 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 1c3f0918baf..ef5d15264b5 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -1330,6 +1330,30 @@ class _PostProcessingModelHandler(Generic[ExampleT,
     return self._base.get_postprocess_fns() + [self._postprocess_fn]
 
 
+class OOMProtectedFn:
+  def __init__(self, func):
+    self.func = func
+
+  def __call__(self, *args, **kwargs):
+    try:
+      return self.func(*args, **kwargs)
+    except Exception as e:
+      # Check string to avoid hard import dependency
+      if 'out of memory' in str(e) and 'CUDA' in str(e):
+        logging.warning("Caught CUDA OOM during operation. Cleaning memory.")
+        try:
+          import gc
+
+          import torch
+          gc.collect()
+          torch.cuda.empty_cache()
+        except ImportError:
+          pass
+        except Exception as cleanup_error:
+          logging.error("Failed to clean up CUDA memory: %s", cleanup_error)
+      raise e
+
+
 class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
                                                           Iterable[ExampleT]]],
                                    beam.PCollection[PredictionT]]):
@@ -1831,7 +1855,9 @@ class _ProxyLoader:
     unique_tag = self.model_tag + '_' + uuid.uuid4().hex
     # Ensure that each model loaded in a different process for parallelism
     multi_process_shared.MultiProcessShared(
-        self.loader_func, tag=unique_tag, always_proxy=True,
+        OOMProtectedFn(self.loader_func),
+        tag=unique_tag,
+        always_proxy=True,
         spawn_process=True).acquire()
     # Only return the tag to avoid pickling issues with the model itself.
     return unique_tag
@@ -2021,10 +2047,13 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
         unique_tag = model
         model = multi_process_shared.MultiProcessShared(
             lambda: None, tag=model, always_proxy=True).acquire()
-      result_generator = self._model_handler.run_inference(
-          batch, model, inference_args)
-      if self.use_model_manager:
-        self._model.release_model(self._model_tag, unique_tag)
+      try:
+        result_generator = (OOMProtectedFn(self._model_handler.run_inference))(
+            batch, model, inference_args)
+      finally:
+        # Always release the model so that it can be reloaded.
+        if self.use_model_manager:
+          self._model.release_model(self._model_tag, unique_tag)
     except BaseException as e:
       if self._metrics_collector:
         self._metrics_collector.failed_batches_counter.inc()
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index feccd8b0f12..8236ac5c1e5 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -20,6 +20,7 @@ import math
 import multiprocessing
 import os
 import pickle
+import random
 import sys
 import tempfile
 import time
@@ -2338,6 +2339,31 @@ class ModelManagerTest(unittest.TestCase):
           })
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  @unittest.skipIf(
+      not try_import_model_manager(), 'Model Manager not available')
+  def test_run_inference_impl_with_model_manager_oom(self):
+    class OOMFakeModelHandler(SimpleFakeModelHandler):
+      def run_inference(
+          self,
+          batch: Sequence[int],
+          model: FakeModel,
+          inference_args=None) -> Iterable[int]:
+        if random.random() < 0.8:
+          raise MemoryError("Simulated OOM")
+        for example in batch:
+          yield model.predict(example)
+
+      def batch_elements_kwargs(self):
+        return {'min_batch_size': 1, 'max_batch_size': 1}
+
+    with self.assertRaises(Exception):
+      with TestPipeline() as pipeline:
+        examples = [1, 5, 3, 10]
+        pcoll = pipeline | 'start' >> beam.Create(examples)
+        actual = pcoll | base.RunInference(
+            OOMFakeModelHandler(), use_model_manager=True)
+        assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py 
b/sdks/python/apache_beam/ml/inference/model_manager.py
index bf7c6a43ba6..186611984df 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -176,14 +176,23 @@ class ResourceEstimator:
   individual models based on aggregate system memory readings and the
   configuration of active models at that time.
   """
-  def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+  def __init__(
+      self,
+      smoothing_factor: float = 0.2,
+      min_data_points: int = 5,
+      verbose_logging: bool = False):
     self.smoothing_factor = smoothing_factor
     self.min_data_points = min_data_points
+    self.verbose_logging = verbose_logging
     self.estimates: Dict[str, float] = {}
     self.history = defaultdict(lambda: deque(maxlen=20))
     self.known_models = set()
     self._lock = threading.Lock()
 
+  def logging_info(self, message: str, *args):
+    if self.verbose_logging:
+      logger.info(message, *args)
+
   def is_unknown(self, model_tag: str) -> bool:
     with self._lock:
       return model_tag not in self.estimates
@@ -196,7 +205,7 @@ class ResourceEstimator:
     with self._lock:
       self.estimates[model_tag] = cost
       self.known_models.add(model_tag)
-      logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+      self.logging_info("Initial Profile for %s: %s MB", model_tag, cost)
 
   def add_observation(
       self, active_snapshot: Dict[str, int], peak_memory: float):
@@ -207,7 +216,7 @@ class ResourceEstimator:
     else:
       model_list = "\t- None"
 
-    logger.info(
+    self.logging_info(
         "Adding Observation:\n PeakMemory: %.1f MB\n  Instances:\n%s",
         peak_memory,
         model_list)
@@ -256,7 +265,7 @@ class ResourceEstimator:
       # Not enough data to solve yet
       return
 
-    logger.info(
+    self.logging_info(
         "Solving with %s total observations for %s models.",
         len(A),
         len(unique))
@@ -280,9 +289,9 @@ class ResourceEstimator:
         else:
           self.estimates[model] = calculated_cost
 
-        logger.info(
+        self.logging_info(
             "Updated Estimate for %s: %.1f MB", model, self.estimates[model])
-      logger.info("System Bias: %s MB", bias)
+      self.logging_info("System Bias: %s MB", bias)
 
     except Exception as e:
       logger.error("Solver failed: %s", e)
@@ -321,10 +330,13 @@ class ModelManager:
       eviction_cooldown_seconds: float = 10.0,
       min_model_copies: int = 1,
       wait_timeout_seconds: float = 300.0,
-      lock_timeout_seconds: float = 60.0):
+      lock_timeout_seconds: float = 60.0,
+      verbose_logging: bool = False):
 
     self._estimator = ResourceEstimator(
-        min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+        min_data_points=min_data_points,
+        smoothing_factor=smoothing_factor,
+        verbose_logging=verbose_logging)
     self._monitor = monitor if monitor else GPUMonitor(
         poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
     self._slack_percentage = slack_percentage
@@ -333,6 +345,7 @@ class ModelManager:
     self._min_model_copies = min_model_copies
     self._wait_timeout_seconds = wait_timeout_seconds
     self._lock_timeout_seconds = lock_timeout_seconds
+    self._verbose_logging = verbose_logging
 
     # Resource State
     self._models = defaultdict(list)
@@ -361,20 +374,24 @@ class ModelManager:
 
     self._monitor.start()
 
+  def logging_info(self, message: str, *args):
+    if self._verbose_logging:
+      logger.info(message, *args)
+
   def all_models(self, tag) -> list[Any]:
     return self._models[tag]
 
   # Should hold _cv lock when calling
   def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
     if self._total_active_jobs > 0:
-      logger.info(
+      self.logging_info(
           "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
       self._cv.wait(timeout=self._lock_timeout_seconds)
       # return False since we have waited and need to re-evaluate
       # in caller to make sure our priority is still valid.
       return False
 
-    logger.info("Unknown model %s detected. Flushing GPU.", tag)
+    self.logging_info("Unknown model %s detected. Flushing GPU.", tag)
     self._delete_all_models()
 
     self._isolation_mode = True
@@ -412,7 +429,7 @@ class ModelManager:
     for _, instances in self._models.items():
       total_model_count += len(instances)
     curr, _, _ = self._monitor.get_stats()
-    logger.info(
+    self.logging_info(
         "Waiting for resources to free up: "
         "tag=%s ticket num%s model count=%s "
         "idle count=%s resource usage=%.1f MB "
@@ -462,7 +479,7 @@ class ModelManager:
       # SLOW PATH: Enqueue and wait for turn to acquire model,
       # with unknown models having priority and order enforced
       # by ticket number as FIFO.
-      logger.info(
+      self.logging_info(
           "Acquire Queued: tag=%s, priority=%d "
           "total models count=%s ticket num=%s",
           tag,
@@ -484,7 +501,7 @@ class ModelManager:
                 f"after {wait_time_elapsed:.1f} seconds.")
           if not self._wait_queue or self._wait_queue[
               0].ticket_num != ticket_num:
-            logger.info(
+            self.logging_info(
                 "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
             self._wait_in_queue(my_ticket)
             continue
@@ -520,7 +537,7 @@ class ModelManager:
           # Path B: Concurrent
           else:
             if self._isolation_mode:
-              logger.info(
+              self.logging_info(
                   "Waiting due to isolation in progress: tag=%s ticket num%s",
                   tag,
                   ticket_num)
@@ -596,7 +613,7 @@ class ModelManager:
       self._total_active_jobs += 1
       return target_instance
 
-    logger.info("No idle model found for tag: %s", tag)
+    self.logging_info("No idle model found for tag: %s", tag)
     return None
 
   def _evict_to_make_space(
@@ -671,17 +688,21 @@ class ModelManager:
     if isinstance(instance, str):
       # If the instance is a string, it's a uuid used
       # to retrieve the model from MultiProcessShared
-      multi_process_shared.MultiProcessShared(
-          lambda: "N/A", tag=instance).unsafe_hard_delete()
+      try:
+        multi_process_shared.MultiProcessShared(
+            lambda: "N/A", tag=instance).unsafe_hard_delete()
+      except (EOFError, OSError, BrokenPipeError):
+        # This can happen even in normal operation.
+        pass
     if hasattr(instance, 'mock_model_unsafe_hard_delete'):
       # Call the mock unsafe hard delete method for testing
       instance.mock_model_unsafe_hard_delete()
     del instance
 
   def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
-    logger.info("Evicting Model: %s (Score %d)", tag, score)
+    self.logging_info("Evicting Model: %s (Score %d)", tag, score)
     curr, _, _ = self._monitor.get_stats()
-    logger.info("Resource Usage Before Eviction: %.1f MB", curr)
+    self.logging_info("Resource Usage Before Eviction: %.1f MB", curr)
 
     if key in self._idle_lru:
       del self._idle_lru[key]
@@ -697,7 +718,7 @@ class ModelManager:
     self._monitor.refresh()
     self._monitor.reset_peak()
     curr, _, _ = self._monitor.get_stats()
-    logger.info("Resource Usage After Eviction: %.1f MB", curr)
+    self.logging_info("Resource Usage After Eviction: %.1f MB", curr)
 
   def _spawn_new_model(
       self,
@@ -707,7 +728,7 @@ class ModelManager:
       est_cost: float) -> Any:
     try:
       with self._cv:
-        logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
+        self.logging_info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
         baseline_snap, _, _ = self._monitor.get_stats()
         instance = loader_func()
         _, peak_during_load, _ = self._monitor.get_stats()

Reply via email to