AMOOOMA commented on code in PR #37113: URL: https://github.com/apache/beam/pull/37113#discussion_r2747636002
########## sdks/python/apache_beam/ml/inference/model_manager_test.py: ########## @@ -0,0 +1,596 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import random +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch + +from apache_beam.utils import multi_process_shared + +try: + from apache_beam.ml.inference.model_manager import GPUMonitor + from apache_beam.ml.inference.model_manager import ModelManager + from apache_beam.ml.inference.model_manager import ResourceEstimator +except ImportError as e: + raise unittest.SkipTest("Model Manager dependencies are not installed") + + +class MockGPUMonitor: + """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ + def __init__(self, total_memory=12000.0, peak_window: int = 5): + self._current = 0.0 + self._peak = 0.0 + self._total = total_memory + self._lock = threading.Lock() + self.running = False + self.history = [] + self.peak_window = peak_window + + def start(self): + self.running = True + + def stop(self): + self.running = False + + def get_stats(self): + with self._lock: + return self._current, self._peak, self._total + + def reset_peak(self): + with self._lock: + self._peak = self._current + self.history = [self._current] + + def set_usage(self, current_mb): + """Sets absolute usage (legacy helper).""" + with self._lock: + self._current = current_mb + self._peak = max(self._peak, current_mb) + + def allocate(self, amount_mb): + """Simulates memory allocation (e.g., tensors loaded to VRAM).""" + with self._lock: + self._current += amount_mb + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def free(self, amount_mb): + """Simulates memory freeing (not used often if pooling is active).""" + with self._lock: + self._current = max(0.0, self._current - amount_mb) + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def refresh(self): + """Simulates a refresh of the monitor stats (no-op for mock).""" + pass + + +class MockModel: + def __init__(self, name, size, monitor): + self.name = name + self.size = size + self.monitor = monitor + self.deleted = False + self.monitor.allocate(size) + + def mock_model_unsafe_hard_delete(self): + if not self.deleted: + self.monitor.free(self.size) + self.deleted = True + + +class Counter(object): + def __init__(self, start=0): + self.running = start + self.lock = threading.Lock() + + def get(self): + return self.running + + def increment(self, value=1): + with self.lock: + self.running += value + return self.running + + +class TestModelManager(unittest.TestCase): + def setUp(self): + """Force reset the Singleton ModelManager before every test.""" + ModelManager._instance = None Review Comment: Done. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
