This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/shareModelAcrossSteps in repository https://gitbox.apache.org/repos/asf/beam.git
commit db22dc41b8fa6e57898864a79a0c75b463993002 Author: Danny McCormick <[email protected]> AuthorDate: Fri Jun 21 14:58:36 2024 +0200 Add support for sharing models across steps --- sdks/python/apache_beam/ml/inference/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 6fe2d5acc5c..48b76434b4c 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1130,6 +1130,7 @@ class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, *, model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, watch_model_pattern: Optional[str] = None, + model_identifier: Optional[str] = None, **kwargs): """ A transform that takes a PCollection of examples (or features) for use @@ -1154,6 +1155,11 @@ class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, to the _RunInferenceDoFn. watch_model_pattern: A glob pattern used to watch a directory for automatic model refresh. + model_identifier: A string used to identify the model being loaded. You + can set this if you want to reuse the same model across multiple + RunInference steps and don't want to reload it twice. Note that using + the same tag for different models will lead to non-deterministic + results, so exercise caution when using this parameter. """ self._model_handler = model_handler self._inference_args = inference_args @@ -1164,8 +1170,12 @@ class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, self._watch_model_pattern = watch_model_pattern self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to - # allow us to effectively disambiguate in multi-model settings. - self._model_tag = uuid.uuid4().hex + # allow us to effectively disambiguate in multi-model settings. Only use + # the same tag if the model being loaded across multiple steps is actually + # the same. + self._model_tag = model_identifier + if model_identifier is None: + self._model_tag = uuid.uuid4().hex def annotations(self): return {
