This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.5.x by this push:
new bd2b5a2 prevent TRT_Logger to be destroyed before TRT engine (#14898)
(#15877)
bd2b5a2 is described below
commit bd2b5a28e775004c504c8154a90884bf9c06cd9d
Author: Kellen Sunderland <[email protected]>
AuthorDate: Fri Aug 16 08:20:32 2019 -0700
prevent TRT_Logger to be destroyed before TRT engine (#14898) (#15877)
* prevent TRT_Logger to be destroyed before TRT engine
* use unique_ptr for trt_logger/parser/engine/executor ownership
* reduce line length for lint
---
src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc | 35 ++++--------
src/operator/subgraph/tensorrt/onnx_to_tensorrt.h | 66 ++++++++++++++--------
src/operator/subgraph/tensorrt/tensorrt-inl.h | 25 ++++----
src/operator/subgraph/tensorrt/tensorrt.cc | 4 +-
4 files changed, 69 insertions(+), 61 deletions(-)
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index 7dbc54b..27f6da4 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -48,23 +48,6 @@ using std::endl;
namespace onnx_to_tensorrt {
-struct InferDeleter {
- template<typename T>
- void operator()(T* obj) const {
- if ( obj ) {
- obj->destroy();
- }
- }
-};
-
-template<typename T>
-inline std::shared_ptr<T> InferObject(T* obj) {
- if ( !obj ) {
- throw std::runtime_error("Failed to create object");
- }
- return std::shared_ptr<T>(obj, InferDeleter());
-}
-
std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
int onnx_ir_major = ir_version / 1000000;
int onnx_ir_minor = ir_version % 1000000 / 10000;
@@ -83,7 +66,9 @@ void PrintVersion() {
<< NV_TENSORRT_PATCH << endl;
}
-std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
+std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
+ unique_ptr<nvonnxparser::IParser>,
+ std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
const std::string& onnx_model,
int32_t max_batch_size,
size_t max_workspace_size,
@@ -91,10 +76,10 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*>
onnxToTrtCtx(
bool debug_builder) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
- TRT_Logger trt_logger(verbosity);
- auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
- auto trt_network = InferObject(trt_builder->createNetwork());
- auto trt_parser = nvonnxparser::createParser(trt_network.get(), trt_logger);
+ auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
+ auto trt_builder = nvinfer1::createInferBuilder(*trt_logger);
+ auto trt_network = trt_builder->createNetwork();
+ auto trt_parser = InferObject(nvonnxparser::createParser(trt_network,
*trt_logger));
::ONNX_NAMESPACE::ModelProto parsed_model;
// We check for a valid parse, but the main effect is the side effect
// of populating parsed_model
@@ -139,8 +124,10 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*>
onnxToTrtCtx(
trt_builder->setMaxBatchSize(max_batch_size);
trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);
- nvinfer1::ICudaEngine* trt_engine =
trt_builder->buildCudaEngine(*trt_network.get());
- return std::make_tuple(trt_engine, trt_parser);
+ auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
+ trt_builder->destroy();
+ trt_network->destroy();
+ return std::make_tuple(std::move(trt_engine), std::move(trt_parser),
std::move(trt_logger));
}
} // namespace onnx_to_tensorrt
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
index 3e8ea1b..b89422f 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
@@ -32,6 +32,7 @@
#include <NvInfer.h>
#include <fstream>
+#include <memory>
#include <iostream>
#include <sstream>
#include <string>
@@ -40,33 +41,51 @@
namespace onnx_to_tensorrt {
+struct InferDeleter {
+ template<typename T>
+ void operator()(T* obj) const {
+ if ( obj ) {
+ obj->destroy();
+ }
+ }
+};
+
+template<typename T>
+using unique_ptr = std::unique_ptr<T, InferDeleter>;
+
+template<typename T>
+inline unique_ptr<T> InferObject(T* obj) {
+ if ( !obj ) {
+ throw std::runtime_error("Failed to create object");
+ }
+ return unique_ptr<T>(obj, InferDeleter());
+}
+
class TRT_Logger : public nvinfer1::ILogger {
- nvinfer1::ILogger::Severity _verbosity;
- std::ostream* _ostream;
+ nvinfer1::ILogger::Severity _verbosity;
+ std::ostream* _ostream;
public:
- TRT_Logger(Severity verbosity = Severity::kWARNING,
- std::ostream& ostream = std::cout)
- : _verbosity(verbosity), _ostream(&ostream) {}
- void log(Severity severity, const char* msg) override {
- if ( severity <= _verbosity ) {
- time_t rawtime = std::time(0);
- char buf[256];
- strftime(&buf[0], 256,
- "%Y-%m-%d %H:%M:%S",
- std::gmtime(&rawtime));
- const char* sevstr = (severity ==
Severity::kINTERNAL_ERROR ? " BUG" :
- severity == Severity::kERROR
? " ERROR" :
- severity == Severity::kWARNING
? "WARNING" :
- severity == Severity::kINFO
? " INFO" :
- "UNKNOWN");
- (*_ostream) << "[" << buf << " " << sevstr << "] "
- << msg
- << std::endl;
- }
- }
+ TRT_Logger(Severity verbosity = Severity::kWARNING,
+ std::ostream& ostream = std::cout) :
+ _verbosity(verbosity), _ostream(&ostream) {}
+ void log(Severity severity, const char* msg) override {
+ if (severity <= _verbosity) {
+ time_t rawtime = std::time(0);
+ char buf[256];
+ strftime(&buf[0], 256, "%Y-%m-%d %H:%M:%S", std::gmtime(&rawtime));
+ const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" :
+ severity == Severity::kERROR ? " ERROR" :
+ severity == Severity::kWARNING ? "WARNING" :
+ severity == Severity::kINFO ? " INFO" :
+ "UNKNOWN");
+ (*_ostream) << "[" << buf << " " << sevstr << "] " << msg << std::endl;
+ }
+ }
};
-std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
+std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
+ unique_ptr<nvonnxparser::IParser>,
+ std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
const std::string& onnx_model,
int32_t max_batch_size = 32,
size_t max_workspace_size = 1L << 30,
@@ -75,5 +94,4 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*>
onnxToTrtCtx(
} // namespace onnx_to_tensorrt
#endif // MXNET_USE_TENSORRT
-
#endif // MXNET_OPERATOR_SUBGRAPH_TENSORRT_ONNX_TO_TENSORRT_H_
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index e258d89..c175ac4 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -51,10 +51,14 @@ struct TRTParam {
};
struct TRTEngineParam {
- TRTEngineParam(nvinfer1::ICudaEngine* trt_engine,
- nvonnxparser::IParser* _parser,
- const std::unordered_map<std::string, uint32_t> input_map,
- const std::unordered_map<std::string, uint32_t> output_map) {
+ TRTEngineParam(onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine>
_trt_engine,
+ onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser>
_trt_parser,
+ std::unique_ptr<onnx_to_tensorrt::TRT_Logger> _trt_logger,
+ const std::unordered_map<std::string, uint32_t>& input_map,
+ const std::unordered_map<std::string, uint32_t>& output_map) {
+ trt_engine = std::move(_trt_engine);
+ trt_logger = std::move(_trt_logger);
+ trt_parser = std::move(_trt_parser);
binding_order = std::make_shared<std::vector<std::pair<uint32_t, bool> >
>();
bindings = std::make_shared<std::vector<void*> >();
binding_order->reserve(trt_engine->getNbBindings());
@@ -67,16 +71,13 @@ struct TRTEngineParam {
binding_order->emplace_back(output_map.at(binding_name), false);
}
}
- trt_executor = trt_engine->createExecutionContext();
- trt_parser = _parser;
+ trt_executor =
onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext());
}
- ~TRTEngineParam() {
- trt_parser->destroy();
- trt_executor->destroy();
- }
- nvinfer1::IExecutionContext* trt_executor;
- nvonnxparser::IParser* trt_parser;
+ onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
+ onnx_to_tensorrt::unique_ptr<nvinfer1::IExecutionContext> trt_executor;
+ onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> trt_parser;
+ std::unique_ptr<onnx_to_tensorrt::TRT_Logger> trt_logger;
std::shared_ptr<std::vector<std::pair<uint32_t, bool> > > binding_order;
std::shared_ptr<std::vector<void*> > bindings;
};
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cc
b/src/operator/subgraph/tensorrt/tensorrt.cc
index 7652510..71b2981 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/tensorrt.cc
@@ -312,7 +312,9 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs,
Context ctx,
graph.attrs["shape"] = std::make_shared<nnvm::any>(std::move(shapes));
auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph,
¶ms_map);
auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph,
max_batch_size, 1 << 30);
- return OpStatePtr::Create<TRTEngineParam>(std::get<0>(trt_tuple),
std::get<1>(trt_tuple),
+ return OpStatePtr::Create<TRTEngineParam>(std::move(std::get<0>(trt_tuple)),
+ std::move(std::get<1>(trt_tuple)),
+ std::move(std::get<2>(trt_tuple)),
inputs_to_idx, outputs_to_idx);
}