This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 22d1c576272 Add OOM handling for RunInference with model manager
(#37557)
22d1c576272 is described below
commit 22d1c576272af62d691ca3124c99b9c563fbbb8c
Author: RuiLong J. <[email protected]>
AuthorDate: Thu Feb 12 06:51:11 2026 -0800
Add OOM handling for RunInference with model manager (#37557)
* Add OOM protection handling for RunInference
* Make sure we release the model regardless
* Add testing coverage
* Lint and make logging optional
* Pass verbose logging setting from model manager to estimator
* Lint
* Prevent flakes on EOFError
* Enforce batch size
---
sdks/python/apache_beam/ml/inference/base.py | 39 ++++++++++++--
sdks/python/apache_beam/ml/inference/base_test.py | 26 +++++++++
.../apache_beam/ml/inference/model_manager.py | 63 ++++++++++++++--------
3 files changed, 102 insertions(+), 26 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 1c3f0918baf..ef5d15264b5 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -1330,6 +1330,30 @@ class _PostProcessingModelHandler(Generic[ExampleT,
return self._base.get_postprocess_fns() + [self._postprocess_fn]
+class OOMProtectedFn:
+ def __init__(self, func):
+ self.func = func
+
+ def __call__(self, *args, **kwargs):
+ try:
+ return self.func(*args, **kwargs)
+ except Exception as e:
+ # Check string to avoid hard import dependency
+ if 'out of memory' in str(e) and 'CUDA' in str(e):
+ logging.warning("Caught CUDA OOM during operation. Cleaning memory.")
+ try:
+ import gc
+
+ import torch
+ gc.collect()
+ torch.cuda.empty_cache()
+ except ImportError:
+ pass
+ except Exception as cleanup_error:
+ logging.error("Failed to clean up CUDA memory: %s", cleanup_error)
+ raise e
+
+
class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
Iterable[ExampleT]]],
beam.PCollection[PredictionT]]):
@@ -1831,7 +1855,9 @@ class _ProxyLoader:
unique_tag = self.model_tag + '_' + uuid.uuid4().hex
# Ensure that each model loaded in a different process for parallelism
multi_process_shared.MultiProcessShared(
- self.loader_func, tag=unique_tag, always_proxy=True,
+ OOMProtectedFn(self.loader_func),
+ tag=unique_tag,
+ always_proxy=True,
spawn_process=True).acquire()
# Only return the tag to avoid pickling issues with the model itself.
return unique_tag
@@ -2021,10 +2047,13 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT,
PredictionT]):
unique_tag = model
model = multi_process_shared.MultiProcessShared(
lambda: None, tag=model, always_proxy=True).acquire()
- result_generator = self._model_handler.run_inference(
- batch, model, inference_args)
- if self.use_model_manager:
- self._model.release_model(self._model_tag, unique_tag)
+ try:
+ result_generator = (OOMProtectedFn(self._model_handler.run_inference))(
+ batch, model, inference_args)
+ finally:
+ # Always release the model so that it can be reloaded.
+ if self.use_model_manager:
+ self._model.release_model(self._model_tag, unique_tag)
except BaseException as e:
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index feccd8b0f12..8236ac5c1e5 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -20,6 +20,7 @@ import math
import multiprocessing
import os
import pickle
+import random
import sys
import tempfile
import time
@@ -2338,6 +2339,31 @@ class ModelManagerTest(unittest.TestCase):
})
assert_that(actual, equal_to(expected), label='assert:inferences')
+ @unittest.skipIf(
+ not try_import_model_manager(), 'Model Manager not available')
+ def test_run_inference_impl_with_model_manager_oom(self):
+ class OOMFakeModelHandler(SimpleFakeModelHandler):
+ def run_inference(
+ self,
+ batch: Sequence[int],
+ model: FakeModel,
+ inference_args=None) -> Iterable[int]:
+ if random.random() < 0.8:
+ raise MemoryError("Simulated OOM")
+ for example in batch:
+ yield model.predict(example)
+
+ def batch_elements_kwargs(self):
+ return {'min_batch_size': 1, 'max_batch_size': 1}
+
+ with self.assertRaises(Exception):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ OOMFakeModelHandler(), use_model_manager=True)
+ assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py
b/sdks/python/apache_beam/ml/inference/model_manager.py
index bf7c6a43ba6..186611984df 100644
--- a/sdks/python/apache_beam/ml/inference/model_manager.py
+++ b/sdks/python/apache_beam/ml/inference/model_manager.py
@@ -176,14 +176,23 @@ class ResourceEstimator:
individual models based on aggregate system memory readings and the
configuration of active models at that time.
"""
- def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
+ def __init__(
+ self,
+ smoothing_factor: float = 0.2,
+ min_data_points: int = 5,
+ verbose_logging: bool = False):
self.smoothing_factor = smoothing_factor
self.min_data_points = min_data_points
+ self.verbose_logging = verbose_logging
self.estimates: Dict[str, float] = {}
self.history = defaultdict(lambda: deque(maxlen=20))
self.known_models = set()
self._lock = threading.Lock()
+ def logging_info(self, message: str, *args):
+ if self.verbose_logging:
+ logger.info(message, *args)
+
def is_unknown(self, model_tag: str) -> bool:
with self._lock:
return model_tag not in self.estimates
@@ -196,7 +205,7 @@ class ResourceEstimator:
with self._lock:
self.estimates[model_tag] = cost
self.known_models.add(model_tag)
- logger.info("Initial Profile for %s: %s MB", model_tag, cost)
+ self.logging_info("Initial Profile for %s: %s MB", model_tag, cost)
def add_observation(
self, active_snapshot: Dict[str, int], peak_memory: float):
@@ -207,7 +216,7 @@ class ResourceEstimator:
else:
model_list = "\t- None"
- logger.info(
+ self.logging_info(
"Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s",
peak_memory,
model_list)
@@ -256,7 +265,7 @@ class ResourceEstimator:
# Not enough data to solve yet
return
- logger.info(
+ self.logging_info(
"Solving with %s total observations for %s models.",
len(A),
len(unique))
@@ -280,9 +289,9 @@ class ResourceEstimator:
else:
self.estimates[model] = calculated_cost
- logger.info(
+ self.logging_info(
"Updated Estimate for %s: %.1f MB", model, self.estimates[model])
- logger.info("System Bias: %s MB", bias)
+ self.logging_info("System Bias: %s MB", bias)
except Exception as e:
logger.error("Solver failed: %s", e)
@@ -321,10 +330,13 @@ class ModelManager:
eviction_cooldown_seconds: float = 10.0,
min_model_copies: int = 1,
wait_timeout_seconds: float = 300.0,
- lock_timeout_seconds: float = 60.0):
+ lock_timeout_seconds: float = 60.0,
+ verbose_logging: bool = False):
self._estimator = ResourceEstimator(
- min_data_points=min_data_points, smoothing_factor=smoothing_factor)
+ min_data_points=min_data_points,
+ smoothing_factor=smoothing_factor,
+ verbose_logging=verbose_logging)
self._monitor = monitor if monitor else GPUMonitor(
poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
self._slack_percentage = slack_percentage
@@ -333,6 +345,7 @@ class ModelManager:
self._min_model_copies = min_model_copies
self._wait_timeout_seconds = wait_timeout_seconds
self._lock_timeout_seconds = lock_timeout_seconds
+ self._verbose_logging = verbose_logging
# Resource State
self._models = defaultdict(list)
@@ -361,20 +374,24 @@ class ModelManager:
self._monitor.start()
+ def logging_info(self, message: str, *args):
+ if self._verbose_logging:
+ logger.info(message, *args)
+
def all_models(self, tag) -> list[Any]:
return self._models[tag]
# Should hold _cv lock when calling
def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
if self._total_active_jobs > 0:
- logger.info(
+ self.logging_info(
"Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
self._cv.wait(timeout=self._lock_timeout_seconds)
# return False since we have waited and need to re-evaluate
# in caller to make sure our priority is still valid.
return False
- logger.info("Unknown model %s detected. Flushing GPU.", tag)
+ self.logging_info("Unknown model %s detected. Flushing GPU.", tag)
self._delete_all_models()
self._isolation_mode = True
@@ -412,7 +429,7 @@ class ModelManager:
for _, instances in self._models.items():
total_model_count += len(instances)
curr, _, _ = self._monitor.get_stats()
- logger.info(
+ self.logging_info(
"Waiting for resources to free up: "
"tag=%s ticket num%s model count=%s "
"idle count=%s resource usage=%.1f MB "
@@ -462,7 +479,7 @@ class ModelManager:
# SLOW PATH: Enqueue and wait for turn to acquire model,
# with unknown models having priority and order enforced
# by ticket number as FIFO.
- logger.info(
+ self.logging_info(
"Acquire Queued: tag=%s, priority=%d "
"total models count=%s ticket num=%s",
tag,
@@ -484,7 +501,7 @@ class ModelManager:
f"after {wait_time_elapsed:.1f} seconds.")
if not self._wait_queue or self._wait_queue[
0].ticket_num != ticket_num:
- logger.info(
+ self.logging_info(
"Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
self._wait_in_queue(my_ticket)
continue
@@ -520,7 +537,7 @@ class ModelManager:
# Path B: Concurrent
else:
if self._isolation_mode:
- logger.info(
+ self.logging_info(
"Waiting due to isolation in progress: tag=%s ticket num%s",
tag,
ticket_num)
@@ -596,7 +613,7 @@ class ModelManager:
self._total_active_jobs += 1
return target_instance
- logger.info("No idle model found for tag: %s", tag)
+ self.logging_info("No idle model found for tag: %s", tag)
return None
def _evict_to_make_space(
@@ -671,17 +688,21 @@ class ModelManager:
if isinstance(instance, str):
# If the instance is a string, it's a uuid used
# to retrieve the model from MultiProcessShared
- multi_process_shared.MultiProcessShared(
- lambda: "N/A", tag=instance).unsafe_hard_delete()
+ try:
+ multi_process_shared.MultiProcessShared(
+ lambda: "N/A", tag=instance).unsafe_hard_delete()
+ except (EOFError, OSError, BrokenPipeError):
+ # This can happen even in normal operation.
+ pass
if hasattr(instance, 'mock_model_unsafe_hard_delete'):
# Call the mock unsafe hard delete method for testing
instance.mock_model_unsafe_hard_delete()
del instance
def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
- logger.info("Evicting Model: %s (Score %d)", tag, score)
+ self.logging_info("Evicting Model: %s (Score %d)", tag, score)
curr, _, _ = self._monitor.get_stats()
- logger.info("Resource Usage Before Eviction: %.1f MB", curr)
+ self.logging_info("Resource Usage Before Eviction: %.1f MB", curr)
if key in self._idle_lru:
del self._idle_lru[key]
@@ -697,7 +718,7 @@ class ModelManager:
self._monitor.refresh()
self._monitor.reset_peak()
curr, _, _ = self._monitor.get_stats()
- logger.info("Resource Usage After Eviction: %.1f MB", curr)
+ self.logging_info("Resource Usage After Eviction: %.1f MB", curr)
def _spawn_new_model(
self,
@@ -707,7 +728,7 @@ class ModelManager:
est_cost: float) -> Any:
try:
with self._cv:
- logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
+ self.logging_info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
baseline_snap, _, _ = self._monitor.get_stats()
instance = loader_func()
_, peak_during_load, _ = self._monitor.get_stats()