damccorm commented on code in PR #25200:
URL: https://github.com/apache/beam/pull/25200#discussion_r1092132299
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -339,6 +375,79 @@ def validate_inference_args(
third_party_model_handler.batch_elements_kwargs()
third_party_model_handler.validate_inference_args({})
+ def test_run_inference_prediction_result_with_model_id(self):
+ examples = [1, 5, 3, 10]
+ expected = [
+ base.PredictionResult(
+ example=example,
+ inference=example + 1,
+ model_id='fake_model_id_default') for example in examples
+ ]
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ FakeModelHandlerReturnsPredictionResult())
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ @pytest.mark.it_postcommit
+ def test_run_inference_prediction_result_with_side_input(self):
+ test_pipeline = TestPipeline(is_integration_test=True)
Review Comment:
LGTM
##########
sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py:
##########
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Used for internal testing. No backwards compatibility.
+"""
+
+import argparse
+import logging
+import time
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+
+import apache_beam as beam
+from apache_beam.ml.inference import base
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
+from apache_beam.transforms.periodicsequence import PeriodicImpulse
+from apache_beam.transforms.userstate import CombiningValueStateSpec
+
+
+# create some fake models which returns different inference results.
+class FakeModelDefault:
+ def predict(self, example: int) -> int:
+ return example
+
+
+class FakeModelAdd(FakeModelDefault):
+ def predict(self, example: int) -> int:
+ return example + 1
+
+
+class FakeModelSub(FakeModelDefault):
+ def predict(self, example: int) -> int:
+ return example - 1
+
+
+class FakeModelHandlerReturnsPredictionResult(
+ base.ModelHandler[int, base.PredictionResult, FakeModelDefault]):
+ def __init__(self, clock=None, model_id='fake_model_id_default'):
+ self.model_id = model_id
+ self._fake_clock = clock
+
+ def load_model(self):
+ if self._fake_clock:
+ self._fake_clock.current_time_ns += 500_000_000 # 500ms
+ if self.model_id == 'model_add.pkl':
+ return FakeModelAdd()
+ elif self.model_id == 'model_sub.pkl':
+ return FakeModelSub()
+ return FakeModelDefault()
+
+ def run_inference(
+ self,
+ batch: Sequence[int],
+ model: FakeModelDefault,
+ inference_args=None) -> Iterable[base.PredictionResult]:
+ for example in batch:
+ yield base.PredictionResult(
+ model_id=self.model_id,
+ example=example,
+ inference=model.predict(example))
+
+ def update_model_path(self, model_path: Optional[str] = None):
+ self.model_id = model_path if model_path else self.model_id
+
+
+def run(argv=None, save_main_session=True):
+ parser = argparse.ArgumentParser()
+ first_ts = time.time()
+ side_input_interval = 60
+ main_input_interval = 20
+ # give some time for dataflow to start.
+ last_ts = first_ts + 1200
+ mid_ts = (first_ts + last_ts) / 2
+
+ _, pipeline_args = parser.parse_known_args(argv)
+ options = PipelineOptions(pipeline_args)
+ options.view_as(SetupOptions).save_main_session = save_main_session
+
+ test_pipeline = beam.Pipeline(options=options)
+
+ class GetModel(beam.DoFn):
+ def process(self, element) -> Iterable[base.ModelMetdata]:
+ if time.time() > mid_ts:
+ yield base.ModelMetdata(
+ model_id='model_add.pkl', model_name='model_add')
+ else:
+ yield base.ModelMetdata(
+ model_id='model_sub.pkl', model_name='model_sub')
+
+ class _EmitSingletonSideInput(beam.DoFn):
+ COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
+
+ def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
+ _, path = element
+ counter = count_state.read()
+ if counter == 0:
+ count_state.add(1)
+ yield path
+
+ def validate_prediction_result(x: base.PredictionResult):
+ model_id = x.model_id
+ if model_id == 'model_sub.pkl':
+ assert (x.example == 1 and x.inference == 0)
+
+ if model_id == 'model_add.pkl':
+ assert (x.example == 1 and x.inference == 2)
Review Comment:
Nit: Please add an if/assert for the default model
##########
sdks/python/apache_beam/pipeline.py:
##########
@@ -525,6 +525,14 @@ def run(self, test_runner_api='AUTO'):
self.contains_external_transforms = (
ExternalTransformFinder.contains_external_transforms(self))
+ self.contains_run_inference_transform = (
+ RunInferenceSideInputFinder.contains_run_inference_transform(self))
+
+ if (self.contains_run_inference_transform and
+ not self._options.view_as(StandardOptions).streaming):
+ raise RuntimeError(
Review Comment:
I think I like the error that you have
##########
sdks/python/apache_beam/pipeline.py:
##########
@@ -525,14 +525,31 @@ def run(self, test_runner_api='AUTO'):
self.contains_external_transforms = (
ExternalTransformFinder.contains_external_transforms(self))
- self.contains_run_inference_transform = (
- RunInferenceSideInputFinder.contains_run_inference_transform(self))
+ # Finds if RunInference has side inputs enables.
+ # also, checks for the side input window is global and has non default
+ # triggers.
+ run_inference_visitor = RunInferenceVisitor().visit_run_inference(self)
+ self._run_inference_contains_side_input = (
+ run_inference_visitor.contains_run_inference_side_inputs)
- if (self.contains_run_inference_transform and
+ self.run_inference_global_window_non_default_trigger = (
+ run_inference_visitor.contains_global_windows_non_default_trigger)
+
+ if (self._run_inference_contains_side_input and
not self._options.view_as(StandardOptions).streaming):
raise RuntimeError(
"SideInputs to RunInference PTransform is only supported "
- "in streaming mode.")
+ "in streaming mode. To run in streaming mode, add the"
+ " --streaming pipeline option")
+
+ if (self._run_inference_contains_side_input and
+ not self.run_inference_global_window_non_default_trigger):
+ raise RuntimeError(
+ "The RunInference PTransform's SideInput is either using "
+ "GlobalWindows with a default trigger or no Windowing, which "
Review Comment:
```suggestion
"GlobalWindows with a default trigger, non-global windowing, or no
Windowing, which "
```
(probably needs formatting + corresponding test changes)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]