This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e9c81de7453 Vertex AI Model Handler Private Endpoint Support (#27696)
e9c81de7453 is described below

commit e9c81de7453977f5162a41aeb8682aaef6714e55
Author: Jack McCluskey <[email protected]>
AuthorDate: Mon Jul 31 10:20:44 2023 -0400

    Vertex AI Model Handler Private Endpoint Support (#27696)
    
    * Vertex AI Model Handler Private Endpoint Support
    
    * linting
    
    * Trailing whitespace
    
    * import order
    
    * Extra context on experiments
    
    * Trailing whitespace
---
 .../inference/vertex_ai_image_classification.py    | 19 ++++++-
 .../ml/inference/vertex_ai_inference.py            | 60 +++++++++++++++-------
 .../ml/inference/vertex_ai_inference_test.py       | 18 +++++--
 3 files changed, 73 insertions(+), 24 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py 
b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
index f8f9e803f6f..73126569e98 100644
--- 
a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
+++ 
b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py
@@ -74,11 +74,24 @@ def parse_known_args(argv):
       type=str,
       required=True,
       help='GCP location for the Endpoint')
+  parser.add_argument(
+      '--endpoint_network',
+      dest='vpc_network',
+      type=str,
+      required=False,
+      help='GCP network the endpoint is peered to')
   parser.add_argument(
       '--experiment',
       dest='experiment',
+      type=str,
       required=False,
-      help='GCP experiment to pass to init')
+      help='Vertex AI experiment label to apply to queries')
+  parser.add_argument(
+      '--private',
+      dest='private',
+      type=bool,
+      default=False,
+      help="True if the Vertex AI endpoint is a private endpoint")
   return parser.parse_known_args(argv)
 
 
@@ -130,7 +143,9 @@ def run(
       endpoint_id=known_args.endpoint,
       project=known_args.project,
       location=known_args.location,
-      experiment=known_args.experiment)
+      experiment=known_args.experiment,
+      network=known_args.vpc_network,
+      private=known_args.private)
 
   pipeline = test_pipeline
   if not test_pipeline:
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index a0e0d9d3f8f..3c3414923c8 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -23,7 +23,7 @@ from typing import Iterable
 from typing import Optional
 from typing import Sequence
 
-from google.api_core.exceptions import ClientError
+from google.api_core.exceptions import ServerError
 from google.api_core.exceptions import TooManyRequests
 from google.cloud import aiplatform
 
@@ -41,20 +41,21 @@ LOGGER = logging.getLogger("VertexAIModelHandlerJSON")
 # pylint: disable=line-too-long
 
 
-def _retry_on_gcp_client_error(exception):
+def _retry_on_appropriate_gcp_error(exception):
   """
-  Retry filter that returns True if a returned HTTP error code is 4xx. This is
-  used to retry remote requests that fail, most notably 429 (TooManyRequests.)
-  This is used for GCP-specific client errors.
+  Retry filter that returns True if a returned HTTP error code is 5xx or 429.
+  This is used to retry remote requests that fail, most notably 429
+  (TooManyRequests.)
 
   Args:
     exception: the returned exception encountered during the request/response
       loop.
 
   Returns:
-    boolean indication whether or not the exception is a GCP ClientError.
+    boolean indication whether or not the exception is a Server Error (5xx) or
+      a TooManyRequests (429) error.
   """
