damccorm commented on code in PR #31052:
URL: https://github.com/apache/beam/pull/31052#discussion_r1575457163


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
       return self._unkeyed.share_model_across_processes()
     return True
 
+  def max_shared_model_copies(self) -> int:
+    if self._single_model:
+      return self._unkeyed.max_shared_model_copies()
+    for mh in self._id_to_mh_map.values():
+      if mh.max_shared_model_copies() != 1:
+        raise ValueError(
+            'KeyedModelHandler cannot map records to multiple '
+            'models if one or more of its ModelHandlers '
+            'require multiple model copies (set via'

Review Comment:
   Done!



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1482,28 @@ def load():
     if isinstance(side_input_model_path, str) and side_input_model_path != '':
       model_tag = side_input_model_path
     if self._model_handler.share_model_across_processes():
-      model = multi_process_shared.MultiProcessShared(
-          load, tag=model_tag, always_proxy=True).acquire()
+      # TODO - update this to populate a list of models of configurable length

Review Comment:
   Oh that was from development, I just forgot to remove it. Removed



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1518,28 @@ def load():
     if isinstance(side_input_model_path, str) and side_input_model_path != '':
       model_tag = side_input_model_path
     if self._model_handler.share_model_across_processes():
-      model = multi_process_shared.MultiProcessShared(
-          load, tag=model_tag, always_proxy=True).acquire()
+      # TODO - update this to populate a list of models of configurable length
+      models = []
+      for i in range(self._model_handler.max_shared_model_copies()):
+        models.append(
+            multi_process_shared.MultiProcessShared(
+                load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+      model_wrapper = _CrossProcessModelWrapper(models, model_tag)
     else:
       model = self._shared_model_handle.acquire(load, tag=model_tag)
+      model_wrapper = _CrossProcessModelWrapper([model], model_tag)

Review Comment:
   I thought it was cleaner to just have a single object here so that consumers 
don't need to worry about it. `_CrossProcessModelWrapper` should be agnostic to 
which loading method is used.
   
   Maybe this is just a naming issue - does `_SharedModelWrapper` sound cleaner?



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1518,28 @@ def load():
     if isinstance(side_input_model_path, str) and side_input_model_path != '':
       model_tag = side_input_model_path
     if self._model_handler.share_model_across_processes():
-      model = multi_process_shared.MultiProcessShared(
-          load, tag=model_tag, always_proxy=True).acquire()
+      # TODO - update this to populate a list of models of configurable length
+      models = []
+      for i in range(self._model_handler.max_shared_model_copies()):
+        models.append(
+            multi_process_shared.MultiProcessShared(
+                load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+      model_wrapper = _CrossProcessModelWrapper(models, model_tag)
     else:
       model = self._shared_model_handle.acquire(load, tag=model_tag)
+      model_wrapper = _CrossProcessModelWrapper([model], model_tag)

Review Comment:
   I tried this, let me know what you think



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1378,6 +1421,45 @@ def update(
     self._inference_request_batch_byte_size.update(examples_byte_size)
 
 
+class _ModelRoutingStrategy():
+  """A class meant to sit in a shared location for mapping incoming batches to
+  different models. Currently only supports round-robin, but can be extended
+  to support other protocols if needed.
+  """
+  def __init__(self):
+    self._cur_index = 0
+
+  def next_model_index(self, num_models):
+    self._cur_index = (self._cur_index + 1) % num_models
+    return self._cur_index
+
+
+class _CrossProcessModelWrapper():
+  """A router class to map incoming calls to the correct model.
+  
+    This allows us to round robin calls to models sitting in different
+    processes so that we can more efficiently use resources (e.g. GPUs).
+  """
+  def __init__(self, models: List[Any], model_tag: str):
+    self.models = models
+    if len(models) > 0:

Review Comment:
   No, should be `> 1`. Otherwise we don't need this.



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1518,28 @@ def load():
     if isinstance(side_input_model_path, str) and side_input_model_path != '':
       model_tag = side_input_model_path
     if self._model_handler.share_model_across_processes():
-      model = multi_process_shared.MultiProcessShared(
-          load, tag=model_tag, always_proxy=True).acquire()
+      # TODO - update this to populate a list of models of configurable length
+      models = []
+      for i in range(self._model_handler.max_shared_model_copies()):
+        models.append(
+            multi_process_shared.MultiProcessShared(
+                load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+      model_wrapper = _CrossProcessModelWrapper(models, model_tag)
     else:
       model = self._shared_model_handle.acquire(load, tag=model_tag)
+      model_wrapper = _CrossProcessModelWrapper([model], model_tag)
     # since shared_model_handle is shared across threads, the model path
     # might not get updated in the model handler
     # because we directly get cached weak ref model from shared cache, instead
     # of calling load(). For sanity check, call update_model_path again.
     if isinstance(side_input_model_path, str):
       self._model_handler.update_model_path(side_input_model_path)
     else:
-      self._model_handler.update_model_paths(self._model, 
side_input_model_path)
-    return model
+      if self._model is not None:
+        models = self._model.all_models()
+        for m in models:
+          self._model_handler.update_model_paths(m, side_input_model_path)
+    return model_wrapper

Review Comment:
   We actually can now since we're always returning `_CrossProcessModelWrapper` 
which is a nice improvement. Updated



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -79,11 +90,15 @@ def __init__(
     self._env_vars = kwargs.get('env_vars', {})
     self._multi_process_shared = multi_process_shared
     self._state = state
+    self._incrementing = incrementing
+    self._max_copies = max_copies
     self._num_bytes_per_element = num_bytes_per_element
 
   def load_model(self):
     if self._fake_clock:
       self._fake_clock.current_time_ns += 500_000_000  # 500ms
+    if self._incrementing:
+      return FakeIncrementingModel()

Review Comment:
   Done!



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -234,6 +235,9 @@ def __init__(
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or

Review Comment:
   > Possible wording suggestion:
   
   Updated
   
   > Maybe this should be a ValueError if a user specifies both?
   
   I don't think it should be a ValueError - if you change it to True and you 
set this param, that is kinda reasonable and a no-op makes sense IMO since 
we're still honoring your choice.



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1378,6 +1421,45 @@ def update(
     self._inference_request_batch_byte_size.update(examples_byte_size)
 
 
+class _ModelRoutingStrategy():
+  """A class meant to sit in a shared location for mapping incoming batches to
+  different models. Currently only supports round-robin, but can be extended
+  to support other protocols if needed.
+  """
+  def __init__(self):
+    self._cur_index = 0
+
+  def next_model_index(self, num_models):
+    self._cur_index = (self._cur_index + 1) % num_models
+    return self._cur_index
+
+
+class _CrossProcessModelWrapper():
+  """A router class to map incoming calls to the correct model.
+  
+    This allows us to round robin calls to models sitting in different
+    processes so that we can more efficiently use resources (e.g. GPUs).
+  """
+  def __init__(self, models: List[Any], model_tag: str):
+    self.models = models
+    if len(models) > 0:

Review Comment:
   Fixed



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1416,8 +1498,10 @@ def load():
       if isinstance(side_input_model_path, str):
         self._model_handler.update_model_path(side_input_model_path)
       else:
-        self._model_handler.update_model_paths(
-            self._model, side_input_model_path)
+        if self._model is not None:

Review Comment:
   We initially assign it to None, and IIRC there's some update paths where 
this can get called before it has been assigned. Right now `update_model_paths` 
is expected to handle (and no-op) the `None` case, so this is at least a bit 
cleaner



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to