This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new bf2d5b1  Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM 
(#19615)
bf2d5b1 is described below

commit bf2d5b1d1635963e2072aedcfeccbec558e381ee
Author: Serge Panev <[email protected]>
AuthorDate: Mon Jan 4 17:27:36 2021 -0800

    Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM (#19615)
    
    Signed-off-by: Serge Panev <[email protected]>
---
 src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h |  72 +++--
 src/operator/subgraph/tensorrt/nnvm_to_onnx.cc    | 378 +++++++++++++++++-----
 2 files changed, 348 insertions(+), 102 deletions(-)

diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h 
b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
index 2d13357..721eb67 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
@@ -47,7 +47,8 @@ using namespace nnvm;
 using namespace ::onnx;
 using int64 = ::google::protobuf::int64;
 
-std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const 
ShapeVector& shape_inputs,
+std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs,
     const nnvm::IndexedGraph& ig);
 
 std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& 
dtype_inputs,
@@ -70,7 +71,12 @@ void ConvertOutput(GraphProto* graph_proto,
                    const std::string& node_name, const ShapeVector& shapes,
                    const DTypeVector& dtypes, const nnvm::IndexedGraph &ig);
 
-typedef void (*ConverterFunction)(NodeProto *node_proto,
+void DefaultConnectInputsOutputs(const array_view<IndexedGraph::NodeEntry>& 
inputs,
+                                 const nnvm::IndexedGraph& ig,
+                                 const std::string& node_name);
+
+typedef void (*ConverterFunction)(GraphProto *graph_proto,
+                                  const std::string& node_name,
                                   const NodeAttrs &attrs,
                                   const nnvm::IndexedGraph &ig,
                                   const array_view<IndexedGraph::NodeEntry> 
&inputs);
@@ -84,83 +90,112 @@ void ConvDeconvConvertHelper(NodeProto *node_proto,
                              ConvDeconvType type);
 
 // Forward declarations
-void ConvertIdentity(NodeProto* node_proto,
+void ConvertIdentity(GraphProto *graph_proto,
+                     const std::string& node_name,
                      const NodeAttrs &attrs,
                      const nnvm::IndexedGraph& ig,
                      const array_view<IndexedGraph::NodeEntry> &inputs);
 
 void ConvertConvolution(
-                        NodeProto *node_proto,
+                        GraphProto *graph_proto,
+                        const std::string& node_name,
                         const NodeAttrs &attrs,
                         const nnvm::IndexedGraph &ig,
                         const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertDeconvolution(NodeProto *node_proto,
+void ConvertDeconvolution(GraphProto *graph_proto,
+                        const std::string& node_name,
                         const NodeAttrs &attrs,
                         const nnvm::IndexedGraph &ig,
                         const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertPooling(NodeProto *node_proto,
+void ConvertPooling(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertRelu(NodeProto *node_proto,
+void ConvertRelu(GraphProto *graph_proto,
+                 const std::string& node_name,
                  const NodeAttrs &attrs,
                  const nnvm::IndexedGraph &ig,
                  const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertActivation(NodeProto *node_proto,
+void ConvertActivation(GraphProto *graph_proto,
+                       const std::string& node_name,
                        const NodeAttrs &attrs,
                        const nnvm::IndexedGraph &ig,
                        const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertFullyConnected(NodeProto *node_proto,
+void ConvertFullyConnected(GraphProto *graph_proto,
+                           const std::string& node_name,
                            const NodeAttrs &attrs,
                            const nnvm::IndexedGraph &ig,
                            const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertFlatten(NodeProto *node_proto,
+
+void ConvertSlice(GraphProto *graph_proto,
+                  const std::string& node_name,
+                  const NodeAttrs &attrs,
+                  const nnvm::IndexedGraph &ig,
+                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(GraphProto *graph_proto,
+                          const std::string& node_name,
+                          const NodeAttrs &attrs,
+                          const nnvm::IndexedGraph &ig,
+                          const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertDropout(NodeProto *node_proto,
+void ConvertDropout(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertBatchNorm(NodeProto *node_proto,
+void ConvertBatchNorm(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseAdd(NodeProto *node_proto,
+void ConvertElementwiseAdd(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseMul(NodeProto *node_proto,
+void ConvertElementwiseMul(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseSub(NodeProto *node_proto,
+void ConvertElementwiseSub(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertConcatenate(NodeProto *node_proto,
+void ConvertConcatenate(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertClip(NodeProto *node_proto,
+void ConvertClip(GraphProto *graph_proto,
+                 const std::string& node_name,
                  const NodeAttrs &attrs,
                  const nnvm::IndexedGraph &ig,
                  const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertPad(NodeProto* node_proto,
+void ConvertPad(GraphProto *graph_proto,
+                const std::string& node_name,
                 const NodeAttrs & attrs,
                 const nnvm::IndexedGraph &ig,
                 const array_view<IndexedGraph::NodeEntry> &inputs);
@@ -185,6 +220,7 @@ static const std::unordered_map<std::string, 
ConverterFunction> converter_map =
   {"Pad", ConvertPad},
   {"Pooling", ConvertPooling},
   {"relu", ConvertRelu},
+  {"slice", ConvertSlice}
 };
 
 typedef void (*PreprocessFunction)(const NodeAttrs &attrs,
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc 
b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 3c03126..cd83cb7 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -129,8 +129,6 @@ std::string ConvertNnvmGraphToOnnx(
       }  // is_placeholder
     } else {
       // It's an op, rather than a "variable" (constant or placeholder)
-      NodeProto* node_proto = graph_proto->add_node();
-      node_proto->set_name(node_name);
       if (converter_map.count(op->name) == 0) {
         LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
                    << node_name << ") "
@@ -139,19 +137,7 @@ std::string ConvertNnvmGraphToOnnx(
       // Find function ptr to a converter based on the op name, and invoke the 
converter. This
       // looks unsafe because find may not succeed, but it does because we're 
in the operator
       // logic after testing that this node name does not represent a variable.
-      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
-      // Add all inputs to the current node (i.e. add graph edges)
-      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
-        std::string in_node_name = ig[entry.node_id].source->attrs.name;
-        // As before, we're not adding labels e.g. for SoftmaxOutput, but I 
wish there was a less
-        // hacky way to do it than name matching.
-        if (in_node_name.find("label") != std::string::npos) {
-          continue;
-        }
-        node_proto->add_input(in_node_name);
-      }
-      // The node's output will have the same name as the node name.
-      node_proto->add_output(node_name);
+      converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig, 
node.inputs);
       // See if the current node is an output node
       auto out_iter = output_lookup.find(node_name);
       // We found an output
@@ -170,16 +156,113 @@ std::string ConvertNnvmGraphToOnnx(
   return serialized_onnx_graph;
 }
 
-void ConvertIdentity(NodeProto* node_proto, const NodeAttrs& attrs,
-                     const nnvm::IndexedGraph& /*ig*/,
-                     const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void DefaultConnectInputsOutputs(NodeProto *node_proto,
+                                 const array_view<IndexedGraph::NodeEntry>& 
inputs,
+                                 const nnvm::IndexedGraph& ig,
+                                 const std::string& node_name) {
+  for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) {
+    std::string in_node_name = ig[entry.node_id].source->attrs.name;
+    // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish 
there was a less
+    // hacky way to do it than name matching.
+    if (in_node_name.find("label") != std::string::npos) {
+      continue;
+    }
+    node_proto->add_input(in_node_name);
+  }
+  // The node's output will have the same name as the node name.
+  node_proto->add_output(node_name);
+}
+
+TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t& 
size,
+                                const std::string& name, const 
TensorProto_DataType& dtype) {
+  TensorProto* const initializer_proto = graph_proto->add_initializer();
+  initializer_proto->set_name(name);
+  initializer_proto->set_data_type(dtype);
+  initializer_proto->add_dims(static_cast<int64>(size));
+
+  ValueInfoProto* const input_proto = graph_proto->add_input();
+  input_proto->set_name(name);
+  auto var = input_proto->mutable_type()->mutable_tensor_type();
+  var->set_elem_type(dtype);
+  var->mutable_shape()->add_dim()->set_dim_value(static_cast<int64>(size));
+  return initializer_proto;
+}
+
+// Keep for when ONNX version will be updated
+/*
+void ConvertSlice(GraphProto* const graph_proto, const Node* node, const 
Graph& g) {
+  const auto& params = nnvm::get<SliceParam>(node->attrs.parsed);
+  int64 nb_slices = static_cast<int64>(params.begin.ndim());
+
+  // starts
+  auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name + 
"_starts",
+                                  TensorProto_DataType_INT64);
+  for (auto& opt : params.begin) {
+    if (opt.has_value()) {
+      init_starts->add_int64_data(static_cast<int64>(opt.value()));
+    } else {
+      init_starts->add_int64_data(static_cast<int64>(0));
+    }
+  }
+
+  // ends
+  auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name + 
"_ends",
+                                TensorProto_DataType_INT64);
+  for (auto& opt : params.end) {
+    if (opt.has_value()) {
+      init_ends->add_int64_data(static_cast<int64>(opt.value()));
+    } else {
+      init_ends->add_int64_data(static_cast<int64>(INT_MAX));
+    }
+  }
+
+  // axes
+  auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name + 
"_axes",
+                                TensorProto_DataType_INT64);
+  for (int64_t i = 0; i < nb_slices; ++i) {
+    init_axes->add_int64_data(static_cast<int64>(i));
+  }
+
+  // slice node
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node->attrs.name);
+  node_proto->set_op_type("Slice");
+  node_proto->add_input(node->inputs[0].node->attrs.name);
+  node_proto->add_input(node->attrs.name + "_starts");
+  node_proto->add_input(node->attrs.name + "_ends");
+  node_proto->add_input(node->attrs.name + "_axes");
+
+  // steps
+  if (params.step.ndim() != 0) {
+    auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name + 
"_steps",
+                                   TensorProto_DataType_INT64);
+    for (auto& opt : params.step) {
+      if (opt.has_value()) {
+        init_steps->add_int64_data(static_cast<int64>(opt.value()));
+      } else {
+        init_steps->add_int64_data(static_cast<int64>(1));
+      }
+    }
+    node_proto->add_input(node->attrs.name + "_steps");
+  }
+
+  node_proto->add_output(node->attrs.name);
+}
+*/
+
+void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name, 
const NodeAttrs& attrs,
+                     const nnvm::IndexedGraph& ig,
+                     const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Identity");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 template <class ConvDeconvParam>
-void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
-                             const nnvm::IndexedGraph& /*ig*/,
-                             const array_view<IndexedGraph::NodeEntry>& 
/*input*/,
+void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs,
+                             const nnvm::IndexedGraph& ig,
+                             const array_view<IndexedGraph::NodeEntry>& inputs,
                              const ConvDeconvParam& param,
                              ConvDeconvType type) {
   if (type == ConvDeconvType::Convolution) {
@@ -239,25 +322,36 @@ void ConvDeconvConvertHelper(NodeProto* node_proto, const 
NodeAttrs& attrs,
   }
 }
 
-void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name,
+                        const NodeAttrs& attrs,
                         const nnvm::IndexedGraph& ig,
                         const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
       ConvDeconvType::Convolution);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertConvolution
 
-void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertDeconvolution(GraphProto *graph_proto, const std::string& 
node_name,
+                          const NodeAttrs& attrs,
                           const nnvm::IndexedGraph& ig,
                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
       ConvDeconvType::Deconvolution);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertDeconvolution
 
-void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPooling(GraphProto *graph_proto, const std::string& node_name,
+                    const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
 
   const mxnet::TShape kernel = pooling_param.kernel;
@@ -274,6 +368,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& 
attrs,
     } else {
       LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " 
<< attrs.name;
     }
+    DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
     return;
   }
 
@@ -328,17 +423,24 @@ void ConvertPooling(NodeProto* node_proto, const 
NodeAttrs& attrs,
   } else {
     count_include_pad->set_i(1);
   }
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertPooling
 
-void ConvertRelu(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                 const nnvm::IndexedGraph& /*ig*/,
-                 const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const 
NodeAttrs& /*attrs*/,
+                 const nnvm::IndexedGraph& ig,
+                 const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Relu");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
-                       const nnvm::IndexedGraph& /*ig*/,
-                       const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertActivation(GraphProto *graph_proto, const std::string& node_name,
+                       const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& ig,
+                       const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
   std::string act_type;
   switch (act_param.act_type) {
@@ -360,42 +462,120 @@ void ConvertActivation(NodeProto* node_proto, const 
NodeAttrs& attrs,
   }
 
   node_proto->set_op_type(act_type);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& 
/*inputs*/) {
+void ConvertFullyConnected(GraphProto *graph_proto, const std::string& 
node_name,
+                           const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
   const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
-  if (act_param.no_bias) {
-      node_proto->set_op_type("MatMul");
+  // ONNX spec doesn't support GEMMs with input of different dims, so we need 
to replace it
+  // by Transpose+MatMul+Add
+  if (!act_param.flatten && !act_param.no_bias) {
+    NodeProto* tranpose_node_proto = graph_proto->add_node();
+    NodeProto* matmul_node_proto = graph_proto->add_node();
+    NodeProto* add_node_proto = graph_proto->add_node();
+    tranpose_node_proto->set_name(node_name+"_Transpose");
+    matmul_node_proto->set_name(node_name+"_MatMul");
+    add_node_proto->set_name(node_name+"_Add");
+
+    tranpose_node_proto->set_op_type("Transpose");
+    matmul_node_proto->set_op_type("MatMul");
+    add_node_proto->set_op_type("Add");
+
+    std::string input_node_name = 
ig[inputs[op::conv::kData].node_id].source->attrs.name;
+    std::string weight_node_name = 
ig[inputs[op::conv::kWeight].node_id].source->attrs.name;
+    std::string bias_node_name = 
ig[inputs[op::conv::kBias].node_id].source->attrs.name;
+
+    tranpose_node_proto->add_input(weight_node_name);
+    tranpose_node_proto->add_output(node_name+"_Transpose");
+
+    matmul_node_proto->add_input(input_node_name);
+    matmul_node_proto->add_input(node_name+"_Transpose");
+    matmul_node_proto->add_output(node_name+"_MatMul");
+
+    add_node_proto->add_input(node_name+"_MatMul");
+    add_node_proto->add_input(bias_node_name);
+    // Add's output is the output of the Transpose+MatMul+Add subgraph
+    add_node_proto->add_output(node_name);
   } else {
-      node_proto->set_op_type("Gemm");
-
-      AttributeProto* const alpha = node_proto->add_attribute();
-      alpha->set_name("alpha");
-      alpha->set_type(AttributeProto::FLOAT);
-      alpha->set_f(1.0f);
-
-      AttributeProto* const beta = node_proto->add_attribute();
-      beta->set_name("beta");
-      beta->set_type(AttributeProto::FLOAT);
-      beta->set_f(1.0f);
-
-      AttributeProto* const transA = node_proto->add_attribute();
-      transA->set_name("transA");
-      transA->set_type(AttributeProto::INT);
-      transA->set_i(0);
+    NodeProto* node_proto = graph_proto->add_node();
+    node_proto->set_name(node_name);
+    if (act_param.no_bias) {
+        node_proto->set_op_type("MatMul");
+    } else {
+        node_proto->set_op_type("Gemm");
+
+        AttributeProto* const alpha = node_proto->add_attribute();
+        alpha->set_name("alpha");
+        alpha->set_type(AttributeProto::FLOAT);
+        alpha->set_f(1.0f);
+
+        AttributeProto* const beta = node_proto->add_attribute();
+        beta->set_name("beta");
+        beta->set_type(AttributeProto::FLOAT);
+        beta->set_f(1.0f);
+
+        AttributeProto* const transA = node_proto->add_attribute();
+        transA->set_name("transA");
+        transA->set_type(AttributeProto::INT);
+        transA->set_i(0);
+
+        AttributeProto* const transB = node_proto->add_attribute();
+        transB->set_name("transB");
+        transB->set_type(AttributeProto::INT);
+        transB->set_i(1);
+    }
+    DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
+  }
+}
 
-      AttributeProto* const transB = node_proto->add_attribute();
-      transB->set_name("transB");
-      transB->set_type(AttributeProto::INT);
-      transB->set_i(1);
+void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const 
NodeAttrs& attrs,
+                  const nnvm::IndexedGraph& ig,
+                  const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  const auto& params = nnvm::get<SliceParam>(attrs.parsed);
+  node_proto->set_op_type("Slice");
+
+  // starts
+  AttributeProto* const starts = node_proto->add_attribute();
+  starts->set_name("starts");
+  starts->set_type(AttributeProto::INTS);
+
+  // ends
+  AttributeProto* const ends = node_proto->add_attribute();
+  ends->set_name("ends");
+  ends->set_type(AttributeProto::INTS);
+
+  // axes
+  AttributeProto* const axes = node_proto->add_attribute();
+  axes->set_name("axes");
+  axes->set_type(AttributeProto::INTS);
+
+  for (int64_t i = 1; i < params.begin.ndim(); ++i) {
+    if (params.begin[i].has_value()) {
+      starts->add_ints(static_cast<int64>(params.begin[i].value()));
+    } else {
+      starts->add_ints(static_cast<int64>(0));
+    }
+    if (params.end[i].has_value()) {
+      ends->add_ints(static_cast<int64>(params.end[i].value()));
+    } else {
+      ends->add_ints(static_cast<int64>(INT_MAX));
+    }
+    axes->add_ints(static_cast<int64>(i));
   }
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name,
+                    const NodeAttrs& /*attrs*/,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Flatten");
 
   // Setting by default to 1 since MXNet doesn't provide such an attribute for 
Flatten in its
@@ -405,11 +585,15 @@ void ConvertFlatten(NodeProto* node_proto, const 
NodeAttrs& /*attrs*/,
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
   axis->set_i(1);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
-                      const nnvm::IndexedGraph& /*ig*/,
-                      const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name,
+                      const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& ig,
+                      const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("BatchNormalization");
   const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
 
@@ -430,29 +614,45 @@ void ConvertBatchNorm(NodeProto* node_proto, const 
NodeAttrs& attrs,
   // (default in ONNX3) implies running batchnorm on all spatial features so 
we need to explicitly
   // disable this for MXNet's BatchNorm.
   spatial->set_i(0);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& 
/*inputs*/) {
+void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string& 
node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Add");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseSub(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& 
/*inputs*/) {
+void ConvertElementwiseSub(GraphProto *graph_proto, const std::string& 
node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Sub");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseMul(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& 
/*inputs*/) {
+void ConvertElementwiseMul(GraphProto *graph_proto, const std::string& 
node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Mul");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs,
-                        const nnvm::IndexedGraph& /*ig*/,
-                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) 
{
+void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name,
+                        const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& ig,
+                        const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
   node_proto->set_op_type("Concat");
   node_proto->set_name(attrs.name);
@@ -461,6 +661,7 @@ void ConvertConcatenate(NodeProto* node_proto, const 
NodeAttrs& attrs,
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
   axis->set_i(static_cast<int64_t>(_param.dim));
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 inline TensorProto_DataType ConvertDType(int dtype) {
@@ -615,9 +816,11 @@ void ConvertOutput(
   }
 }
 
-void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs,
-                 const nnvm::IndexedGraph& /*ig*/,
-                 const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const 
NodeAttrs& attrs,
+                 const nnvm::IndexedGraph& ig,
+                 const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& param = nnvm::get<ClipParam>(attrs.parsed);
 
   node_proto->set_op_type("Clip");
@@ -633,11 +836,14 @@ void ConvertClip(NodeProto* node_proto, const NodeAttrs& 
attrs,
   a_min->set_name("min");
   a_min->set_type(AttributeProto::FLOAT);
   a_min->set_f(static_cast<float>(param.a_min));
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs,
-                const nnvm::IndexedGraph& /*ig*/,
-                const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const 
NodeAttrs& attrs,
+                const nnvm::IndexedGraph& ig,
+                const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& param = nnvm::get<PadParam>(attrs.parsed);
 
   node_proto->set_op_type("Pad");
@@ -679,12 +885,16 @@ void ConvertPad(NodeProto* node_proto, const NodeAttrs& 
attrs,
   value->set_name("value");
   value->set_type(AttributeProto::FLOAT);
   value->set_f(param.constant_value);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertDropout(GraphProto *graph_proto, const std::string& node_name, 
const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Dropout");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 void PreprocessBatchNorm(const NodeAttrs &attrs,

Reply via email to