This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e9c81de7453 Vertex AI Model Handler Private Endpoint Support (#27696)
e9c81de7453 is described below
commit e9c81de7453977f5162a41aeb8682aaef6714e55
Author: Jack McCluskey <[email protected]>
AuthorDate: Mon Jul 31 10:20:44 2023 -0400
Vertex AI Model Handler Private Endpoint Support (#27696)
* Vertex AI Model Handler Private Endpoint Support
* linting
* Trailing whitespace
* import order
* Extra context on experiments
* Trailing whitespace
---
.../inference/vertex_ai_image_classification.py | 19 ++++++-
.../ml/inference/vertex_ai_inference.py | 60 +++++++++++++++-------
.../ml/inference/vertex_ai_inference_test.py | 18 +++++--
3 files changed, 73 insertions(+), 24 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
index f8f9e803f6f..73126569e98 100644
---
a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
+++
b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
@@ -74,11 +74,24 @@ def parse_known_args(argv):
type=str,
required=True,
help='GCP location for the Endpoint')
+ parser.add_argument(
+ '--endpoint_network',
+ dest='vpc_network',
+ type=str,
+ required=False,
+ help='GCP network the endpoint is peered to')
parser.add_argument(
'--experiment',
dest='experiment',
+ type=str,
required=False,
- help='GCP experiment to pass to init')
+ help='Vertex AI experiment label to apply to queries')
+ parser.add_argument(
+ '--private',
+ dest='private',
+ type=bool,
+ default=False,
+ help="True if the Vertex AI endpoint is a private endpoint")
return parser.parse_known_args(argv)
@@ -130,7 +143,9 @@ def run(
endpoint_id=known_args.endpoint,
project=known_args.project,
location=known_args.location,
- experiment=known_args.experiment)
+ experiment=known_args.experiment,
+ network=known_args.vpc_network,
+ private=known_args.private)
pipeline = test_pipeline
if not test_pipeline:
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index a0e0d9d3f8f..3c3414923c8 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -23,7 +23,7 @@ from typing import Iterable
from typing import Optional
from typing import Sequence
-from google.api_core.exceptions import ClientError
+from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
from google.cloud import aiplatform
@@ -41,20 +41,21 @@ LOGGER = logging.getLogger("VertexAIModelHandlerJSON")
# pylint: disable=line-too-long
-def _retry_on_gcp_client_error(exception):
+def _retry_on_appropriate_gcp_error(exception):
"""
- Retry filter that returns True if a returned HTTP error code is 4xx. This is
- used to retry remote requests that fail, most notably 429 (TooManyRequests.)
- This is used for GCP-specific client errors.
+ Retry filter that returns True if a returned HTTP error code is 5xx or 429.
+ This is used to retry remote requests that fail, most notably 429
+ (TooManyRequests.)
Args:
exception: the returned exception encountered during the request/response
loop.
Returns:
- boolean indication whether or not the exception is a GCP ClientError.
+ boolean indication whether or not the exception is a Server Error (5xx) or
+ a TooManyRequests (429) error.
"""
- return isinstance(exception, ClientError)
+ return isinstance(exception, (TooManyRequests, ServerError))
class VertexAIModelHandlerJSON(ModelHandler[Any,
@@ -67,6 +68,7 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
location: str,
experiment: Optional[str] = None,
network: Optional[str] = None,
+ private: bool = False,
**kwargs):
"""Implementation of the ModelHandler interface for Vertex AI.
**NOTE:** This API and its implementation are under development and
@@ -76,21 +78,33 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
Vertex AI endpoint. In that way it functions more like a mid-pipeline
IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB.
If you wish to make larger requests and use a private endpoint, provide
- the Compute Engine network you wish to use.
+ the Compute Engine network you wish to use and set `private=True`
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query
project: the GCP project name where the endpoint is deployed
location: the GCP location where the endpoint is deployed
experiment: optional. experiment label to apply to the
- queries
+ queries. See
+
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
+ for more information.
network: optional. the full name of the Compute Engine
network the endpoint is deployed on; used for private
- endpoints only.
+ endpoints. The network or subnetwork Dataflow pipeline
+ option must be set and match this network for pipeline
+ execution.
Ex: "projects/12345/global/networks/myVPC"
+ private: optional. if the deployed Vertex AI endpoint is
+ private, set to true. Requires a network to be provided
+ as well.
"""
self._env_vars = kwargs.get('env_vars', {})
+
+ if private and network is None:
+ raise ValueError(
+ "A VPC network must be provided to use a private endpoint.")
+
# TODO: support the full list of options for aiplatform.init()
# See
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init
aiplatform.init(
@@ -102,7 +116,9 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
# Check for liveness here but don't try to actually store the endpoint
# in the class yet
self.endpoint_name = endpoint_id
- _ = self._retrieve_endpoint(self.endpoint_name)
+ self.is_private = private
+
+ _ = self._retrieve_endpoint(self.endpoint_name, self.is_private)
# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
@@ -113,18 +129,27 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)
- def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
+ def _retrieve_endpoint(
+ self, endpoint_id: str, is_private: bool) -> aiplatform.Endpoint:
"""Retrieves an AI Platform endpoint and queries it for liveness/deployed
models.
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+ is_private: a boolean indicating if the Vertex AI endpoint is a private
+ endpoint
Returns:
An aiplatform.Endpoint object
Raises:
ValueError: if endpoint is inactive or has no models deployed to it.
"""
- endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
+ if is_private:
+ endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+ endpoint_name=endpoint_id)
+ LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+ else:
+ endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
+ LOGGER.debug("Treating endpoint %s as public", endpoint_id)
try:
mod_list = endpoint.list_models()
@@ -133,7 +158,7 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
"Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
if len(mod_list) == 0:
- raise ValueError("Endpoint %s has no models deployed to it.")
+ raise ValueError("Endpoint %s has no models deployed to it.",
endpoint_id)
return endpoint
@@ -143,11 +168,11 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
"""
# Check to make sure the endpoint is still active since pipeline
# construction time
- ep = self._retrieve_endpoint(self.endpoint_name)
+ ep = self._retrieve_endpoint(self.endpoint_name, self.is_private)
return ep
@retry.with_exponential_backoff(
- num_retries=5, retry_filter=_retry_on_gcp_client_error)
+ num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
batch: Sequence[Any],
@@ -170,9 +195,6 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
- except ClientError as e:
- LOGGER.warning("request failed with error code %i", e.code)
- raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
index 36f367e6d77..34c7927272d 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
@@ -19,7 +19,8 @@
import unittest
try:
- from apache_beam.ml.inference.vertex_ai_inference import
_retry_on_gcp_client_error
+ from apache_beam.ml.inference.vertex_ai_inference import
_retry_on_appropriate_gcp_error
+ from apache_beam.ml.inference.vertex_ai_inference import
VertexAIModelHandlerJSON
from google.api_core.exceptions import TooManyRequests
except ImportError:
raise unittest.SkipTest('VertexAI dependencies are not installed')
@@ -28,11 +29,22 @@ except ImportError:
class RetryOnClientErrorTest(unittest.TestCase):
def test_retry_on_client_error_positive(self):
e = TooManyRequests(message="fake service rate limiting")
- self.assertTrue(_retry_on_gcp_client_error(e))
+ self.assertTrue(_retry_on_appropriate_gcp_error(e))
def test_retry_on_client_error_negative(self):
e = ValueError()
- self.assertFalse(_retry_on_gcp_client_error(e))
+ self.assertFalse(_retry_on_appropriate_gcp_error(e))
+
+
+class ModelHandlerArgConditions(unittest.TestCase):
+ def test_exception_on_private_without_network(self):
+ self.assertRaises(
+ ValueError,
+ VertexAIModelHandlerJSON,
+ endpoint_id="1",
+ project="testproject",
+ location="us-central1",
+ private=True)
if __name__ == '__main__':