AMOOOMA commented on code in PR #37557:
URL: https://github.com/apache/beam/pull/37557#discussion_r2795731817
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2338,6 +2339,28 @@ def
test_run_inference_impl_with_model_manager_args(self):
})
assert_that(actual, equal_to(expected), label='assert:inferences')
+ @unittest.skipIf(
+ not try_import_model_manager(), 'Model Manager not available')
+ def test_run_inference_impl_with_model_manager_oom(self):
+ class OOMFakeModelHandler(SimpleFakeModelHandler):
+ def run_inference(
+ self,
+ batch: Sequence[int],
+ model: FakeModel,
+ inference_args=None) -> Iterable[int]:
+ if random.random() < 0.8:
+ raise MemoryError("Simulated OOM")
+ for example in batch:
+ yield model.predict(example)
+
+ with self.assertRaises(Exception):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ OOMFakeModelHandler(), use_model_manager=True)
+ assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')
Review Comment:
Ah yes good catch! I didn't notice the batch size, so was assuming 0.2^4,
will update!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]