This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/flexAPI in repository https://gitbox.apache.org/repos/asf/beam.git
commit 896c642ae8143b95ace42e99c7df0b925d175ea2 Author: Danny Mccormick <[email protected]> AuthorDate: Tue Feb 3 13:33:30 2026 -0500 Reapply "Support Vertex Flex API in GeminiModelHandler (#36982)" (#37051) This reverts commit 585ad4181f955f6e349497806a22a2ab736013e3. --- .../apache_beam/ml/inference/gemini_inference.py | 20 +++++++++++++++-- .../ml/inference/gemini_inference_test.py | 25 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index a79fbe8a555..c0edfa16761 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -25,6 +25,7 @@ from typing import Union from google import genai from google.genai import errors +from google.genai.types import HttpOptions from google.genai.types import Part from PIL.Image import Image @@ -108,6 +109,7 @@ class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult, api_key: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, + use_vertex_flex_api: Optional[bool] = False, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -139,6 +141,7 @@ class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult, location: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, project must also be provided and api_key should not be set. + use_vertex_flex_api: if true, use the Vertex Flex API. min_batch_size: optional. the minimum batch size to use when batching inputs. max_batch_size: optional. the maximum batch size to use when batching @@ -178,6 +181,8 @@ class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult, self.location = location self.use_vertex = True + self.use_vertex_flex_api = use_vertex_flex_api + super().__init__( namespace='GeminiModelHandler', retry_filter=_retry_on_appropriate_service_error, @@ -192,8 +197,19 @@ class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult, provided when the GeminiModelHandler class is instantiated. """ if self.use_vertex: - return genai.Client( - vertexai=True, project=self.project, location=self.location) + if self.use_vertex_flex_api: + return genai.Client( + vertexai=True, + project=self.project, + location=self.location, + http_options=HttpOptions( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + # Set timeout in the unit of millisecond. + timeout=600000)) + else: + return genai.Client( + vertexai=True, project=self.project, location=self.location) return genai.Client(api_key=self.api_key) def request( diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index cb73c7de13f..012287e98f3 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,6 +17,7 @@ # pytype: skip-file import unittest +from unittest import mock try: from google.genai import errors @@ -81,5 +82,29 @@ class ModelHandlerArgConditions(unittest.TestCase): ) [email protected]('apache_beam.ml.inference.gemini_inference.genai.Client') [email protected]('apache_beam.ml.inference.gemini_inference.HttpOptions') +class TestGeminiModelHandler(unittest.TestCase): + def test_create_client_with_flex_api( + self, mock_http_options, mock_genai_client): + handler = GeminiModelHandler( + model_name="gemini-pro", + request_fn=generate_from_string, + project="test-project", + location="us-central1", + use_vertex_flex_api=True) + handler.create_client() + mock_http_options.assert_called_with( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + timeout=600000, + ) + mock_genai_client.assert_called_with( + vertexai=True, + project="test-project", + location="us-central1", + http_options=mock_http_options.return_value) + + if __name__ == '__main__': unittest.main()
