damccorm commented on code in PR #37506:
URL: https://github.com/apache/beam/pull/37506#discussion_r2775930893


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1803,31 +1814,70 @@ def load_model_status(
   return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)
 
 
+class _ProxyLoader:
+  """
+  A helper callable to wrap the loader for MultiProcessShared.
+  """
+  def __init__(self, loader_func, model_tag):
+    self.loader_func = loader_func
+    self.model_tag = model_tag
+
+  def __call__(self):
+    unique_tag = self.model_tag + '_' + uuid.uuid4().hex

Review Comment:
   Why are we making this a unique tag every time? Won't this cause 
https://github.com/apache/beam/blob/f504c5309bcfe95ab99035d9d71c9b99f1ca2e43/sdks/python/apache_beam/ml/inference/model_manager.py#L558
 to always miss?



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1803,31 +1814,70 @@ def load_model_status(
   return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)
 
 
+class _ProxyLoader:
+  """
+  A helper callable to wrap the loader for MultiProcessShared.
+  """
+  def __init__(self, loader_func, model_tag):
+    self.loader_func = loader_func
+    self.model_tag = model_tag
+
+  def __call__(self):
+    unique_tag = self.model_tag + '_' + uuid.uuid4().hex
+    # Ensure that each model loaded in a different process for parallelism
+    multi_process_shared.MultiProcessShared(
+        self.loader_func, tag=unique_tag, always_proxy=True,
+        spawn_process=True).acquire()
+    # Only return the tag to avoid pickling issues with the model itself.
+    return unique_tag
+
+
 class _SharedModelWrapper():
   """A router class to map incoming calls to the correct model.
 
     This allows us to round robin calls to models sitting in different
     processes so that we can more efficiently use resources (e.g. GPUs).
   """
-  def __init__(self, models: list[Any], model_tag: str):
+  def __init__(
+      self,
+      models: Union[list[Any], ModelManager],
+      model_tag: str,
+      loader_func: Callable[[], Any] = None):

Review Comment:
   ```suggestion
         loader_func: Optional[Callable[[], Any]] = None):
   ```
   
   nit



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2279,5 +2279,65 @@ def test_max_batch_duration_secs_only(self):
     self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
 
 
+class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]):

Review Comment:
   Spelling nit:
   
   
   ```suggestion
   class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1803,31 +1814,70 @@ def load_model_status(
   return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)
 
 
+class _ProxyLoader:
+  """
+  A helper callable to wrap the loader for MultiProcessShared.
+  """
+  def __init__(self, loader_func, model_tag):
+    self.loader_func = loader_func
+    self.model_tag = model_tag
+
+  def __call__(self):
+    unique_tag = self.model_tag + '_' + uuid.uuid4().hex

Review Comment:
   Oh I see - its because the unique_tag isn't the thing that is used for the 
LRU cache, it only gets referenced when we actually load the model (which we 
only want to do once per tag in the model manager).
   
   I think this is correct, but could you add some more high level comments to 
the class explaining how the helper callable is used?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to