mkolod commented on a change in pull request #11325: Added TensorRT runtime 
integration
URL: https://github.com/apache/incubator-mxnet/pull/11325#discussion_r201778476
 
 

 ##########
 File path: src/operator/contrib/tensorrt.cc
 ##########
 @@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(TRTParam);
+
+OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
+                         tensorrt::NameToIdx_t input_map,
+                         tensorrt::NameToIdx_t output_map) {
+  TRTEngineParam param;
+  for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
+    const std::string& binding_name = trt_engine->getBindingName(b);
+    if (trt_engine->bindingIsInput(b)) {
+      param.binding_map.emplace_back(input_map[binding_name],
+                                     tensorrt::TypeIO::Inputs);
+    } else {
+      param.binding_map.emplace_back(output_map[binding_name],
+                                     tensorrt::TypeIO::Outputs);
+    }
+  }
+  param.trt_executor = trt_engine->createExecutionContext();
+  return OpStatePtr::Create<TRTEngineParam>(param);
+}
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
+                          const std::vector<TShape>& ishape,
+                          const std::vector<int>& itype) {
+  const TRTParam& node_param = nnvm::get<TRTParam>(attrs.parsed);
+
+  ::onnx::ModelProto model_proto;
+  bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
+  if (!success) {
+    LOG(FATAL) << "Problems parsing serialized ONNX model.";
+  }
+  auto graph = model_proto.graph();
+  auto first_input_type = graph.input(0).type().tensor_type();
+  auto dim_value = first_input_type.shape().dim(0).dim_value();
+  uint64_t batch_size = static_cast<uint64_t>(dim_value);
+  // Need to set up max workspace size based on device properties
+  nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
+      node_param.serialized_onnx_graph, batch_size, 1 << 30);
+
+  LOG(INFO) << "TensorRT engine instantiated!!!";
+
+  tensorrt::NameToIdx_t output_map;
+  for (auto& el : node_param.output_map) {
+    output_map[el.first] = std::get<0>(el.second);
+  }
+  return GetPtrMapping(trt_engine, node_param.input_map, output_map);
+}
+
+void TRTParamParser(nnvm::NodeAttrs* attrs) {
+  using namespace mshadow;
+
+  TRTParam param_;
+
+  try {
+    param_.Init(attrs->dict);
+    common::Deserialize(&param_.input_map, param_.serialized_input_map);
+    common::Deserialize(&param_.output_map, param_.serialized_output_map);
+    param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+
+  attrs->parsed = std::move(param_);
+}
+
+template <>
+void TRTCompute<cpu>(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  LOG(FATAL) << "TRTCompute not implemented on the CPU";
+}
+
+inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* 
in_shape,
+                          std::vector<TShape>* out_shape) {
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
+  }
+  return true;
+}
+
+inline bool TRTInferStorageType(const NodeAttrs& attrs, const int dev_mask,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* in_storage_type,
+                                std::vector<int>* out_storage_type) {
+  return storage_type_assign(out_storage_type, mxnet::kDefaultStorage,
+                             dispatch_mode, DispatchMode::kFCompute);
+}
+
+inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* in_dtype,
+                         std::vector<int>* out_dtype) {
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
+  }
+  return true;
+}
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
 
 Review comment:
   @reminisce It's a good point regarding the Module. Unfortunately this 
initial PR does not support Module, partly because it relies on Symbol's 
`simple_bind` method to provide the shared buffer, which is necessary because 
TensorRT stores copies of the weights. Binding from Module hides away the 
access to the weights from the user, and the user providing the weights during 
binding makes this approach work not only with symbolic models, but also Gluon 
models converted to symbolic from the pre-trained GluonCV model zoo. The 
concern over Module is valid, and the support for it would probably require 
some eventual re-working on the integration, but I though it would be good to 
at least get this initial integration out there, and then broaden the scope to 
the more common Module approach.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to