This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 697d2992aeb feat: Add support for custom prediction routes in Vertex
AI inference (#37155)
697d2992aeb is described below
commit 697d2992aeb975b871107a62e2b196a08d541a04
Author: liferoad <[email protected]>
AuthorDate: Tue Jan 13 10:09:28 2026 -0500
feat: Add support for custom prediction routes in Vertex AI inference
(#37155)
* feat: Add support for custom prediction routes in Vertex AI inference
using the `invoke_route` parameter and custom response parsing.
* lint
* lint 2
* fix: ensure invoke response is bytes and add type hint for request_body
* test: mock `aiplatform.init` in VertexAI inference tests to prevent
global state pollution.
* lint
* added the IT
* added license
* lint
* updated endpoint
* trigger postcommit
* lint
* lint
* lint
* fixed the response
---
.github/trigger_files/beam_PostCommit_Python.json | 5 +-
.../vertex_ai_custom_prediction/Dockerfile | 21 +++++
.../vertex_ai_custom_prediction/README.md | 103 +++++++++++++++++++++
.../vertex_ai_custom_prediction/echo_server.py | 43 +++++++++
.../ml/inference/vertex_ai_inference.py | 72 +++++++++++++-
.../ml/inference/vertex_ai_inference_it_test.py | 47 ++++++++++
.../ml/inference/vertex_ai_inference_test.py | 70 ++++++++++++++
sdks/python/apache_beam/yaml/yaml_ml.py | 9 ++
8 files changed, 364 insertions(+), 6 deletions(-)
diff --git a/.github/trigger_files/beam_PostCommit_Python.json
b/.github/trigger_files/beam_PostCommit_Python.json
index 47e479f18a9..e43868bf4f2 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,6 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run.",
"pr": "36271",
- "modification": 36
-}
-
+ "modification": 37
+}
\ No newline at end of file
diff --git
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
new file mode 100644
index 00000000000..a62b9edd406
--- /dev/null
+++
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM python:3.10-slim
+WORKDIR /app
+RUN pip install flask gunicorn
+COPY echo_server.py main.py
+CMD ["gunicorn", "--bind", "0.0.0.0:8080", "main:app"]
diff --git
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
new file mode 100644
index 00000000000..834a27be7f7
--- /dev/null
+++
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
@@ -0,0 +1,103 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Vertex AI Custom Prediction Route Test Setup
+
+To run the `test_vertex_ai_custom_prediction_route` in
[vertex_ai_inference_it_test.py](../../vertex_ai_inference_it_test.py), you
need a dedicated Vertex AI endpoint with an invoke-enabled model deployed.
+
+## Resource Setup Steps
+
+Run these commands in the `apache-beam-testing` project (or your own test
project).
+
+### 1. Build and Push Container
+
+From this directory:
+
+```bash
+# on Linux
+export PROJECT_ID="apache-beam-testing" # Or your project
+export IMAGE_URI="gcr.io/${PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"
+
+docker build -t ${IMAGE_URI} .
+docker push ${IMAGE_URI}
+```
+
+### 2. Upload Model and Deploy Endpoint
+
+Use the Python SDK to deploy (easier than gcloud for specific invocation
flags).
+
+```python
+from google.cloud import aiplatform
+
+PROJECT_ID = "apache-beam-testing"
+REGION = "us-central1"
+IMAGE_URI = f"gcr.io/{PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"
+
+aiplatform.init(project=PROJECT_ID, location=REGION)
+
+# 1. Upload Model with invoke route enabled
+model = aiplatform.Model.upload(
+ display_name="beam-invoke-echo-model",
+ serving_container_image_uri=IMAGE_URI,
+ serving_container_invoke_route_prefix="/*", # <--- Critical for custom
routes
+ serving_container_health_route="/health",
+ sync=True,
+)
+
+# 2. Create Dedicated Endpoint (required for invoke)
+endpoint = aiplatform.Endpoint.create(
+ display_name="beam-invoke-test-endpoint",
+ dedicated_endpoint_enabled=True,
+ sync=True,
+)
+
+# 3. Deploy Model
+# NOTE: Set min_replica_count=0 to save costs when not testing
+endpoint.deploy(
+ model=model,
+ traffic_percentage=100,
+ machine_type="n1-standard-2",
+ min_replica_count=0,
+ max_replica_count=1,
+ sync=True,
+)
+
+print(f"Deployment Complete!")
+print(f"Endpoint ID: {endpoint.name}")
+```
+
+### 3. Update Test Configuration
+
+1. Copy the **Endpoint ID** printed above (e.g., `1234567890`).
+2. Update `_INVOKE_ENDPOINT_ID` in
`apache_beam/ml/inference/vertex_ai_inference_it_test.py`.
+
+## Cleanup
+
+To avoid costs, undeploy and delete resources when finished:
+
+```bash
+# Undeploy model from endpoint
+gcloud ai endpoints undeploy-model <ENDPOINT_ID> --deployed-model-id
<DEPLOYED_MODEL_ID> --region=us-central1
+
+# Delete endpoint
+gcloud ai endpoints delete <ENDPOINT_ID> --region=us-central1
+
+# Delete model
+gcloud ai models delete <MODEL_ID> --region=us-central1
+```
diff --git
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
new file mode 100644
index 00000000000..6e48e62a2a7
--- /dev/null
+++
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask import Flask
+from flask import jsonify
+from flask import request
+
+app = Flask(__name__)
+
+
[email protected]('/predict', methods=['POST'])
+def predict():
+ data = request.get_json()
+ # Echo back the instances
+ return jsonify({
+ "predictions": [{
+ "echo": inst
+ } for inst in data.get('instances', [])],
+ "deployedModelId": "echo-model"
+ })
+
+
[email protected]('/health', methods=['GET'])
+def health():
+ return 'OK', 200
+
+
+if __name__ == '__main__':
+ app.run(host='0.0.0.0', port=8080)
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index 9858b59039c..cd3d0beb593 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import json
import logging
from collections.abc import Iterable
from collections.abc import Mapping
@@ -63,6 +64,7 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
+ invoke_route: Optional[str] = None,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
@@ -95,6 +97,12 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
private: optional. if the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
+ invoke_route: optional. the custom route path to use when invoking
+ endpoints with arbitrary prediction routes. When specified, uses
+ `Endpoint.invoke()` instead of `Endpoint.predict()`. The route
+ should start with a forward slash, e.g., "/predict/v1".
+ See
https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
+ for more information.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
@@ -104,6 +112,7 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
+ self._invoke_route = invoke_route
if min_batch_size is not None:
self._batching_kwargs["min_batch_size"] = min_batch_size
if max_batch_size is not None:
@@ -203,9 +212,66 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
Returns:
An iterable of Predictions.
"""
- prediction = model.predict(instances=list(batch),
parameters=inference_args)
- return utils._convert_to_result(
- batch, prediction.predictions, prediction.deployed_model_id)
+ if self._invoke_route:
+ # Use invoke() for endpoints with custom prediction routes
+ request_body: dict[str, Any] = {"instances": list(batch)}
+ if inference_args:
+ request_body["parameters"] = inference_args
+ response = model.invoke(
+ request_path=self._invoke_route,
+ body=json.dumps(request_body).encode("utf-8"),
+ headers={"Content-Type": "application/json"})
+ if hasattr(response, "content"):
+ return self._parse_invoke_response(batch, response.content)
+ return self._parse_invoke_response(batch, bytes(response))
+ else:
+ prediction = model.predict(
+ instances=list(batch), parameters=inference_args)
+ return utils._convert_to_result(
+ batch, prediction.predictions, prediction.deployed_model_id)
+
+ def _parse_invoke_response(self, batch: Sequence[Any],
+ response: bytes) -> Iterable[PredictionResult]:
+ """Parses the response from Endpoint.invoke() into PredictionResults.
+
+ Args:
+ batch: the original batch of inputs.
+ response: the raw bytes response from invoke().
+
+ Returns:
+ An iterable of PredictionResults.
+ """
+ try:
+ response_json = json.loads(response.decode("utf-8"))
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
+ LOGGER.warning(
+ "Failed to decode invoke response as JSON, returning raw bytes: %s",
+ e)
+ # Return raw response for each batch item
+ return [
+ PredictionResult(example=example, inference=response)
+ for example in batch
+ ]
+
+ # Handle standard Vertex AI response format with "predictions" key
+ if isinstance(response_json, dict) and "predictions" in response_json:
+ predictions = response_json["predictions"]
+ model_id = response_json.get("deployedModelId")
+ return utils._convert_to_result(batch, predictions, model_id)
+
+ # Handle response as a list of predictions (one per input)
+ if isinstance(response_json, list) and len(response_json) == len(batch):
+ return utils._convert_to_result(batch, response_json, None)
+
+ # Handle single prediction response
+ if len(batch) == 1:
+ return [PredictionResult(example=batch[0], inference=response_json)]
+
+ # Fallback: return the full response for each batch item
+ return [
+ PredictionResult(example=example, inference=response_json)
+ for example in batch
+ ]
def batch_elements_kwargs(self) -> Mapping[str, Any]:
return self._batching_kwargs
diff --git
a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
index c6d62eb3e3e..11643992c39 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
@@ -23,12 +23,15 @@ import uuid
import pytest
+import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import RunInference
from apache_beam.testing.test_pipeline import TestPipeline
# pylint: disable=ungrouped-imports
try:
from apache_beam.examples.inference import vertex_ai_image_classification
+ from apache_beam.ml.inference.vertex_ai_inference import
VertexAIModelHandlerJSON
except ImportError as e:
raise unittest.SkipTest(
"Vertex AI model handler dependencies are not installed")
@@ -42,6 +45,13 @@ _ENDPOINT_NETWORK =
"projects/844138762903/global/networks/beam-test-vpc"
# pylint: disable=line-too-long
_SUBNETWORK =
"https://www.googleapis.com/compute/v1/projects/apache-beam-testing/regions/us-central1/subnetworks/beam-test-vpc"
+# Constants for custom prediction routes (invoke) test
+# Follow
beam/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
+# to get endpoint ID after deploying invoke-enabled model
+_INVOKE_ENDPOINT_ID = "6890840581900075008"
+_INVOKE_ROUTE = "/predict"
+_INVOKE_OUTPUT_DIR = "gs://apache-beam-ml/testing/outputs/vertex_invoke"
+
class VertexAIInference(unittest.TestCase):
@pytest.mark.vertex_ai_postcommit
@@ -63,6 +73,43 @@ class VertexAIInference(unittest.TestCase):
test_pipeline.get_full_options_as_args(**extra_opts))
self.assertEqual(FileSystems().exists(output_file), True)
+ @pytest.mark.vertex_ai_postcommit
+ @unittest.skipIf(
+ not _INVOKE_ENDPOINT_ID,
+ "Invoke endpoint not configured. Set _INVOKE_ENDPOINT_ID.")
+ def test_vertex_ai_custom_prediction_route(self):
+ """Test custom prediction routes using invoke_route parameter.
+
+ This test verifies that VertexAIModelHandlerJSON correctly uses
+ Endpoint.invoke() instead of Endpoint.predict() when invoke_route
+ is specified, enabling custom prediction routes.
+ """
+ output_file = '/'.join(
+ [_INVOKE_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt'])
+
+ test_pipeline = TestPipeline(is_integration_test=True)
+
+ model_handler = VertexAIModelHandlerJSON(
+ endpoint_id=_INVOKE_ENDPOINT_ID,
+ project=_ENDPOINT_PROJECT,
+ location=_ENDPOINT_REGION,
+ invoke_route=_INVOKE_ROUTE)
+
+ # Test inputs - simple data to echo back
+ test_inputs = [{"value": 1}, {"value": 2}, {"value": 3}]
+
+ with test_pipeline as p:
+ results = (
+ p
+ | "CreateInputs" >> beam.Create(test_inputs)
+ | "RunInference" >> RunInference(model_handler)
+ | "ExtractResults" >>
+ beam.Map(lambda result: f"{result.example}:{result.inference}"))
+ _ = results | "WriteOutput" >> beam.io.WriteToText(
+ output_file, shard_name_template='')
+
+ self.assertTrue(FileSystems().exists(output_file))
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
index 91a3b82cf76..8aa638ebe7c 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
@@ -48,5 +48,75 @@ class ModelHandlerArgConditions(unittest.TestCase):
private=True)
+class ParseInvokeResponseTest(unittest.TestCase):
+ """Tests for _parse_invoke_response method."""
+ def _create_handler_with_invoke_route(self, invoke_route="/test"):
+ """Creates a mock handler with invoke_route for testing."""
+ import unittest.mock as mock
+
+ # Mock both _retrieve_endpoint and aiplatform.init to prevent test
+ # pollution of global aiplatform state
+ with mock.patch.object(VertexAIModelHandlerJSON,
+ '_retrieve_endpoint',
+ return_value=None):
+ with mock.patch('google.cloud.aiplatform.init'):
+ handler = VertexAIModelHandlerJSON(
+ endpoint_id="1",
+ project="testproject",
+ location="us-central1",
+ invoke_route=invoke_route)
+ return handler
+
+ def test_parse_invoke_response_with_predictions_key(self):
+ """Test parsing response with standard 'predictions' key."""
+ handler = self._create_handler_with_invoke_route()
+ batch = [{"input": "test1"}, {"input": "test2"}]
+ response = (
+ b'{"predictions": ["result1", "result2"], '
+ b'"deployedModelId": "model123"}')
+
+ results = list(handler._parse_invoke_response(batch, response))
+
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0].example, {"input": "test1"})
+ self.assertEqual(results[0].inference, "result1")
+ self.assertEqual(results[1].example, {"input": "test2"})
+ self.assertEqual(results[1].inference, "result2")
+
+ def test_parse_invoke_response_list_format(self):
+ """Test parsing response as a list of predictions."""
+ handler = self._create_handler_with_invoke_route()
+ batch = [{"input": "test1"}, {"input": "test2"}]
+ response = b'["result1", "result2"]'
+
+ results = list(handler._parse_invoke_response(batch, response))
+
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0].inference, "result1")
+ self.assertEqual(results[1].inference, "result2")
+
+ def test_parse_invoke_response_single_prediction(self):
+ """Test parsing response with a single prediction."""
+ handler = self._create_handler_with_invoke_route()
+ batch = [{"input": "test1"}]
+ response = b'{"output": "single result"}'
+
+ results = list(handler._parse_invoke_response(batch, response))
+
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0].inference, {"output": "single result"})
+
+ def test_parse_invoke_response_non_json(self):
+ """Test handling non-JSON response."""
+ handler = self._create_handler_with_invoke_route()
+ batch = [{"input": "test1"}]
+ response = b'not valid json'
+
+ results = list(handler._parse_invoke_response(batch, response))
+
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0].inference, response)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py
b/sdks/python/apache_beam/yaml/yaml_ml.py
index e5a88f54eba..4e750b79ce3 100644
--- a/sdks/python/apache_beam/yaml/yaml_ml.py
+++ b/sdks/python/apache_beam/yaml/yaml_ml.py
@@ -168,6 +168,7 @@ class
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
+ invoke_route: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None):
@@ -236,6 +237,13 @@ class
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
private: If the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
+ invoke_route: The custom route path to use when invoking
+ endpoints with arbitrary prediction routes. When specified, uses
+ `Endpoint.invoke()` instead of `Endpoint.predict()`. The route
+ should start with a forward slash, e.g., "/predict/v1".
+ See
+
https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
+ for more information.
min_batch_size: The minimum batch size to use when batching
inputs.
max_batch_size: The maximum batch size to use when batching
@@ -258,6 +266,7 @@ class
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
experiment=experiment,
network=network,
private=private,
+ invoke_route=invoke_route,
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs)