This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a472d1a004c Add support for limiting number of models in memory
(#28263)
a472d1a004c is described below
commit a472d1a004c306ab27217dc71761a66949370a85
Author: Danny McCormick <[email protected]>
AuthorDate: Thu Aug 31 17:16:37 2023 -0400
Add support for limiting number of models in memory (#28263)
* Add support for limiting number of models in memory
* Add test to demonstrate behavior
* doc
---
sdks/python/apache_beam/ml/inference/base.py | 35 ++++++++++-------
sdks/python/apache_beam/ml/inference/base_test.py | 46 +++++++++++++++++++----
2 files changed, 60 insertions(+), 21 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 80b73c00675..90d43cfddb9 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -304,21 +304,15 @@ class _ModelManager:
"""
A class for efficiently managing copies of multiple models. Will load a
single copy of each model into a multi_process_shared object and then
- return a lookup key for that object. Optionally takes in a max_models
- parameter, if that is set it will only hold that many models in memory at
- once before evicting one (using LRU logic).
+ return a lookup key for that object.
"""
- def __init__(
- self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+ def __init__(self, mh_map: Dict[str, ModelHandler]):
"""
Args:
mh_map: A map from keys to model handlers which can be used to load a
model.
- max_models: The maximum number of models to load at any given time
- before evicting 1 from memory (using LRU logic). Leave as None to
- allow unlimited models.
"""
- self._max_models = max_models
+ self._max_models = None
# Map keys to model handlers
self._mh_map: Dict[str, ModelHandler] = mh_map
# Map keys to the last updated model path for that key
@@ -376,14 +370,12 @@ class _ModelManager:
def increment_max_models(self, increment: int):
"""
Increments the number of models that this instance of a _ModelManager is
- able to hold.
+ able to hold. If it is never called, no limit is imposed.
Args:
increment: the amount by which we are incrementing the number of models.
"""
if self._max_models is None:
- raise ValueError(
- "Cannot increment max_models if self._max_models is None (unlimited"
+
- " models mode).")
+ self._max_models = 0
self._max_models += increment
def update_model_handler(self, key: str, model_path: str, previous_key: str):
@@ -436,7 +428,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
List[KeyModelMapping[KeyT, ExampleT, PredictionT,
- ModelT]]]):
+ ModelT]]],
+ max_models_per_worker_hint: Optional[int] = None):
"""A ModelHandler that takes keyed examples and returns keyed predictions.
For example, if the original model is used with RunInference to take a
@@ -494,6 +487,11 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
unkeyed: Either (a) an implementation of ModelHandler that does not
require keys or (b) a list of KeyModelMappings mapping lists of keys to
unkeyed ModelHandlers.
+ max_models_per_worker_hint: A hint to the runner indicating how many
+ models can be held in memory at one time per worker process. For
+ example, if your worker has 8 GB of memory provisioned and your workers
+ take up 1 GB each, you should set this to 7 to allow all models to sit
+ in memory with some buffer.
"""
self._metrics_collectors: Dict[str, _MetricsCollector] = {}
self._default_metrics_collector: _MetricsCollector = None
@@ -511,6 +509,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
self._unkeyed = unkeyed
return
+ self._max_models_per_worker_hint = max_models_per_worker_hint
# To maintain an efficient representation, we will map all keys in a given
# KeyModelMapping to a single id (the first key in the KeyModelMapping
# list). We will then map that key to a ModelHandler. This will allow us to
@@ -587,6 +586,14 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
keys,
self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+ # The first time a MultiProcessShared ModelManager is used for inference
+ # from this process, we should increment its max model count
+ if self._max_models_per_worker_hint is not None:
+ lock = threading.Lock()
+ if lock.acquire(blocking=False):
+ model.increment_max_models(self._max_models_per_worker_hint)
+ self._max_models_per_worker_hint = None
+
batch_by_key = defaultdict(list)
key_by_id = defaultdict(set)
for key, example in batch:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index 3180c697a36..1b1a7393872 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -345,6 +345,41 @@ class RunInferenceBaseTest(unittest.TestCase):
load_latency_dist_aggregate = metrics['distributions'][0]
self.assertEqual(load_latency_dist_aggregate.committed.count, 2)
+ def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint(
+ self):
+ pipeline = TestPipeline()
+ examples = [1, 5, 3, 10, 2, 4, 6, 8, 9, 7, 1, 5, 3, 10, 2, 4, 6, 8, 9, 7]
+ metrics_namespace = 'test_namespace'
+ keyed_examples = [(i, example) for i, example in enumerate(examples)]
+ pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+ mhs = [
+ base.KeyModelMapping([0, 2, 4, 6, 8],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyModelMapping(
+ [1, 3, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ _ = pcoll | base.RunInference(
+ base.KeyedModelHandler(mhs, max_models_per_worker_hint=1),
+ metrics_namespace=metrics_namespace)
+ result = pipeline.run()
+ result.wait_until_finish()
+
+ metrics_filter =
MetricsFilter().with_namespace(namespace=metrics_namespace)
+ metrics = result.metrics().query(metrics_filter)
+ assert len(metrics['counters']) != 0
+ assert len(metrics['distributions']) != 0
+
+ metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs')
+ metrics = result.metrics().query(metrics_filter)
+ load_latency_dist_aggregate = metrics['distributions'][0]
+ # We should flip back and forth between models a bit since
+ # max_models_per_worker_hint=1, but we shouldn't thrash forever
+ # since most examples belong to the second ModelMapping
+ self.assertGreater(load_latency_dist_aggregate.committed.count, 2)
+ self.assertLess(load_latency_dist_aggregate.committed.count, 12)
+
def test_keyed_many_model_handlers_validation(self):
def mult_two(example: str) -> int:
return int(example) * 2
@@ -1366,7 +1401,8 @@ class RunInferenceBaseTest(unittest.TestCase):
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
- mm = base._ModelManager(mh_map=mhs, max_models=2)
+ mm = base._ModelManager(mh_map=mhs)
+ mm.increment_max_models(2)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
@@ -1440,7 +1476,8 @@ class RunInferenceBaseTest(unittest.TestCase):
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
- mm = base._ModelManager(mh_map=mhs, max_models=1)
+ mm = base._ModelManager(mh_map=mhs)
+ mm.increment_max_models(1)
mm.increment_max_models(1)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
@@ -1476,11 +1513,6 @@ class RunInferenceBaseTest(unittest.TestCase):
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))
- def test_model_manager_fails_if_no_default_initially(self):
- mm = base._ModelManager(mh_map={})
- with self.assertRaisesRegex(ValueError, r'self._max_models is None'):
- mm.increment_max_models(5)
-
if __name__ == '__main__':
unittest.main()