damccorm commented on code in PR #37557:
URL: https://github.com/apache/beam/pull/37557#discussion_r2795624503
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2338,6 +2339,28 @@ def
test_run_inference_impl_with_model_manager_args(self):
})
assert_that(actual, equal_to(expected), label='assert:inferences')
+ @unittest.skipIf(
+ not try_import_model_manager(), 'Model Manager not available')
+ def test_run_inference_impl_with_model_manager_oom(self):
+ class OOMFakeModelHandler(SimpleFakeModelHandler):
+ def run_inference(
+ self,
+ batch: Sequence[int],
+ model: FakeModel,
+ inference_args=None) -> Iterable[int]:
+ if random.random() < 0.8:
+ raise MemoryError("Simulated OOM")
+ for example in batch:
+ yield model.predict(example)
+
+ with self.assertRaises(Exception):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ OOMFakeModelHandler(), use_model_manager=True)
+ assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')
Review Comment:
I don't think the suggestion is good, but there is a 20% chance this test
succeeds. Could we drop the batch size to 1? That would help since we'd get 4
run_inference calls instead of just one
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]