This is an automated email from the ASF dual-hosted git repository. jrmccluskey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new fe188e3635b [#34236] Add Vertex AI Multi-Modal embedding handler (#35677) fe188e3635b is described below commit fe188e3635b894fef4f6d2b2f7eda0c09608556f Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com> AuthorDate: Mon Sep 8 16:51:33 2025 -0400 [#34236] Add Vertex AI Multi-Modal embedding handler (#35677) * Prototype Vertex MultiModal embedding handler * remove unused types * change temp file and artifact paths to use dedicated directories * formatting * quick unit tests for the base multimodal embedding handler * Migrate to input adapter, add testing for video * linting * isort * made segment configuration per-video instance * fix corrected video input type * speed up video test by passing a GCS URI instead of loading the video * formatting * move to wrapped inputs * clarify types in dict_input_fn * linting * fix chunk construction * update main input to use wrappers --- sdks/python/apache_beam/ml/transforms/base.py | 39 ++++ sdks/python/apache_beam/ml/transforms/base_test.py | 117 ++++++++++ .../ml/transforms/embeddings/vertex_ai.py | 241 ++++++++++++++++++++- .../ml/transforms/embeddings/vertex_ai_test.py | 107 +++++++++ 4 files changed, 501 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 3b95ed719e5..4031777ce15 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -810,3 +810,42 @@ class _ImageEmbeddingHandler(_EmbeddingHandler): return ( self._underlying.get_metrics_namespace() or 'BeamML_ImageEmbeddingHandler') + + +class _MultiModalEmbeddingHandler(_EmbeddingHandler): + """ + A ModelHandler intended to be work on + list[dict[str, TypedDict(Image, Video, str)]] inputs. + + The inputs to the model handler are expected to be a list of dicts. + + For example, if the original mode is used with RunInference to take a + PCollection[E] to a PCollection[P], this ModelHandler would take a + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. + + _MultiModalEmbeddingHandler will accept an EmbeddingsManager instance, which + contains the details of the model to be loaded and the inference_fn to be + used. The purpose of _MultiMOdalEmbeddingHandler is to generate embeddings + for image, video, and text inputs using the EmbeddingsManager instance. + + If the input is not an Image representation column, a RuntimeError will be + raised. + + This is an internal class and offers no backwards compatibility guarantees. + + Args: + embeddings_manager: An EmbeddingsManager instance. + """ + def _validate_column_data(self, batch): + # Don't want to require framework-specific imports + # here, so just catch columns of primatives for now. + if isinstance(batch[0], (int, str, float, bool)): + raise TypeError( + 'Embeddings can only be generated on ' + ' dict[str, dataclass] types. ' + f'Got dict[str, {type(batch[0])}] instead.') + + def get_metrics_namespace(self) -> str: + return ( + self._underlying.get_metrics_namespace() or + 'BeamML_MultiModalEmbeddingHandler') diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 309c085f08f..190381cc2f3 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -23,6 +23,7 @@ import tempfile import time import unittest from collections.abc import Sequence +from dataclasses import dataclass from typing import Any from typing import Optional @@ -629,6 +630,122 @@ class TestImageEmbeddingHandler(unittest.TestCase): ) +@dataclass +class FakeMultiModalInput: + image: Optional[PIL_Image] = None + video: Optional[Any] = None + text: Optional[str] = None + + +class FakeMultiModalModel: + def __call__(self, + example: list[FakeMultiModalInput]) -> list[FakeMultiModalInput]: + for i in range(len(example)): + if not isinstance(example[i], FakeMultiModalInput): + raise TypeError('Input must be a MultiModalInput') + return example + + +class FakeMultiModalModelHandler(ModelHandler): + def run_inference( + self, + batch: Sequence[FakeMultiModalInput], + model: Any, + inference_args: Optional[dict[str, Any]] = None): + return model(batch) + + def load_model(self): + return FakeMultiModalModel() + + +class FakeMultiModalEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, columns, **kwargs): + super().__init__(columns=columns, **kwargs) + + def get_model_handler(self) -> ModelHandler: + FakeModelHandler.__repr__ = lambda x: 'FakeMultiModalEmbeddingsManager' # type: ignore[method-assign] + return FakeMultiModalModelHandler() + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=base._MultiModalEmbeddingHandler(self))) + + def __repr__(self): + return 'FakeMultiModalEmbeddingsManager' + + +class TestMultiModalEmbeddingHandler(unittest.TestCase): + def setUp(self) -> None: + self.embedding_config = FakeMultiModalEmbeddingsManager(columns=['x']) + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + @unittest.skipIf(PIL is None, 'PIL module is not installed.') + def test_handler_with_non_dict_datatype(self): + image_handler = base._MultiModalEmbeddingHandler( + embeddings_manager=self.embedding_config) + data = [ + ('x', 'hi there'), + ('x', 'not an image'), + ('x', 'image_path.jpg'), + ] + with self.assertRaises(TypeError): + image_handler.run_inference(data, None, None) + + @unittest.skipIf(PIL is None, 'PIL module is not installed.') + def test_handler_with_incorrect_datatype(self): + image_handler = base._MultiModalEmbeddingHandler( + embeddings_manager=self.embedding_config) + data = [ + { + 'x': 'hi there' + }, + { + 'x': 'not an image' + }, + { + 'x': 'image_path.jpg' + }, + ] + with self.assertRaises(TypeError): + image_handler.run_inference(data, None, None) + + @unittest.skipIf(PIL is None, 'PIL module is not installed.') + def test_handler_with_dict_inputs(self): + input_one = FakeMultiModalInput( + image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image one") + input_two = FakeMultiModalInput( + image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image two") + input_three = FakeMultiModalInput( + image=PIL.Image.new(mode='RGB', size=(1, 1)), + video=bytes.fromhex('2Ef0 F1f2 '), + text="test image three with video") + data = [ + { + 'x': input_one + }, + { + 'x': input_two + }, + { + 'x': input_three + }, + ] + expected_data = [{key: value for key, value in d.items()} for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_config)) + assert_that( + result, + equal_to(expected_data), + ) + + class TestUtilFunctions(unittest.TestCase): def test_dict_input_fn_normal(self): input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index a645ce32e2a..c7c46d246b9 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,10 +19,14 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +import functools import logging +from collections.abc import Callable from collections.abc import Sequence +from dataclasses import dataclass from typing import Any from typing import Optional +from typing import cast from google.api_core.exceptions import ServerError from google.api_core.exceptions import TooManyRequests @@ -33,15 +37,28 @@ import vertexai from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RemoteModelHandler from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter from apache_beam.ml.transforms.base import _ImageEmbeddingHandler +from apache_beam.ml.transforms.base import _MultiModalEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler from vertexai.language_models import TextEmbeddingInput from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import Image from vertexai.vision_models import MultiModalEmbeddingModel - -__all__ = ["VertexAITextEmbeddings", "VertexAIImageEmbeddings"] +from vertexai.vision_models import MultiModalEmbeddingResponse +from vertexai.vision_models import Video +from vertexai.vision_models import VideoEmbedding +from vertexai.vision_models import VideoSegmentConfig + +__all__ = [ + "VertexAITextEmbeddings", + "VertexAIImageEmbeddings", + "VertexAIMultiModalEmbeddings", + "VertexAIMultiModalInput", +] DEFAULT_TASK_TYPE = "RETRIEVAL_DOCUMENT" # TODO: https://github.com/apache/beam/issues/29356 @@ -54,7 +71,6 @@ TASK_TYPE_INPUTS = [ "CLUSTERING" ] _BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. -_MSEC_TO_SEC = 1000 LOGGER = logging.getLogger("VertexAIEmbeddings") @@ -281,3 +297,222 @@ class VertexAIImageEmbeddings(EmbeddingsManager): return RunInference( model_handler=_ImageEmbeddingHandler(self), inference_args=self.inference_args) + + +@dataclass +class VertexImage: + image_content: Image + embedding: Optional[list[float]] = None + + +@dataclass +class VertexVideo: + video_content: Video + config: VideoSegmentConfig + embeddings: Optional[list[VideoEmbedding]] = None + + +@dataclass +class VertexAIMultiModalInput: + image: Optional[VertexImage] = None + video: Optional[VertexVideo] = None + contextual_text: Optional[Chunk] = None + + +class _VertexAIMultiModalEmbeddingHandler(RemoteModelHandler): + def __init__( + self, + model_name: str, + dimension: Optional[int] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + **kwargs): + vertexai.init(project=project, location=location, credentials=credentials) + self.model_name = model_name + self.dimension = dimension + + super().__init__( + namespace='VertexAIMultiModelEmbeddingHandler', + retry_filter=_retry_on_appropriate_gcp_error, + **kwargs) + + def request( + self, + batch: Sequence[VertexAIMultiModalInput], + model: MultiModalEmbeddingModel, + inference_args: Optional[dict[str, Any]] = None): + embeddings = [] + # Max request size for multi-modal embedding models is 1 + for input in batch: + image_content: Optional[Image] = None + video_content: Optional[Video] = None + text_content: Optional[str] = None + video_config: Optional[VideoSegmentConfig] = None + + if input.image: + image_content = input.image.image_content + if input.video: + video_content = input.video.video_content + video_config = input.video.config + if input.contextual_text: + text_content = input.contextual_text.content.text + + prediction = model.get_embeddings( + image=image_content, + video=video_content, + contextual_text=text_content, + dimension=self.dimension, + video_segment_config=video_config) + embeddings.append(prediction) + return embeddings + + def create_client(self) -> MultiModalEmbeddingModel: + model = MultiModalEmbeddingModel.from_pretrained(self.model_name) + return model + + def __repr__(self): + # ModelHandler is internal to the user and is not exposed. + # Hence we need to override the __repr__ method to expose + # the name of the class. + return 'VertexAIMultiModalEmbeddings' + + +def _multimodal_dict_input_fn( + image_column: Optional[str], + video_column: Optional[str], + text_column: Optional[str], + batch: Sequence[dict[str, Any]]) -> list[VertexAIMultiModalInput]: + multimodal_inputs: list[VertexAIMultiModalInput] = [] + for item in batch: + img: Optional[VertexImage] = None + vid: Optional[VertexVideo] = None + text: Optional[Chunk] = None + if image_column: + img = item[image_column] + if video_column: + vid = item[video_column] + if text_column: + text = item[text_column] + multimodal_inputs.append( + VertexAIMultiModalInput(image=img, video=vid, contextual_text=text)) + return multimodal_inputs + + +def _multimodal_dict_output_fn( + image_column: Optional[str], + video_column: Optional[str], + text_column: Optional[str], + batch: Sequence[dict[str, Any]], + embeddings: Sequence[MultiModalEmbeddingResponse]) -> list[dict[str, Any]]: + results = [] + for batch_idx, item in enumerate(batch): + mm_embedding = embeddings[batch_idx] + if image_column: + item[image_column].embedding = mm_embedding.image_embedding + if video_column: + item[video_column].embeddings = mm_embedding.video_embeddings + if text_column: + item[text_column].embedding = Embedding( + dense_embedding=mm_embedding.text_embedding) + results.append(item) + return results + + +def _create_multimodal_dict_adapter( + image_column: Optional[str], + video_column: Optional[str], + text_column: Optional[str] +) -> EmbeddingTypeAdapter[dict[str, Any], dict[str, Any]]: + return EmbeddingTypeAdapter[dict[str, Any], dict[str, Any]]( + input_fn=cast( + Callable[[Sequence[dict[str, Any]]], list[str]], + functools.partial( + _multimodal_dict_input_fn, + image_column, + video_column, + text_column)), + output_fn=cast( + Callable[[Sequence[dict[str, Any]], Sequence[Any]], + list[dict[str, Any]]], + functools.partial( + _multimodal_dict_output_fn, + image_column, + video_column, + text_column))) + + +class VertexAIMultiModalEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + image_column: Optional[str] = None, + video_column: Optional[str] = None, + text_column: Optional[str] = None, + dimension: Optional[int] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + **kwargs): + """ + Embedding Config for Vertex AI Multi-Modal Embedding models following + https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-multimodal-embeddings # pylint: disable=line-too-long + Multi-Modal Embeddings are generated for a batch of image, video, and + string groupings using the Vertex AI API. Embeddings are returned in a list + for each image in the batch as MultiModalEmbeddingResponses. This + transform makes remote calls to the Vertex AI service and may incur costs + for use. + + Args: + model_name: The name of the Vertex AI Multi-Modal Embedding model. + image_column: The column containing image data to be embedded. This data + is expected to be formatted as VertexImage objects, containing a Vertex + Image object. + video_column: The column containing video data to be embedded. This data + is expected to be formatted as VertexVideo objects, containing a Vertex + Video object an a VideoSegmentConfig object. + text_column: The column containing text data to be embedded. This data is + expected to be formatted as Chunk objects, containing the string to be + embedded in the Chunk's content field. + dimension: The length of the embedding vector to generate. Must be one of + 128, 256, 512, or 1408. If not set, Vertex AI's default value is 1408. + If submitting video content, dimension *musst* be 1408. + project: The default GCP project for API calls. + location: The default location for API calls. + credentials: Custom credentials for API calls. + Defaults to environment credentials. + """ + self.model_name = model_name + self.project = project + self.location = location + self.credentials = credentials + self.kwargs = kwargs + if dimension is not None and dimension not in (128, 256, 512, 1408): + raise ValueError( + "dimension argument must be one of 128, 256, 512, or 1408") + self.dimension = dimension + if not image_column and not video_column and not text_column: + raise ValueError("at least one input column must be specified") + if video_column is not None and dimension != 1408: + raise ValueError( + "Vertex AI does not support custom dimensions for video input, want dimension = 1408, got ", + dimension) + self.type_adapter = _create_multimodal_dict_adapter( + image_column=image_column, + video_column=video_column, + text_column=text_column) + super().__init__(type_adapter=self.type_adapter, **kwargs) + + def get_model_handler(self) -> ModelHandler: + return _VertexAIMultiModalEmbeddingHandler( + model_name=self.model_name, + dimension=self.dimension, + project=self.project, + location=self.location, + credentials=self.credentials, + **self.kwargs) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return RunInference( + model_handler=_MultiModalEmbeddingHandler(self), + inference_args=self.inference_args) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py index 1a47f81b665..ba43ea32508 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py @@ -26,10 +26,18 @@ from apache_beam.ml.transforms import base from apache_beam.ml.transforms.base import MLTransform try: + from apache_beam.ml.rag.types import Chunk + from apache_beam.ml.rag.types import Content + from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAIMultiModalEmbeddings from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAIImageEmbeddings + from apache_beam.ml.transforms.embeddings.vertex_ai import VertexImage + from apache_beam.ml.transforms.embeddings.vertex_ai import VertexVideo from vertexai.vision_models import Image + from vertexai.vision_models import Video + from vertexai.vision_models import VideoSegmentConfig except ImportError: + VertexAIMultiModalEmbeddings = None # type: ignore VertexAITextEmbeddings = None # type: ignore VertexAIImageEmbeddings = None # type: ignore @@ -286,5 +294,104 @@ class VertexAIImageEmbeddingsTest(unittest.TestCase): dimension=127) +image_feature_column: str = "img_feature" +text_feature_column: str = "txt_feature" +video_feature_column: str = "vid_feature" + + +def _make_text_chunk(input: str) -> Chunk: + return Chunk(content=Content(text=input)) + + +@unittest.skipIf( + VertexAIMultiModalEmbeddings is None, + 'Vertex AI Python SDK is not installed.') +class VertexAIMultiModalEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp( + prefix='_vertex_ai_multi_modal_test') + self.gcs_artifact_location = os.path.join( + 'gs://temp-storage-for-perf-tests/vertex_ai_multi_modal', + uuid.uuid4().hex) + self.model_name = "multimodalembedding" + self.image_path = "gs://apache-beam-ml/testing/inputs/vertex_images/sunflowers/1008566138_6927679c8a.jpg" # pylint: disable=line-too-long + self.video_path = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4" # pylint: disable=line-too-long + self.video_segment_config = VideoSegmentConfig(end_offset_sec=1) + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_vertex_ai_multimodal_embedding_img_and_text(self): + embedding_config = VertexAIMultiModalEmbeddings( + model_name=self.model_name, + image_column=image_feature_column, + text_column=text_feature_column, + dimension=128, + project="apache-beam-testing", + location="us-central1") + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline | "CreateData" >> beam.Create([{ + image_feature_column: VertexImage( + image_content=Image(gcs_uri=self.image_path)), + text_feature_column: _make_text_chunk("an image of sunflowers"), + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[image_feature_column].embedding) == 128 + assert len( + element[text_feature_column].embedding.dense_embedding) == 128 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + def test_vertex_ai_multimodal_embedding_video(self): + embedding_config = VertexAIMultiModalEmbeddings( + model_name=self.model_name, + video_column=video_feature_column, + dimension=1408, + project="apache-beam-testing", + location="us-central1") + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline | "CreateData" >> beam.Create([{ + video_feature_column: VertexVideo( + video_content=Video(gcs_uri=self.video_path), + config=self.video_segment_config) + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + # Videos are returned in VideoEmbedding objects, must unroll + # for each segment. + for segment in element[video_feature_column].embeddings: + assert len(segment.embedding) == 1408 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + def test_improper_dimension(self): + with self.assertRaises(ValueError): + _ = VertexAIMultiModalEmbeddings( + model_name=self.model_name, + image_column="fake_img_column", + dimension=127) + + def test_missing_columns(self): + with self.assertRaises(ValueError): + _ = VertexAIMultiModalEmbeddings( + model_name=self.model_name, dimension=128) + + def test_improper_video_dimension(self): + with self.assertRaises(ValueError): + _ = VertexAIMultiModalEmbeddings( + model_name=self.model_name, + video_column=video_feature_column, + dimension=128) + + if __name__ == '__main__': unittest.main()