This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/shareModelAcrossSteps
in repository https://gitbox.apache.org/repos/asf/beam.git

commit db22dc41b8fa6e57898864a79a0c75b463993002
Author: Danny McCormick <[email protected]>
AuthorDate: Fri Jun 21 14:58:36 2024 +0200

    Add support for sharing models across steps
---
 sdks/python/apache_beam/ml/inference/base.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 6fe2d5acc5c..48b76434b4c 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -1130,6 +1130,7 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
       *,
       model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
       watch_model_pattern: Optional[str] = None,
+      model_identifier: Optional[str] = None,
       **kwargs):
     """
     A transform that takes a PCollection of examples (or features) for use
@@ -1154,6 +1155,11 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
           to the _RunInferenceDoFn.
         watch_model_pattern: A glob pattern used to watch a directory
           for automatic model refresh.
+        model_identifier: A string used to identify the model being loaded. You
+          can set this if you want to reuse the same model across multiple
+          RunInference steps and don't want to reload it twice. Note that using
+          the same tag for different models will lead to non-deterministic
+          results, so exercise caution when using this parameter.
     """
     self._model_handler = model_handler
     self._inference_args = inference_args
@@ -1164,8 +1170,12 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
     self._watch_model_pattern = watch_model_pattern
     self._kwargs = kwargs
     # Generate a random tag to use for shared.py and multi_process_shared.py to
-    # allow us to effectively disambiguate in multi-model settings.
-    self._model_tag = uuid.uuid4().hex
+    # allow us to effectively disambiguate in multi-model settings. Only use
+    # the same tag if the model being loaded across multiple steps is actually
+    # the same.
+    self._model_tag = model_identifier
+    if model_identifier is None:
+      self._model_tag = uuid.uuid4().hex
 
   def annotations(self):
     return {

Reply via email to