damccorm commented on code in PR #37506: URL: https://github.com/apache/beam/pull/37506#discussion_r2775930893
########## sdks/python/apache_beam/ml/inference/base.py: ########## @@ -1803,31 +1814,70 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex Review Comment: Why are we making this a unique tag every time? Won't this cause https://github.com/apache/beam/blob/f504c5309bcfe95ab99035d9d71c9b99f1ca2e43/sdks/python/apache_beam/ml/inference/model_manager.py#L558 to always miss? ########## sdks/python/apache_beam/ml/inference/base.py: ########## @@ -1803,31 +1814,70 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + # Ensure that each model loaded in a different process for parallelism + multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True, + spawn_process=True).acquire() + # Only return the tag to avoid pickling issues with the model itself. + return unique_tag + + class _SharedModelWrapper(): """A router class to map incoming calls to the correct model. This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ - def __init__(self, models: list[Any], model_tag: str): + def __init__( + self, + models: Union[list[Any], ModelManager], + model_tag: str, + loader_func: Callable[[], Any] = None): Review Comment: ```suggestion loader_func: Optional[Callable[[], Any]] = None): ``` nit ########## sdks/python/apache_beam/ml/inference/base_test.py: ########## @@ -2279,5 +2279,65 @@ def test_max_batch_duration_secs_only(self): self.assertEqual(kwargs, {'max_batch_duration_secs': 60}) +class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]): Review Comment: Spelling nit: ```suggestion class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): ``` ########## sdks/python/apache_beam/ml/inference/base.py: ########## @@ -1803,31 +1814,70 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex Review Comment: Oh I see - its because the unique_tag isn't the thing that is used for the LRU cache, it only gets referenced when we actually load the model (which we only want to do once per tag in the model manager). I think this is correct, but could you add some more high level comments to the class explaining how the helper callable is used? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
