gemini-code-assist[bot] commented on code in PR #37557:
URL: https://github.com/apache/beam/pull/37557#discussion_r2795614657


##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2338,6 +2339,28 @@ def 
test_run_inference_impl_with_model_manager_args(self):
           })
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  @unittest.skipIf(
+      not try_import_model_manager(), 'Model Manager not available')
+  def test_run_inference_impl_with_model_manager_oom(self):
+    class OOMFakeModelHandler(SimpleFakeModelHandler):
+      def run_inference(
+          self,
+          batch: Sequence[int],
+          model: FakeModel,
+          inference_args=None) -> Iterable[int]:
+        if random.random() < 0.8:
+          raise MemoryError("Simulated OOM")
+        for example in batch:
+          yield model.predict(example)
+
+    with self.assertRaises(Exception):
+      with TestPipeline() as pipeline:
+        examples = [1, 5, 3, 10]
+        pcoll = pipeline | 'start' >> beam.Create(examples)
+        actual = pcoll | base.RunInference(
+            OOMFakeModelHandler(), use_model_manager=True)
+        assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This new test has a few issues that could be improved:
   
   1.  **Non-determinism**: The use of `random.random()` makes this test 
non-deterministic, which can lead to flaky builds. It's better to make tests 
deterministic.
   
   2.  **Confusing structure**: The `assert_that` call is inside the `with 
self.assertRaises(Exception):` block. This means the assertion is only checked 
if the test is about to fail anyway because no exception was raised. It's 
clearer to have separate tests for success and failure cases.
   
   3.  **Incomplete test coverage**: The `OOMProtectedFn` specifically looks 
for `'out of memory'` and `'CUDA'` in the exception string to trigger the 
memory cleanup logic. The `MemoryError("Simulated OOM")` raised here does not 
contain `'out of memory'`, so the cleanup path is not actually being tested.
   
   I'm suggesting a change to make this test deterministic and to correctly 
test the OOM cleanup path by raising a more specific error. This will make the 
test more reliable and ensure the new functionality is properly verified.
   
   ```python
       class OOMFakeModelHandler(SimpleFakeModelHandler):
         def run_inference(
             self,
             batch: Sequence[int],
             model: FakeModel,
             inference_args=None) -> Iterable[int]:
           # This will always raise to test the OOM path.
           raise MemoryError("CUDA out of memory. Simulated OOM.")
   
       with self.assertRaises(MemoryError):
         with TestPipeline() as pipeline:
           examples = [1, 5, 3, 10]
           pcoll = pipeline | 'start' >> beam.Create(examples)
           # The pipeline will fail, so we don't need to check the output.
           _ = pcoll | base.RunInference(
               OOMFakeModelHandler(), use_model_manager=True)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to