This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new bd2b5a2  prevent TRT_Logger to be destroyed before TRT engine (#14898) 
(#15877)
bd2b5a2 is described below

commit bd2b5a28e775004c504c8154a90884bf9c06cd9d
Author: Kellen Sunderland <[email protected]>
AuthorDate: Fri Aug 16 08:20:32 2019 -0700

    prevent TRT_Logger to be destroyed before TRT engine (#14898) (#15877)
    
    * prevent TRT_Logger to be destroyed before TRT engine
    
    * use unique_ptr for trt_logger/parser/engine/executor ownership
    
    * reduce line length for lint
---
 src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc | 35 ++++--------
 src/operator/subgraph/tensorrt/onnx_to_tensorrt.h  | 66 ++++++++++++++--------
 src/operator/subgraph/tensorrt/tensorrt-inl.h      | 25 ++++----
 src/operator/subgraph/tensorrt/tensorrt.cc         |  4 +-
 4 files changed, 69 insertions(+), 61 deletions(-)

diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc 
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index 7dbc54b..27f6da4 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -48,23 +48,6 @@ using std::endl;
 
 namespace onnx_to_tensorrt {
 
-struct InferDeleter {
-  template<typename T>
-    void operator()(T* obj) const {
-      if ( obj ) {
-        obj->destroy();
-      }
-    }
-};
-
-template<typename T>
-inline std::shared_ptr<T> InferObject(T* obj) {
-  if ( !obj ) {
-    throw std::runtime_error("Failed to create object");
-  }
-  return std::shared_ptr<T>(obj, InferDeleter());
-}
-
 std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
   int onnx_ir_major = ir_version / 1000000;
   int onnx_ir_minor = ir_version % 1000000 / 10000;
@@ -83,7 +66,9 @@ void PrintVersion() {
     << NV_TENSORRT_PATCH << endl;
 }
 
-std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
+std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
+           unique_ptr<nvonnxparser::IParser>,
+           std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
         const std::string& onnx_model,
         int32_t max_batch_size,
         size_t max_workspace_size,
@@ -91,10 +76,10 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> 
onnxToTrtCtx(
         bool debug_builder) {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
-  TRT_Logger trt_logger(verbosity);
-  auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
-  auto trt_network = InferObject(trt_builder->createNetwork());
-  auto trt_parser  = nvonnxparser::createParser(trt_network.get(), trt_logger);
+  auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
+  auto trt_builder = nvinfer1::createInferBuilder(*trt_logger);
+  auto trt_network = trt_builder->createNetwork();
+  auto trt_parser  = InferObject(nvonnxparser::createParser(trt_network, 
*trt_logger));
   ::ONNX_NAMESPACE::ModelProto parsed_model;
   // We check for a valid parse, but the main effect is the side effect
   // of populating parsed_model
@@ -139,8 +124,10 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> 
onnxToTrtCtx(
   trt_builder->setMaxBatchSize(max_batch_size);
   trt_builder->setMaxWorkspaceSize(max_workspace_size);
   trt_builder->setDebugSync(debug_builder);
-  nvinfer1::ICudaEngine* trt_engine = 
trt_builder->buildCudaEngine(*trt_network.get());
-  return std::make_tuple(trt_engine, trt_parser);
+  auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
+  trt_builder->destroy();
+  trt_network->destroy();
+  return std::make_tuple(std::move(trt_engine), std::move(trt_parser), 
std::move(trt_logger));
 }
 
 }  // namespace onnx_to_tensorrt
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h 
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
index 3e8ea1b..b89422f 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
@@ -32,6 +32,7 @@
 #include <NvInfer.h>
 
 #include <fstream>
+#include <memory>
 #include <iostream>
 #include <sstream>
 #include <string>
@@ -40,33 +41,51 @@
 
 namespace onnx_to_tensorrt {
 
+struct InferDeleter {
+  template<typename T>
+    void operator()(T* obj) const {
+      if ( obj ) {
+        obj->destroy();
+      }
+    }
+};
+
+template<typename T>
+using unique_ptr = std::unique_ptr<T, InferDeleter>;
+
+template<typename T>
+inline unique_ptr<T> InferObject(T* obj) {
+  if ( !obj ) {
+    throw std::runtime_error("Failed to create object");
+  }
+  return unique_ptr<T>(obj, InferDeleter());
+}
+
 class TRT_Logger : public nvinfer1::ILogger {
-        nvinfer1::ILogger::Severity _verbosity;
-        std::ostream* _ostream;
+  nvinfer1::ILogger::Severity _verbosity;
+  std::ostream* _ostream;
  public:
-        TRT_Logger(Severity verbosity = Severity::kWARNING,
-                   std::ostream& ostream = std::cout)
-                : _verbosity(verbosity), _ostream(&ostream) {}
-        void log(Severity severity, const char* msg) override {
-                if ( severity <= _verbosity ) {
-                        time_t rawtime = std::time(0);
-                        char buf[256];
-                        strftime(&buf[0], 256,
-                                 "%Y-%m-%d %H:%M:%S",
-                                 std::gmtime(&rawtime));
-                        const char* sevstr = (severity == 
Severity::kINTERNAL_ERROR ? "    BUG" :
-                                              severity == Severity::kERROR     
     ? "  ERROR" :
-                                              severity == Severity::kWARNING   
     ? "WARNING" :
-                                              severity == Severity::kINFO      
     ? "   INFO" :
-                                              "UNKNOWN");
-                        (*_ostream) << "[" << buf << " " << sevstr << "] "
-                                    << msg
-                                    << std::endl;
-                }
-        }
+  TRT_Logger(Severity verbosity = Severity::kWARNING,
+             std::ostream& ostream = std::cout) :
+               _verbosity(verbosity), _ostream(&ostream) {}
+  void log(Severity severity, const char* msg) override {
+    if (severity <= _verbosity) {
+      time_t rawtime = std::time(0);
+      char buf[256];
+      strftime(&buf[0], 256, "%Y-%m-%d %H:%M:%S", std::gmtime(&rawtime));
+      const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? "    BUG" :
+                            severity == Severity::kERROR          ? "  ERROR" :
+                            severity == Severity::kWARNING        ? "WARNING" :
+                            severity == Severity::kINFO           ? "   INFO" :
+                            "UNKNOWN");
+      (*_ostream) << "[" << buf << " " << sevstr << "] " << msg << std::endl;
+    }
+  }
 };
 
-std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
+std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
+           unique_ptr<nvonnxparser::IParser>,
+           std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
         const std::string& onnx_model,
         int32_t max_batch_size = 32,
         size_t max_workspace_size = 1L << 30,
@@ -75,5 +94,4 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> 
onnxToTrtCtx(
 }  // namespace onnx_to_tensorrt
 
 #endif  // MXNET_USE_TENSORRT
-
 #endif  // MXNET_OPERATOR_SUBGRAPH_TENSORRT_ONNX_TO_TENSORRT_H_
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h 
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index e258d89..c175ac4 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -51,10 +51,14 @@ struct TRTParam {
 };
 
 struct TRTEngineParam {
-  TRTEngineParam(nvinfer1::ICudaEngine* trt_engine,
-                 nvonnxparser::IParser* _parser,
-                 const std::unordered_map<std::string, uint32_t> input_map,
-                 const std::unordered_map<std::string, uint32_t> output_map) {
+  TRTEngineParam(onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> 
_trt_engine,
+                 onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> 
_trt_parser,
+                 std::unique_ptr<onnx_to_tensorrt::TRT_Logger> _trt_logger,
+                 const std::unordered_map<std::string, uint32_t>& input_map,
+                 const std::unordered_map<std::string, uint32_t>& output_map) {
+    trt_engine = std::move(_trt_engine);
+    trt_logger = std::move(_trt_logger);
+    trt_parser = std::move(_trt_parser);
     binding_order = std::make_shared<std::vector<std::pair<uint32_t, bool> > 
>();
     bindings = std::make_shared<std::vector<void*> >();
     binding_order->reserve(trt_engine->getNbBindings());
@@ -67,16 +71,13 @@ struct TRTEngineParam {
         binding_order->emplace_back(output_map.at(binding_name), false);
       }
     }
-    trt_executor = trt_engine->createExecutionContext();
-    trt_parser = _parser;
+    trt_executor = 
onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext());
   }
 
-  ~TRTEngineParam() {
-    trt_parser->destroy();
-    trt_executor->destroy();
-  }
-  nvinfer1::IExecutionContext* trt_executor;
-  nvonnxparser::IParser* trt_parser;
+  onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
+  onnx_to_tensorrt::unique_ptr<nvinfer1::IExecutionContext> trt_executor;
+  onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> trt_parser;
+  std::unique_ptr<onnx_to_tensorrt::TRT_Logger> trt_logger;
   std::shared_ptr<std::vector<std::pair<uint32_t, bool> > > binding_order;
   std::shared_ptr<std::vector<void*> > bindings;
 };
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cc 
b/src/operator/subgraph/tensorrt/tensorrt.cc
index 7652510..71b2981 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/tensorrt.cc
@@ -312,7 +312,9 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, 
Context ctx,
   graph.attrs["shape"]        = std::make_shared<nnvm::any>(std::move(shapes));
   auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, 
&params_map);
   auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, 
max_batch_size, 1 << 30);
-  return OpStatePtr::Create<TRTEngineParam>(std::get<0>(trt_tuple), 
std::get<1>(trt_tuple),
+  return OpStatePtr::Create<TRTEngineParam>(std::move(std::get<0>(trt_tuple)),
+                                            std::move(std::get<1>(trt_tuple)),
+                                            std::move(std::get<2>(trt_tuple)),
                                             inputs_to_idx, outputs_to_idx);
 }
 

Reply via email to