This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6602f1e125f Update RunInference to work with model manager (#37506)
6602f1e125f is described below

commit 6602f1e125ff93be1278442110bb19678ff508ab
Author: RuiLong J. <[email protected]>
AuthorDate: Sun Feb 8 18:36:54 2026 -0800

    Update RunInference to work with model manager (#37506)
    
    * Update RunInference to work with model manager
    
    * Fix lint
    
    * More lint
    
    * Add unittest main call
    
    * Update sdks/python/apache_beam/ml/inference/base.py
    
    Co-authored-by: Danny McCormick <[email protected]>
    
    * Update sdks/python/apache_beam/ml/inference/base_test.py
    
    Co-authored-by: Danny McCormick <[email protected]>
    
    * Add some comments explaining the model loading logistics
    
    * Update name in tests as well
    
    ---------
    
    Co-authored-by: Danny McCormick <[email protected]>
---
 sdks/python/apache_beam/ml/inference/base.py       | 106 ++++++++++--
 sdks/python/apache_beam/ml/inference/base_test.py  |  78 ++++++++-
 .../ml/inference/model_manager_it_test.py          | 191 +++++++++++++++++++++
 3 files changed, 350 insertions(+), 25 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index ad2e2f8d0e3..1c3f0918baf 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -68,8 +68,11 @@ except ImportError:
 try:
   # pylint: disable=wrong-import-order, wrong-import-position
   import resource
+
+  from apache_beam.ml.inference.model_manager import ModelManager
 except ImportError:
   resource = None  # type: ignore[assignment]
+  ModelManager = None  # type: ignore[assignment]
 
 _NANOSECOND_TO_MILLISECOND = 1_000_000
 _NANOSECOND_TO_MICROSECOND = 1_000
@@ -533,11 +536,12 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT, 
PredictionT, ModelT]):
     raise NotImplementedError(type(self))
 
 
-class _ModelManager:
+class _ModelHandlerManager:
   """
-  A class for efficiently managing copies of multiple models. Will load a
-  single copy of each model into a multi_process_shared object and then
-  return a lookup key for that object.
+  A class for efficiently managing copies of multiple model handlers.
+  Will load a single copy of each model from the model handler into a
+  multi_process_shared object and then return a lookup key for that
+  object. Used for KeyedModelHandler only.
   """
   def __init__(self, mh_map: dict[str, ModelHandler]):
     """
@@ -602,8 +606,9 @@ class _ModelManager:
 
   def increment_max_models(self, increment: int):
     """
-    Increments the number of models that this instance of a _ModelManager is
-    able to hold. If it is never called, no limit is imposed.
+    Increments the number of models that this instance of a
+    _ModelHandlerManager is able to hold. If it is never called,
+    no limit is imposed.
     Args:
       increment: the amount by which we are incrementing the number of models.
     """
@@ -656,7 +661,7 @@ class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, 
ModelT]):
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[tuple[KeyT, ExampleT],
                                      tuple[KeyT, PredictionT],
-                                     Union[ModelT, _ModelManager]]):
+                                     Union[ModelT, _ModelHandlerManager]]):
   def __init__(
       self,
       unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
@@ -809,15 +814,15 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
               'to exactly one model handler.')
         self._key_to_id_map[key] = keys[0]
 
-  def load_model(self) -> Union[ModelT, _ModelManager]:
+  def load_model(self) -> Union[ModelT, _ModelHandlerManager]:
     if self._single_model:
       return self._unkeyed.load_model()
-    return _ModelManager(self._id_to_mh_map)
+    return _ModelHandlerManager(self._id_to_mh_map)
 
   def run_inference(
       self,
       batch: Sequence[tuple[KeyT, ExampleT]],
-      model: Union[ModelT, _ModelManager],
+      model: Union[ModelT, _ModelHandlerManager],
       inference_args: Optional[dict[str, Any]] = None
   ) -> Iterable[tuple[KeyT, PredictionT]]:
     if self._single_model:
@@ -919,7 +924,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
 
   def update_model_paths(
       self,
-      model: Union[ModelT, _ModelManager],
+      model: Union[ModelT, _ModelHandlerManager],
       model_paths: list[KeyModelPathMapping[KeyT]] = None):
     # When there are many models, the keyed model handler is responsible for
     # reorganizing the model handlers into cohorts and telling the model
@@ -1338,6 +1343,8 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
       model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
       watch_model_pattern: Optional[str] = None,
       model_identifier: Optional[str] = None,
+      use_model_manager: bool = False,
+      model_manager_args: Optional[dict[str, Any]] = None,
       **kwargs):
     """
     A transform that takes a PCollection of examples (or features) for use
