damccorm commented on code in PR #27603:
URL: https://github.com/apache/beam/pull/27603#discussion_r1281147101
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -739,8 +739,95 @@ def with_exception_handling(
return self
+class _ModelManager:
+ """
+ A class for efficiently managing copies of multiple models. Will load a
+ single copy of each model into a multi_process_shared object and then
+ return a lookup key for that object. Optionally takes in a max_models
+ parameter, if that is set it will only hold that many models in memory at
+ once before evicting one (using LRU logic).
+ """
+ def __init__(
+ self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+ """
+ Args:
+ mh_map: A map from keys to model handlers which can be used to load a
+ model.
+ max_models: The maximum number of models to load at any given time
+ before evicting 1 from memory (using LRU logic). Leave as None to
+ allow unlimited models.
+ """
+ self._max_models = max_models
+ self._mh_map = mh_map
+ self._loaded_keys = []
+ self._proxy_map = {}
+ self._tag_map = {}
+
+ def _get_tag(self, key: str) -> str:
+ """
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to load the model using multi_process_shared.py
+ """
+ if key not in self._tag_map:
+ self._tag_map[key] = uuid.uuid4().hex
+ return self._tag_map[key]
+
+ def load(self, key: str) -> str:
+ """
+ Loads the appropriate model for the given key into memory.
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to access the model using multi_process_shared.py.
+ """
+ tag = self._get_tag(key)
+ if key in self._loaded_keys:
+ if self._max_models is not None:
+ # move key to the back of the list
+ self._loaded_keys.append(
+ self._loaded_keys.pop(self._loaded_keys.index(key)))
+ return tag
+
+ mh = self._mh_map[key]
+ if self._max_models is not None and self._max_models <= len(
+ self._loaded_keys):
+ # If we're about to exceed our LRU size,
+ # remove the front model from the list from memory.
+ key_to_remove = self._loaded_keys[0]
+ tag_to_remove = self._get_tag(key_to_remove)
+ shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+ shared_handle.release(model_to_remove)
+ del self._tag_map[key]
Review Comment:
I like the idea of relying on dictionary order, that requires an ordereddict
though, so I can't use the defaultdict, right? I updated to do this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]