This is an automated email from the ASF dual-hosted git repository.

kellen pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new afe4a71  add deconv in TRT subgraph (#15666) (#16043)
afe4a71 is described below

commit afe4a713210749f8fe300f023ecc2678fbee502d
Author: Kellen Sunderland <[email protected]>
AuthorDate: Sat Aug 31 09:24:05 2019 -0700

    add deconv in TRT subgraph (#15666) (#16043)
---
 src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h | 19 ++++++-
 src/operator/subgraph/tensorrt/nnvm_to_onnx.cc    | 46 ++++++++++++-----
 src/operator/subgraph/tensorrt/tensorrt-inl.h     |  2 +
 tests/python/tensorrt/test_tensorrt_deconv.py     | 63 +++++++++++++++++++++++
 4 files changed, 116 insertions(+), 14 deletions(-)

diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h 
b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
index 55b3d93..5a433f1 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
@@ -41,6 +41,8 @@ namespace mxnet {
 namespace op {
 namespace nnvm_to_onnx {
 
+enum ConvDeconvType {Convolution, Deconvolution};
+
 using namespace nnvm;
 using namespace ::onnx;
 using int64 = ::google::protobuf::int64;
@@ -48,8 +50,7 @@ using int64 = ::google::protobuf::int64;
 std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const 
ShapeVector& shape_inputs,
     const nnvm::IndexedGraph& ig);
 
-std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
-dtype_inputs,
+std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& 
dtype_inputs,
     const nnvm::IndexedGraph& ig);
 
 std::unordered_map<std::string, uint32_t> GetOutputLookup(const 
nnvm::IndexedGraph& ig);
@@ -74,12 +75,25 @@ typedef void (*ConverterFunction)(NodeProto *node_proto,
                                   const nnvm::IndexedGraph &ig,
                                   const array_view<IndexedGraph::NodeEntry> 
&inputs);
 
+template <class ConvDeconvParam>
+void ConvDeconvConvertHelper(NodeProto *node_proto,
+                             const NodeAttrs &attrs,
+                             const nnvm::IndexedGraph &ig,
+                             const array_view<IndexedGraph::NodeEntry> &inputs,
+                             const ConvDeconvParam& param,
+                             ConvDeconvType type);
+
 // Forward declarations
 void ConvertConvolution(NodeProto *node_proto,
                         const NodeAttrs &attrs,
                         const nnvm::IndexedGraph &ig,
                         const array_view<IndexedGraph::NodeEntry> &inputs);
 
+void ConvertDeconvolution(NodeProto *node_proto,
+                        const NodeAttrs &attrs,
+                        const nnvm::IndexedGraph &ig,
+                        const array_view<IndexedGraph::NodeEntry> &inputs);
+
 void ConvertPooling(NodeProto *node_proto,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
@@ -158,6 +172,7 @@ static const std::unordered_map<std::string, 
ConverterFunction> converter_map =
   {"BatchNorm", ConvertBatchNorm},
   {"clip", ConvertClip},
   {"Convolution", ConvertConvolution},
+  {"Deconvolution", ConvertDeconvolution},
   {"Concat", ConvertConcatenate},
   {"Dropout", ConvertDropout},
   {"elemwise_add", ConvertElementwiseAdd},
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc 
b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 6116f29..84580d0 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -31,6 +31,7 @@
 #include <mxnet/base.h>
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
+#include <operator/nn/deconvolution-inl.h>
 
 #include "../../../common/utils.h"
 #include "../../../ndarray/ndarray_function.h"
@@ -170,20 +171,25 @@ std::string ConvertNnvmGraphToOnnx(
   return serialized_onnx_graph;
 }
 
-void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
-                        const nnvm::IndexedGraph& /*ig*/,
-                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) 
{
-  const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
-
-  node_proto->set_op_type("Conv");
+template <class ConvDeconvParam>
+void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
+                             const nnvm::IndexedGraph& /*ig*/,
+                             const array_view<IndexedGraph::NodeEntry>& 
/*input*/,
+                             const ConvDeconvParam& param,
+                             ConvDeconvType type) {
+  if (type == ConvDeconvType::Convolution) {
+    node_proto->set_op_type("Conv");
+  } else {
+    node_proto->set_op_type("ConvTranspose");
+  }
 
-  const mxnet::TShape kernel = conv_param.kernel;
-  const mxnet::TShape stride = conv_param.stride;
-  const mxnet::TShape dilate = conv_param.dilate;
-  const mxnet::TShape pad = conv_param.pad;
-  const uint32_t num_group = conv_param.num_group;
+  const mxnet::TShape kernel = param.kernel;
+  const mxnet::TShape stride = param.stride;
+  const mxnet::TShape dilate = param.dilate;
+  const mxnet::TShape pad = param.pad;
+  const uint32_t num_group = param.num_group;
   // const bool no_bias = conv_param.no_bias;
-  const dmlc::optional<int> layout = conv_param.layout;
+  const dmlc::optional<int> layout = param.layout;
 
   // dilations
   AttributeProto* const dilations = node_proto->add_attribute();
@@ -226,8 +232,24 @@ void ConvertConvolution(NodeProto* node_proto, const 
NodeAttrs& attrs,
   for (const dim_t kval : stride) {
     strides->add_ints(static_cast<int64>(kval));
   }
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& ig,
+                        const array_view<IndexedGraph::NodeEntry>& inputs) {
+  const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
+  ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
+      ConvDeconvType::Convolution);
 }  // end ConvertConvolution
 
+void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                          const nnvm::IndexedGraph& ig,
+                          const array_view<IndexedGraph::NodeEntry>& inputs) {
+  const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
+  ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
+      ConvDeconvType::Deconvolution);
+}  // end ConvertDeconvolution
+
 void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
                     const nnvm::IndexedGraph& /*ig*/,
                     const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h 
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index c175ac4..7b0dcc1 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -89,6 +89,7 @@ class TensorrtSelector : public SubgraphSelector {
     "clip",
     "Concat",
     "Convolution",
+    "Deconvolution",
     "Dropout",
     "elemwise_add",
     "elemwise_sub",
@@ -105,6 +106,7 @@ class TensorrtSelector : public SubgraphSelector {
   const std::unordered_set<std::string> withWeightsOps = {
     "BatchNorm",
     "Convolution",
+    "Deconvolution",
     "FullyConnected"
   };
 
diff --git a/tests/python/tensorrt/test_tensorrt_deconv.py 
b/tests/python/tensorrt/test_tensorrt_deconv.py
new file mode 100644
index 0000000..ef567d1
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_deconv.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import assert_almost_equal
+
+def get_params():
+    arg_params = {}
+    aux_params = {}
+    arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3))
+    arg_params["trt_bn_test_deconv_weight"] = mx.nd.ones((1, 1, 3, 3))
+    return arg_params, aux_params
+
+def get_symbol():
+    data = mx.sym.Variable("data")
+    conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, 
num_filter=1, num_group=1,
+                              name="trt_bn_test_conv")
+    deconv = mx.sym.Deconvolution(data=conv, kernel=(3, 3), no_bias=True, 
num_filter=1,
+                                  num_group=1, name="trt_bn_test_deconv")
+    return deconv
+
+def test_deconvolution_produce_same_output_as_tensorrt():
+    arg_params, aux_params = get_params()
+    arg_params_trt, aux_params_trt = get_params()
+
+    sym = get_symbol()
+    sym_trt = get_symbol().get_backend_symbol("TensorRT")
+
+    mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, 
aux_params_trt)
+
+    executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), 
grad_req='null', force_rebind=True)
+    executor.copy_params_from(arg_params, aux_params)
+
+    executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), 
grad_req='null',
+                                  force_rebind=True)
+    executor_trt.copy_params_from(arg_params_trt, aux_params_trt)
+
+    input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3))
+
+    y = executor.forward(is_train=False, data=input_data)
+    y_trt = executor_trt.forward(is_train=False, data=input_data)
+
+    print(y[0].asnumpy())
+    print(y_trt[0].asnumpy())
+    assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

Reply via email to