This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 14fd366f0ad Add support for sharing models across steps (#31665)
14fd366f0ad is described below
commit 14fd366f0ad760a94994932b978caf8d3a59f15a
Author: Danny McCormick <[email protected]>
AuthorDate: Fri Jun 21 16:51:02 2024 +0200
Add support for sharing models across steps (#31665)
* Add support for sharing models across steps
* Tests + CHANGES
---
CHANGES.md | 2 +-
sdks/python/apache_beam/ml/inference/base.py | 15 +++++++--
sdks/python/apache_beam/ml/inference/base_test.py | 39 +++++++++++++++++++++++
3 files changed, 53 insertions(+), 3 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 0d4de2e3d58..ffff8159614 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -66,7 +66,7 @@
## New Features / Improvements
-* X feature added (Java/Python)
([#X](https://github.com/apache/beam/issues/X)).
+* Multiple RunInference instances can now share the same model instance by
setting the model_identifier parameter (Python)
([#31665](https://github.com/apache/beam/issues/31665)).
## Breaking Changes
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 6fe2d5acc5c..401b57fdb80 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -1130,6 +1130,7 @@ class
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
*,
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
watch_model_pattern: Optional[str] = None,
+ model_identifier: Optional[str] = None,
**kwargs):
"""
A transform that takes a PCollection of examples (or features) for use
@@ -1154,6 +1155,12 @@ class
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
to the _RunInferenceDoFn.
watch_model_pattern: A glob pattern used to watch a directory
for automatic model refresh.
+ model_identifier: A string used to identify the model being loaded. You
+ can set this if you want to reuse the same model across multiple
+ RunInference steps and don't want to reload it twice. Note that using
+ the same tag for different models will lead to non-deterministic
+ results, so exercise caution when using this parameter. This only
+ impacts models which are already being shared across processes.
"""
self._model_handler = model_handler
self._inference_args = inference_args
@@ -1164,8 +1171,12 @@ class
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
self._watch_model_pattern = watch_model_pattern
self._kwargs = kwargs
# Generate a random tag to use for shared.py and multi_process_shared.py to
- # allow us to effectively disambiguate in multi-model settings.
- self._model_tag = uuid.uuid4().hex
+ # allow us to effectively disambiguate in multi-model settings. Only use
+ # the same tag if the model being loaded across multiple steps is actually
+ # the same.
+ self._model_tag = model_identifier
+ if model_identifier is None:
+ self._model_tag = uuid.uuid4().hex
def annotations(self):
return {
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index ec1664f494c..b392ebd3085 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -1592,6 +1592,45 @@ class RunInferenceBaseTest(unittest.TestCase):
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))
+ def test_run_inference_loads_different_models(self):
+ mh1 = FakeModelHandler(incrementing=True, min_batch_size=3)
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+ actual = (
+ pcoll
+ | 'ri1' >> base.RunInference(mh1)
+ | 'ri2' >> base.RunInference(mh1))
+ assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')
+
+ def test_run_inference_loads_different_models_multi_process_shared(self):
+ mh1 = FakeModelHandler(
+ incrementing=True, min_batch_size=3, multi_process_shared=True)
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+ actual = (
+ pcoll
+ | 'ri1' >> base.RunInference(mh1)
+ | 'ri2' >> base.RunInference(mh1))
+ assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')
+
+ def test_runinference_loads_same_model_with_identifier_multi_process_shared(
+ self):
+ mh1 = FakeModelHandler(
+ incrementing=True, min_batch_size=3, multi_process_shared=True)
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
+ actual = (
+ pcoll
+ | 'ri1' >> base.RunInference(
+ mh1,
+
model_identifier='same_model_with_identifier_multi_process_shared'
+ )
+ | 'ri2' >> base.RunInference(
+ mh1,
+
model_identifier='same_model_with_identifier_multi_process_shared'
+ ))
+ assert_that(actual, equal_to([4, 5, 6]), label='assert:inferences')
+
def test_run_inference_watch_file_pattern_side_input_label(self):
pipeline = TestPipeline()
# label of the WatchPattern transform.