This is an automated email from the ASF dual-hosted git repository. vterentev pushed a commit to branch oss-image-detection in repository https://gitbox.apache.org/repos/asf/beam.git
commit 6a09efb8c8fd9836bd5ce360653edf0e67ea4461 Author: Vitaly Terentyev <[email protected]> AuthorDate: Thu Dec 25 12:13:08 2025 +0400 ML pipelines: RunInference - OSS Image Object detection --- .../beam_Inference_Python_Benchmarks_Dataflow.yml | 14 +- ...Benchmarks_Dataflow_Pytorch_Image_Detection.txt | 41 +++ .test-infra/tools/refresh_looker_metrics.py | 1 + .../inference/pytorch_image_object_detection.py | 410 +++++++++++++++++++++ .../pytorch_object_detection_requirements.txt | 24 ++ .../pytorch_object_detection_benchmarks.py | 42 +++ website/www/site/content/en/performance/_index.md | 1 + .../pytorchimageobjectdetection/_index.md | 42 +++ website/www/site/data/performance.yaml | 16 + 9 files changed, 590 insertions(+), 1 deletion(-) diff --git a/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml b/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml index ff7480c320a..e1da7670221 100644 --- a/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml +++ b/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml @@ -92,6 +92,7 @@ jobs: ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Batch_DistilBert_Base_Uncased.txt ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_VLLM_Gemma_Batch.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt # The env variables are created and populated in the test-arguments-action as "<github.job>_test_arguments_<argument_file_paths_index>" - name: get current time run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV @@ -189,4 +190,15 @@ jobs: -Prunner=DataflowRunner \ -PpythonVersion=3.10 \ -PloadTest.requirementsTxtFile=apache_beam/ml/inference/torch_tests_requirements.txt \ - '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_5 }} --job_name=benchmark-tests-pytorch-imagenet-python-gpu-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/torch/result_resnet152_gpu-${{env.NOW_UTC}}.txt' \ No newline at end of file + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_5 }} --job_name=benchmark-tests-pytorch-imagenet-python-gpu-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/torch/result_resnet152_gpu-${{env.NOW_UTC}}.txt' + - name: run PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 180 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.inference.pytorch_object_detection_benchmarks \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + -PloadTest.requirementsTxtFile=apache_beam/ml/inference/pytorch_object_detection_requirements.txt \ + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_9 }} --mode=batch --job_name=benchmark-tests-pytorch-object_detection-batch-${{env.NOW_UTC}} --output_table=apache-beam-testing.beam_run_inference.result_torch_inference_object_detection_batch' \ diff --git a/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt new file mode 100644 index 00000000000..2cb905f1ed5 --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--region=us-central1 +--worker_machine_type=n1-standard-4 +--num_workers=50 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--requirements_file=apache_beam/ml/inference/pytorch_object_detection_requirements.txt +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=result_torch_inference_object_detection_batch +--input_options={} +--influx_measurement=result_torch_inference_object_detection_batch +--pretrained_model_name=fasterrcnn_resnet50_fpn +--device=GPU +--mode=batch +--inference_batch_size=8 +--resize_shorter_side=800 +--score_threshold=0.5 +--max_detections=50 +--input=gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt +--model_state_dict_path=gs://apache-beam-ml/models/torchvision.detection.fasterrcnn_resnet50_fpn.pth +--runner=DataflowRunner +--experiments=use_runner_v2 +--worker_accelerator=type=nvidia-tesla-t4,count=1,install-nvidia-driver=true diff --git a/.test-infra/tools/refresh_looker_metrics.py b/.test-infra/tools/refresh_looker_metrics.py index a4c6999be77..9b8296c56d3 100644 --- a/.test-infra/tools/refresh_looker_metrics.py +++ b/.test-infra/tools/refresh_looker_metrics.py @@ -43,6 +43,7 @@ LOOKS_TO_DOWNLOAD = [ ("82", ["263", "264", "265", "266", "267"]), # PyTorch Sentiment Streaming DistilBERT base uncased ("85", ["268", "269", "270", "271", "272"]), # PyTorch Sentiment Batch DistilBERT base uncased ("86", ["284", "285", "286", "287", "288"]), # VLLM Batch Gemma + #TODO: PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch ] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py new file mode 100644 index 00000000000..e70c4b157f4 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py @@ -0,0 +1,410 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This batch pipeline performs object detection using an open-source PyTorch +TorchVision detection model (e.g., Faster R-CNN ResNet50 FPN) on GPU. + +It reads image URIs from a GCS input file, decodes and preprocesses images, +runs batched GPU inference via RunInference, post-processes detection outputs, +and writes results to BigQuery. + +The pipeline targets stable and reproducible performance measurements for +GPU inference workloads (no right-fitting; fixed batch size). +""" + +import argparse +import io +import json +import logging +import time +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Any +from typing import Dict +from typing import List + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult + +import torch +import PIL.Image as PILImage + + +# ============ Utility & Preprocessing ============ + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_to_tensor( + image_bytes: bytes, + resize_shorter_side: Optional[int] = None) -> torch.Tensor: + """Decode bytes -> RGB PIL -> optional resize -> float tensor [0..1], CHW. + + Note: TorchVision detection models apply their own normalization internally. + """ + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + if resize_shorter_side and resize_shorter_side > 0: + w, h = img.size + # Resize so that shorter side == resize_shorter_side, keep aspect ratio. + if w < h: + new_w = resize_shorter_side + new_h = int(h * (resize_shorter_side / float(w))) + else: + new_h = resize_shorter_side + new_w = int(w * (resize_shorter_side / float(h))) + img = img.resize((new_w, new_h)) + + import numpy as np + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 in [0..1] + arr = np.transpose(arr, (2, 0, 1)) # CHW + return torch.from_numpy(arr) + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +# ============ DoFns ============ + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, tensor).""" + def __init__(self, resize_shorter_side: Optional[int] = None): + self.resize_shorter_side = resize_shorter_side + + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + start = now_millis() + try: + b = load_image_from_uri(uri) + tensor = decode_to_tensor(b, resize_shorter_side=self.resize_shorter_side) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms, "uri": uri} + except Exception as e: + logging.warning("Decode failed for %s (%s): %s", image_id, uri, e) + return + + +def _torchvision_detection_inference_fn( + model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]: + """Custom inference for TorchVision detection models. + + TorchVision detection models expect: List[Tensor] (each tensor: CHW float [0..1]). + """ + with torch.no_grad(): + inputs = [] + for t in batch: + if isinstance(t, torch.Tensor): + inputs.append(t.to(device)) + else: + # Defensive: if somehow non-tensor slips through. + inputs.append(torch.as_tensor(t).to(device)) + outputs = model(inputs) # List[Dict[str, Tensor]] + return outputs + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__( + self, + model_name: str, + score_threshold: float, + max_detections: int): + self.model_name = model_name + self.score_threshold = score_threshold + self.max_detections = max_detections + + def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]: + """Extract detection fields from torchvision output dict.""" + # Expect: {'boxes': Tensor[N,4], 'labels': Tensor[N], 'scores': Tensor[N]} + boxes = inference_obj.get("boxes") + labels = inference_obj.get("labels") + scores = inference_obj.get("scores") + + # Convert to CPU lists + if isinstance(scores, torch.Tensor): + scores_list = scores.detach().cpu().tolist() + else: + scores_list = list(scores) if scores is not None else [] + + if isinstance(labels, torch.Tensor): + labels_list = labels.detach().cpu().tolist() + else: + labels_list = list(labels) if labels is not None else [] + + if isinstance(boxes, torch.Tensor): + boxes_list = boxes.detach().cpu().tolist() + else: + boxes_list = list(boxes) if boxes is not None else [] + + # Filter by threshold and trim to max_detections + dets = [] + for i in range(min(len(scores_list), len(labels_list), len(boxes_list))): + score = float(scores_list[i]) + if score < self.score_threshold: + continue + box = boxes_list[i] # [x1,y1,x2,y2] + dets.append({ + "label_id": int(labels_list[i]), + "score": score, + "box": [float(box[0]), float(box[1]), float(box[2]), float(box[3])], + }) + if len(dets) >= self.max_detections: + break + + return { + "detections": dets, + "num_detections": len(dets), + } + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred.inference should be torchvision dict for this element, + # but keep robust fallback. + inference_obj = pred.inference + if isinstance(inference_obj, list) and len(inference_obj) == 1: + inference_obj = inference_obj[0] + + if not isinstance(inference_obj, dict): + logging.warning( + "Unexpected inference type for %s: %s", image_id, type(inference_obj) + ) + yield { + "image_id": image_id, + "model_name": self.model_name, + "detections": json.dumps([]), + "num_detections": 0, + "infer_ms": now_millis(), + } + return + + extracted = self._extract_detection(inference_obj) + + yield { + "image_id": image_id, + "model_name": self.model_name, + "detections": json.dumps(extracted["detections"]), + "num_detections": int(extracted["num_detections"]), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument('--mode', default='batch', choices=['batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with image URIs') + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='fasterrcnn_resnet50_fpn', + help=('TorchVision detection model name ' + '(e.g., fasterrcnn_resnet50_fpn)')) + parser.add_argument( + '--model_state_dict_path', + required=True, + help='GCS path to a state_dict .pth for the chosen model') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # Batch sizing (no right-fitting) + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Preprocess + parser.add_argument('--resize_shorter_side', type=int, default=0) + + # Postprocess + parser.add_argument('--score_threshold', type=float, default=0.5) + parser.add_argument('--max_detections', type=int, default=50) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +def create_torchvision_detection_model(model_name: str): + """Creates a TorchVision detection model instance. + + Note: We will load weights via state_dict_path (required by Beam handler when + model_class is provided). + """ + import torchvision + + name = model_name.strip() + + if name == "fasterrcnn_resnet50_fpn": + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None) + elif name == "retinanet_resnet50_fpn": + model = torchvision.models.detection.retinanet_resnet50_fpn(weights=None) + else: + raise ValueError(f"Unsupported detection model: {model_name}") + + model.eval() + return model + + +# ============ Main pipeline ============ + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + known_args, pipeline_args = parse_known_args(argv) + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(StandardOptions).streaming = False + + device = 'cuda' if known_args.device.upper() == 'GPU' else 'cpu' + resize_shorter_side = ( + known_args.resize_shorter_side + ) if known_args.resize_shorter_side > 0 else None + + # Fixed batch size (no right-fitting) + batch_size = int(known_args.inference_batch_size) + + model_handler = PytorchModelHandlerTensor( + model_class=lambda: create_torchvision_detection_model( + known_args.pretrained_model_name + ), + model_params={}, + state_dict_path=known_args.model_state_dict_path, + device=device, + inference_batch_size=batch_size, + inference_fn=_torchvision_detection_inference_fn, + ) + + pipeline = test_pipeline or beam.Pipeline(options=pipeline_options) + + pcoll = ( + pipeline + | 'ReadURIsBatch' >> beam.Create( + list(read_gcs_file_lines(known_args.input)) + ) + | 'FilterEmptyBatch' >> beam.Filter(lambda s: s.strip()) + ) + + keyed = ( + pcoll + | 'MakeKey' >> beam.ParDo(MakeKeyDoFn()) + ) + + # Batch exactly-once behavior: + # 1) Dedup by key within the run to ensure stable writes. + # 2) Use FILE_LOADS for BQ to avoid streaming insert duplicates in retries. + keyed = keyed | 'DistinctByKey' >> beam.Distinct() + + preprocessed = ( + keyed + | 'DecodePreprocess' >> beam.ParDo( + DecodePreprocessDoFn(resize_shorter_side=resize_shorter_side)) + ) + + to_infer = ( + preprocessed + | 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"])) + ) + + predictions = ( + to_infer + | 'RunInference' >> RunInference(KeyedModelHandler(model_handler)) + ) + + results = ( + predictions + | 'PostProcess' >> beam.ParDo( + PostProcessDoFn( + model_name=known_args.pretrained_model_name, + score_threshold=known_args.score_threshold, + max_detections=known_args.max_detections)) + ) + + if known_args.publish_to_big_query == 'true': + _ = ( + results + | 'WriteToBigQuery' >> beam.io.WriteToBigQuery( + known_args.output_table, + schema=('image_id:STRING, model_name:STRING, ' + 'detections:STRING, num_detections:INT64, infer_ms:INT64'), + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + method=beam.io.WriteToBigQuery.Method.FILE_LOADS) + ) + + result = pipeline.run() + result.wait_until_finish(duration=1800000) # 30 min + try: + result.cancel() + except Exception: + pass + + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/ml/inference/pytorch_object_detection_requirements.txt b/sdks/python/apache_beam/ml/inference/pytorch_object_detection_requirements.txt new file mode 100644 index 00000000000..c3ce392d1a4 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/pytorch_object_detection_requirements.txt @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +torch>=2.2.0,<2.8.0 +torchvision>=0.17.0,<0.21.0 +Pillow>=10.0.0 +numpy>=1.25.0 +google-cloud-monitoring>=2.27.0 +protobuf>=4.25.1 +requests>=2.31.0 diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_object_detection_benchmarks.py b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_object_detection_benchmarks.py new file mode 100644 index 00000000000..b31112f68bf --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_object_detection_benchmarks.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import pytorch_image_object_detection +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark + + +class PytorchImageObjectDetectionBenchmarkTest(DataflowCostBenchmark): + def __init__(self): + self.metrics_namespace = 'BeamML_PyTorch' + super().__init__( + metrics_namespace=self.metrics_namespace, + pcollection='PostProcess.out0') + + def test(self): + extra_opts = {} + extra_opts['input'] = self.pipeline.get_option('input_file') + self.result = pytorch_image_object_detection.run( + self.pipeline.get_full_options_as_args(**extra_opts), + test_pipeline=self.pipeline) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + PytorchImageObjectDetectionBenchmarkTest().run() diff --git a/website/www/site/content/en/performance/_index.md b/website/www/site/content/en/performance/_index.md index 17bdc6f3de0..a96e2ee03bd 100644 --- a/website/www/site/content/en/performance/_index.md +++ b/website/www/site/content/en/performance/_index.md @@ -57,3 +57,4 @@ See the following pages for performance measures recorded when running various B - [PyTorch Vision Classification Resnet 152 Tesla T4 GPU](/performance/pytorchresnet152tesla) - [TensorFlow MNIST Image Classification](/performance/tensorflowmnist) - [VLLM Gemma Batch Completion Tesla T4 GPU](/performance/vllmgemmabatchtesla) +- [PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch](/performance/pytorchimageobjectdetection) diff --git a/website/www/site/content/en/performance/pytorchimageobjectdetection/_index.md b/website/www/site/content/en/performance/pytorchimageobjectdetection/_index.md new file mode 100644 index 00000000000..04e4d76f439 --- /dev/null +++ b/website/www/site/content/en/performance/pytorchimageobjectdetection/_index.md @@ -0,0 +1,42 @@ +--- +title: "PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch Performance" +--- + +<!-- +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> + +# PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch + +**Model**: PyTorch Image Object Detection — Faster R-CNN ResNet-50 FPN (pretrained on COCO) +**Accelerator**: Tesla T4 GPU (fixed batch size) +**Host**: 50 × n1-standard-4 (4 vCPUs, 15 GB RAM) + +This batch pipeline performs object detection using an open-source PyTorch Faster R-CNN ResNet-50 FPN model on GPU. +It reads image URIs from GCS, decodes and preprocesses images, and runs batched inference with a fixed batch size to measure stable GPU performance. +The pipeline ensures exactly-once semantics within batch execution by deduplicating inputs and writing results to BigQuery using file-based loads, enabling reproducible and comparable performance measurements across runs. + +The following graphs show various metrics when running PyTorch Image Object Detection Faster R-CNN ResNet-50 Batch pipeline. +See the [glossary](/performance/glossary) for definitions. + +Full pipeline implementation is available [here](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py). + +## What is the estimated cost to run the pipeline? + +{{< performance_looks io="pytorchimageobjectdetection" read_or_write="write" section="cost" >}} + +## How has various metrics changed when running the pipeline for different Beam SDK versions? + +{{< performance_looks io="pytorchimageobjectdetection" read_or_write="write" section="version" >}} + +## How has various metrics changed over time when running the pipeline? + +{{< performance_looks io="pytorchimageobjectdetection" read_or_write="write" section="date" >}} diff --git a/website/www/site/data/performance.yaml b/website/www/site/data/performance.yaml index 17a6612160c..e669ca3b63c 100644 --- a/website/www/site/data/performance.yaml +++ b/website/www/site/data/performance.yaml @@ -250,3 +250,19 @@ looks: title: AvgThroughputBytesPerSec by Version - id: dKyJy5ZKhkBdSTXRY3wZR6fXzptSs2qm title: AvgThroughputElementsPerSec by Version + pytorchimageobjectdetection: + write: + folder: #TODO + cost: + - id: #TODO + title: RunTime and EstimatedCost + date: + - id: #TODO + title: AvgThroughputBytesPerSec by Date + - id: #TODO + title: AvgThroughputElementsPerSec by Date + version: + - id: #TODO + title: AvgThroughputBytesPerSec by Version + - id: #TODO + title: AvgThroughputElementsPerSec by Version
