This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3df0917  [MXNET-703] Minor refactor of TensorRT code (#13311)
3df0917 is described below

commit 3df091705587eb83991dd9346a230085b585a2cc
Author: Kellen Sunderland <[email protected]>
AuthorDate: Fri Jan 25 16:59:21 2019 -0800

    [MXNET-703] Minor refactor of TensorRT code (#13311)
---
 src/executor/onnx_to_tensorrt.cc        |  4 ++--
 src/executor/trt_graph_executor.cc      |  7 +++----
 src/operator/contrib/nnvm_to_onnx-inl.h | 14 +++++++-------
 src/operator/contrib/nnvm_to_onnx.cc    |  4 ++--
 4 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
index c37b856..f7fbc8f 100644
--- a/src/executor/onnx_to_tensorrt.cc
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx(
   }
 
   if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
-      int nerror = trt_parser->getNbErrors();
-      for ( int i=0; i < nerror; ++i ) {
+      size_t nerror = trt_parser->getNbErrors();
+      for ( size_t i=0; i < nerror; ++i ) {
         nvonnxparser::IParserError const* error = trt_parser->getError(i);
         if ( error->node() != -1 ) {
           ::ONNX_NAMESPACE::NodeProto const& node =
diff --git a/src/executor/trt_graph_executor.cc 
b/src/executor/trt_graph_executor.cc
index 92bdcab..85ce168 100644
--- a/src/executor/trt_graph_executor.cc
+++ b/src/executor/trt_graph_executor.cc
@@ -133,7 +133,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
   }
 
   auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
-  for (auto trt_group : trt_groups) {
+  for (const auto &trt_group : trt_groups) {
     if (trt_group.size() > 1) {
       g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
       g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, 
arg_grad_ctxes,
@@ -142,7 +142,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
     }
   }
 
-
   InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
                 g.GetAttr<nnvm::DTypeVector>("dtype"),
                 g.GetAttr<StorageTypeVector>("storage_type"),
@@ -188,7 +187,7 @@ void TrtGraphExecutor::InitArguments(const 
nnvm::IndexedGraph& idx,
     const uint32_t eid = idx.entry_id(nid, 0);
     const TShape& inferred_shape = inferred_shapes[eid];
     const int inferred_dtype = inferred_dtypes[eid];
-    const NDArrayStorageType inferred_stype = (NDArrayStorageType) 
inferred_stypes[eid];
+    const auto inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
     const std::string& arg_name = idx[nid].source->attrs.name;
     // aux_states
     if (mutable_nodes.count(nid)) {
@@ -427,7 +426,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol 
symbol,
                                          std::unordered_map<std::string, 
NDArray> *shared_buffer,
                                          Executor *shared_exec) {
   auto exec = new exec::TrtGraphExecutor();
-  exec->Init(symbol, default_ctx, group2ctx,
+  exec->Init(std::move(symbol), default_ctx, group2ctx,
              in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
              arg_shape_map, arg_dtype_map, arg_stype_map,
              grad_req_types, param_names,
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h 
b/src/operator/contrib/nnvm_to_onnx-inl.h
index e0c4d93..0994f7e 100644
--- a/src/operator/contrib/nnvm_to_onnx-inl.h
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -70,7 +70,7 @@ struct ONNXParam : public dmlc::Parameter<ONNXParam> {
   nnvm_to_onnx::InferenceMap_t output_map;
   ::onnx::ModelProto onnx_pb_graph;
 
-  ONNXParam() {}
+  ONNXParam() = default;
 
   ONNXParam(const ::onnx::ModelProto& onnx_graph,
            const nnvm_to_onnx::InferenceMap_t& input_map,
@@ -104,14 +104,14 @@ std::unordered_map<std::string, uint32_t> 
GetOutputLookup(const nnvm::IndexedGra
 void ConvertPlaceholder(
   const std::string& node_name,
   const std::unordered_map<std::string, TShape>& placeholder_shapes,
-  GraphProto* const graph_proto);
+  GraphProto* graph_proto);
 
-void ConvertConstant(GraphProto* const graph_proto,
+void ConvertConstant(GraphProto* graph_proto,
   const std::string& node_name,
-  std::unordered_map<std::string, NDArray>* const shared_buffer);
+  std::unordered_map<std::string, NDArray>* shared_buffer);
 
-void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
-                   GraphProto* const graph_proto,
+void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* trt_output_map,
+                   GraphProto* graph_proto,
                    const std::unordered_map<std::string, uint32_t>::iterator& 
out_iter,
                    const std::string& node_name,
                    const nnvm::Graph& g,
@@ -169,7 +169,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,
 
 ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph &g,
-    std::unordered_map<std::string, NDArray> *const shared_buffer);
+    std::unordered_map<std::string, NDArray>* shared_buffer);
 
 static const std::unordered_map<std::string, ConverterFunction> converter_map 
= {
   {"Convolution", ConvertConvolution},
diff --git a/src/operator/contrib/nnvm_to_onnx.cc 
b/src/operator/contrib/nnvm_to_onnx.cc
index ccb6e04..58a4654 100644
--- a/src/operator/contrib/nnvm_to_onnx.cc
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -263,7 +263,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& 
attrs,
   AttributeProto* const kernel_shape = node_proto->add_attribute();
   kernel_shape->set_name("kernel_shape");
   kernel_shape->set_type(AttributeProto::INTS);
-  for (int kval : kernel) {
+  for (dim_t kval : kernel) {
     kernel_shape->add_ints(static_cast<int64>(kval));
   }
 
@@ -283,7 +283,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& 
attrs,
   AttributeProto* const strides = node_proto->add_attribute();
   strides->set_name("strides");
   strides->set_type(AttributeProto::INTS);
-  for (int kval : stride) {
+  for (dim_t kval : stride) {
     strides->add_ints(static_cast<int64>(kval));
   }
 

Reply via email to