jrmccluskey commented on code in PR #34252:
URL: https://github.com/apache/beam/pull/34252#discussion_r2080257920


##########
sdks/python/apache_beam/ml/inference/test_triton_model_handler.py:
##########
@@ -0,0 +1,409 @@
+import unittest
+from unittest.mock import patch, MagicMock, ANY, call
+import json
+from google.cloud import aiplatform
+from apache_beam.ml.inference.vertex_ai_inference import 
VertexAITritonModelHandler
+from apache_beam.ml.inference import utils 
+from apache_beam.ml.inference.base import PredictionResult
+import numpy as np
+import base64 

Review Comment:
   import order is wrong, the linting/formatting checks should have the correct 
order listed but for reference you should be importing in at least two distinct 
blocks: native python imports first, then third-party imports. These should be 
in alphabetical order within each block as well.



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -14,20 +14,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import json
 import logging
 from collections.abc import Iterable
 from collections.abc import Mapping
 from collections.abc import Sequence
 from typing import Any
 from typing import Optional
+from typing import Dict

Review Comment:
   use the built-in `dict` type for hints instead of `typing.Dict`



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -219,109 +217,81 @@ def validate_inference_args(self, inference_args: 
Optional[dict[str, Any]]):
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
 
-InputT = Any
-class VertexAITritonModelHandler(ModelHandler[InputT,
-                                            PredictionResult,
-                                            aiplatform.Endpoint]):
-    """
-    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
-    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
-    """
-    
-    def __init__(self, 
-                 project_id: str,
-                 location: str,
-                 endpoint_name: str,
-                 input_name: str,
-                 input_datatype: str, 
-                 input_tensor_shape: List[int], 
-                 output_tensor_name: Optional[str] = None,
-                 network: Optional[str] = None,
-                 private: bool = False,
-                 experiment: Optional[str] = None,
-                 *, 
-                 min_batch_size: Optional[int] = None,
-                 max_batch_size: Optional[int] = None,
-                 max_batch_duration_secs: Optional[int] = None,
-                 **kwargs
-                 ):
-        self.project_id = project_id
-        self.location = location 
-        self.endpoint_name = endpoint_name
-        self.input_name = input_name
-        self.input_datatype = input_datatype
-        self.input_tensor_shape = input_tensor_shape 
-        self.output_tensor_name = output_tensor_name
-        self.network = network
-        self.private = private
-        self.experiment = experiment
+class VertexAITritonModelHandler(RemoteModelHandler[Any, PredictionResult, 
aiplatform.Endpoint]):
+    def __init__(
+        self,
+        endpoint_id: str,
+        project: str,
+        location: str,
+        input_name: str,
+        datatype: str,
+        experiment: Optional[str] = None,
+        network: Optional[str] = None,
+        private: bool = False,
+        min_batch_size: Optional[int] = None,
+        max_batch_size: Optional[int] = None,
+        max_batch_duration_secs: Optional[int] = None,
+        **kwargs
+    ):
+        """
+        Initialize the handler for Triton Inference Server on Vertex AI.
 
+        Args:
+            endpoint_id: Vertex AI endpoint ID.
+            project: GCP project name.
+            location: GCP location.
+            input_name: Name of the input tensor for Triton.
+            datatype: Data type of the input (e.g., 'FP32', 'BYTES').
+            experiment: Optional experiment label.
+            network: Optional VPC network for private endpoints.
+            private: Whether the endpoint is private.
+            min_batch_size: Minimum batch size for batching.
+            max_batch_size: Maximum batch size for batching.
+            max_batch_duration_secs: Max buffering time for batches.
+            **kwargs: Additional arguments passed to the base class.
+        """
+        self.input_name = input_name
+        self.datatype = datatype
         self._batching_kwargs = {}
-        self._env_vars = kwargs.get('env_vars', {}) 
         if min_batch_size is not None:
             self._batching_kwargs["min_batch_size"] = min_batch_size
         if max_batch_size is not None:
             self._batching_kwargs["max_batch_size"] = max_batch_size
         if max_batch_duration_secs is not None:
             self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
