This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a472d1a004c Add support for limiting number of models in memory 
(#28263)
a472d1a004c is described below

commit a472d1a004c306ab27217dc71761a66949370a85
Author: Danny McCormick <[email protected]>
AuthorDate: Thu Aug 31 17:16:37 2023 -0400

    Add support for limiting number of models in memory (#28263)
    
    * Add support for limiting number of models in memory
    
    * Add test to demonstrate behavior
    
    * doc
---
 sdks/python/apache_beam/ml/inference/base.py      | 35 ++++++++++-------
 sdks/python/apache_beam/ml/inference/base_test.py | 46 +++++++++++++++++++----
 2 files changed, 60 insertions(+), 21 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 80b73c00675..90d43cfddb9 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -304,21 +304,15 @@ class _ModelManager:
   """
   A class for efficiently managing copies of multiple models. Will load a
   single copy of each model into a multi_process_shared object and then
-  return a lookup key for that object. Optionally takes in a max_models
-  parameter, if that is set it will only hold that many models in memory at
-  once before evicting one (using LRU logic).
+  return a lookup key for that object.
   """
-  def __init__(
-      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+  def __init__(self, mh_map: Dict[str, ModelHandler]):
     """
     Args:
       mh_map: A map from keys to model handlers which can be used to load a
         model.
-      max_models: The maximum number of models to load at any given time
-        before evicting 1 from memory (using LRU logic). Leave as None to
-        allow unlimited models.
     """
-    self._max_models = max_models
+    self._max_models = None
     # Map keys to model handlers
     self._mh_map: Dict[str, ModelHandler] = mh_map
     # Map keys to the last updated model path for that key
@@ -376,14 +370,12 @@ class _ModelManager:
   def increment_max_models(self, increment: int):
     """
     Increments the number of models that this instance of a _ModelManager is
-    able to hold.
+    able to hold. If it is never called, no limit is imposed.
     Args:
       increment: the amount by which we are incrementing the number of models.
     """
     if self._max_models is None:
-      raise ValueError(
-          "Cannot increment max_models if self._max_models is None (unlimited" 
+
-          " models mode).")
+      self._max_models = 0
     self._max_models += increment
 
   def update_model_handler(self, key: str, model_path: str, previous_key: str):
@@ -436,7 +428,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
       self,
       unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
                      List[KeyModelMapping[KeyT, ExampleT, PredictionT,
-                                          ModelT]]]):
+                                          ModelT]]],
+      max_models_per_worker_hint: Optional[int] = None):
     """A ModelHandler that takes keyed examples and returns keyed predictions.
 
     For example, if the original model is used with RunInference to take a
@@ -494,6 +487,11 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
       unkeyed: Either (a) an implementation of ModelHandler that does not
         require keys or (b) a list of KeyModelMappings mapping lists of keys to
         unkeyed ModelHandlers.
+      max_models_per_worker_hint: A hint to the runner indicating how many
+        models can be held in memory at one time per worker process. For
+        example, if your worker has 8 GB of memory provisioned and your workers
+        take up 1 GB each, you should set this to 7 to allow all models to sit
+        in memory with some buffer.
     """
     self._metrics_collectors: Dict[str, _MetricsCollector] = {}
     self._default_metrics_collector: _MetricsCollector = None
@@ -511,6 +509,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
       self._unkeyed = unkeyed
       return
 
+    self._max_models_per_worker_hint = max_models_per_worker_hint
     # To maintain an efficient representation, we will map all keys in a given
     # KeyModelMapping to a single id (the first key in the KeyModelMapping
     # list). We will then map that key to a ModelHandler. This will allow us to
@@ -587,6 +586,14 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
           keys,
           self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
 
+    # The first time a MultiProcessShared ModelManager is used for inference
+    # from this process, we should increment its max model count
+    if self._max_models_per_worker_hint is not None:
+      lock = threading.Lock()
+      if lock.acquire(blocking=False):
+        model.increment_max_models(self._max_models_per_worker_hint)
+      self._max_models_per_worker_hint = None
+
     batch_by_key = defaultdict(list)
     key_by_id = defaultdict(set)
     for key, example in batch:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 3180c697a36..1b1a7393872 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -345,6 +345,41 @@ class RunInferenceBaseTest(unittest.TestCase):
     load_latency_dist_aggregate = metrics['distributions'][0]
     self.assertEqual(load_latency_dist_aggregate.committed.count, 2)
 
+  def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint(
+      self):
+    pipeline = TestPipeline()
+    examples = [1, 5, 3, 10, 2, 4, 6, 8, 9, 7, 1, 5, 3, 10, 2, 4, 6, 8, 9, 7]
+    metrics_namespace = 'test_namespace'
+    keyed_examples = [(i, example) for i, example in enumerate(examples)]
+    pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+    mhs = [
+        base.KeyModelMapping([0, 2, 4, 6, 8],
+                             FakeModelHandler(
+                                 state=200, multi_process_shared=True)),
+        base.KeyModelMapping(
+            [1, 3, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
+            FakeModelHandler(multi_process_shared=True))
+    ]
+    _ = pcoll | base.RunInference(
+        base.KeyedModelHandler(mhs, max_models_per_worker_hint=1),
+        metrics_namespace=metrics_namespace)
+    result = pipeline.run()
+    result.wait_until_finish()
+
+    metrics_filter = 
MetricsFilter().with_namespace(namespace=metrics_namespace)
+    metrics = result.metrics().query(metrics_filter)
+    assert len(metrics['counters']) != 0
+    assert len(metrics['distributions']) != 0
+
+    metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs')
+    metrics = result.metrics().query(metrics_filter)
+    load_latency_dist_aggregate = metrics['distributions'][0]
+    # We should flip back and forth between models a bit since
+    # max_models_per_worker_hint=1, but we shouldn't thrash forever
+    # since most examples belong to the second ModelMapping
+    self.assertGreater(load_latency_dist_aggregate.committed.count, 2)
+    self.assertLess(load_latency_dist_aggregate.committed.count, 12)
+
   def test_keyed_many_model_handlers_validation(self):
     def mult_two(example: str) -> int:
       return int(example) * 2
@@ -1366,7 +1401,8 @@ class RunInferenceBaseTest(unittest.TestCase):
     mh2 = FakeModelHandler(state=2)
     mh3 = FakeModelHandler(state=3)
     mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
-    mm = base._ModelManager(mh_map=mhs, max_models=2)
+    mm = base._ModelManager(mh_map=mhs)
+    mm.increment_max_models(2)
     tag1 = mm.load('key1').model_tag
     sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
     model1 = sh1.acquire()
@@ -1440,7 +1476,8 @@ class RunInferenceBaseTest(unittest.TestCase):
     mh2 = FakeModelHandler(state=2)
     mh3 = FakeModelHandler(state=3)
     mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
-    mm = base._ModelManager(mh_map=mhs, max_models=1)
+    mm = base._ModelManager(mh_map=mhs)
+    mm.increment_max_models(1)
     mm.increment_max_models(1)
     tag1 = mm.load('key1').model_tag
     sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
@@ -1476,11 +1513,6 @@ class RunInferenceBaseTest(unittest.TestCase):
         mh3.load_model, tag=tag3).acquire()
     self.assertEqual(8, model3.predict(10))
 
-  def test_model_manager_fails_if_no_default_initially(self):
-    mm = base._ModelManager(mh_map={})
-    with self.assertRaisesRegex(ValueError, r'self._max_models is None'):
-      mm.increment_max_models(5)
-
 
 if __name__ == '__main__':
   unittest.main()

Reply via email to