damccorm commented on code in PR #35677:
URL: https://github.com/apache/beam/pull/35677#discussion_r2330932442


##########
sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py:
##########
@@ -281,3 +297,222 @@ def get_ptransform_for_processing(self, **kwargs) -> 
beam.PTransform:
     return RunInference(
         model_handler=_ImageEmbeddingHandler(self),
         inference_args=self.inference_args)
+
+
+@dataclass
+class VertexImage:
+  image_content: Image
+  embedding: Optional[list[float]] = None
+
+
+@dataclass
+class VertexVideo:
+  video_content: Video
+  config: VideoSegmentConfig
+  embeddings: Optional[list[VideoEmbedding]] = None
+
+
+@dataclass
+class VertexAIMultiModalInput:
+  image: Optional[VertexImage] = None
+  video: Optional[VertexVideo] = None
+  contextual_text: Optional[Chunk] = None
+
+
+class _VertexAIMultiModalEmbeddingHandler(RemoteModelHandler):
+  def __init__(
+      self,
+      model_name: str,
+      dimension: Optional[int] = None,
+      project: Optional[str] = None,
+      location: Optional[str] = None,
+      credentials: Optional[Credentials] = None,
+      **kwargs):
+    vertexai.init(project=project, location=location, credentials=credentials)
+    self.model_name = model_name
+    self.dimension = dimension
+
+    super().__init__(
+        namespace='VertexAIMultiModelEmbeddingHandler',
+        retry_filter=_retry_on_appropriate_gcp_error,
+        **kwargs)
+
+  def request(
+      self,
+      batch: Sequence[VertexAIMultiModalInput],
+      model: MultiModalEmbeddingModel,
+      inference_args: Optional[dict[str, Any]] = None):
+    embeddings = []
+    # Max request size for multi-modal embedding models is 1
+    for input in batch:
+      image_content: Optional[Image] = None
+      video_content: Optional[Video] = None
+      text_content: Optional[str] = None
+      video_config: Optional[VideoSegmentConfig] = None
+
+      if input.image:
+        image_content = input.image.image_content
+      if input.video:
+        video_content = input.video.video_content
+        video_config = input.video.config
+      if input.contextual_text:
+        text_content = input.contextual_text.content.text
+
+      prediction = model.get_embeddings(
+          image=image_content,
+          video=video_content,
+          contextual_text=text_content,
+          dimension=self.dimension,
+          video_segment_config=video_config)
+      embeddings.append(prediction)
+    return embeddings
+
+  def create_client(self) -> MultiModalEmbeddingModel:
+    model = MultiModalEmbeddingModel.from_pretrained(self.model_name)
+    return model
+
+  def __repr__(self):
+    # ModelHandler is internal to the user and is not exposed.
+    # Hence we need to override the __repr__ method to expose
+    # the name of the class.
+    return 'VertexAIMultiModalEmbeddings'
+
+
+def _multimodal_dict_input_fn(
+    image_column: Optional[str],
+    video_column: Optional[str],
+    text_column: Optional[str],
+    batch: Sequence[dict[str, Any]]) -> list[VertexAIMultiModalInput]:
+  multimodal_inputs: list[VertexAIMultiModalInput] = []
+  for item in batch:
+    img: Optional[VertexImage] = None
+    vid: Optional[VertexVideo] = None
+    text: Optional[Chunk] = None
+    if image_column:
+      img = item[image_column]
+    if video_column:
+      vid = item[video_column]
+    if text_column:
+      text = item[text_column]
+    multimodal_inputs.append(
+        VertexAIMultiModalInput(image=img, video=vid, contextual_text=text))
+  return multimodal_inputs
+
+
+def _multimodal_dict_output_fn(
+    image_column: Optional[str],
+    video_column: Optional[str],
+    text_column: Optional[str],
+    batch: Sequence[dict[str, Any]],
+    embeddings: Sequence[MultiModalEmbeddingResponse]) -> list[dict[str, Any]]:

Review Comment:
   Yeah, you're right. Ok, thanks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to