tvalentyn commented on code in PR #27603:
URL: https://github.com/apache/beam/pull/27603#discussion_r1279879864
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -739,8 +739,95 @@ def with_exception_handling(
return self
+class _ModelManager:
+ """
+ A class for efficiently managing copies of multiple models. Will load a
+ single copy of each model into a multi_process_shared object and then
+ return a lookup key for that object. Optionally takes in a max_models
+ parameter, if that is set it will only hold that many models in memory at
+ once before evicting one (using LRU logic).
+ """
+ def __init__(
+ self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+ """
+ Args:
+ mh_map: A map from keys to model handlers which can be used to load a
+ model.
+ max_models: The maximum number of models to load at any given time
+ before evicting 1 from memory (using LRU logic). Leave as None to
+ allow unlimited models.
+ """
+ self._max_models = max_models
+ self._mh_map = mh_map
+ self._loaded_keys = []
+ self._proxy_map = {}
+ self._tag_map = {}
+
+ def _get_tag(self, key: str) -> str:
+ """
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to load the model using multi_process_shared.py
+ """
+ if key not in self._tag_map:
+ self._tag_map[key] = uuid.uuid4().hex
+ return self._tag_map[key]
+
+ def load(self, key: str) -> str:
+ """
+ Loads the appropriate model for the given key into memory.
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to access the model using multi_process_shared.py.
+ """
+ tag = self._get_tag(key)
+ if key in self._loaded_keys:
+ if self._max_models is not None:
+ # move key to the back of the list
+ self._loaded_keys.append(
+ self._loaded_keys.pop(self._loaded_keys.index(key)))
+ return tag
+
+ mh = self._mh_map[key]
+ if self._max_models is not None and self._max_models <= len(
+ self._loaded_keys):
+ # If we're about to exceed our LRU size,
+ # remove the front model from the list from memory.
+ key_to_remove = self._loaded_keys[0]
+ tag_to_remove = self._get_tag(key_to_remove)
+ shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+ shared_handle.release(model_to_remove)
+ del self._tag_map[key]
Review Comment:
sounds like `self._tag_map` stores only the loaded keys. In that case you
could let go of `self._loaded_keys` (or make it a map that stores the tags).
and rely on dictionary insertion order to retrieve the oldest entry + reinsert
as necessary.
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_model_manager_loads_shared_model(self):
Review Comment:
How about we add a test that existing model references are still valid
after we loaded more models than max_model ?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -739,8 +739,95 @@ def with_exception_handling(
return self
+class _ModelManager:
+ """
+ A class for efficiently managing copies of multiple models. Will load a
+ single copy of each model into a multi_process_shared object and then
+ return a lookup key for that object. Optionally takes in a max_models
+ parameter, if that is set it will only hold that many models in memory at
+ once before evicting one (using LRU logic).
+ """
+ def __init__(
+ self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+ """
+ Args:
+ mh_map: A map from keys to model handlers which can be used to load a
+ model.
+ max_models: The maximum number of models to load at any given time
+ before evicting 1 from memory (using LRU logic). Leave as None to
+ allow unlimited models.
+ """
+ self._max_models = max_models
+ self._mh_map = mh_map
+ self._loaded_keys = []
+ self._proxy_map = {}
+ self._tag_map = {}
+
+ def _get_tag(self, key: str) -> str:
+ """
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to load the model using multi_process_shared.py
+ """
+ if key not in self._tag_map:
+ self._tag_map[key] = uuid.uuid4().hex
+ return self._tag_map[key]
+
+ def load(self, key: str) -> str:
+ """
+ Loads the appropriate model for the given key into memory.
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to access the model using multi_process_shared.py.
+ """
+ tag = self._get_tag(key)
+ if key in self._loaded_keys:
+ if self._max_models is not None:
+ # move key to the back of the list
+ self._loaded_keys.append(
+ self._loaded_keys.pop(self._loaded_keys.index(key)))
+ return tag
+
+ mh = self._mh_map[key]
+ if self._max_models is not None and self._max_models <= len(
+ self._loaded_keys):
+ # If we're about to exceed our LRU size,
+ # remove the front model from the list from memory.
+ key_to_remove = self._loaded_keys[0]
+ tag_to_remove = self._get_tag(key_to_remove)
+ shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+ shared_handle.release(model_to_remove)
+ del self._tag_map[key]
Review Comment:
you could also use `collections.defaultdict(lambda: uuid.uuid4().hex)`
instead of `_get_tag()`
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_model_manager_loads_shared_model(self):
+ mhs = {
+ 'key1': FakeModelHandler(state=1),
+ 'key2': FakeModelHandler(state=2),
+ 'key3': FakeModelHandler(state=3)
+ }
+ mm = base._ModelManager(mh_map=mhs)
+ tag1 = mm.load('key1')
+ # Use bad_mh's load function to make sure we're actually loading the
+ # version already stored
+ bad_mh = FakeModelHandler(state=100)
+ model1 = multi_process_shared.MultiProcessShared(
+ bad_mh.load_model, tag=tag1).acquire()
+ self.assertEqual(1, model1.predict(10))
+
+ tag2 = mm.load('key2')
+ tag3 = mm.load('key3')
+ model2 = multi_process_shared.MultiProcessShared(
+ bad_mh.load_model, tag=tag2).acquire()
+ model3 = multi_process_shared.MultiProcessShared(
+ bad_mh.load_model, tag=tag3).acquire()
+ self.assertEqual(2, model2.predict(10))
+ self.assertEqual(3, model3.predict(10))
+
+ def test_model_manager_evicts_models(self):
+ mh1 = FakeModelHandler(state=1)
+ mh2 = FakeModelHandler(state=2)
+ mh3 = FakeModelHandler(state=3)
+ mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
+ mm = base._ModelManager(mh_map=mhs, max_models=2)
+ tag1 = mm.load('key1')
+ sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
+ model1 = sh1.acquire()
+ self.assertEqual(1, model1.predict(10))
+ model1.increment_state(5)
+ self.assertEqual(6, model1.predict(10))
+ sh1.release(model1)
+
+ tag2 = mm.load('key2')
+ tag3 = mm.load('key3')
+ sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
+ model2 = sh2.acquire()
+ sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
+ model3 = sh3.acquire()
+ model2.increment_state(5)
+ model3.increment_state(5)
+ self.assertEqual(7, model2.predict(10))
+ self.assertEqual(8, model3.predict(10))
+ sh2.release(model2)
+ sh3.release(model3)
+
+ # This should get recreated, so it shouldn't have the state updates
+ model1 = multi_process_shared.MultiProcessShared(
+ mh1.load_model, tag=tag1).acquire()
+ self.assertEqual(1, model1.predict(10))
+
+ # These should not get recreated, so they should have the state updates
+ model2 = multi_process_shared.MultiProcessShared(
+ mh2.load_model, tag=tag2).acquire()
+ self.assertEqual(7, model2.predict(10))
+ model3 = multi_process_shared.MultiProcessShared(
+ mh3.load_model, tag=tag3).acquire()
+ self.assertEqual(8, model3.predict(10))
+
+ def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
+ self):
+ mh1 = FakeModelHandler(state=1)
+ mh2 = FakeModelHandler(state=2)
+ mh3 = FakeModelHandler(state=3)
+ mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
+ mm = base._ModelManager(mh_map=mhs, max_models=1)
+ mm.increment_max_models(1)
+ tag1 = mm.load('key1')
+ sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
+ model1 = sh1.acquire()
+ self.assertEqual(1, model1.predict(10))
+ model1.increment_state(5)
+ self.assertEqual(6, model1.predict(10))
+ sh1.release(model1)
+
+ tag2 = mm.load('key2')
+ tag3 = mm.load('key3')
+ sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
+ model2 = sh2.acquire()
+ sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
+ model3 = sh3.acquire()
+ model2.increment_state(5)
+ model3.increment_state(5)
+ self.assertEqual(7, model2.predict(10))
+ self.assertEqual(8, model3.predict(10))
+ sh2.release(model2)
+ sh3.release(model3)
+
+ # This should get recreated, so it shouldn't have the state updates
+ model1 = multi_process_shared.MultiProcessShared(
+ mh1.load_model, tag=tag1).acquire()
+ self.assertEqual(1, model1.predict(10))
+
+ # These should not get recreated, so they should have the state updates
+ model2 = multi_process_shared.MultiProcessShared(
+ mh2.load_model, tag=tag2).acquire()
+ self.assertEqual(7, model2.predict(10))
+ model3 = multi_process_shared.MultiProcessShared(
+ mh3.load_model, tag=tag3).acquire()
+ self.assertEqual(8, model3.predict(10))
Review Comment:
WDYT about running this portion of test scenario in a subprocess?
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_model_manager_loads_shared_model(self):
Review Comment:
How about we add a test that existing model references are still valid
after we loaded more models than max_model ?
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_model_manager_loads_shared_model(self):
Review Comment:
How about we add a test that existing model references are still valid
after we loaded more models than max_model ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]