jrmccluskey commented on code in PR #34252:
URL: https://github.com/apache/beam/pull/34252#discussion_r1991609113


##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url
+
+    def load_model(self) -> str:
+        """
+        This method can load or return any model resource information. In the 
case of
+        a Triton endpoint, it could simply return the URL or any other 
required reference.
+        """
+        return self.endpoint_url
+
+    def run_inference(
+        self,
+        batch,
+        model,
+        inference_args=None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        payload = {
+            "inputs": [
+                {
+                    "name": "name",
+                    "shape": [1, 1],
+                    "datatype": "BYTES",
+                    "data": batch,  
+                }
+            ]
+        }
+        response = requests.post(model, json=payload)

Review Comment:
   There are AI platform libraries you should use to actually send the request, 
not the requests library. see [`aiplatform` and 
`aiplatform_v1`](https://cloud.google.com/python/docs/reference/aiplatform/latest).
 There's a [demo 
notebook](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/e128823a12039e3ed8449254bed29abf37eb0201/notebooks/community/vertex_endpoints/nvidia-triton/nvidia-triton-custom-container-prediction.ipynb)
 that explains what to do with those



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url
+
+    def load_model(self) -> str:
+        """
+        This method can load or return any model resource information. In the 
case of
+        a Triton endpoint, it could simply return the URL or any other 
required reference.
+        """
+        return self.endpoint_url
+
+    def run_inference(
+        self,
+        batch,
+        model,
+        inference_args=None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        payload = {
+            "inputs": [
+                {
+                    "name": "name",
+                    "shape": [1, 1],
+                    "datatype": "BYTES",
+                    "data": batch,  
+                }
+            ]
+        }

Review Comment:
   These fields should not be hard coded, `shape` is a per-payload variable 
while `name` and `datatype` should be configurable. 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url

Review Comment:
   Taking the single URL is somewhat unfriendly to the user when we can 
construct the endpoint string consistently with project ID, region, and 
endpoint names. 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url
+
+    def load_model(self) -> str:
+        """
+        This method can load or return any model resource information. In the 
case of
+        a Triton endpoint, it could simply return the URL or any other 
required reference.
+        """
+        return self.endpoint_url

Review Comment:
   load_model in a remote context should probably check liveness of the remote 
endpoint, see _retrieve_endpoint() above



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):

Review Comment:
   the types for the model handler need to be specified, e.g. 
`ModelHandler[InputType, PredictionResult, ModelType]`



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):

Review Comment:
   this class should be instrumented with the client-side throttling code like 
the `VertexAIModelHandlerJSON` class. 



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url
+
+    def load_model(self) -> str:
+        """
+        This method can load or return any model resource information. In the 
case of
+        a Triton endpoint, it could simply return the URL or any other 
required reference.
+        """
+        return self.endpoint_url
+
+    def run_inference(
+        self,
+        batch,
+        model,
+        inference_args=None
+    ) -> Iterable[PredictionResult]:
+        """
+        Sends a prediction request with the Triton-specific payload structure.
+        """
+        
+        payload = {
+            "inputs": [
+                {
+                    "name": "name",
+                    "shape": [1, 1],
+                    "datatype": "BYTES",
+                    "data": batch,  
+                }
+            ]
+        }
+        response = requests.post(model, json=payload)
+        response.raise_for_status()
+        prediction_response = response.json()
+        
+        yield prediction_response

Review Comment:
   this is not the correct return type, you need to create the 
`PredictionResult` objects which are tuples of the input and output.



##########
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py:
##########
@@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
+
+
+class VertexAITritonModelHandler(ModelHandler):
+    """
+    A custom model handler for Vertex AI endpoints hosting Triton Inference 
Servers.
+    It constructs a payload that Triton expects and calls the raw predict 
endpoint.
+    """
+    
+    def __init__(self, endpoint_url: str):
+        self.endpoint_url = endpoint_url
+
+    def load_model(self) -> str:
+        """
+        This method can load or return any model resource information. In the 
case of
+        a Triton endpoint, it could simply return the URL or any other 
required reference.
+        """
+        return self.endpoint_url
+
+    def run_inference(
+        self,
+        batch,
+        model,
+        inference_args=None
+    ) -> Iterable[PredictionResult]:

Review Comment:
   These args need type hints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to