@@ -1378,6 +1385,8 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
     self._exception_handling_timeout = None
     self._timeout = None
     self._watch_model_pattern = watch_model_pattern
+    self._use_model_manager = use_model_manager
+    self._model_manager_args = model_manager_args
     self._kwargs = kwargs
     # Generate a random tag to use for shared.py and multi_process_shared.py to
     # allow us to effectively disambiguate in multi-model settings. Only use
@@ -1490,7 +1499,9 @@ class 
RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
             self._clock,
             self._metrics_namespace,
             load_model_at_runtime,
-            self._model_tag),
+            self._model_tag,
+            self._use_model_manager,
+            self._model_manager_args),
         self._inference_args,
         beam.pvalue.AsSingleton(
             self._model_metadata_pcoll,
@@ -1803,21 +1814,55 @@ def load_model_status(
   return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)
 
 
+class _ProxyLoader:
+  """
+  A helper callable to wrap the loader for MultiProcessShared.
+  """
+  def __init__(self, loader_func, model_tag):
+    self.loader_func = loader_func
+    self.model_tag = model_tag
+
+  def __call__(self):
+    # Generate a unique tag for the model being loaded so that
+    # we will have unique instances of the model in multi_process_shared
+    # space instead of reusing the same instance over. The instance will
+    # be initialized and left running as a separate process, which then
+    # can be grabbed again using the unique tag if needed during inference.
+    unique_tag = self.model_tag + '_' + uuid.uuid4().hex
+    # Ensure that each model loaded in a different process for parallelism
+    multi_process_shared.MultiProcessShared(
+        self.loader_func, tag=unique_tag, always_proxy=True,
+        spawn_process=True).acquire()
+    # Only return the tag to avoid pickling issues with the model itself.
+    return unique_tag
+
+
 class _SharedModelWrapper():
   """A router class to map incoming calls to the correct model.
 
     This allows us to round robin calls to models sitting in different
     processes so that we can more efficiently use resources (e.g. GPUs).
   """
-  def __init__(self, models: list[Any], model_tag: str):
+  def __init__(
+      self,
+      models: Union[list[Any], ModelManager],
+      model_tag: str,
+      loader_func: Optional[Callable[[], Any]] = None):
     self.models = models
-    if len(models) > 1:
+    self.use_model_manager = not isinstance(models, list)
+    self.model_tag = model_tag
+    self.loader_func = loader_func
+    if not self.use_model_manager and len(models) > 1:
       self.model_router = multi_process_shared.MultiProcessShared(
           lambda: _ModelRoutingStrategy(),
           tag=f'{model_tag}_counter',
           always_proxy=True).acquire()
 
   def next_model(self):
+    if self.use_model_manager:
+      loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag)
+      return self.models.acquire_model(self.model_tag, loader_wrapper)
+
     if len(self.models) == 1:
       # Short circuit if there's no routing strategy needed in order to
       # avoid the cross-process call
@@ -1825,9 +1870,19 @@ class _SharedModelWrapper():
 
     return self.models[self.model_router.next_model_index(len(self.models))]
 
+  def release_model(self, model_tag: str, model: Any):
+    if self.use_model_manager:
+      self.models.release_model(model_tag, model)
+
   def all_models(self):
+    if self.use_model_manager:
+      return self.models.all_models()[self.model_tag]
     return self.models
 
