This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b7cef5   MXNet-TRT: Add PrePartition param caching - move 
init_tensorrt_params logic  (#18490)
7b7cef5 is described below

commit 7b7cef5e0ba4a95fdfed55e6c8ba1aff59e6fc91
Author: Serge Panev <[email protected]>
AuthorDate: Tue Aug 4 18:23:48 2020 -0700

     MXNet-TRT: Add PrePartition param caching - move init_tensorrt_params 
logic  (#18490)
    
    * Update to TRT 7 API
    
    Signed-off-by: Serge Panev <[email protected]>
    
    * Add PrePartition param caching - move init_tensorrt_params logic
    
    Signed-off-by: Serge Panev <[email protected]>
    
    * Handle node with no defined input
    
    Signed-off-by: Serge Panev <[email protected]>
    
    * Remove tmp comment
    
    Signed-off-by: Serge Panev <[email protected]>
---
 src/operator/subgraph/build_subgraph.cc            |  2 +-
 src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc |  4 +-
 src/operator/subgraph/tensorrt/tensorrt-inl.h      | 48 +++++++++++++++++++---
 src/operator/subgraph/tensorrt/tensorrt.cu         |  6 +--
 4 files changed, 49 insertions(+), 11 deletions(-)

diff --git a/src/operator/subgraph/build_subgraph.cc 
b/src/operator/subgraph/build_subgraph.cc
index 93cb174..38038f2 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -430,7 +430,7 @@ void SortEntries(const std::unordered_map<const 
nnvm::NodeEntry*, size_t>& entry
 }
 
 /*!
- * \brief Given a subgraph, find the output entries of a subgraph.
+ * \brief Given a subgraph, find the input entries of a subgraph.
  * \param g pointer to the whole graph
  * \param simple_nods vector of simple nodes in top sorted order
  * \param subgraph_nodes vector of pointers of simples of a subgraph.
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc 
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index d82f754..4f5bdcb 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -74,7 +74,9 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
 
   auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
   auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
-  auto trt_network = InferObject(trt_builder->createNetwork());
+  const auto explicitBatch = 1U << static_cast<uint32_t>(
+                             
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
   auto trt_parser  = InferObject(nvonnxparser::createParser(*trt_network, 
*trt_logger));
   ::ONNX_NAMESPACE::ModelProto parsed_model;
   // We check for a valid parse, but the main effect is the side effect
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h 
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index 16cc130..b35a171 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -267,6 +267,23 @@ class TensorrtProperty : public SubgraphProperty {
     return std::make_shared<TensorrtProperty>();
   }
 
+  void PrePartition(const nnvm::Graph& g,
+    const std::vector<std::pair<std::string, std::string>>& options_map) 
override {
+    auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
+    auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
+    NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
+    NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
+    in_args_dict.clear();
+    in_aux_dict.clear();
+    // we trust the Python API, len(in_arg_names) == len(in_args_ptr)
+    for (unsigned i = 0; i < in_arg_names.size(); ++i) {
+      in_args_dict[in_arg_names[i]] = in_args_ptr[i];
+    }
+    for (unsigned i = 0; i < in_aux_names.size(); ++i) {
+      in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
+    }
+  }
+
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                    const int subgraph_id) const override {
     nnvm::ObjectPtr n = nnvm::Node::Create();
@@ -280,16 +297,33 @@ class TensorrtProperty : public SubgraphProperty {
     n->attrs.op = Op::Get("_TensorRT");
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+
+    // Mapping subgraph params with NDArrays
+    TRTParam param;
     std::ostringstream params_oss;
-    for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
-      params_oss << e << ";";
+    for (auto &param_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
+      NDArray *cache;
+      auto it_args = in_args_dict.find(param_name);
+      if (it_args != in_args_dict.end()) {
+        cache = it_args->second;
+      } else {
+        auto it_aux = in_aux_dict.find(param_name);
+        if (it_aux != in_aux_dict.end()) {
+          cache = it_aux->second;
+        }
+      }
+      if (cache != nullptr) {
+        param.params_map.emplace(param_name, cache->Copy(Context()));
+        param.params_map[param_name].WaitToRead();
+        params_oss << param_name << ";";
+      }
     }
     auto tensorrt_params_names = params_oss.str();
-    tensorrt_params_names.pop_back();
-    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
-    TRTParam param;
+    if (!tensorrt_params_names.empty()) {
+      tensorrt_params_names.pop_back();
+    }
     n->attrs.parsed = param;
-    n->op()->attr_parser(&(n->attrs));
+    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
     return n;
   }
 
@@ -328,6 +362,8 @@ class TensorrtProperty : public SubgraphProperty {
     }
     subgraph_node->attrs.parsed = std::move(_params);
   }
+
+  std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
 };
 
 
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu 
b/src/operator/subgraph/tensorrt/tensorrt.cu
index 4a5b23b..826f9a5 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cu
+++ b/src/operator/subgraph/tensorrt/tensorrt.cu
@@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext& 
ctx,
       param.bindings->at(i) = outputs[p.first].dptr_;
     }
   }
-  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
-  param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, 
nullptr);
+  param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr);
 }
 
 NNVM_REGISTER_OP(_TensorRT)
-.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 }  // namespace op
 }  // namespace mxnet

Reply via email to