This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/flexAPI
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 896c642ae8143b95ace42e99c7df0b925d175ea2
Author: Danny Mccormick <[email protected]>
AuthorDate: Tue Feb 3 13:33:30 2026 -0500

    Reapply "Support Vertex Flex API in GeminiModelHandler (#36982)" (#37051)
    
    This reverts commit 585ad4181f955f6e349497806a22a2ab736013e3.
---
 .../apache_beam/ml/inference/gemini_inference.py   | 20 +++++++++++++++--
 .../ml/inference/gemini_inference_test.py          | 25 ++++++++++++++++++++++
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py 
b/sdks/python/apache_beam/ml/inference/gemini_inference.py
index a79fbe8a555..c0edfa16761 100644
--- a/sdks/python/apache_beam/ml/inference/gemini_inference.py
+++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py
@@ -25,6 +25,7 @@ from typing import Union
 
 from google import genai
 from google.genai import errors
+from google.genai.types import HttpOptions
 from google.genai.types import Part
 from PIL.Image import Image
 
@@ -108,6 +109,7 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       api_key: Optional[str] = None,
       project: Optional[str] = None,
       location: Optional[str] = None,
+      use_vertex_flex_api: Optional[bool] = False,
       *,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
@@ -139,6 +141,7 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       location: the GCP project to use for Vertex AI requests. Setting this
         parameter routes requests to Vertex AI. If this paramter is provided,
         project must also be provided and api_key should not be set.
+      use_vertex_flex_api: if true, use the Vertex Flex API.
       min_batch_size: optional. the minimum batch size to use when batching
         inputs.
       max_batch_size: optional. the maximum batch size to use when batching
@@ -178,6 +181,8 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       self.location = location
       self.use_vertex = True
 
+    self.use_vertex_flex_api = use_vertex_flex_api
+
     super().__init__(
         namespace='GeminiModelHandler',
         retry_filter=_retry_on_appropriate_service_error,
@@ -192,8 +197,19 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
     provided when the GeminiModelHandler class is instantiated.
     """
     if self.use_vertex:
-      return genai.Client(
-          vertexai=True, project=self.project, location=self.location)
+      if self.use_vertex_flex_api:
+        return genai.Client(
+            vertexai=True,
+            project=self.project,
+            location=self.location,
+            http_options=HttpOptions(
+                api_version="v1",
+                headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
+                # Set timeout in the unit of millisecond.
+                timeout=600000))
+      else:
+        return genai.Client(
+            vertexai=True, project=self.project, location=self.location)
     return genai.Client(api_key=self.api_key)
 
   def request(
diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py 
b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py
index cb73c7de13f..012287e98f3 100644
--- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py
@@ -17,6 +17,7 @@
 # pytype: skip-file
 
 import unittest
+from unittest import mock
 
 try:
   from google.genai import errors
@@ -81,5 +82,29 @@ class ModelHandlerArgConditions(unittest.TestCase):
     )
 
 
[email protected]('apache_beam.ml.inference.gemini_inference.genai.Client')
[email protected]('apache_beam.ml.inference.gemini_inference.HttpOptions')
+class TestGeminiModelHandler(unittest.TestCase):
+  def test_create_client_with_flex_api(
+      self, mock_http_options, mock_genai_client):
+    handler = GeminiModelHandler(
+        model_name="gemini-pro",
+        request_fn=generate_from_string,
+        project="test-project",
+        location="us-central1",
+        use_vertex_flex_api=True)
+    handler.create_client()
+    mock_http_options.assert_called_with(
+        api_version="v1",
+        headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
+        timeout=600000,
+    )
+    mock_genai_client.assert_called_with(
+        vertexai=True,
+        project="test-project",
+        location="us-central1",
+        http_options=mock_http_options.return_value)
+
+
 if __name__ == '__main__':
   unittest.main()

Reply via email to