+  def force_reset(self):
+    if self.use_model_manager:
+      self.models.force_reset()
+
 
 class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   def __init__(
@@ -1836,7 +1891,9 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
       clock,
       metrics_namespace,
       load_model_at_runtime: bool = False,
-      model_tag: str = "RunInference"):
+      model_tag: str = "RunInference",
+      use_model_manager: bool = False,
+      model_manager_args: Optional[dict[str, Any]] = None):
     """A DoFn implementation generic to frameworks.
 
       Args:
@@ -1860,6 +1917,8 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
     # _cur_tag is the tag of the actually loaded model
     self._model_tag = model_tag
     self._cur_tag = model_tag
+    self.use_model_manager = use_model_manager
+    self._model_manager_args = model_manager_args or {}
 
   def _load_model(
       self,
@@ -1894,7 +1953,15 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
       model_tag = side_input_model_path
     # Ensure the tag we're loading is valid, if not replace it with a valid tag
     self._cur_tag = self._model_metadata.get_valid_tag(model_tag)
-    if self._model_handler.share_model_across_processes():
+    if self.use_model_manager:
+      logging.info("Using Model Manager to manage models automatically.")
+      model_manager = multi_process_shared.MultiProcessShared(
+          lambda: ModelManager(**self._model_manager_args),
+          tag='model_manager',
+          always_proxy=True).acquire()
+      model_wrapper = _SharedModelWrapper(
+          model_manager, self._cur_tag, self._model_handler.load_model)
+    elif self._model_handler.share_model_across_processes():
       models = []
       for copy_tag in _get_tags_for_copies(self._cur_tag,
                                            self._model_handler.model_copies()):
@@ -1949,8 +2016,15 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
     start_time = _to_microseconds(self._clock.time_ns())
     try:
       model = self._model.next_model()
+      if isinstance(model, str):
+        # ModelManager with MultiProcessShared returns the model tag
+        unique_tag = model
+        model = multi_process_shared.MultiProcessShared(
+            lambda: None, tag=model, always_proxy=True).acquire()
       result_generator = self._model_handler.run_inference(
           batch, model, inference_args)
+      if self.use_model_manager:
+        self._model.release_model(self._model_tag, unique_tag)
     except BaseException as e:
       if self._metrics_collector:
         self._metrics_collector.failed_batches_counter.inc()
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 55784166ad5..feccd8b0f12 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -17,6 +17,7 @@
 
 """Tests for apache_beam.ml.base."""
 import math
+import multiprocessing
 import os
 import pickle
 import sys
@@ -1599,13 +1600,13 @@ class RunInferenceBaseTest(unittest.TestCase):
       actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
-  def test_model_manager_loads_shared_model(self):
+  def test_model_handler_manager_loads_shared_model(self):
     mhs = {
         'key1': FakeModelHandler(state=1),
         'key2': FakeModelHandler(state=2),
         'key3': FakeModelHandler(state=3)
     }
-    mm = base._ModelManager(mh_map=mhs)
+    mm = base._ModelHandlerManager(mh_map=mhs)
     tag1 = mm.load('key1').model_tag
     # Use bad_mh's load function to make sure we're actually loading the
     # version already stored
@@ -1623,12 +1624,12 @@ class RunInferenceBaseTest(unittest.TestCase):
     self.assertEqual(2, model2.predict(10))
     self.assertEqual(3, model3.predict(10))
 
-  def test_model_manager_evicts_models(self):
+  def test_model_handler_manager_evicts_models(self):
     mh1 = FakeModelHandler(state=1)
     mh2 = FakeModelHandler(state=2)
     mh3 = FakeModelHandler(state=3)
     mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
-    mm = base._ModelManager(mh_map=mhs)
+    mm = base._ModelHandlerManager(mh_map=mhs)
     mm.increment_max_models(2)
     tag1 = mm.load('key1').model_tag
     sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
@@ -1667,10 +1668,10 @@ class RunInferenceBaseTest(unittest.TestCase):
         mh3.load_model, tag=tag3).acquire()
     self.assertEqual(8, model3.predict(10))
 
-  def test_model_manager_evicts_models_after_update(self):
+  def test_model_handler_manager_evicts_models_after_update(self):
     mh1 = FakeModelHandler(state=1)
     mhs = {'key1': mh1}
-    mm = base._ModelManager(mh_map=mhs)
+    mm = base._ModelHandlerManager(mh_map=mhs)
     tag1 = mm.load('key1').model_tag
     sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
     model1 = sh1.acquire()
@@ -1697,13 +1698,12 @@ class RunInferenceBaseTest(unittest.TestCase):
     self.assertEqual(6, model1.predict(10))
     sh1.release(model1)
 
-  def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
-      self):
+  def test_model_handler_manager_evicts_models_after_being_incremented(self):
     mh1 = FakeModelHandler(state=1)
     mh2 = FakeModelHandler(state=2)
     mh3 = FakeModelHandler(state=3)
     mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
-    mm = base._ModelManager(mh_map=mhs)
+    mm = base._ModelHandlerManager(mh_map=mhs)
     mm.increment_max_models(1)
     mm.increment_max_models(1)
     tag1 = mm.load('key1').model_tag
@@ -2279,5 +2279,65 @@ class ModelHandlerBatchingArgsTest(unittest.TestCase):
     self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
 
 
+class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
+  def load_model(self):
+    return FakeModel()
+
+  def run_inference(
+      self,
+      batch: Sequence[int],
+      model: FakeModel,
+      inference_args=None) -> Iterable[int]:
+    for example in batch:
+      yield model.predict(example)
+
+
+def try_import_model_manager():
+  try:
+    # pylint: disable=unused-import
+    from apache_beam.ml.inference.model_manager import ModelManager
+    return True
+  except ImportError:
+    return False
+
+
+class ModelManagerTest(unittest.TestCase):
+  """Tests for RunInference with Model Manager integration."""
+  def tearDown(self):
+    for p in multiprocessing.active_children():
+      p.terminate()
+      p.join()
+
+  @unittest.skipIf(
+      not try_import_model_manager(), 'Model Manager not available')
+  def test_run_inference_impl_with_model_manager(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          SimpleFakeModelHandler(), use_model_manager=True)
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  @unittest.skipIf(
+      not try_import_model_manager(), 'Model Manager not available')
+  def test_run_inference_impl_with_model_manager_args(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          SimpleFakeModelHandler(),
+          use_model_manager=True,
+          model_manager_args={
+              'slack_percentage': 0.2,
+              'poll_interval': 1.0,
+              'peak_window_seconds': 10.0,
+              'min_data_points': 10,
+              'smoothing_factor': 0.5
+          })
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py 
b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py
new file mode 100644
index 00000000000..eaa645b1216
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py
@@ -0,0 +1,191 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+# pylint: disable=ungrouped-imports
+try:
+  import torch
+
+  from apache_beam.ml.inference.huggingface_inference import 
HuggingFacePipelineModelHandler
+except ImportError as e:
+  raise unittest.SkipTest(
+      "HuggingFace model handler dependencies are not installed")
+
+
+class HuggingFaceGpuTest(unittest.TestCase):
+
+  # Skips the test if you run it on a machine without a GPU
+  @unittest.skipIf(
+      not torch.cuda.is_available(), "No GPU detected, skipping GPU test")
+  def test_sentiment_analysis_on_gpu_large_input(self):
+    """
+    Runs inference on a GPU (device=0) with a larger set of inputs.
+    """
+    model_handler = HuggingFacePipelineModelHandler(
+        task="sentiment-analysis",
+        model="distilbert-base-uncased-finetuned-sst-2-english",
+        device=0,
+        inference_args={"batch_size": 4})
+    DUPLICATE_FACTOR = 2
+
+    with TestPipeline() as pipeline:
+      examples = [
+          "I absolutely love this product, it's a game changer!",
+          "This is the worst experience I have ever had.",
+          "The weather is okay, but I wish it were sunnier.",
+          "Apache Beam makes parallel processing incredibly efficient.",
+          "I am extremely disappointed with the service.",
+          "Logic and reason are the pillars of good debugging.",
+          "I'm so happy today!",
+          "This error message is confusing and unhelpful.",
+          "The movie was fantastic and the acting was superb.",
+          "I hate waiting in line for so long."
+      ] * DUPLICATE_FACTOR
+
+      pcoll = pipeline | 'CreateInputs' >> beam.Create(examples)
+
+      predictions = pcoll | 'RunInference' >> RunInference(
+          model_handler, use_model_manager=True)
+
+      actual_labels = predictions | beam.Map(lambda x: x.inference['label'])
+
+      expected_labels = [
+          'POSITIVE',  # "love this product"
+          'NEGATIVE',  # "worst experience"
+          'NEGATIVE',  # "weather is okay, but..."
+          'POSITIVE',  # "incredibly efficient"
+          'NEGATIVE',  # "disappointed"
+          'POSITIVE',  # "pillars of good debugging"
+          'POSITIVE',  # "so happy"
+          'NEGATIVE',  # "confusing and unhelpful"
+          'POSITIVE',  # "fantastic"
+          'NEGATIVE'  # "hate waiting"
+      ] * DUPLICATE_FACTOR
+
+      assert_that(
+          actual_labels, equal_to(expected_labels), label='CheckPredictions')
+
+  @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected")
+  def test_sentiment_analysis_large_roberta_gpu(self):
+    """
+    Runs inference using a Large architecture (RoBERTa-Large, ~355M params).
+    This tests if the GPU can handle larger weights and requires more VRAM.
+    """
+
+    model_handler = HuggingFacePipelineModelHandler(
+        task="sentiment-analysis",
+        model="Siebert/sentiment-roberta-large-english",
+        device=0,
+        inference_args={"batch_size": 2})
+
+    DUPLICATE_FACTOR = 2
+
+    with TestPipeline() as pipeline:
+      examples = [
+          "I absolutely love this product, it's a game changer!",
+          "This is the worst experience I have ever had.",
+          "Apache Beam scales effortlessly to massive datasets.",
+          "I am somewhat annoyed by the delay.",
+          "The nuanced performance of this large model is impressive.",
+          "I regret buying this immediately.",
+          "The sunset looks beautiful tonight.",
+          "This documentation is sparse and misleading.",
+          "Winning the championship felt surreal.",
+          "I'm feeling very neutral about this whole situation."
+      ] * DUPLICATE_FACTOR
+
+      pcoll = pipeline | 'CreateInputs' >> beam.Create(examples)
+      predictions = pcoll | 'RunInference' >> RunInference(
+          model_handler, use_model_manager=True)
+      actual_labels = predictions | beam.Map(lambda x: x.inference['label'])
+
+      expected_labels = [
+          'POSITIVE',  # love
+          'NEGATIVE',  # worst
+          'POSITIVE',  # scales effortlessly
+          'NEGATIVE',  # annoyed
+          'POSITIVE',  # impressive
+          'NEGATIVE',  # regret
+          'POSITIVE',  # beautiful
+          'NEGATIVE',  # misleading
+          'POSITIVE',  # surreal
+          'NEGATIVE'  # "neutral"
+      ] * DUPLICATE_FACTOR
+
+      assert_that(
+          actual_labels,
+          equal_to(expected_labels),
+          label='CheckPredictionsLarge')
+
+  @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected")
+  def test_parallel_inference_branches(self):
+    """
+    Tests a branching pipeline where one input source feeds two 
+    RunInference transforms running in parallel.
+    
+    Topology:
+            [ Input Data ]
+                    |
+        +--------+--------+
+        |                 |
+    [ Translation ]   [ Sentiment ]
+    """
+
+    translator_handler = HuggingFacePipelineModelHandler(
+        task="translation_en_to_es",
+        model="Helsinki-NLP/opus-mt-en-es",
+        device=0,
+        inference_args={"batch_size": 8})
+    sentiment_handler = HuggingFacePipelineModelHandler(
+        task="sentiment-analysis",
+        model="nlptown/bert-base-multilingual-uncased-sentiment",
+        device=0,
+        inference_args={"batch_size": 8})
+    base_examples = [
+        "I love this product.",
+        "This is terrible.",
+        "Hello world.",
+        "The service was okay.",
+        "I am extremely angry."
+    ]
+    MULTIPLIER = 10
+    examples = base_examples * MULTIPLIER
+
+    with TestPipeline() as pipeline:
+      inputs = pipeline | 'CreateInputs' >> beam.Create(examples)
+      _ = (
+          inputs
+          | 'RunTranslation' >> RunInference(
+              translator_handler, use_model_manager=True)
+          | 'ExtractSpanish' >>
+          beam.Map(lambda x: x.inference['translation_text']))
+      _ = (
+          inputs
+          | 'RunSentiment' >> RunInference(
+              sentiment_handler, use_model_manager=True)
+          | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label']))
+
+
+if __name__ == "__main__":
+  unittest.main()

Reply via email to