This is an automated email from the ASF dual-hosted git repository. vterentev pushed a commit to branch oss-image-cpu in repository https://gitbox.apache.org/repos/asf/beam.git
commit 1b0882c098516312a64f7f68ac8056900724337b Author: Vitaly Terentyev <[email protected]> AuthorDate: Thu Jan 22 13:06:07 2026 +0400 Fix inference_fn --- .../inference/pytorch_image_object_detection.py | 27 ++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py index 0bb5f53e5b0..cb66e0f9dba 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py @@ -31,12 +31,13 @@ import io import json import logging import time -from typing import Iterable -from typing import Optional -from typing import Tuple from typing import Any from typing import Dict +from typing import Iterable from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -135,12 +136,24 @@ class DecodePreprocessDoFn(beam.DoFn): def _torchvision_detection_inference_fn( - model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]: - """Custom inference for TorchVision detection models. - - TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]). + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: torch.device, + inference_args: Optional[dict[str, Any]] = None, + model_id: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Inference function for TorchVision detection models. + + TorchVision detection models expect List[Tensor] where each tensor is: + - shape: [3, H, W] + - dtype: float32 + - values: [0..1] """ + del inference_args + del model_id + with torch.no_grad(): + # Ensure tensors are on device inputs = [] for t in batch: if isinstance(t, torch.Tensor):