-  return isinstance(exception, ClientError)
+  return isinstance(exception, (TooManyRequests, ServerError))
 
 
 class VertexAIModelHandlerJSON(ModelHandler[Any,
@@ -67,6 +68,7 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
       location: str,
       experiment: Optional[str] = None,
       network: Optional[str] = None,
+      private: bool = False,
       **kwargs):
     """Implementation of the ModelHandler interface for Vertex AI.
     **NOTE:** This API and its implementation are under development and
@@ -76,21 +78,33 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
     Vertex AI endpoint. In that way it functions more like a mid-pipeline
     IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB.
     If you wish to make larger requests and use a private endpoint, provide
-    the Compute Engine network you wish to use.
+    the Compute Engine network you wish to use and set `private=True`
 
     Args:
       endpoint_id: the numerical ID of the Vertex AI endpoint to query
       project: the GCP project name where the endpoint is deployed
       location: the GCP location where the endpoint is deployed
       experiment: optional. experiment label to apply to the
-        queries
+        queries. See
+        
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
+        for more information.
       network: optional. the full name of the Compute Engine
         network the endpoint is deployed on; used for private
-        endpoints only.
+        endpoints. The network or subnetwork Dataflow pipeline
+        option must be set and match this network for pipeline
+        execution.
         Ex: "projects/12345/global/networks/myVPC"
+      private: optional. if the deployed Vertex AI endpoint is
+        private, set to true. Requires a network to be provided
+        as well.
     """
 
     self._env_vars = kwargs.get('env_vars', {})
+
+    if private and network is None:
+      raise ValueError(
+          "A VPC network must be provided to use a private endpoint.")
+
     # TODO: support the full list of options for aiplatform.init()
     # See 
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init
     aiplatform.init(
@@ -102,7 +116,9 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
     # Check for liveness here but don't try to actually store the endpoint
     # in the class yet
     self.endpoint_name = endpoint_id
-    _ = self._retrieve_endpoint(self.endpoint_name)
+    self.is_private = private
+
+    _ = self._retrieve_endpoint(self.endpoint_name, self.is_private)
 
     # Configure AdaptiveThrottler and throttling metrics for client-side
     # throttling behavior.
@@ -113,18 +129,27 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
     self.throttler = AdaptiveThrottler(
         window_ms=1, bucket_ms=1, overload_ratio=2)
 
-  def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
+  def _retrieve_endpoint(
+      self, endpoint_id: str, is_private: bool) -> aiplatform.Endpoint:
     """Retrieves an AI Platform endpoint and queries it for liveness/deployed
     models.
 
     Args:
       endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+      is_private: a boolean indicating if the Vertex AI endpoint is a private
+        endpoint
     Returns:
       An aiplatform.Endpoint object
     Raises:
       ValueError: if endpoint is inactive or has no models deployed to it.
     """
-    endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
+    if is_private:
+      endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+          endpoint_name=endpoint_id)
+      LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+    else:
+      endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
+      LOGGER.debug("Treating endpoint %s as public", endpoint_id)
 
     try:
       mod_list = endpoint.list_models()
@@ -133,7 +158,7 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
           "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
 
     if len(mod_list) == 0:
-      raise ValueError("Endpoint %s has no models deployed to it.")
+      raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
 
     return endpoint
 
@@ -143,11 +168,11 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
     """
     # Check to make sure the endpoint is still active since pipeline
     # construction time
-    ep = self._retrieve_endpoint(self.endpoint_name)
+    ep = self._retrieve_endpoint(self.endpoint_name, self.is_private)
     return ep
 
   @retry.with_exponential_backoff(
-      num_retries=5, retry_filter=_retry_on_gcp_client_error)
+      num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
   def get_request(
       self,
       batch: Sequence[Any],
@@ -170,9 +195,6 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
     except TooManyRequests as e:
       LOGGER.warning("request was limited by the service with code %i", e.code)
       raise
-    except ClientError as e:
-      LOGGER.warning("request failed with error code %i", e.code)
-      raise
     except Exception as e:
       LOGGER.error("unexpected exception raised as part of request, got %s", e)
       raise
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
index 36f367e6d77..34c7927272d 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
@@ -19,7 +19,8 @@
 import unittest
 
 try:
-  from apache_beam.ml.inference.vertex_ai_inference import 
_retry_on_gcp_client_error
+  from apache_beam.ml.inference.vertex_ai_inference import 
_retry_on_appropriate_gcp_error
+  from apache_beam.ml.inference.vertex_ai_inference import 
VertexAIModelHandlerJSON
   from google.api_core.exceptions import TooManyRequests
 except ImportError:
   raise unittest.SkipTest('VertexAI dependencies are not installed')
@@ -28,11 +29,22 @@ except ImportError:
 class RetryOnClientErrorTest(unittest.TestCase):
   def test_retry_on_client_error_positive(self):
     e = TooManyRequests(message="fake service rate limiting")
-    self.assertTrue(_retry_on_gcp_client_error(e))
+    self.assertTrue(_retry_on_appropriate_gcp_error(e))
 
   def test_retry_on_client_error_negative(self):
     e = ValueError()
-    self.assertFalse(_retry_on_gcp_client_error(e))
+    self.assertFalse(_retry_on_appropriate_gcp_error(e))
+
+
+class ModelHandlerArgConditions(unittest.TestCase):
+  def test_exception_on_private_without_network(self):
+    self.assertRaises(
+        ValueError,
+        VertexAIModelHandlerJSON,
+        endpoint_id="1",
+        project="testproject",
+        location="us-central1",
+        private=True)
 
 
 if __name__ == '__main__':

Reply via email to