gemini-code-assist[bot] commented on code in PR #37186: URL: https://github.com/apache/beam/pull/37186#discussion_r3499840966
########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,536 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize shorter side->center crop->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + resize_size = 256 + w, h = img.size + if w < h: + new_w = resize_size + new_h = int(h * resize_size / w) + else: + new_h = resize_size + new_w = int(w * resize_size / h) + + img = img.resize((new_w, new_h)) + + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop((left, top, left + size, top + size)) + + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + + arr = np.asarray(img).astype("float32") / 255.0 + arr = (arr - mean) / std + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, + (bytes, + bytearray)) else element.encode('utf-8') + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__(self, input_mode: str, image_size: int = 224): + self.input_mode = input_mode + self.image_size = image_size + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + raise ValueError( + f"Unable to find 'logits' in model output. " + f"Available keys: {list(inference_obj.keys())}") + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Model factory (timm) ============ + + +def create_timm_m(model_name: str, num_classes: int = 1000): + import timm + model = timm.create_model( + model_name, pretrained=True, num_classes=num_classes) + model.eval() + return model + + +def pick_batch_size(arg: str) -> Optional[int]: + if isinstance(arg, str) and arg.lower() == 'auto': + return None + try: + return int(arg) + except Exception: + return None + + +class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor): + def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs): + self._batch_sizes_to_try = batch_sizes_to_try + self._rightfit_image_size = image_size + super().__init__(*args, **kwargs) + + def load_model(self): + model = super().load_model() + last_err = None + + for bs in self._batch_sizes_to_try: + try: + model_device = next(model.parameters()).device + dummy = torch.zeros( + (bs, 3, self._rightfit_image_size, self._rightfit_image_size), + dtype=torch.float32, + device=model_device) + + with torch.no_grad(): + model(dummy) + + self._batch_size = bs + self._inference_batch_size = bs + logging.info("Selected inference batch size: %s", bs) + return model + except RuntimeError as e: + last_err = e + logging.warning("Batch size %s failed during worker warmup: %s", bs, e) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + raise RuntimeError( + f"No valid inference batch size found from {self._batch_sizes_to_try}" + ) from last_err Review Comment:  The `BatchElements` transform is applied on the client side during pipeline construction using the initial maximum batch size (e.g., 64). Therefore, even if `load_model` dynamically selects a smaller batch size (e.g., 16) on the worker, the runner will still send batches of up to 64 elements to `run_inference`, which will cause GPU OOM. To fix this, `RightFittingPytorchModelHandlerTensor` must override `run_inference` to sub-batch the incoming batch to the dynamically selected `self._inference_batch_size`. ```python class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor): def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs): self._batch_sizes_to_try = batch_sizes_to_try self._rightfit_image_size = image_size super().__init__(*args, **kwargs) def load_model(self): model = super().load_model() last_err = None for bs in self._batch_sizes_to_try: try: model_device = next(model.parameters()).device dummy = torch.zeros( (bs, 3, self._rightfit_image_size, self._rightfit_image_size), dtype=torch.float32, device=model_device) with torch.no_grad(): model(dummy) self._batch_size = bs self._inference_batch_size = bs logging.info("Selected inference batch size: %s", bs) return model except RuntimeError as e: last_err = e logging.warning("Batch size %s failed during worker warmup: %s", bs, e) if torch.cuda.is_available(): torch.cuda.empty_cache() raise RuntimeError( f"No valid inference batch size found from {self._batch_sizes_to_try}" ) from last_err def run_inference(self, batch, model, inference_args=None): size = self._inference_batch_size sub_batches = [batch[i:i + size] for i in range(0, len(batch), size)] results = [] for sub_batch in sub_batches: results.extend(super().run_inference(sub_batch, model, inference_args)) return results ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py: ########## @@ -0,0 +1,533 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs object detection using an open-source PyTorch +TorchVision detection model (e.g., Faster R-CNN ResNet50 FPN) on GPU. + +It reads image URIs from a GCS input file, decodes and preprocesses images, +runs batched GPU inference via RunInference, post-processes detection outputs, +and writes results to BigQuery. + +The pipeline targets stable and reproducible performance measurements for +GPU inference workloads (no right-fitting; fixed batch size). +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_to_tens(image_bytes: bytes, image_size: int = 800) -> torch.Tensor: + """Decode bytes -> RGB PIL -> resize/pad square -> float tensor [0..1], CHW. + + TorchVision detection models accept float tensors in [0..1]. We force a fixed + square shape so PytorchModelHandlerTensor can batch tensors with torch.stack. + """ + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + w, h = img.size + scale = min(image_size / float(w), image_size / float(h)) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + + img = img.resize((new_w, new_h)) + + padded = PILImage.new("RGB", (image_size, image_size), color=(0, 0, 0)) + left = (image_size - new_w) // 2 + top = (image_size - new_h) // 2 + padded.paste(img, (left, top)) + + import numpy as np + arr = np.asarray(padded).astype("float32") / 255.0 + arr = np.transpose(arr, (2, 0, 1)) + return torch.from_numpy(arr).float() + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) where the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, tensor).""" + def __init__(self, image_size: int = 800): + self.image_size = image_size + + def process(self, kv: Tuple[str, str]): + uri, _ = kv + start = now_millis() + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + tensor = decode_to_tens(image_bytes, image_size=self.image_size) + preprocess_ms = now_millis() - start + yield uri, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except (OSError, ValueError): + logging.exception("Decode failed for %s", uri) + return + + +def _torchvision_detection_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: torch.device, + inference_args: Optional[dict[str, Any]] = None, + model_id: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Inference function for TorchVision detection models. + + TorchVision detection models expect List[Tensor] where each tensor is: + - shape: [3, H, W] + - dtype: float32 + - values: [0..1] + """ + del inference_args + del model_id + + with torch.no_grad(): + # Ensure tensors are on device + inputs = [] + for t in batch: + if isinstance(t, torch.Tensor): + inputs.append(t.to(device)) + else: + # Defensive: if somehow non-tensor slips through. + inputs.append(torch.as_tensor(t).to(device)) + outputs = model(inputs) # List[Dict[str, Tensor]] + return outputs + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__( + self, model_name: str, score_threshold: float, max_detections: int): + self.model_name = model_name + self.score_threshold = score_threshold + self.max_detections = max_detections + + def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]: + """Extract detection fields from torchvision output dict.""" + # Expect: {'boxes': Tensor[N,4], 'labels': Tensor[N], 'scores': Tensor[N]} + boxes = inference_obj.get("boxes") + labels = inference_obj.get("labels") + scores = inference_obj.get("scores") + + # Convert to CPU lists + if isinstance(scores, torch.Tensor): + scores_list = scores.detach().cpu().tolist() + else: + scores_list = list(scores) if scores is not None else [] + + if isinstance(labels, torch.Tensor): + labels_list = labels.detach().cpu().tolist() + else: + labels_list = list(labels) if labels is not None else [] + + if isinstance(boxes, torch.Tensor): + boxes_list = boxes.detach().cpu().tolist() + else: + boxes_list = list(boxes) if boxes is not None else [] + + # Filter by threshold and trim to max_detections + dets = [] + for i in range(min(len(scores_list), len(labels_list), len(boxes_list))): + score = float(scores_list[i]) + if score < self.score_threshold: + continue + box = boxes_list[i] # [x1,y1,x2,y2] + dets.append({ + "label_id": int(labels_list[i]), + "score": score, + "box": [float(box[0]), float(box[1]), float(box[2]), float(box[3])], + }) + if len(dets) >= self.max_detections: + break + + return { + "detections": dets, + "num_detections": len(dets), + } + + def process(self, kv: Tuple[str, PredictionResult]): + image_uri, pred = kv + + # pred can be PredictionResult OR raw torchvision dict. + if hasattr(pred, "inference"): + inference_obj = pred.inference + else: + inference_obj = pred + + if isinstance(inference_obj, list) and len(inference_obj) == 1: + inference_obj = inference_obj[0] + + if not isinstance(inference_obj, dict): + logging.warning( + "Unexpected inf-ce type for %s: %s", image_uri, type(inference_obj)) + yield { + "image_id": image_uri, + "model_name": self.model_name, + "detections": json.dumps([]), + "num_detections": 0, + "infer_ms": now_millis(), + } + return + + extracted = self._extract_detection(inference_obj) + + yield { + "image_id": image_uri, + "model_name": self.model_name, + "detections": json.dumps(extracted["detections"]), + "num_detections": int(extracted["num_detections"]), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='fasterrcnn_resnet50_fpn', + help=( + 'TorchVision detection model name ' + '(e.g., fasterrcnn_resnet50_fpn)')) + parser.add_argument( + '--model_state_dict_path', + required=True, + help='GCS path to a state_dict .pth for the chosen model') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # Batch sizing (no right-fitting) + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Preprocess + parser.add_argument('--image_size', type=int, default=800) + + # Postprocess + parser.add_argument('--score_threshold', type=float, default=0.5) + parser.add_argument('--max_detections', type=int, default=50) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + Review Comment:  Splitting the topic/subscription path and reconstructing it with `project` causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use `topic_path` and `subscription_path` directly. ```python def ensure_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: publisher.get_topic(request={"topic": topic_path}) except NotFound: publisher.create_topic(name=topic_path) try: subscriber.get_subscription( request={"subscription": subscription_path}) except NotFound: subscriber.create_subscription( name=subscription_path, topic=topic_path) def cleanup_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: subscriber.delete_subscription( request={"subscription": subscription_path}) logging.info(f"Deleted subscription: {subscription_path}") except NotFound: logging.info(f"Subscription already deleted: {subscription_path}") try: publisher.delete_topic(request={"topic": topic_path}) logging.info(f"Deleted topic: {topic_path}") except NotFound: logging.info(f"Topic already deleted: {topic_path}") ``` ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,536 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize shorter side->center crop->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + resize_size = 256 + w, h = img.size + if w < h: + new_w = resize_size + new_h = int(h * resize_size / w) + else: + new_h = resize_size + new_w = int(w * resize_size / h) + + img = img.resize((new_w, new_h)) + + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop((left, top, left + size, top + size)) + + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + + arr = np.asarray(img).astype("float32") / 255.0 + arr = (arr - mean) / std + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, + (bytes, + bytearray)) else element.encode('utf-8') + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__(self, input_mode: str, image_size: int = 224): + self.input_mode = input_mode + self.image_size = image_size + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + raise ValueError( + f"Unable to find 'logits' in model output. " + f"Available keys: {list(inference_obj.keys())}") + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Model factory (timm) ============ + + +def create_timm_m(model_name: str, num_classes: int = 1000): + import timm + model = timm.create_model( + model_name, pretrained=True, num_classes=num_classes) + model.eval() + return model + + +def pick_batch_size(arg: str) -> Optional[int]: + if isinstance(arg, str) and arg.lower() == 'auto': + return None + try: + return int(arg) + except Exception: + return None + + +class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor): + def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs): + self._batch_sizes_to_try = batch_sizes_to_try + self._rightfit_image_size = image_size + super().__init__(*args, **kwargs) + + def load_model(self): + model = super().load_model() + last_err = None + + for bs in self._batch_sizes_to_try: + try: + model_device = next(model.parameters()).device + dummy = torch.zeros( + (bs, 3, self._rightfit_image_size, self._rightfit_image_size), + dtype=torch.float32, + device=model_device) + + with torch.no_grad(): + model(dummy) + + self._batch_size = bs + self._inference_batch_size = bs + logging.info("Selected inference batch size: %s", bs) + return model + except RuntimeError as e: + last_err = e + logging.warning("Batch size %s failed during worker warmup: %s", bs, e) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + raise RuntimeError( + f"No valid inference batch size found from {self._batch_sizes_to_try}" + ) from last_err + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + _ = ( + pipeline + | 'ReadGCSFile' >> beam.io.ReadFromText(known_args.input) + | 'FilterEmpty' >> beam.Filter(lambda line: line.strip()) + | 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8')) + | 'ToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic)) + return pipeline.run() + + +# ============ Main pipeline ============ + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + known_args, pipeline_args = parse_known_args(argv) + + if known_args.mode == 'streaming': + ensure_pubsub_resources( + project=known_args.project, + topic_path=known_args.pubsub_topic, + subscription_path=known_args.pubsub_subscription) + + # Start feeder thread that reads URIs from GCS and fills Pub/Sub. + # Delay is used to allow the main streaming pipeline workers to start + # and autoscale before the feeder pipeline begins publishing messages. + threading.Thread( + target=lambda: ( + time.sleep(known_args.feeder_start_delay_sec), run_load_pipeline( + known_args, pipeline_args)), + daemon=True).start() + + # StandardOptions + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(StandardOptions).streaming = ( + known_args.mode == 'streaming') + + # Build model handler with right-fitting batch size + desired_batch = pick_batch_size(known_args.inference_batch_size) + + # Device + device = 'GPU' if known_args.device.upper() == 'GPU' else 'CPU' + + tried = [64, 32, 16, 8] if desired_batch is None else [desired_batch] + + model_handler = RightFittingPytorchModelHandlerTensor( + batch_sizes_to_try=tried, + image_size=known_args.image_size, + device=device, + model_class=lambda: create_timm_m(known_args.pretrained_model_name), + model_params={}, + state_dict_path=known_args.model_state_dict_path, + inference_batch_size=tried[0], + ) + + pipeline = test_pipeline or beam.Pipeline(options=pipeline_options) + + if known_args.mode == 'batch': + pcoll = ( + pipeline + | 'ReadURIsBatch' >> beam.io.ReadFromText(known_args.input) + | 'FilterEmptyBatch' >> beam.Filter(lambda s: s.strip())) + else: + pcoll = ( + pipeline + | 'ReadFromPubSub' >> + beam.io.ReadFromPubSub(subscription=known_args.pubsub_subscription) + | 'DecodeUTF8' >> beam.Map(lambda x: x.decode('utf-8')) + | 'Window' >> beam.WindowInto( + window.FixedWindows(known_args.window_sec), + trigger=beam.trigger.AfterProcessingTime( + known_args.trigger_proc_time_sec), + accumulation_mode=beam.trigger.AccumulationMode.DISCARDING, + allowed_lateness=0)) + + keyed = ( + pcoll + | 'MakeKey' >> beam.ParDo(MakeKeyDoFn(input_mode=known_args.input_mode))) + + preprocessed = ( + keyed + | 'DecodePreprocess' >> beam.ParDo( + DecodePreprocessDoFn( + input_mode=known_args.input_mode, + image_size=known_args.image_size))) + + to_infer = ( + preprocessed + | + 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"].float()))) + + predictions = ( + to_infer + | 'Reshuffle' >> beam.Reshuffle() + | 'RunInference' >> RunInference( + KeyedModelHandler(model_handler)).with_resource_hints( + accelerator="type:nvidia-tesla-t4;count:1;install-nvidia-driver")) Review Comment:  If the user runs the pipeline with `--device CPU`, the pipeline will still request a Tesla T4 GPU on Dataflow because of the hardcoded `.with_resource_hints` call. We should conditionally apply the resource hints only when running on GPU. ```python inference_transform = RunInference(KeyedModelHandler(model_handler)) if device == 'cuda': inference_transform = inference_transform.with_resource_hints( accelerator="type:nvidia-tesla-t4;count:1;install-nvidia-driver") predictions = ( to_infer | 'Reshuffle' >> beam.Reshuffle() | 'RunInference' >> inference_transform) ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,651 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except (OSError, ValueError) as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int, int]] = [] + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + image_idx = len(images) + images.append(img) + + start_i = len(texts) + texts.extend(candidates) + end_i = len(texts) + offsets.append((image_idx, start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + image_inputs = processor( + images=images, + return_tensors="pt", + ) + image_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in image_inputs.items() + } + + text_inputs = processor( + text=texts, + return_tensors="pt", + padding=True, + truncation=True, + ) + text_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in text_inputs.items() + } + + image_features = model.get_image_features( + pixel_values=image_inputs["pixel_values"]) + text_features = model.get_text_features( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs.get("attention_mask"), + ) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + logit_scale = model.logit_scale.exp() + + batch_ms = now_millis() - start_batch + total_pairs = len(texts) + + items = zip(offsets, candidates_list, blip_ms_list) + for (image_idx, start_i, end_i), candidates, blip_ms in items: + if start_i == end_i: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + continue + + candidate_features = text_features[start_i:end_i] + image_feature = image_features[image_idx].unsqueeze(0) + + pair_scores = (candidate_features * + image_feature).sum(dim=-1) * logit_scale + + scores = pair_scores.detach().cpu().tolist() + + if self.score_normalize: + scores_t = torch.tensor(scores, dtype=torch.float32) + scores = torch.softmax(scores_t, dim=0).tolist() + + best_idx = max(range(len(scores)), key=lambda i, s=scores: s[i]) + + pairs = end_i - start_i + clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs))) + if pairs > 0: + clip_ms_elem = max(1, clip_ms_elem) + + total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None + results.append({ + "best_caption": candidates[best_idx], + "best_score": float(scores[best_idx]), + "candidates": candidates, + "scores": scores, + "blip_ms": blip_ms, + "clip_ms": clip_ms_elem, + "total_ms": total_ms, + }) + + return results + + def get_metrics_namespace(self) -> str: + return "clip_ranking" + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Device + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # BLIP + parser.add_argument( + '--blip_model_name', default='Salesforce/blip-image-captioning-base') + parser.add_argument('--blip_batch_size', type=int, default=4) + parser.add_argument('--num_captions', type=int, default=5) + parser.add_argument('--max_new_tokens', type=int, default=30) + parser.add_argument('--num_beams', type=int, default=5) + + # CLIP + parser.add_argument( + '--clip_model_name', default='openai/clip-vit-base-patch32') + parser.add_argument('--clip_batch_size', type=int, default=8) + parser.add_argument( + '--clip_score_normalize', default='false', choices=['true', 'false']) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + Review Comment:  Splitting the topic/subscription path and reconstructing it with `project` causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use `topic_path` and `subscription_path` directly. ```python def ensure_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: publisher.get_topic(request={"topic": topic_path}) except NotFound: publisher.create_topic(name=topic_path) try: subscriber.get_subscription( request={"subscription": subscription_path}) except NotFound: subscriber.create_subscription( name=subscription_path, topic=topic_path) def cleanup_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: subscriber.delete_subscription( request={"subscription": subscription_path}) logging.info(f"Deleted subscription: {subscription_path}") except NotFound: logging.info(f"Subscription already deleted: {subscription_path}") try: publisher.delete_topic(request={"topic": topic_path}) logging.info(f"Deleted topic: {topic_path}") except NotFound: logging.info(f"Topic already deleted: {topic_path}") ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,651 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except (OSError, ValueError) as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int, int]] = [] + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + image_idx = len(images) + images.append(img) + + start_i = len(texts) + texts.extend(candidates) + end_i = len(texts) + offsets.append((image_idx, start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + image_inputs = processor( + images=images, + return_tensors="pt", + ) + image_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in image_inputs.items() + } + + text_inputs = processor( + text=texts, + return_tensors="pt", + padding=True, + truncation=True, + ) + text_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in text_inputs.items() + } + + image_features = model.get_image_features( + pixel_values=image_inputs["pixel_values"]) + text_features = model.get_text_features( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs.get("attention_mask"), + ) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + logit_scale = model.logit_scale.exp() + + batch_ms = now_millis() - start_batch + total_pairs = len(texts) + + items = zip(offsets, candidates_list, blip_ms_list) + for (image_idx, start_i, end_i), candidates, blip_ms in items: + if start_i == end_i: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + continue + + candidate_features = text_features[start_i:end_i] + image_feature = image_features[image_idx].unsqueeze(0) + + pair_scores = (candidate_features * + image_feature).sum(dim=-1) * logit_scale + + scores = pair_scores.detach().cpu().tolist() + + if self.score_normalize: + scores_t = torch.tensor(scores, dtype=torch.float32) + scores = torch.softmax(scores_t, dim=0).tolist() + + best_idx = max(range(len(scores)), key=lambda i, s=scores: s[i]) + + pairs = end_i - start_i + clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs))) + if pairs > 0: + clip_ms_elem = max(1, clip_ms_elem) + + total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None + results.append({ + "best_caption": candidates[best_idx], + "best_score": float(scores[best_idx]), + "candidates": candidates, + "scores": scores, + "blip_ms": blip_ms, + "clip_ms": clip_ms_elem, + "total_ms": total_ms, + }) + + return results + + def get_metrics_namespace(self) -> str: + return "clip_ranking" + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Device + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # BLIP + parser.add_argument( + '--blip_model_name', default='Salesforce/blip-image-captioning-base') + parser.add_argument('--blip_batch_size', type=int, default=4) + parser.add_argument('--num_captions', type=int, default=5) + parser.add_argument('--max_new_tokens', type=int, default=30) + parser.add_argument('--num_beams', type=int, default=5) + + # CLIP + parser.add_argument( + '--clip_model_name', default='openai/clip-vit-base-patch32') + parser.add_argument('--clip_batch_size', type=int, default=8) + parser.add_argument( + '--clip_score_normalize', default='false', choices=['true', 'false']) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') Review Comment:  `pipeline_args` is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using `pipeline_args = list(pipeline_args)`. ```suggestion def run_load_pipeline(known_args, pipeline_args): """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" pipeline_args = list(pipeline_args) # enforce smaller/CPU-only defaults for feeder override_or_add(pipeline_args, '--device', 'CPU') ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py: ########## @@ -0,0 +1,533 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs object detection using an open-source PyTorch +TorchVision detection model (e.g., Faster R-CNN ResNet50 FPN) on GPU. + +It reads image URIs from a GCS input file, decodes and preprocesses images, +runs batched GPU inference via RunInference, post-processes detection outputs, +and writes results to BigQuery. + +The pipeline targets stable and reproducible performance measurements for +GPU inference workloads (no right-fitting; fixed batch size). +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_to_tens(image_bytes: bytes, image_size: int = 800) -> torch.Tensor: + """Decode bytes -> RGB PIL -> resize/pad square -> float tensor [0..1], CHW. + + TorchVision detection models accept float tensors in [0..1]. We force a fixed + square shape so PytorchModelHandlerTensor can batch tensors with torch.stack. + """ + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + w, h = img.size + scale = min(image_size / float(w), image_size / float(h)) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + + img = img.resize((new_w, new_h)) + + padded = PILImage.new("RGB", (image_size, image_size), color=(0, 0, 0)) + left = (image_size - new_w) // 2 + top = (image_size - new_h) // 2 + padded.paste(img, (left, top)) + + import numpy as np + arr = np.asarray(padded).astype("float32") / 255.0 + arr = np.transpose(arr, (2, 0, 1)) + return torch.from_numpy(arr).float() + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) where the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, tensor).""" + def __init__(self, image_size: int = 800): + self.image_size = image_size + + def process(self, kv: Tuple[str, str]): + uri, _ = kv + start = now_millis() + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + tensor = decode_to_tens(image_bytes, image_size=self.image_size) + preprocess_ms = now_millis() - start + yield uri, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except (OSError, ValueError): + logging.exception("Decode failed for %s", uri) + return + + +def _torchvision_detection_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: torch.device, + inference_args: Optional[dict[str, Any]] = None, + model_id: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Inference function for TorchVision detection models. + + TorchVision detection models expect List[Tensor] where each tensor is: + - shape: [3, H, W] + - dtype: float32 + - values: [0..1] + """ + del inference_args + del model_id + + with torch.no_grad(): + # Ensure tensors are on device + inputs = [] + for t in batch: + if isinstance(t, torch.Tensor): + inputs.append(t.to(device)) + else: + # Defensive: if somehow non-tensor slips through. + inputs.append(torch.as_tensor(t).to(device)) + outputs = model(inputs) # List[Dict[str, Tensor]] + return outputs + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__( + self, model_name: str, score_threshold: float, max_detections: int): + self.model_name = model_name + self.score_threshold = score_threshold + self.max_detections = max_detections + + def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]: + """Extract detection fields from torchvision output dict.""" + # Expect: {'boxes': Tensor[N,4], 'labels': Tensor[N], 'scores': Tensor[N]} + boxes = inference_obj.get("boxes") + labels = inference_obj.get("labels") + scores = inference_obj.get("scores") + + # Convert to CPU lists + if isinstance(scores, torch.Tensor): + scores_list = scores.detach().cpu().tolist() + else: + scores_list = list(scores) if scores is not None else [] + + if isinstance(labels, torch.Tensor): + labels_list = labels.detach().cpu().tolist() + else: + labels_list = list(labels) if labels is not None else [] + + if isinstance(boxes, torch.Tensor): + boxes_list = boxes.detach().cpu().tolist() + else: + boxes_list = list(boxes) if boxes is not None else [] + + # Filter by threshold and trim to max_detections + dets = [] + for i in range(min(len(scores_list), len(labels_list), len(boxes_list))): + score = float(scores_list[i]) + if score < self.score_threshold: + continue + box = boxes_list[i] # [x1,y1,x2,y2] + dets.append({ + "label_id": int(labels_list[i]), + "score": score, + "box": [float(box[0]), float(box[1]), float(box[2]), float(box[3])], + }) + if len(dets) >= self.max_detections: + break + + return { + "detections": dets, + "num_detections": len(dets), + } + + def process(self, kv: Tuple[str, PredictionResult]): + image_uri, pred = kv + + # pred can be PredictionResult OR raw torchvision dict. + if hasattr(pred, "inference"): + inference_obj = pred.inference + else: + inference_obj = pred + + if isinstance(inference_obj, list) and len(inference_obj) == 1: + inference_obj = inference_obj[0] + + if not isinstance(inference_obj, dict): + logging.warning( + "Unexpected inf-ce type for %s: %s", image_uri, type(inference_obj)) + yield { + "image_id": image_uri, + "model_name": self.model_name, + "detections": json.dumps([]), + "num_detections": 0, + "infer_ms": now_millis(), + } + return + + extracted = self._extract_detection(inference_obj) + + yield { + "image_id": image_uri, + "model_name": self.model_name, + "detections": json.dumps(extracted["detections"]), + "num_detections": int(extracted["num_detections"]), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='fasterrcnn_resnet50_fpn', + help=( + 'TorchVision detection model name ' + '(e.g., fasterrcnn_resnet50_fpn)')) + parser.add_argument( + '--model_state_dict_path', + required=True, + help='GCS path to a state_dict .pth for the chosen model') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # Batch sizing (no right-fitting) + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Preprocess + parser.add_argument('--image_size', type=int, default=800) + + # Postprocess + parser.add_argument('--score_threshold', type=float, default=0.5) + parser.add_argument('--max_detections', type=int, default=50) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +def create_torchvision_detection_model(model_name: str): + """Creates a TorchVision detection model instance. + + Note: We will load weights via state_dict_path (required by Beam handler when + model_class is provided). + """ + import torchvision + + name = model_name.strip() + + if name == "fasterrcnn_resnet50_fpn": + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None) + elif name == "retinanet_resnet50_fpn": + model = torchvision.models.detection.retinanet_resnet50_fpn(weights=None) + else: + raise ValueError(f"Unsupported detection model: {model_name}") + + model.eval() + return model + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') Review Comment:  `pipeline_args` is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using `pipeline_args = list(pipeline_args)`. ```suggestion def run_load_pipeline(known_args, pipeline_args): """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" pipeline_args = list(pipeline_args) # enforce smaller/CPU-only defaults for feeder override_or_add(pipeline_args, '--device', 'CPU') ``` ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,536 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize shorter side->center crop->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + resize_size = 256 + w, h = img.size + if w < h: + new_w = resize_size + new_h = int(h * resize_size / w) + else: + new_h = resize_size + new_w = int(w * resize_size / h) + + img = img.resize((new_w, new_h)) + + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop((left, top, left + size, top + size)) + + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + + arr = np.asarray(img).astype("float32") / 255.0 + arr = (arr - mean) / std + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, + (bytes, + bytearray)) else element.encode('utf-8') + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__(self, input_mode: str, image_size: int = 224): + self.input_mode = input_mode + self.image_size = image_size + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + raise ValueError( + f"Unable to find 'logits' in model output. " + f"Available keys: {list(inference_obj.keys())}") + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Model factory (timm) ============ + + +def create_timm_m(model_name: str, num_classes: int = 1000): + import timm + model = timm.create_model( + model_name, pretrained=True, num_classes=num_classes) + model.eval() + return model + + +def pick_batch_size(arg: str) -> Optional[int]: + if isinstance(arg, str) and arg.lower() == 'auto': + return None + try: + return int(arg) + except Exception: + return None + + +class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor): + def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs): + self._batch_sizes_to_try = batch_sizes_to_try + self._rightfit_image_size = image_size + super().__init__(*args, **kwargs) + + def load_model(self): + model = super().load_model() + last_err = None + + for bs in self._batch_sizes_to_try: + try: + model_device = next(model.parameters()).device + dummy = torch.zeros( + (bs, 3, self._rightfit_image_size, self._rightfit_image_size), + dtype=torch.float32, + device=model_device) + + with torch.no_grad(): + model(dummy) + + self._batch_size = bs + self._inference_batch_size = bs + logging.info("Selected inference batch size: %s", bs) + return model + except RuntimeError as e: + last_err = e + logging.warning("Batch size %s failed during worker warmup: %s", bs, e) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + raise RuntimeError( + f"No valid inference batch size found from {self._batch_sizes_to_try}" + ) from last_err + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') Review Comment:  `pipeline_args` is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using `pipeline_args = list(pipeline_args)`. ```suggestion def run_load_pipeline(known_args, pipeline_args): """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" pipeline_args = list(pipeline_args) # enforce smaller/CPU-only defaults for feeder override_or_add(pipeline_args, '--device', 'CPU') ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,651 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except (OSError, ValueError) as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int, int]] = [] + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + image_idx = len(images) + images.append(img) + + start_i = len(texts) + texts.extend(candidates) + end_i = len(texts) + offsets.append((image_idx, start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + image_inputs = processor( + images=images, + return_tensors="pt", + ) + image_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in image_inputs.items() + } Review Comment:  Hugging Face `BatchEncoding` / `BatchFeature` objects have a built-in `.to(device)` method that cleanly moves all internal tensors to the specified device. We can replace the dict comprehension with `.to(self.device)`. ```suggestion image_inputs = processor( images=images, return_tensors="pt", ).to(self.device) ``` ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,536 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize shorter side->center crop->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + resize_size = 256 + w, h = img.size + if w < h: + new_w = resize_size + new_h = int(h * resize_size / w) + else: + new_h = resize_size + new_w = int(w * resize_size / h) + + img = img.resize((new_w, new_h)) + + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop((left, top, left + size, top + size)) + + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + + arr = np.asarray(img).astype("float32") / 255.0 + arr = (arr - mean) / std + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, + (bytes, + bytearray)) else element.encode('utf-8') + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__(self, input_mode: str, image_size: int = 224): + self.input_mode = input_mode + self.image_size = image_size + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + raise ValueError( + f"Unable to find 'logits' in model output. " + f"Available keys: {list(inference_obj.keys())}") + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + Review Comment:  Splitting the topic/subscription path and reconstructing it with `project` causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use `topic_path` and `subscription_path` directly. ```python def ensure_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: publisher.get_topic(request={"topic": topic_path}) except NotFound: publisher.create_topic(name=topic_path) try: subscriber.get_subscription( request={"subscription": subscription_path}) except NotFound: subscriber.create_subscription( name=subscription_path, topic=topic_path) def cleanup_pubsub_resources( project: str, topic_path: str, subscription_path: str): publisher = pubsub_v1.PublisherClient() subscriber = pubsub_v1.SubscriberClient() try: subscriber.delete_subscription( request={"subscription": subscription_path}) logging.info(f"Deleted subscription: {subscription_path}") except NotFound: logging.info(f"Subscription already deleted: {subscription_path}") try: publisher.delete_topic(request={"topic": topic_path}) logging.info(f"Deleted topic: {topic_path}") except NotFound: logging.info(f"Topic already deleted: {topic_path}") ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,651 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except (OSError, ValueError) as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int, int]] = [] + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + image_idx = len(images) + images.append(img) + + start_i = len(texts) + texts.extend(candidates) + end_i = len(texts) + offsets.append((image_idx, start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + image_inputs = processor( + images=images, + return_tensors="pt", + ) + image_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in image_inputs.items() + } + + text_inputs = processor( + text=texts, + return_tensors="pt", + padding=True, + truncation=True, + ) + text_inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in text_inputs.items() + } Review Comment:  Hugging Face `BatchEncoding` / `BatchFeature` objects have a built-in `.to(device)` method that cleanly moves all internal tensors to the specified device. We can replace the dict comprehension with `.to(self.device)`. ```suggestion text_inputs = processor( text=texts, return_tensors="pt", padding=True, truncation=True, ).to(self.device) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
