This is an automated email from the ASF dual-hosted git repository.

vterentev pushed a commit to branch oss-image-cpu
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 1b0882c098516312a64f7f68ac8056900724337b
Author: Vitaly Terentyev <[email protected]>
AuthorDate: Thu Jan 22 13:06:07 2026 +0400

    Fix inference_fn
---
 .../inference/pytorch_image_object_detection.py    | 27 ++++++++++++++++------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py 
b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
index 0bb5f53e5b0..cb66e0f9dba 100644
--- 
a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
+++ 
b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
@@ -31,12 +31,13 @@ import io
 import json
 import logging
 import time
-from typing import Iterable
-from typing import Optional
-from typing import Tuple
 from typing import Any
 from typing import Dict
+from typing import Iterable
 from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
 
 import apache_beam as beam
 from apache_beam.io.filesystems import FileSystems
@@ -135,12 +136,24 @@ class DecodePreprocessDoFn(beam.DoFn):
 
 
 def _torchvision_detection_inference_fn(
-    model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]:
-  """Custom inference for TorchVision detection models.
-
-  TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]).
+    batch: Sequence[torch.Tensor],
+    model: torch.nn.Module,
+    device: torch.device,
+    inference_args: Optional[dict[str, Any]] = None,
+    model_id: Optional[str] = None,
+) -> List[Dict[str, Any]]:
+  """Inference function for TorchVision detection models.
+
+  TorchVision detection models expect List[Tensor] where each tensor is:
+    - shape: [3, H, W]
+    - dtype: float32
+    - values: [0..1]
   """
+  del inference_args
+  del model_id
+
   with torch.no_grad():
+    # Ensure tensors are on device
     inputs = []
     for t in batch:
       if isinstance(t, torch.Tensor):

Reply via email to