This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new bf2d5b1 Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM
(#19615)
bf2d5b1 is described below
commit bf2d5b1d1635963e2072aedcfeccbec558e381ee
Author: Serge Panev <[email protected]>
AuthorDate: Mon Jan 4 17:27:36 2021 -0800
Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM (#19615)
Signed-off-by: Serge Panev <[email protected]>
---
src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h | 72 +++--
src/operator/subgraph/tensorrt/nnvm_to_onnx.cc | 378 +++++++++++++++++-----
2 files changed, 348 insertions(+), 102 deletions(-)
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
index 2d13357..721eb67 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
@@ -47,7 +47,8 @@ using namespace nnvm;
using namespace ::onnx;
using int64 = ::google::protobuf::int64;
-std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const
ShapeVector& shape_inputs,
+std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(
+ const ShapeVector& shape_inputs,
const nnvm::IndexedGraph& ig);
std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
dtype_inputs,
@@ -70,7 +71,12 @@ void ConvertOutput(GraphProto* graph_proto,
const std::string& node_name, const ShapeVector& shapes,
const DTypeVector& dtypes, const nnvm::IndexedGraph &ig);
-typedef void (*ConverterFunction)(NodeProto *node_proto,
+void DefaultConnectInputsOutputs(const array_view<IndexedGraph::NodeEntry>&
inputs,
+ const nnvm::IndexedGraph& ig,
+ const std::string& node_name);
+
+typedef void (*ConverterFunction)(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry>
&inputs);
@@ -84,83 +90,112 @@ void ConvDeconvConvertHelper(NodeProto *node_proto,
ConvDeconvType type);
// Forward declarations
-void ConvertIdentity(NodeProto* node_proto,
+void ConvertIdentity(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph& ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
void ConvertConvolution(
- NodeProto *node_proto,
+ GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertDeconvolution(NodeProto *node_proto,
+void ConvertDeconvolution(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertPooling(NodeProto *node_proto,
+void ConvertPooling(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertRelu(NodeProto *node_proto,
+void ConvertRelu(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertActivation(NodeProto *node_proto,
+void ConvertActivation(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertFullyConnected(NodeProto *node_proto,
+void ConvertFullyConnected(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertFlatten(NodeProto *node_proto,
+
+void ConvertSlice(GraphProto *graph_proto,
+ const std::string& node_name,
+ const NodeAttrs &attrs,
+ const nnvm::IndexedGraph &ig,
+ const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(GraphProto *graph_proto,
+ const std::string& node_name,
+ const NodeAttrs &attrs,
+ const nnvm::IndexedGraph &ig,
+ const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertDropout(NodeProto *node_proto,
+void ConvertDropout(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertBatchNorm(NodeProto *node_proto,
+void ConvertBatchNorm(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertElementwiseAdd(NodeProto *node_proto,
+void ConvertElementwiseAdd(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertElementwiseMul(NodeProto *node_proto,
+void ConvertElementwiseMul(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertElementwiseSub(NodeProto *node_proto,
+void ConvertElementwiseSub(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertConcatenate(NodeProto *node_proto,
+void ConvertConcatenate(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertClip(NodeProto *node_proto,
+void ConvertClip(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
-void ConvertPad(NodeProto* node_proto,
+void ConvertPad(GraphProto *graph_proto,
+ const std::string& node_name,
const NodeAttrs & attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);
@@ -185,6 +220,7 @@ static const std::unordered_map<std::string,
ConverterFunction> converter_map =
{"Pad", ConvertPad},
{"Pooling", ConvertPooling},
{"relu", ConvertRelu},
+ {"slice", ConvertSlice}
};
typedef void (*PreprocessFunction)(const NodeAttrs &attrs,
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 3c03126..cd83cb7 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -129,8 +129,6 @@ std::string ConvertNnvmGraphToOnnx(
} // is_placeholder
} else {
// It's an op, rather than a "variable" (constant or placeholder)
- NodeProto* node_proto = graph_proto->add_node();
- node_proto->set_name(node_name);
if (converter_map.count(op->name) == 0) {
LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
<< node_name << ") "
@@ -139,19 +137,7 @@ std::string ConvertNnvmGraphToOnnx(
// Find function ptr to a converter based on the op name, and invoke the
converter. This
// looks unsafe because find may not succeed, but it does because we're
in the operator
// logic after testing that this node name does not represent a variable.
- converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
- // Add all inputs to the current node (i.e. add graph edges)
- for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
- std::string in_node_name = ig[entry.node_id].source->attrs.name;
- // As before, we're not adding labels e.g. for SoftmaxOutput, but I
wish there was a less
- // hacky way to do it than name matching.
- if (in_node_name.find("label") != std::string::npos) {
- continue;
- }
- node_proto->add_input(in_node_name);
- }
- // The node's output will have the same name as the node name.
- node_proto->add_output(node_name);
+ converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig,
node.inputs);
// See if the current node is an output node
auto out_iter = output_lookup.find(node_name);
// We found an output
@@ -170,16 +156,113 @@ std::string ConvertNnvmGraphToOnnx(
return serialized_onnx_graph;
}
-void ConvertIdentity(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void DefaultConnectInputsOutputs(NodeProto *node_proto,
+ const array_view<IndexedGraph::NodeEntry>&
inputs,
+ const nnvm::IndexedGraph& ig,
+ const std::string& node_name) {
+ for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) {
+ std::string in_node_name = ig[entry.node_id].source->attrs.name;
+ // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish
there was a less
+ // hacky way to do it than name matching.
+ if (in_node_name.find("label") != std::string::npos) {
+ continue;
+ }
+ node_proto->add_input(in_node_name);
+ }
+ // The node's output will have the same name as the node name.
+ node_proto->add_output(node_name);
+}
+
+TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t&
size,
+ const std::string& name, const
TensorProto_DataType& dtype) {
+ TensorProto* const initializer_proto = graph_proto->add_initializer();
+ initializer_proto->set_name(name);
+ initializer_proto->set_data_type(dtype);
+ initializer_proto->add_dims(static_cast<int64>(size));
+
+ ValueInfoProto* const input_proto = graph_proto->add_input();
+ input_proto->set_name(name);
+ auto var = input_proto->mutable_type()->mutable_tensor_type();
+ var->set_elem_type(dtype);
+ var->mutable_shape()->add_dim()->set_dim_value(static_cast<int64>(size));
+ return initializer_proto;
+}
+
+// Keep for when ONNX version will be updated
+/*
+void ConvertSlice(GraphProto* const graph_proto, const Node* node, const
Graph& g) {
+ const auto& params = nnvm::get<SliceParam>(node->attrs.parsed);
+ int64 nb_slices = static_cast<int64>(params.begin.ndim());
+
+ // starts
+ auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name +
"_starts",
+ TensorProto_DataType_INT64);
+ for (auto& opt : params.begin) {
+ if (opt.has_value()) {
+ init_starts->add_int64_data(static_cast<int64>(opt.value()));
+ } else {
+ init_starts->add_int64_data(static_cast<int64>(0));
+ }
+ }
+
+ // ends
+ auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name +
"_ends",
+ TensorProto_DataType_INT64);
+ for (auto& opt : params.end) {
+ if (opt.has_value()) {
+ init_ends->add_int64_data(static_cast<int64>(opt.value()));
+ } else {
+ init_ends->add_int64_data(static_cast<int64>(INT_MAX));
+ }
+ }
+
+ // axes
+ auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name +
"_axes",
+ TensorProto_DataType_INT64);
+ for (int64_t i = 0; i < nb_slices; ++i) {
+ init_axes->add_int64_data(static_cast<int64>(i));
+ }
+
+ // slice node
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node->attrs.name);
+ node_proto->set_op_type("Slice");
+ node_proto->add_input(node->inputs[0].node->attrs.name);
+ node_proto->add_input(node->attrs.name + "_starts");
+ node_proto->add_input(node->attrs.name + "_ends");
+ node_proto->add_input(node->attrs.name + "_axes");
+
+ // steps
+ if (params.step.ndim() != 0) {
+ auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name +
"_steps",
+ TensorProto_DataType_INT64);
+ for (auto& opt : params.step) {
+ if (opt.has_value()) {
+ init_steps->add_int64_data(static_cast<int64>(opt.value()));
+ } else {
+ init_steps->add_int64_data(static_cast<int64>(1));
+ }
+ }
+ node_proto->add_input(node->attrs.name + "_steps");
+ }
+
+ node_proto->add_output(node->attrs.name);
+}
+*/
+
+void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name,
const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Identity");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
template <class ConvDeconvParam>
-void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>&
/*input*/,
+void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs,
const ConvDeconvParam& param,
ConvDeconvType type) {
if (type == ConvDeconvType::Convolution) {
@@ -239,25 +322,36 @@ void ConvDeconvConvertHelper(NodeProto* node_proto, const
NodeAttrs& attrs,
}
}
-void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& attrs,
const nnvm::IndexedGraph& ig,
const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
ConvDeconvType::Convolution);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
} // end ConvertConvolution
-void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertDeconvolution(GraphProto *graph_proto, const std::string&
node_name,
+ const NodeAttrs& attrs,
const nnvm::IndexedGraph& ig,
const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
ConvDeconvType::Deconvolution);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
} // end ConvertDeconvolution
-void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPooling(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
const mxnet::TShape kernel = pooling_param.kernel;
@@ -274,6 +368,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs&
attrs,
} else {
LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: "
<< attrs.name;
}
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
return;
}
@@ -328,17 +423,24 @@ void ConvertPooling(NodeProto* node_proto, const
NodeAttrs& attrs,
} else {
count_include_pad->set_i(1);
}
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
} // end ConvertPooling
-void ConvertRelu(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const
NodeAttrs& /*attrs*/,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Relu");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertActivation(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
std::string act_type;
switch (act_param.act_type) {
@@ -360,42 +462,120 @@ void ConvertActivation(NodeProto* node_proto, const
NodeAttrs& attrs,
}
node_proto->set_op_type(act_type);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>&
/*inputs*/) {
+void ConvertFullyConnected(GraphProto *graph_proto, const std::string&
node_name,
+ const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
- if (act_param.no_bias) {
- node_proto->set_op_type("MatMul");
+ // ONNX spec doesn't support GEMMs with input of different dims, so we need
to replace it
+ // by Transpose+MatMul+Add
+ if (!act_param.flatten && !act_param.no_bias) {
+ NodeProto* tranpose_node_proto = graph_proto->add_node();
+ NodeProto* matmul_node_proto = graph_proto->add_node();
+ NodeProto* add_node_proto = graph_proto->add_node();
+ tranpose_node_proto->set_name(node_name+"_Transpose");
+ matmul_node_proto->set_name(node_name+"_MatMul");
+ add_node_proto->set_name(node_name+"_Add");
+
+ tranpose_node_proto->set_op_type("Transpose");
+ matmul_node_proto->set_op_type("MatMul");
+ add_node_proto->set_op_type("Add");
+
+ std::string input_node_name =
ig[inputs[op::conv::kData].node_id].source->attrs.name;
+ std::string weight_node_name =
ig[inputs[op::conv::kWeight].node_id].source->attrs.name;
+ std::string bias_node_name =
ig[inputs[op::conv::kBias].node_id].source->attrs.name;
+
+ tranpose_node_proto->add_input(weight_node_name);
+ tranpose_node_proto->add_output(node_name+"_Transpose");
+
+ matmul_node_proto->add_input(input_node_name);
+ matmul_node_proto->add_input(node_name+"_Transpose");
+ matmul_node_proto->add_output(node_name+"_MatMul");
+
+ add_node_proto->add_input(node_name+"_MatMul");
+ add_node_proto->add_input(bias_node_name);
+ // Add's output is the output of the Transpose+MatMul+Add subgraph
+ add_node_proto->add_output(node_name);
} else {
- node_proto->set_op_type("Gemm");
-
- AttributeProto* const alpha = node_proto->add_attribute();
- alpha->set_name("alpha");
- alpha->set_type(AttributeProto::FLOAT);
- alpha->set_f(1.0f);
-
- AttributeProto* const beta = node_proto->add_attribute();
- beta->set_name("beta");
- beta->set_type(AttributeProto::FLOAT);
- beta->set_f(1.0f);
-
- AttributeProto* const transA = node_proto->add_attribute();
- transA->set_name("transA");
- transA->set_type(AttributeProto::INT);
- transA->set_i(0);
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
+ if (act_param.no_bias) {
+ node_proto->set_op_type("MatMul");
+ } else {
+ node_proto->set_op_type("Gemm");
+
+ AttributeProto* const alpha = node_proto->add_attribute();
+ alpha->set_name("alpha");
+ alpha->set_type(AttributeProto::FLOAT);
+ alpha->set_f(1.0f);
+
+ AttributeProto* const beta = node_proto->add_attribute();
+ beta->set_name("beta");
+ beta->set_type(AttributeProto::FLOAT);
+ beta->set_f(1.0f);
+
+ AttributeProto* const transA = node_proto->add_attribute();
+ transA->set_name("transA");
+ transA->set_type(AttributeProto::INT);
+ transA->set_i(0);
+
+ AttributeProto* const transB = node_proto->add_attribute();
+ transB->set_name("transB");
+ transB->set_type(AttributeProto::INT);
+ transB->set_i(1);
+ }
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
+ }
+}
- AttributeProto* const transB = node_proto->add_attribute();
- transB->set_name("transB");
- transB->set_type(AttributeProto::INT);
- transB->set_i(1);
+void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const
NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
+ const auto& params = nnvm::get<SliceParam>(attrs.parsed);
+ node_proto->set_op_type("Slice");
+
+ // starts
+ AttributeProto* const starts = node_proto->add_attribute();
+ starts->set_name("starts");
+ starts->set_type(AttributeProto::INTS);
+
+ // ends
+ AttributeProto* const ends = node_proto->add_attribute();
+ ends->set_name("ends");
+ ends->set_type(AttributeProto::INTS);
+
+ // axes
+ AttributeProto* const axes = node_proto->add_attribute();
+ axes->set_name("axes");
+ axes->set_type(AttributeProto::INTS);
+
+ for (int64_t i = 1; i < params.begin.ndim(); ++i) {
+ if (params.begin[i].has_value()) {
+ starts->add_ints(static_cast<int64>(params.begin[i].value()));
+ } else {
+ starts->add_ints(static_cast<int64>(0));
+ }
+ if (params.end[i].has_value()) {
+ ends->add_ints(static_cast<int64>(params.end[i].value()));
+ } else {
+ ends->add_ints(static_cast<int64>(INT_MAX));
+ }
+ axes->add_ints(static_cast<int64>(i));
}
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& /*attrs*/,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Flatten");
// Setting by default to 1 since MXNet doesn't provide such an attribute for
Flatten in its
@@ -405,11 +585,15 @@ void ConvertFlatten(NodeProto* node_proto, const
NodeAttrs& /*attrs*/,
axis->set_name("axis");
axis->set_type(AttributeProto::INT);
axis->set_i(1);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("BatchNormalization");
const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
@@ -430,29 +614,45 @@ void ConvertBatchNorm(NodeProto* node_proto, const
NodeAttrs& attrs,
// (default in ONNX3) implies running batchnorm on all spatial features so
we need to explicitly
// disable this for MXNet's BatchNorm.
spatial->set_i(0);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>&
/*inputs*/) {
+void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string&
node_name,
+ const NodeAttrs& /*attrs*/,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Add");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertElementwiseSub(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>&
/*inputs*/) {
+void ConvertElementwiseSub(GraphProto *graph_proto, const std::string&
node_name,
+ const NodeAttrs& /*attrs*/,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Sub");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertElementwiseMul(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>&
/*inputs*/) {
+void ConvertElementwiseMul(GraphProto *graph_proto, const std::string&
node_name,
+ const NodeAttrs& /*attrs*/,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Mul");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/)
{
+void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name,
+ const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
node_proto->set_op_type("Concat");
node_proto->set_name(attrs.name);
@@ -461,6 +661,7 @@ void ConvertConcatenate(NodeProto* node_proto, const
NodeAttrs& attrs,
axis->set_name("axis");
axis->set_type(AttributeProto::INT);
axis->set_i(static_cast<int64_t>(_param.dim));
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
inline TensorProto_DataType ConvertDType(int dtype) {
@@ -615,9 +816,11 @@ void ConvertOutput(
}
}
-void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const
NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& param = nnvm::get<ClipParam>(attrs.parsed);
node_proto->set_op_type("Clip");
@@ -633,11 +836,14 @@ void ConvertClip(NodeProto* node_proto, const NodeAttrs&
attrs,
a_min->set_name("min");
a_min->set_type(AttributeProto::FLOAT);
a_min->set_f(static_cast<float>(param.a_min));
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const
NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
const auto& param = nnvm::get<PadParam>(attrs.parsed);
node_proto->set_op_type("Pad");
@@ -679,12 +885,16 @@ void ConvertPad(NodeProto* node_proto, const NodeAttrs&
attrs,
value->set_name("value");
value->set_type(AttributeProto::FLOAT);
value->set_f(param.constant_value);
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
-void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs,
- const nnvm::IndexedGraph& /*ig*/,
- const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertDropout(GraphProto *graph_proto, const std::string& node_name,
const NodeAttrs& attrs,
+ const nnvm::IndexedGraph& ig,
+ const array_view<IndexedGraph::NodeEntry>& inputs) {
+ NodeProto* node_proto = graph_proto->add_node();
+ node_proto->set_name(node_name);
node_proto->set_op_type("Dropout");
+ DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
}
void PreprocessBatchNorm(const NodeAttrs &attrs,