This is an automated email from the ASF dual-hosted git repository.
kellen pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.5.x by this push:
new afe4a71 add deconv in TRT subgraph (#15666) (#16043)
afe4a71 is described below
commit afe4a713210749f8fe300f023ecc2678fbee502d
Author: Kellen Sunderland <[email protected]>
AuthorDate: Sat Aug 31 09:24:05 2019 -0700
add deconv in TRT subgraph (#15666) (#16043)
---
src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h | 19 ++++++-
src/operator/subgraph/tensorrt/nnvm_to_onnx.cc | 46 ++++++++++++-----
src/operator/subgraph/tensorrt/tensorrt-inl.h | 2 +
tests/python/tensorrt/test_tensorrt_deconv.py | 63 +++++++++++++++++++++++
4 files changed, 116 insertions(+), 14 deletions(-)
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
index 55b3d93..5a433f1 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
@@ -41,6 +41,8 @@ namespace mxnet {
namespace op {
namespace nnvm_to_onnx {
+enum ConvDeconvType {Convolution, Deconvolution};
+
using namespace nnvm;
using namespace ::onnx;
using int64 = ::google::protobuf::int64;
@@ -48,8 +50,7 @@ using int64 = ::google::protobuf::int64;
std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const
ShapeVector& shape_inputs,
const nnvm::IndexedGraph& ig);
-std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
-dtype_inputs,
+std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
dtype_inputs,
const nnvm::IndexedGraph& ig);
std::unordered_map<std::string, uint32_t> GetOutputLookup(const
nnvm::IndexedGraph& ig);
@@ -74,12 +75,25 @@ typedef void (*ConverterFunction)(NodeProto *node_proto,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry>
&inputs);
+template <class ConvDeconvParam>
+void ConvDeconvConvertHelper(NodeProto *node_proto,
+ const NodeAttrs &attrs,
+ const nnvm::IndexedGraph &ig,
+ const array_view<IndexedGraph::NodeEntry> &inputs,
+ const ConvDeconvParam& param,
+ ConvDeconvType type);
+
// Forward declarations
void ConvertConvolution(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
+void ConvertDeconvolution(NodeProto *node_proto,
+ const NodeAttrs &attrs,
+ const nnvm::IndexedGraph &ig,
+ const array_view<IndexedGraph::NodeEntry> &inputs);
+
void ConvertPooling(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
@@ -158,6 +172,7 @@ static const std::unordered_map<std::string,
ConverterFunction> converter_map =
{"BatchNorm", ConvertBatchNorm},
{"clip", ConvertClip},
{"Convolution", ConvertConvolution},
+ {"Deconvolution", ConvertDeconvolution},
{"Concat", ConvertConcatenate},
{"Dropout", ConvertDropout},
{"elemwise_add", ConvertElementwiseAdd},
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 6116f29..84580d0 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -31,6 +31,7 @@
#include <mxnet/base.h>
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>
+#include <operator/nn/deconvolution-inl.h>
#include "../../../common/utils.h"
#include "../../../ndarray/ndarray_function.h"
@@ -170,20 +171,25 @@ std::string ConvertNnvmGraphToOnnx(
return serialized_onnx_graph;
}
-void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/)
{
- const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
-
- node_proto->set_op_type("Conv");
+template <class ConvDeconvParam>
+void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& /*ig*/,
+ const array_view<IndexedGraph::NodeEntry>&
/*input*/,
+ const ConvDeconvParam& param,
+ ConvDeconvType type) {
+ if (type == ConvDeconvType::Convolution) {
+ node_proto->set_op_type("Conv");
+ } else {
+ node_proto->set_op_type("ConvTranspose");
+ }
- const mxnet::TShape kernel = conv_param.kernel;
- const mxnet::TShape stride = conv_param.stride;
- const mxnet::TShape dilate = conv_param.dilate;
- const mxnet::TShape pad = conv_param.pad;
- const uint32_t num_group = conv_param.num_group;
+ const mxnet::TShape kernel = param.kernel;
+ const mxnet::TShape stride = param.stride;
+ const mxnet::TShape dilate = param.dilate;
+ const mxnet::TShape pad = param.pad;
+ const uint32_t num_group = param.num_group;
// const bool no_bias = conv_param.no_bias;
- const dmlc::optional<int> layout = conv_param.layout;
+ const dmlc::optional<int> layout = param.layout;
// dilations
AttributeProto* const dilations = node_proto->add_attribute();
@@ -226,8 +232,24 @@ void ConvertConvolution(NodeProto* node_proto, const
NodeAttrs& attrs,
for (const dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
+ ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
+ ConvDeconvType::Convolution);
} // end ConvertConvolution
+void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
+ ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
+ ConvDeconvType::Deconvolution);
+} // end ConvertDeconvolution
+
void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index c175ac4..7b0dcc1 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -89,6 +89,7 @@ class TensorrtSelector : public SubgraphSelector {
"clip",
"Concat",
"Convolution",
+ "Deconvolution",
"Dropout",
"elemwise_add",
"elemwise_sub",
@@ -105,6 +106,7 @@ class TensorrtSelector : public SubgraphSelector {
const std::unordered_set<std::string> withWeightsOps = {
"BatchNorm",
"Convolution",
+ "Deconvolution",
"FullyConnected"
};
diff --git a/tests/python/tensorrt/test_tensorrt_deconv.py
b/tests/python/tensorrt/test_tensorrt_deconv.py
new file mode 100644
index 0000000..ef567d1
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_deconv.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import assert_almost_equal
+
+def get_params():
+ arg_params = {}
+ aux_params = {}
+ arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3))
+ arg_params["trt_bn_test_deconv_weight"] = mx.nd.ones((1, 1, 3, 3))
+ return arg_params, aux_params
+
+def get_symbol():
+ data = mx.sym.Variable("data")
+ conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True,
num_filter=1, num_group=1,
+ name="trt_bn_test_conv")
+ deconv = mx.sym.Deconvolution(data=conv, kernel=(3, 3), no_bias=True,
num_filter=1,
+ num_group=1, name="trt_bn_test_deconv")
+ return deconv
+
+def test_deconvolution_produce_same_output_as_tensorrt():
+ arg_params, aux_params = get_params()
+ arg_params_trt, aux_params_trt = get_params()
+
+ sym = get_symbol()
+ sym_trt = get_symbol().get_backend_symbol("TensorRT")
+
+ mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt,
aux_params_trt)
+
+ executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3),
grad_req='null', force_rebind=True)
+ executor.copy_params_from(arg_params, aux_params)
+
+ executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3),
grad_req='null',
+ force_rebind=True)
+ executor_trt.copy_params_from(arg_params_trt, aux_params_trt)
+
+ input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3))
+
+ y = executor.forward(is_train=False, data=input_data)
+ y_trt = executor_trt.forward(is_train=False, data=input_data)
+
+ print(y[0].asnumpy())
+ print(y_trt[0].asnumpy())
+ assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4)
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()