This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 14fd366f0ad Add support for sharing models across steps (#31665)
14fd366f0ad is described below

commit 14fd366f0ad760a94994932b978caf8d3a59f15a
Author: Danny McCormick <[email protected]>
AuthorDate: Fri Jun 21 16:51:02 2024 +0200

    Add support for sharing models across steps (#31665)
    
    * Add support for sharing models across steps
    
    * Tests + CHANGES
---
 CHANGES.md                                        |  2 +-
 sdks/python/apache_beam/ml/inference/base.py      | 15 +++++++--
 sdks/python/apache_beam/ml/inference/base_test.py | 39 +++++++++++++++++++++++
 3 files changed, 53 insertions(+), 3 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 0d4de2e3d58..ffff8159614 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -66,7 +66,7 @@
 
 ## New Features / Improvements
 
-* X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
+* Multiple RunInference instances can now share the same model instance by 
setting the model_identifier parameter (Python) 
([#31665](https://github.com/apache/beam/issues/31665)).
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 6fe2d5acc5c..401b57fdb80 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -1130,6 +1130,7 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
       *,
       model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
       watch_model_pattern: Optional[str] = None,
+      model_identifier: Optional[str] = None,
       **kwargs):
     """
     A transform that takes a PCollection of examples (or features) for use
@@ -1154,6 +1155,12 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
           to the _RunInferenceDoFn.
         watch_model_pattern: A glob pattern used to watch a directory
           for automatic model refresh.
+        model_identifier: A string used to identify the model being loaded. You
+          can set this if you want to reuse the same model across multiple
+          RunInference steps and don't want to reload it twice. Note that using
+          the same tag for different models will lead to non-deterministic
+          results, so exercise caution when using this parameter. This only
+          impacts models which are already being shared across processes.
     """
     self._model_handler = model_handler
     self._inference_args = inference_args
@@ -1164,8 +1171,12 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
     self._watch_model_pattern = watch_model_pattern
     self._kwargs = kwargs
     # Generate a random tag to use for shared.py and multi_process_shared.py to
-    # allow us to effectively disambiguate in multi-model settings.
-    self._model_tag = uuid.uuid4().hex
+    # allow us to effectively disambiguate in multi-model settings. Only use
+    # the same tag if the model being loaded across multiple steps is actually
+    # the same.
+    self._model_tag = model_identifier
+    if model_identifier is None:
+      self._model_tag = uuid.uuid4().hex
 
   def annotations(self):
     return {
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index ec1664f494c..b392ebd3085 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -1592,6 +1592,45 @@ class RunInferenceBaseTest(unittest.TestCase):
         mh3.load_model, tag=tag3).acquire()
     self.assertEqual(8, model3.predict(10))
 
+  def test_run_inference_loads_different_models(self):
+    mh1 = FakeModelHandler(incrementing=True, min_batch_size=3)
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+      actual = (
+          pcoll
+          | 'ri1' >> base.RunInference(mh1)
+          | 'ri2' >> base.RunInference(mh1))
+      assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')
+
+  def test_run_inference_loads_different_models_multi_process_shared(self):
+    mh1 = FakeModelHandler(
+        incrementing=True, min_batch_size=3, multi_process_shared=True)
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+      actual = (
+          pcoll
+          | 'ri1' >> base.RunInference(mh1)
+          | 'ri2' >> base.RunInference(mh1))
+      assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')
+
+  def test_runinference_loads_same_model_with_identifier_multi_process_shared(
+      self):
+    mh1 = FakeModelHandler(
+        incrementing=True, min_batch_size=3, multi_process_shared=True)
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+      actual = (
+          pcoll
+          | 'ri1' >> base.RunInference(
+              mh1,
+              
model_identifier='same_model_with_identifier_multi_process_shared'
+          )
+          | 'ri2' >> base.RunInference(
+              mh1,
+              
model_identifier='same_model_with_identifier_multi_process_shared'
+          ))
+      assert_that(actual, equal_to([4, 5, 6]), label='assert:inferences')
+
   def test_run_inference_watch_file_pattern_side_input_label(self):
     pipeline = TestPipeline()
     # label of the WatchPattern transform.

Reply via email to