jrmccluskey commented on code in PR #34252: URL: https://github.com/apache/beam/pull/34252#discussion_r1991609113
########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def load_model(self) -> str: + """ + This method can load or return any model resource information. In the case of + a Triton endpoint, it could simply return the URL or any other required reference. + """ + return self.endpoint_url + + def run_inference( + self, + batch, + model, + inference_args=None + ) -> Iterable[PredictionResult]: + """ + Sends a prediction request with the Triton-specific payload structure. + """ + + payload = { + "inputs": [ + { + "name": "name", + "shape": [1, 1], + "datatype": "BYTES", + "data": batch, + } + ] + } + response = requests.post(model, json=payload) Review Comment: There are AI platform libraries you should use to actually send the request, not the requests library. see [`aiplatform` and `aiplatform_v1`](https://cloud.google.com/python/docs/reference/aiplatform/latest). There's a [demo notebook](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/e128823a12039e3ed8449254bed29abf37eb0201/notebooks/community/vertex_endpoints/nvidia-triton/nvidia-triton-custom-container-prediction.ipynb) that explains what to do with those ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def load_model(self) -> str: + """ + This method can load or return any model resource information. In the case of + a Triton endpoint, it could simply return the URL or any other required reference. + """ + return self.endpoint_url + + def run_inference( + self, + batch, + model, + inference_args=None + ) -> Iterable[PredictionResult]: + """ + Sends a prediction request with the Triton-specific payload structure. + """ + + payload = { + "inputs": [ + { + "name": "name", + "shape": [1, 1], + "datatype": "BYTES", + "data": batch, + } + ] + } Review Comment: These fields should not be hard coded, `shape` is a per-payload variable while `name` and `datatype` should be configurable. ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url Review Comment: Taking the single URL is somewhat unfriendly to the user when we can construct the endpoint string consistently with project ID, region, and endpoint names. ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def load_model(self) -> str: + """ + This method can load or return any model resource information. In the case of + a Triton endpoint, it could simply return the URL or any other required reference. + """ + return self.endpoint_url Review Comment: load_model in a remote context should probably check liveness of the remote endpoint, see _retrieve_endpoint() above ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): Review Comment: the types for the model handler need to be specified, e.g. `ModelHandler[InputType, PredictionResult, ModelType]` ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): Review Comment: this class should be instrumented with the client-side throttling code like the `VertexAIModelHandlerJSON` class. ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def load_model(self) -> str: + """ + This method can load or return any model resource information. In the case of + a Triton endpoint, it could simply return the URL or any other required reference. + """ + return self.endpoint_url + + def run_inference( + self, + batch, + model, + inference_args=None + ) -> Iterable[PredictionResult]: + """ + Sends a prediction request with the Triton-specific payload structure. + """ + + payload = { + "inputs": [ + { + "name": "name", + "shape": [1, 1], + "datatype": "BYTES", + "data": batch, + } + ] + } + response = requests.post(model, json=payload) + response.raise_for_status() + prediction_response = response.json() + + yield prediction_response Review Comment: this is not the correct return type, you need to create the `PredictionResult` objects which are tuples of the input and output. ########## sdks/python/apache_beam/ml/inference/vertex_ai_inference.py: ########## @@ -256,3 +257,46 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def load_model(self) -> str: + """ + This method can load or return any model resource information. In the case of + a Triton endpoint, it could simply return the URL or any other required reference. + """ + return self.endpoint_url + + def run_inference( + self, + batch, + model, + inference_args=None + ) -> Iterable[PredictionResult]: Review Comment: These args need type hints -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org