-        if self.private and self.network is None:
-            raise ValueError(
-                "A VPC network must be provided ('network' arg) to use a 
private endpoint.")
-        try:
-             aiplatform.init(
-                project=self.project_id,
-                location=self.location,
-                experiment=self.experiment,
-                network=self.network)
-             LOGGER.info(
-                 "Initialized aiplatform client for project=%s, location=%s",
-                 self.project_id, self.location)
-        except Exception as e:
-            LOGGER.error("Failed to initialize aiplatform client: %s", e, 
exc_info=True)
-            raise RuntimeError(f"Could not initialize Google Cloud AI Platform 
client: {e}") from e
-        
-        try:
-            LOGGER.info("Performing initial liveness check for endpoint 
%s...", self.endpoint_name)
-            full_endpoint_id_for_check = self.endpoint_name
-            if not full_endpoint_id_for_check.startswith('projects/'):
-                full_endpoint_id_for_check = 
f"projects/{self.project_id}/locations/{self.location}/endpoints/{self.endpoint_name}"
-            _ = self._retrieve_endpoint(
-                endpoint_id=full_endpoint_id_for_check, 
-                project=self.project_id, 
-                location=self.location,
-                is_private=self.private)
-            LOGGER.info("Initial liveness check successful.")
-        except ValueError as e:
-            LOGGER.warning("Initial liveness check for endpoint %s failed: %s. 
"
-                           "Will retry in load_model.", self.endpoint_name, e)
-        except Exception as e:
-             LOGGER.warning("Unexpected error during initial liveness check 
for endpoint %s: %s. "
-                           "Will retry in load_model.", self.endpoint_name, e, 
exc_info=True)
-        # Configure AdaptiveThrottler and throttling metrics for client-side
-        # throttling behavior.
-        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
-        # for more details.
-        self.throttled_secs = Metrics.counter(
-        VertexAITritonModelHandler, "cumulativeThrottlingSeconds")
-        self.throttler = AdaptiveThrottler(
-        window_ms=1, bucket_ms=1, overload_ratio=2) 
-    def load_model(self) -> aiplatform.Endpoint:
-        """Loads the Endpoint object for inference, performing liveness 
check."""
-        
-        endpoint_id_to_load = self.endpoint_name
-        if not endpoint_id_to_load.startswith('projects/'):
-             endpoint_id_to_load = 
f"projects/{self.project_id}/locations/{self.location}/endpoints/{self.endpoint_name}"
-
-        LOGGER.info("Loading/Retrieving Vertex AI endpoint: %s", 
endpoint_id_to_load)
-        ep = self._retrieve_endpoint(
-            endpoint_id=endpoint_id_to_load,
-            project=self.project_id,
-            location=self.location,
-            is_private=self.private)
-        LOGGER.info("Successfully retrieved endpoint object: %s", ep.name)
-        return ep
+        if private and network is None:
+            raise ValueError("A VPC network must be provided for a private 
endpoint.")
+
+        aiplatform.init(
+            project=project,
+            location=location,
+            experiment=experiment,
+            network=network
+        )
+        self.project = project
+        self.endpoint_name = endpoint_id
+        self.location = location
+        self.is_private = private
+
+        self.endpoint = self._retrieve_endpoint(  
+            self.endpoint_name, self.project, self.location, self.is_private
+        )
+
+        super().__init__(
+            namespace='VertexAITritonModelHandler',
+            retry_filter=_retry_on_appropriate_gcp_error,
+            **kwargs
+        )
+    def create_client(self) -> aiplatform.Endpoint:
+        """
+        Create the client for inference.
 
+        Returns:
+            aiplatform.Endpoint object.
+        """
+        return self.endpoint
+    
     def _retrieve_endpoint(

Review Comment:
   I cannot find any sort of discussion around public versus private triton 
endpoints, but as I've said before the aiplatform.Endpoint classes aren't what 
you should be using anyway. 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -14,20 +14,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import json
 import logging
 from collections.abc import Iterable
 from collections.abc import Mapping
 from collections.abc import Sequence
 from typing import Any
 from typing import Optional
+from typing import Dict
 
 from google.api_core.exceptions import ServerError
 from google.api_core.exceptions import TooManyRequests
+from google.api_core.exceptions import GoogleAPICallError
 from google.cloud import aiplatform
 
 from apache_beam.ml.inference import utils
 from apache_beam.ml.inference.base import PredictionResult
+import numpy as np
+MSEC_TO_SEC = 1000
 from apache_beam.ml.inference.base import RemoteModelHandler

Review Comment:
   MSEC_TO_SEC should not be defined in the import block



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,

Review Comment:
   you're still deploying the model to a vertex endpoint, but that object's 
abstraction in the SDK is not useful here



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,

Review Comment:
   raw_predict isn't using an endpoint object, it uses a 
`PredictionServiceClient` 
(https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions#raw-predict-request)
 because you are forced to use the raw_predict API 
(https://cloud.google.com/vertex-ai/docs/predictions/using-nvidia-triton#deploy_the_model_to_endpoint)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to