tvalentyn commented on code in PR #27603:
URL: https://github.com/apache/beam/pull/27603#discussion_r1283392520


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -739,8 +740,81 @@ def with_exception_handling(
     return self
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then
+  return a lookup key for that object. Optionally takes in a max_models
+  parameter, if that is set it will only hold that many models in memory at
+  once before evicting one (using LRU logic).
+  """
+  def __init__(
+      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+    """
+    Args:
+      mh_map: A map from keys to model handlers which can be used to load a
+        model.
+      max_models: The maximum number of models to load at any given time
+        before evicting 1 from memory (using LRU logic). Leave as None to
+        allow unlimited models.
+    """
+    self._max_models = max_models
+    self._mh_map = mh_map
+    self._proxy_map = {}

Review Comment:
   Could we add typehints to these dictionaries?



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -739,8 +740,81 @@ def with_exception_handling(
     return self
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then

Review Comment:
   ```suggestion
     single copy of each model into a MultiProcessShared object and then
   ```



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
       actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_model_manager_loads_shared_model(self):
+    mhs = {
+        'key1': FakeModelHandler(state=1),
+        'key2': FakeModelHandler(state=2),
+        'key3': FakeModelHandler(state=3)
+    }
+    mm = base._ModelManager(mh_map=mhs)
+    tag1 = mm.load('key1')
+    # Use bad_mh's load function to make sure we're actually loading the
+    # version already stored
+    bad_mh = FakeModelHandler(state=100)
+    model1 = multi_process_shared.MultiProcessShared(
+        bad_mh.load_model, tag=tag1).acquire()
+    self.assertEqual(1, model1.predict(10))
+
+    tag2 = mm.load('key2')

Review Comment:
   > Note that this is actually an important property if we have multiple model 
managers operating at once (e.g. in a pipeline that does multiple sets of 
sequential inference)
   
   Sounds like two different model managers will not share the models, even if 
they are the same. Something important to keep in mind for planning memory 
consumption if users do sequential inferences.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to