jrmccluskey commented on code in PR #34252:
URL: https://github.com/apache/beam/pull/34252#discussion_r1994255799


##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,
+        inference_args: Optional[Dict[str, Any]] = None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        config = self.payload_config.copy()
+        if inference_args:
+          config.update(inference_args)

Review Comment:
   not used



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,
+        inference_args: Optional[Dict[str, Any]] = None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        config = self.payload_config.copy()
+        if inference_args:
+          config.update(inference_args)
+
+        payload = {
+            "inputs": [
+                {
+                    "name": config.get("name", "name"),
+                    "shape": config.get("shape", [1, 1]),  
+                    "datatype": config.get("datatype", "BYTES"),
+                    "data": batch,
+                }
+            ]
+        }
+        client = aiplatform.gapic.PredictionServiceClient()
+        predict_response = client.predict(model_name=model, 
instances=[payload])
+        for inp, pred in zip(batch, predict_response.predictions):
+            yield PredictionResult(inp, pred)

Review Comment:
   just use _convert_to_result



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private

Review Comment:
   are there distinctions between public and private triton endpoints? 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,

Review Comment:
   This does not align with usage, an endpoint object is not the model name



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")

Review Comment:
   Wrong class



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint

Review Comment:
   Do triton endpoints function correctly in this way? 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 

Review Comment:
   whitespace



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler[Any,
+                                            PredictionResult,
+                                            aiplatform.Endpoint]):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, 
+                 project_id: str,
+                 region: str,
+                 endpoint_name: str,
+                 location: str,
+                 payload_config: Optional[Dict[str,Any]] = None,
+                 private: bool = False,
+                 
+                 ):
+        self.project_id = project_id
+        self.region = region
+        self.endpoint_name = endpoint_name
+        self.endpoint_url = 
f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict";
+        self.is_private = private
+        self.location = location
+        self.payload_config = payload_config if payload_config else {}
+        
+        # Configure AdaptiveThrottler and throttling metrics for client-side
+        # throttling behavior.
+        # See 
https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
+        # for more details.
+        self.throttled_secs = Metrics.counter(
+        VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
+        self.throttler = AdaptiveThrottler(
+        window_ms=1, bucket_ms=1, overload_ratio=2)
+
+    def load_model(self) -> aiplatform.Endpoint:
+        """Loads the Endpoint object used to build and send prediction request 
to
+            Vertex AI.
+        """
+    # Check to make sure the endpoint is still active since pipeline
+    # construction time
+        ep = self._retrieve_endpoint(
+        self.endpoint_name, self.location, self.is_private)
+        return ep
+
+    def _retrieve_endpoint(
+      self, endpoint_id: str,
+      location: str,
+      is_private: bool) -> aiplatform.Endpoint:
+      """Retrieves an AI Platform endpoint and queries it for liveness/deployed
+      models.
+
+      Args:
+        endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
+        is_private: a boolean indicating if the Vertex AI endpoint is a private
+          endpoint
+      Returns:
+        An aiplatform.Endpoint object
+      Raises:
+        ValueError: if endpoint is inactive or has no models deployed to it.
+      """
+      if is_private:
+        endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as private", endpoint_id)
+      else:
+        endpoint = aiplatform.Endpoint(
+            endpoint_name=endpoint_id, location=location)
+        LOGGER.debug("Treating endpoint %s as public", endpoint_id)
+
+      try:
+        mod_list = endpoint.list_models()
+      except Exception as e:
+        raise ValueError(
+            "Failed to contact endpoint %s, got exception: %s", endpoint_id, e)
+
+      if len(mod_list) == 0:
+        raise ValueError("Endpoint %s has no models deployed to it.", 
endpoint_id)
+
+      return endpoint
+
+    def run_inference(
+        self,
+        batch: Sequence[Any],
+        model: aiplatform.Endpoint,
+        inference_args: Optional[Dict[str, Any]] = None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        config = self.payload_config.copy()
+        if inference_args:
+          config.update(inference_args)
+
+        payload = {
+            "inputs": [
+                {
+                    "name": config.get("name", "name"),
+                    "shape": config.get("shape", [1, 1]),  

Review Comment:
   with batching this will be dynamic



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to