damccorm commented on code in PR #37186: URL: https://github.com/apache/beam/pull/37186#discussion_r2828491340
########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): Review Comment: Why do we have this function? It will not effectively provide a global rate limit since multiple instances of this will be running in parallel ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + if self._model is None: + self._model = self.load_model() + + start = now_millis() + + images = [] + uris = [] + bytes_list = [] + for x in batch: + b = x["image_bytes"] + bytes_list.append(b) + uris.append(x.get("uri", "")) + try: + images.append(decode_pil(b)) + except Exception: + # fallback: a blank image (so pipeline keeps going) + images.append(PILImage.new("RGB", (224, 224), color=(0, 0, 0))) + + # Processor makes pixel_values + inputs = self._processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = self._model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = self._processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image_bytes": bytes_list[i], + "uri": uris[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + self._model = None + self._processor = None + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + self._processor = CLIPProcessor.from_pretrained(self.model_name) + self._model = CLIPModel.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import CLIPProcessor + self._processor = CLIPProcessor.from_pretrained(self.model_name) + if self._model is None: + self._model = self.load_model() + + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int]] = [] + # per element -> [start, end) in flat arrays + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + image_bytes = x["image_bytes"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + try: + img = decode_pil(image_bytes) + except Exception: + img = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + start_i = len(texts) + for c in candidates: + images.append(img) + texts.append(c) + end_i = len(texts) + offsets.append((start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + inputs = self._processor( + text=texts, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in inputs.items() + } + + # avoid NxN logits inside CLIPModel.forward() + img = self._model.get_image_features( + pixel_values=inputs["pixel_values"]) # [N, D] + txt = self._model.get_text_features( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask"), + ) # [N, D] + + img = img / img.norm(dim=-1, keepdim=True) + txt = txt / txt.norm(dim=-1, keepdim=True) + + logit_scale = self._model.logit_scale.exp() # scalar tensor + pair_scores = (img * txt).sum(dim=-1) * logit_scale # [N] + pair_scores_cpu = pair_scores.detach().cpu().tolist() + + batch_ms = now_millis() - start_batch + total_pairs = len(texts) + + items = zip(offsets, candidates_list, blip_ms_list) + for (start_i, end_i), candidates, blip_ms in items: + if start_i == end_i: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + continue + + scores = [float(pair_scores_cpu[j]) for j in range(start_i, end_i)] + + if self.score_normalize: + scores_t = torch.tensor(scores, dtype=torch.float32) + scores = torch.softmax(scores_t, dim=0).tolist() + + best_idx = max(range(len(scores)), key=lambda i, s=scores: s[i]) + + pairs = end_i - start_i + clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs))) + if pairs > 0: + clip_ms_elem = max(1, clip_ms_elem) + + total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None + results.append({ + "best_caption": candidates[best_idx], + "best_score": float(scores[best_idx]), + "candidates": candidates, + "scores": scores, + "blip_ms": blip_ms, + "clip_ms": clip_ms_elem, + "total_ms": total_ms, + }) + + return results + + def get_metrics_namespace(self) -> str: + return "clip_ranking" + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--rate_limit', + type=float, + default=None, + help='Elements per second for load pipeline') + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + + # Device + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # BLIP + parser.add_argument( + '--blip_model_name', default='Salesforce/blip-image-captioning-base') + parser.add_argument('--blip_batch_size', type=int, default=4) + parser.add_argument('--num_captions', type=int, default=5) + parser.add_argument('--max_new_tokens', type=int, default=30) + parser.add_argument('--num_beams', type=int, default=5) + + # CLIP + parser.add_argument( + '--clip_model_name', default='openai/clip-vit-base-patch32') + parser.add_argument('--clip_batch_size', type=int, default=8) + parser.add_argument( + '--clip_score_normalize', default='false', choices=['true', 'false']) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except Exception: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except Exception: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + print(f"Deleted subscription: {subscription_name}") + except Exception as e: + print(f"Failed to delete subscription: {e}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + print(f"Deleted topic: {topic_name}") + except Exception as e: + print(f"Failed to delete topic: {e}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + lines = ( + pipeline + | + 'ReadGCSFile' >> beam.Create(list(read_gcs_file_lines(known_args.input))) Review Comment: Or alternately, if we're doing it locally, can we just do all of our Pub/Sub hydration without a beam pipeline ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() Review Comment: Haven't we already called this in load_model? ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: Review Comment: When could model be none? ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): Review Comment: Creating a rate limit is doable, but we would need to use a stateful DoFn to effectively do this. Basically the idea would be: 1) Key all the incoming data with a single (non-unique) key 2) For each incoming piece of data: - check stored state to see if it is ready to be released, and if not sleep until it is - Yield the element - Store the next release time (current time + delay) in state Because this functionally single-threads the output, it may be too slow to achieve the target rate; if that's the case, in step (1) you can partition to N keys, and do the same thing for each of them, yielding at a rate of `rate_per_sec/N` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + if self._model is None: + self._model = self.load_model() + + start = now_millis() + + images = [] + uris = [] + bytes_list = [] + for x in batch: + b = x["image_bytes"] + bytes_list.append(b) + uris.append(x.get("uri", "")) + try: + images.append(decode_pil(b)) + except Exception: + # fallback: a blank image (so pipeline keeps going) + images.append(PILImage.new("RGB", (224, 224), color=(0, 0, 0))) + + # Processor makes pixel_values + inputs = self._processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = self._model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = self._processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image_bytes": bytes_list[i], + "uri": uris[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): Review Comment: Same general comments as the Blip model handler apply here ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + if self._model is None: + self._model = self.load_model() + + start = now_millis() + + images = [] + uris = [] + bytes_list = [] + for x in batch: + b = x["image_bytes"] + bytes_list.append(b) + uris.append(x.get("uri", "")) + try: + images.append(decode_pil(b)) + except Exception: + # fallback: a blank image (so pipeline keeps going) + images.append(PILImage.new("RGB", (224, 224), color=(0, 0, 0))) + + # Processor makes pixel_values + inputs = self._processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = self._model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = self._processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image_bytes": bytes_list[i], + "uri": uris[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + self._model = None + self._processor = None + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + self._processor = CLIPProcessor.from_pretrained(self.model_name) + self._model = CLIPModel.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import CLIPProcessor + self._processor = CLIPProcessor.from_pretrained(self.model_name) + if self._model is None: + self._model = self.load_model() + + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int]] = [] + # per element -> [start, end) in flat arrays + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + image_bytes = x["image_bytes"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + try: + img = decode_pil(image_bytes) + except Exception: + img = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + start_i = len(texts) + for c in candidates: + images.append(img) + texts.append(c) + end_i = len(texts) + offsets.append((start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + inputs = self._processor( + text=texts, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in inputs.items() + } + + # avoid NxN logits inside CLIPModel.forward() + img = self._model.get_image_features( + pixel_values=inputs["pixel_values"]) # [N, D] + txt = self._model.get_text_features( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask"), + ) # [N, D] + + img = img / img.norm(dim=-1, keepdim=True) + txt = txt / txt.norm(dim=-1, keepdim=True) + + logit_scale = self._model.logit_scale.exp() # scalar tensor + pair_scores = (img * txt).sum(dim=-1) * logit_scale # [N] + pair_scores_cpu = pair_scores.detach().cpu().tolist() + + batch_ms = now_millis() - start_batch + total_pairs = len(texts) + + items = zip(offsets, candidates_list, blip_ms_list) + for (start_i, end_i), candidates, blip_ms in items: + if start_i == end_i: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + continue + + scores = [float(pair_scores_cpu[j]) for j in range(start_i, end_i)] + + if self.score_normalize: + scores_t = torch.tensor(scores, dtype=torch.float32) + scores = torch.softmax(scores_t, dim=0).tolist() + + best_idx = max(range(len(scores)), key=lambda i, s=scores: s[i]) + + pairs = end_i - start_i + clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs))) + if pairs > 0: + clip_ms_elem = max(1, clip_ms_elem) + + total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None + results.append({ + "best_caption": candidates[best_idx], + "best_score": float(scores[best_idx]), + "candidates": candidates, + "scores": scores, + "blip_ms": blip_ms, + "clip_ms": clip_ms_elem, + "total_ms": total_ms, + }) + + return results + + def get_metrics_namespace(self) -> str: + return "clip_ranking" + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--rate_limit', + type=float, + default=None, + help='Elements per second for load pipeline') + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + + # Device + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # BLIP + parser.add_argument( + '--blip_model_name', default='Salesforce/blip-image-captioning-base') + parser.add_argument('--blip_batch_size', type=int, default=4) + parser.add_argument('--num_captions', type=int, default=5) + parser.add_argument('--max_new_tokens', type=int, default=30) + parser.add_argument('--num_beams', type=int, default=5) + + # CLIP + parser.add_argument( + '--clip_model_name', default='openai/clip-vit-base-patch32') + parser.add_argument('--clip_batch_size', type=int, default=8) + parser.add_argument( + '--clip_score_normalize', default='false', choices=['true', 'false']) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except Exception: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except Exception: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + print(f"Deleted subscription: {subscription_name}") + except Exception as e: + print(f"Failed to delete subscription: {e}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + print(f"Deleted topic: {topic_name}") + except Exception as e: + print(f"Failed to delete topic: {e}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + lines = ( + pipeline + | + 'ReadGCSFile' >> beam.Create(list(read_gcs_file_lines(known_args.input))) Review Comment: Can we just use built in Beam transforms to read from gcs instead of doing it all locally? ########## .github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml: ########## @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: Inference Python Benchmarks Dataflow +name: Inference Python Benchmarks Dataflow (1 part) Review Comment: Or if there is a reason, could you add a comment explaining it? (maybe what I'm suggesting would exhaust resources and we need different cron schedules?) ########## .github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml: ########## @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: Inference Python Benchmarks Dataflow +name: Inference Python Benchmarks Dataflow (1 part) Review Comment: Is there a reason to split these into different workflows? If it is just about minimizing the time it takes to run, could we do one workflow with 2 jobs? ########## sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py: ########## @@ -0,0 +1,563 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs object detection using an open-source PyTorch +TorchVision detection model (e.g., Faster R-CNN ResNet50 FPN) on GPU. + +It reads image URIs from a GCS input file, decodes and preprocesses images, +runs batched GPU inference via RunInference, post-processes detection outputs, +and writes results to BigQuery. + +The pipeline targets stable and reproducible performance measurements for +GPU inference workloads (no right-fitting; fixed batch size). +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_to_tens( + image_bytes: bytes, + resize_shorter_side: Optional[int] = None) -> torch.Tensor: + """Decode bytes -> RGB PIL -> optional resize -> float tensor [0..1], CHW. + + Note: TorchVision detection models apply their own normalization internally. + """ + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + if resize_shorter_side and resize_shorter_side > 0: + w, h = img.size + # Resize so that shorter side == resize_shorter_side, keep aspect ratio. + if w < h: + new_w = resize_shorter_side + new_h = int(h * (resize_shorter_side / float(w))) + else: + new_h = resize_shorter_side + new_w = int(w * (resize_shorter_side / float(h))) + img = img.resize((new_w, new_h)) + + import numpy as np + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 in [0..1] + arr = np.transpose(arr, (2, 0, 1)) # CHW + return torch.from_numpy(arr) + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): Review Comment: I have the same questions about this file as the previous one - also, can we split the shared functions out into a helper class? ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,552 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline ensures exactly-once semantics via stateful deduplication and +idempotent BigQuery writes, allowing stable and reproducible performance +measurements under continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Iterable +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.coders import BytesCoder +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import userstate +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize/crop->tensor->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.thumbnail((256, 256)) + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop( + (max(0, left), max(0, top), min(w, left + size), min(h, top + size))) + + # To tensor [0..1] + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 + # Normalize + arr = (arr - mean) / std + # HWC -> CHW + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() # float32, shape (3,224,224) + + +class RateLimitDoFn(beam.DoFn): Review Comment: Same general questions about this file ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,552 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline ensures exactly-once semantics via stateful deduplication and +idempotent BigQuery writes, allowing stable and reproducible performance +measurements under continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Iterable +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.coders import BytesCoder +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import userstate +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize/crop->tensor->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.thumbnail((256, 256)) + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop( + (max(0, left), max(0, top), min(w, left + size), min(h, top + size))) + + # To tensor [0..1] + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 + # Normalize + arr = (arr - mean) / std + # HWC -> CHW + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() # float32, shape (3,224,224) + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, (bytes, bytearray)) else bytes(element) + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DedupDoFn(beam.DoFn): + seen = userstate.ReadModifyWriteStateSpec('seen', BytesCoder()) + + def process(self, element, seen=beam.DoFn.StateParam(seen)): + if seen.read() == b'1': + return + seen.write(b'1') + yield element + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__( + self, input_mode: str, image_size: int = 224, decode_threads: int = 4): + self.input_mode = input_mode + self.image_size = image_size + self.decode_threads = decode_threads + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + # fallback: try first value if dict shape differs + try: + logits = next(iter(inference_obj.values())) + except Exception: + logits = None + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--rate_limit', + type=float, + default=None, + help='Elements per second for load pipeline') + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + # Dedup + parser.add_argument( + '--enable_dedup', default='false', choices=['true', 'false']) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except Exception: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except Exception: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + print(f"Deleted subscription: {subscription_name}") + except Exception as e: + print(f"Failed to delete subscription: {e}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + print(f"Deleted topic: {topic_name}") + except Exception as e: + print(f"Failed to delete topic: {e}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Model factory (timm) ============ + + +def create_timm_m(model_name: str, num_classes: int = 1000): + import timm + model = timm.create_model( + model_name, pretrained=True, num_classes=num_classes) + model.eval() + return model + + +def pick_batch_size(arg: str) -> Optional[int]: + if isinstance(arg, str) and arg.lower() == 'auto': + return None + try: + return int(arg) + except Exception: + return None + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + lines = ( + pipeline + | + 'ReadGCSFile' >> beam.Create(list(read_gcs_file_lines(known_args.input))) + | 'FilterEmpty' >> beam.Filter(lambda line: line.strip())) + if known_args.rate_limit: + lines = lines | 'RateLimit' >> beam.ParDo( + RateLimitDoFn(rate_per_sec=known_args.rate_limit)) + + _ = ( + lines + | 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8')) + | 'ToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic)) + return pipeline.run() + + +# ============ Main pipeline ============ + + +def run( Review Comment: I see this is a rightfitting pipeline, but does it actually use resource hints? https://docs.cloud.google.com/dataflow/docs/guides/right-fitting#python ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,690 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def read_gcs_file_lines(gcs_path: str) -> Iterable[str]: + """Reads text lines from a GCS file.""" + with FileSystems.open(gcs_path) as f: + for line in f.read().decode("utf-8").splitlines(): + yield line.strip() + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def sha1_hex(s: str) -> str: + import hashlib + return hashlib.sha1(s.encode("utf-8")).hexdigest() + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class RateLimitDoFn(beam.DoFn): + def __init__(self, rate_per_sec: float): + self.delay = 1.0 / rate_per_sec + + def process(self, element): + time.sleep(self.delay) + yield element + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, uri) where image_id is stable for dedup and keys.""" + def process(self, element: str): + uri = element + image_id = sha1_hex(uri) + yield image_id, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri)).""" + def process(self, kv: Tuple[str, str]): + image_id, uri = kv + try: + b = load_image_from_uri(uri) + yield image_id, {"image_bytes": b, "uri": uri} + except Exception as e: + logging.warning("Failed to read image %s (%s): %s", image_id, uri, e) + return + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": image_id, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + self._model = None + self._processor = None + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) + self._model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self._model.eval() + self._model.to(self.device) + return self._model + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model, inference_args=None): + + if model is not None: + self._model = model + self._model.to(self.device) + self._model.eval() + if self._processor is None: + from transformers import BlipProcessor + self._processor = BlipProcessor.from_pretrained(self.model_name) Review Comment: A better pattern here might just be to return a class which contains both the processor and the model from `load_model` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
