This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7b7cef5 MXNet-TRT: Add PrePartition param caching - move
init_tensorrt_params logic (#18490)
7b7cef5 is described below
commit 7b7cef5e0ba4a95fdfed55e6c8ba1aff59e6fc91
Author: Serge Panev <[email protected]>
AuthorDate: Tue Aug 4 18:23:48 2020 -0700
MXNet-TRT: Add PrePartition param caching - move init_tensorrt_params
logic (#18490)
* Update to TRT 7 API
Signed-off-by: Serge Panev <[email protected]>
* Add PrePartition param caching - move init_tensorrt_params logic
Signed-off-by: Serge Panev <[email protected]>
* Handle node with no defined input
Signed-off-by: Serge Panev <[email protected]>
* Remove tmp comment
Signed-off-by: Serge Panev <[email protected]>
---
src/operator/subgraph/build_subgraph.cc | 2 +-
src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc | 4 +-
src/operator/subgraph/tensorrt/tensorrt-inl.h | 48 +++++++++++++++++++---
src/operator/subgraph/tensorrt/tensorrt.cu | 6 +--
4 files changed, 49 insertions(+), 11 deletions(-)
diff --git a/src/operator/subgraph/build_subgraph.cc
b/src/operator/subgraph/build_subgraph.cc
index 93cb174..38038f2 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -430,7 +430,7 @@ void SortEntries(const std::unordered_map<const
nnvm::NodeEntry*, size_t>& entry
}
/*!
- * \brief Given a subgraph, find the output entries of a subgraph.
+ * \brief Given a subgraph, find the input entries of a subgraph.
* \param g pointer to the whole graph
* \param simple_nods vector of simple nodes in top sorted order
* \param subgraph_nodes vector of pointers of simples of a subgraph.
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index d82f754..4f5bdcb 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -74,7 +74,9 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
- auto trt_network = InferObject(trt_builder->createNetwork());
+ const auto explicitBatch = 1U << static_cast<uint32_t>(
+
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+ auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
auto trt_parser = InferObject(nvonnxparser::createParser(*trt_network,
*trt_logger));
::ONNX_NAMESPACE::ModelProto parsed_model;
// We check for a valid parse, but the main effect is the side effect
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index 16cc130..b35a171 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -267,6 +267,23 @@ class TensorrtProperty : public SubgraphProperty {
return std::make_shared<TensorrtProperty>();
}
+ void PrePartition(const nnvm::Graph& g,
+ const std::vector<std::pair<std::string, std::string>>& options_map)
override {
+ auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
+ auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
+ NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
+ NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
+ in_args_dict.clear();
+ in_aux_dict.clear();
+ // we trust the Python API, len(in_arg_names) == len(in_args_ptr)
+ for (unsigned i = 0; i < in_arg_names.size(); ++i) {
+ in_args_dict[in_arg_names[i]] = in_args_ptr[i];
+ }
+ for (unsigned i = 0; i < in_aux_names.size(); ++i) {
+ in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
+ }
+ }
+
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id) const override {
nnvm::ObjectPtr n = nnvm::Node::Create();
@@ -280,16 +297,33 @@ class TensorrtProperty : public SubgraphProperty {
n->attrs.op = Op::Get("_TensorRT");
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+
+ // Mapping subgraph params with NDArrays
+ TRTParam param;
std::ostringstream params_oss;
- for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
- params_oss << e << ";";
+ for (auto ¶m_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
+ NDArray *cache;
+ auto it_args = in_args_dict.find(param_name);
+ if (it_args != in_args_dict.end()) {
+ cache = it_args->second;
+ } else {
+ auto it_aux = in_aux_dict.find(param_name);
+ if (it_aux != in_aux_dict.end()) {
+ cache = it_aux->second;
+ }
+ }
+ if (cache != nullptr) {
+ param.params_map.emplace(param_name, cache->Copy(Context()));
+ param.params_map[param_name].WaitToRead();
+ params_oss << param_name << ";";
+ }
}
auto tensorrt_params_names = params_oss.str();
- tensorrt_params_names.pop_back();
- n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
- TRTParam param;
+ if (!tensorrt_params_names.empty()) {
+ tensorrt_params_names.pop_back();
+ }
n->attrs.parsed = param;
- n->op()->attr_parser(&(n->attrs));
+ n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
return n;
}
@@ -328,6 +362,8 @@ class TensorrtProperty : public SubgraphProperty {
}
subgraph_node->attrs.parsed = std::move(_params);
}
+
+ std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
};
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu
b/src/operator/subgraph/tensorrt/tensorrt.cu
index 4a5b23b..826f9a5 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cu
+++ b/src/operator/subgraph/tensorrt/tensorrt.cu
@@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext&
ctx,
param.bindings->at(i) = outputs[p.first].dptr_;
}
}
- const int batch_size = static_cast<int>(inputs[0].shape_[0]);
- param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s,
nullptr);
+ param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr);
}
NNVM_REGISTER_OP(_TensorRT)
-.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
} // namespace op
} // namespace mxnet