pranavm-nvidia commented on code in PR #22131: URL: https://github.com/apache/beam/pull/22131#discussion_r948016495
########## sdks/python/apache_beam/ml/inference/tensorrt_inference.py: ########## @@ -0,0 +1,281 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +import logging +import sys +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Tuple + +import numpy as np + +import tensorrt as trt +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from cuda import cuda + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + +logging.basicConfig(level=logging.INFO) +logging.getLogger("TensorRTEngineHandlerNumPy").setLevel(logging.INFO) +log = logging.getLogger("TensorRTEngineHandlerNumPy") + + +def _load_engine(engine_path): + file = FileSystems.open(engine_path, 'rb') + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(file.read()) + assert engine + return engine + + +def _load_onnx(onnx_path): + builder = trt.Builder(TRT_LOGGER) + network = builder.create_network( + flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, TRT_LOGGER) + with FileSystems.open(onnx_path) as f: + if not parser.parse(f.read()): + log.error("Failed to load ONNX file: %s", onnx_path) + for error in range(parser.num_errors): + log.error(parser.get_error(error)) + sys.exit(1) + return network, builder + + +def _build_engine(network, builder): + config = builder.create_builder_config() + runtime = trt.Runtime(TRT_LOGGER) + plan = builder.build_serialized_network(network, config) + engine = runtime.deserialize_cuda_engine(plan) + builder.reset() + return engine + + +def _validate_inference_args(inference_args): + """Confirms that inference_args is None. + + TensorRT engines do not need extra arguments in their execute_v2() call. + However, since inference_args is an argument in the RunInference interface, + we want to make sure it is not passed here in TensorRT's implementation of + RunInference. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because TensorRT ' + 'engines do not need extra arguments in their execute_v2() call.') + + +def ASSERT_DRV(args): + """CUDA error checking.""" + err, ret = args[0], args[1:] + if isinstance(err, cuda.CUresult): + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + else: + raise RuntimeError("Unknown error type: {}".format(err)) + # Special case so that no unpacking is needed at call-site. + if len(ret) == 1: + return ret[0] + return ret + + +class TensorRTEngine: + def __init__(self, engine: trt.ICudaEngine): + """Implementation of the TensorRTEngine class which handles + allocations associated with TensorRT engine. + + Example Usage:: + + TensorRTEngine(engine) + + Args: + engine: trt.ICudaEngine object that contains TensorRT engine + """ + self.engine = engine + self.context = engine.create_execution_context() + self.inputs = [] + self.outputs = [] + self.gpu_allocations = [] + self.cpu_allocations = [] + """Setup I/O bindings.""" + for i in range(self.engine.num_bindings): + name = self.engine.get_binding_name(i) + dtype = self.engine.get_binding_dtype(i) + shape = self.engine.get_binding_shape(i) + size = trt.volume(shape) * dtype.itemsize + allocation = ASSERT_DRV(cuda.cuMemAlloc(size)) + binding = { + 'index': i, + 'name': name, + 'dtype': np.dtype(trt.nptype(dtype)), + 'shape': list(shape), + 'allocation': allocation, + 'size': size + } + self.gpu_allocations.append(allocation) + if self.engine.binding_is_input(i): + self.inputs.append(binding) + else: + self.outputs.append(binding) + + assert self.context + assert len(self.inputs) > 0 + assert len(self.outputs) > 0 + assert len(self.gpu_allocations) > 0 + + for output in self.outputs: + self.cpu_allocations.append(np.zeros(output['shape'], output['dtype'])) + # Create CUDA Stream. + self.stream = ASSERT_DRV(cuda.cuStreamCreate(0)) + + def get_engine_attrs(self): + """Returns TensorRT engine attributes.""" + return ( + self.engine, + self.context, + self.inputs, + self.outputs, + self.gpu_allocations, + self.cpu_allocations, + self.stream) + + +class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, + PredictionResult, + TensorRTEngine]): + def __init__(self, min_batch_size: int, max_batch_size: int, **kwargs): + """Implementation of the ModelHandler interface for TensorRT. + + Example Usage:: + + pcoll | RunInference( + TensorRTEngineHandlerNumPy( + min_batch_size=1, + max_batch_size=1, + engine_path="my_uri")) + + Args: + min_batch_size: minimum accepted batch size. + max_batch_size: maximum accepted batch size. + kwargs: Additional arguments like 'engine_path' and 'onnx_path' are + currently supported. + + See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/ + for details + """ + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + if 'engine_path' in kwargs: + self.engine_path = kwargs.get('engine_path') + elif 'onnx_path' in kwargs: + self.onnx_path = kwargs.get('onnx_path') + + trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") + + def batch_elements_kwargs(self): + """Sets min_batch_size and max_batch_size of a TensorRT engine.""" + return { + 'min_batch_size': self.min_batch_size, + 'max_batch_size': self.max_batch_size + } + + def load_model(self) -> TensorRTEngine: + """Loads and initializes a TensorRT engine for processing.""" + engine = _load_engine(self.engine_path) + return TensorRTEngine(engine) + + def load_onnx(self) -> Tuple[trt.INetworkDefinition, trt.Builder]: + """Loads and parses an onnx model for processing.""" + return _load_onnx(self.onnx_path) + + def build_engine( + self, network: trt.INetworkDefinition, + builder: trt.Builder) -> TensorRTEngine: + """Build an engine according to parsed/created network.""" + engine = _build_engine(network, builder) + return TensorRTEngine(engine) + + def run_inference( Review Comment: Like @azhurkevich said, we'll have more details after our internal sync, but I can provide a little background information until then. The context mainly allocates three things: 1. Persistent host/device memory: This is for things like look-up tables to speed up computation. 2. Workspace/scratch memory: This is memory that is reused across layers and used during computation - e.g. if we compute Convolution as a GEMM, the generated matrix would be written to scratch memory. 3. Activation memory: This is to store intermediate tensors. TensorRT tries to reuse this across layers as much as possible, but the degree of reusability depends on the network architecture (e.g. parallel branches would require separate memory). You can determine the total size of (2) and (3) by checking `engine.device_memory_size` and optionally manage both of these yourself by using `engine.create_execution_context_without_device_memory()` and setting `context.device_memory` to point to your own memory buffer. Otherwise, TensorRT will allocate them on a per-context basis. Finally, to answer your questions: > How many contexts can we create and how do you recommend to manage memory b/w them? There's some nuance here, but effectively you're only limited by system resources. Regarding memory management, it depends on the use-case. Assuming you have N contexts, then: - If you plan to use all N concurrently, then I'd just let TensorRT manage the per-context memory - If you plan to use only some fraction, K, of them concurrently, then you can allocate enough memory for only K contexts and reuse that. The thing to be careful of is that contexts executing concurrently should not use the same memory. > Are context objects small? Generally no. > Does TensorRT support concurrent execution on multiple processes on the same GPU? I believe, I asked this and answer was yes, just, double checking... If yes, how GPU memory management is solved? Yes it does, but the CUDA context switching overhead might make this impractical (I'm not totally sure though, will need to double check this). For GPU memory management, I think it might be possible to use `nvshmem` to share memory across processes, but I don't have any experience with that. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
