This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch ir-patch in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 5361b9bbe83dcc6a0da580c39e405c004abf0b62 Author: Junru Shao <[email protected]> AuthorDate: Wed Oct 9 23:01:24 2019 -0700 [IR-Bridge] Support attrs for operators: convolution, batch norm, relu (#16351) * Rebased * Trigger CI * ... * Trigger CI * Trigger CI * Trigger CI * ... * ... * ... * Trigger CI * Trigger CI * Trigger CI * Trigger CI * ... * ... --- Makefile | 4 +- src/imperative/cached_op.cc | 14 +- src/v3/include/bridge/legacy_nnvm.h | 64 +++++++ src/v3/include/ir.h | 188 +++++++++++++++++++++ src/v3/include/op/attrs/nn.h | 71 ++++++++ src/v3/src/bridge/legacy_nnvm/attrs.cc | 120 +++++++++++++ .../legacy_nnvm/ir.cc} | 109 ++++++------ src/v3/src/op/attrs.cc | 40 +++++ tests/python/unittest/test_numpy_op.py | 9 +- 9 files changed, 561 insertions(+), 58 deletions(-) diff --git a/Makefile b/Makefile index b18edf0..3a675cd 100644 --- a/Makefile +++ b/Makefile @@ -462,7 +462,7 @@ endif all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib -SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) +SRC = $(wildcard src/*/*/*/*/*/*.cc src/*/*/*/*/*.cc src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) OBJ = $(patsubst %.cc, build/%.o, $(SRC)) CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu) CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC)) @@ -795,6 +795,8 @@ clean_all: clean -include build/*/*.d -include build/*/*/*.d -include build/*/*/*/*.d +-include build/*/*/*/*/*.d +-include build/*/*/*/*/*/*.d ifneq ($(EXTRA_OPERATORS),) -include $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) endif diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 14e9527..5180c7f 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -25,18 +25,18 @@ #include "../operator/operator_common.h" #include "../operator/subgraph/common.h" -#if MXNET_USE_TVM_OP -#ifndef MXNET_AMALGAMATION +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION #include <tvm/node/node.h> namespace mxnet { namespace v3 { -namespace nnvm_relay_bridge { +namespace bridge { +namespace legacy_nnvm { tvm::NodeRef NNVMToRelay(const nnvm::Graph &g); -} // namespace nnvm_relay_bridge +} // namespace legacy_nnvm +} // namespace bridge } // namespace v3 } // namespace mxnet -#endif // MXNET_AMALGAMATION -#endif // MXNET_USE_TVM_OP +#endif namespace mxnet { @@ -325,7 +325,7 @@ bool CachedOp::SetForwardGraph( CHECK_EQ(inputs.size(), num_inputs()); nnvm::Graph& g = info->fwd_graph; #if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION - v3::nnvm_relay_bridge::NNVMToRelay(g); + v3::bridge::legacy_nnvm::NNVMToRelay(g); #endif // MXNET_USE_TVM_OP && !define MXNET_AMALGAMATION ShapeVector shape_inputs; DTypeVector dtype_inputs; diff --git a/src/v3/include/bridge/legacy_nnvm.h b/src/v3/include/bridge/legacy_nnvm.h new file mode 100644 index 0000000..e2c99a5 --- /dev/null +++ b/src/v3/include/bridge/legacy_nnvm.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file legacy_nnvm.h + * \author Junru Shao + */ +#pragma once +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION +#include <nnvm/node.h> + +#include "../ir.h" + +namespace nnvm { +class Op; +class Graph; +} // namespace nnvm + +namespace mxnet { +namespace v3 { +namespace bridge { +namespace legacy_nnvm { + +class NNVMCapsuleNode final : public ir::Node { + public: + nnvm::NodeAttrs attrs; + void VisitAttrs(tvm::AttrVisitor *v) final {} + static constexpr const char *_type_key = "mxnet.v3.bridge.NNVMCapsule"; + MX_V3_DEF_NODE_TYPE_INFO(NNVMCapsuleNode, ir::Node); +}; + +class NNVMCapsule final : public ir::NodeRef { + public: + MX_V3_DEF_NODE_REF_METHODS(NNVMCapsule, ir::NodeRef, NNVMCapsuleNode); + static NNVMCapsule make(const nnvm::NodeAttrs &attrs); +}; + +ir::Call ConvertCall(const nnvm::Op *op, const nnvm::NodeAttrs &attrs, + const ir::Array<ir::Expr> &args); + +ir::Function NNVMToRelay(const nnvm::Graph &g); + +} // namespace legacy_nnvm +} // namespace bridge +} // namespace v3 +} // namespace mxnet +#endif diff --git a/src/v3/include/ir.h b/src/v3/include/ir.h new file mode 100644 index 0000000..24440bc --- /dev/null +++ b/src/v3/include/ir.h @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file ir.h + * \author Junru Shao + */ +#pragma once +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION +// This is a compatibility layer between MXNet v3 and Relay +// We will borrow basically everything from TVM/Relay to here. + +#include <tvm/attrs.h> +#include <tvm/ir.h> +#include <tvm/runtime/c_runtime_api.h> +#include <tvm/runtime/packed_func.h> +#include <tvm/node/container.h> +#include <tvm/node/memory.h> +#include <tvm/node/node.h> +#include <tvm/relay/base.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/module.h> +#include <tvm/relay/op.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/type.h> + +namespace mxnet { +namespace v3 { +namespace ir { + +using tvm::Array; +using tvm::Attrs; +using tvm::AttrsNode; +using tvm::Downcast; +using tvm::GetRef; +using tvm::Integer; +using tvm::IntImm; +using tvm::make_node; +using tvm::Map; +using tvm::MapNode; +using tvm::Node; +using tvm::NodePtr; +using tvm::NullValue; + +using tvm::relay::DataType; +using tvm::relay::IndexExpr; +using tvm::relay::NodeEqual; +using tvm::relay::NodeHash; +using tvm::relay::NodeRef; + +// Relay Expression +using tvm::relay::Expr; +using tvm::relay::ExprNode; + +using tvm::relay::FTVMCompute; +using tvm::relay::FTVMSchedule; +using tvm::relay::TOpPattern; +using tvm::relay::Op; +using tvm::relay::OpNode; + +using tvm::relay::Tuple; +using tvm::relay::TupleNode; + +using tvm::relay::Var; +using tvm::relay::VarNode; + +using tvm::relay::GlobalVar; +using tvm::relay::GlobalVarNode; + +using tvm::relay::Function; +using tvm::relay::FunctionNode; + +using tvm::relay::Call; +using tvm::relay::CallNode; + +using tvm::relay::Let; +using tvm::relay::LetNode; + +using tvm::relay::If; +using tvm::relay::IfNode; + +using tvm::relay::TupleGetItem; +using tvm::relay::TupleGetItemNode; + +using tvm::relay::RefCreate; +using tvm::relay::RefCreateNode; + +using tvm::relay::RefRead; +using tvm::relay::RefReadNode; + +using tvm::relay::RefWrite; +using tvm::relay::RefWriteNode; + +using tvm::relay::TempExpr; +using tvm::relay::TempExprNode; + +// Relay Types +using tvm::relay::Kind; + +using tvm::relay::Type; +using tvm::relay::TypeNode; + +using tvm::relay::BaseTensorType; +using tvm::relay::BaseTensorTypeNode; + +using tvm::relay::TensorType; +using tvm::relay::TensorTypeNode; + +using tvm::relay::TypeVar; +using tvm::relay::TypeVarNode; + +using tvm::relay::GlobalTypeVar; +using tvm::relay::GlobalTypeVarNode; + +using tvm::relay::TypeCall; +using tvm::relay::TypeCallNode; + +using tvm::relay::IncompleteType; +using tvm::relay::IncompleteTypeNode; + +using tvm::relay::FuncType; +using tvm::relay::FuncTypeNode; + +using tvm::relay::TupleType; +using tvm::relay::TupleTypeNode; + +using tvm::relay::RefType; +using tvm::relay::RefTypeNode; + +using tvm::relay::TypeConstraint; +using tvm::relay::TypeConstraintNode; + +using tvm::relay::TypeRelation; +using tvm::relay::TypeRelationNode; + +using tvm::relay::TypeReporter; + +// Relay Functors +using tvm::relay::ExprFunctor; + +} // namespace ir +} // namespace v3 +} // namespace mxnet + +#define MX_V3_DEF_NODE_TYPE_INFO(TypeName, Parent) TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) + +#define MX_V3_DEF_BASE_NODE_INFO(TypeName, Parent) TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) + +#define MX_V3_DEF_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() { \ + } \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) { \ + } \ + NodeName* operator->() const { \ + return static_cast<NodeName*>(node_.get()); \ + } \ + operator bool() const { \ + return this->defined(); \ + } \ + using ContainerType = NodeName; + +#define MX_V3_DECLARE_ATTRS TVM_DECLARE_ATTRS + +#define MX_V3_ATTR_FIELD TVM_ATTR_FIELD + +#define MX_V3_REGISTER_NODE_TYPE TVM_REGISTER_NODE_TYPE + +#define MX_V3_REGISTER_OP RELAY_REGISTER_OP + +#define MX_V3_ADD_FILELINE TVM_ADD_FILELINE +#endif diff --git a/src/v3/include/op/attrs/nn.h b/src/v3/include/op/attrs/nn.h new file mode 100644 index 0000000..cd07603 --- /dev/null +++ b/src/v3/include/op/attrs/nn.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file nn.h + * \author Junru Shao + */ +#pragma once +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION +#include <string> + +#include "../../ir.h" + +namespace mxnet { +namespace v3 { +namespace op { +namespace attrs { + +class ConvAttrs : public ir::AttrsNode<ConvAttrs> { + public: + ir::Array<ir::Integer> stride = {1}; + ir::Array<ir::Integer> padding = {0}; + ir::Array<ir::Integer> dilation = {1}; + int64_t groups = 1; + std::string layout = "INVALID"; + ir::NodeRef capsule{nullptr}; + + MX_V3_DECLARE_ATTRS(ConvAttrs, "mxnet.v3.attrs.ConvAttrs") { + MX_V3_ATTR_FIELD(stride); // {w}, {h, w}, {d, h, w} + MX_V3_ATTR_FIELD(padding); // {w}, {h, w}, {d, h, w} + MX_V3_ATTR_FIELD(dilation); // {w}, {h, w}, {d, h, w} + MX_V3_ATTR_FIELD(groups); + MX_V3_ATTR_FIELD(layout); + } +}; + +class BatchNormAttrs : public ir::AttrsNode<BatchNormAttrs> { + public: + double eps = 1e-5; + double momentum = 0.1; + bool affine = true; + ir::NodeRef capsule{nullptr}; + + MX_V3_DECLARE_ATTRS(ConvAttrs, "mxnet.v3.attrs.BatchNormAttrs") { + MX_V3_ATTR_FIELD(eps); + MX_V3_ATTR_FIELD(momentum); + MX_V3_ATTR_FIELD(affine); + } +}; + +} // namespace attrs +} // namespace op +} // namespace v3 +} // namespace mxnet +#endif diff --git a/src/v3/src/bridge/legacy_nnvm/attrs.cc b/src/v3/src/bridge/legacy_nnvm/attrs.cc new file mode 100644 index 0000000..e88563d --- /dev/null +++ b/src/v3/src/bridge/legacy_nnvm/attrs.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file attrs.cc + * \author Junru Shao + */ +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION +#include <nnvm/node.h> + +#include "../../../../operator/nn/activation-inl.h" +#include "../../../../operator/nn/batch_norm-inl.h" +#include "../../../../operator/nn/convolution-inl.h" +#undef Assign + +#include "../../../include/bridge/legacy_nnvm.h" +#include "../../../include/op/attrs/nn.h" + +namespace mxnet { +namespace v3 { +namespace bridge { +namespace legacy_nnvm { + +using ir::Array; +using ir::Attrs; +using ir::Call; +using ir::CallNode; +using ir::Integer; +using ir::Op; + +static Array<Integer> AsArray(const mxnet::TShape &from) { + Array<Integer> result; + for (const auto &item : from) { + result.push_back(Integer(item)); + } + return result; +} + +static Attrs ConvertAttrs(const mxnet::op::ConvolutionParam &attrs, + const nnvm::NodeAttrs node_attrs) { + static std::unordered_map<int, std::string> layout_map = { + {mshadow::kNCW, "NCW"}, // 1-d conv + {mshadow::kNCHW, "NCHW"}, // 2-d conv + {mshadow::kNHWC, "NHWC"}, // 2-d conv + {mshadow::kNCDHW, "NCDHW"}, // 3-d conv + {mshadow::kNDHWC, "NDHWC"}, // 3-d conv + }; + auto relay_attrs = ir::make_node<v3::op::attrs::ConvAttrs>(); + relay_attrs->stride = AsArray(attrs.stride); + relay_attrs->dilation = AsArray(attrs.dilate); + relay_attrs->padding = AsArray(attrs.pad); + relay_attrs->groups = attrs.num_group; + relay_attrs->layout = layout_map[attrs.layout.value()]; + relay_attrs->capsule = NNVMCapsule::make(node_attrs); + return ir::Attrs(relay_attrs); +} + +static Attrs ConvertAttrs(const mxnet::op::BatchNormParam &attrs, + const nnvm::NodeAttrs &node_attrs) { + auto relay_attrs = ir::make_node<v3::op::attrs::BatchNormAttrs>(); + relay_attrs->eps = attrs.eps; + relay_attrs->momentum = attrs.momentum; + relay_attrs->affine = !attrs.fix_gamma; + relay_attrs->capsule = NNVMCapsule::make(node_attrs); + return ir::Attrs(relay_attrs); +} + +Call ConvertCall(const nnvm::Op *op, const nnvm::NodeAttrs &attrs, + const ir::Array<ir::Expr> &args) { + CHECK(op != nullptr) << "InternalError: operator undefined."; + if (op->name == "Convolution") { + static const Op &op = Op::Get("nn.conv2d"); + const auto &nnvm_attrs = + nnvm::get<mxnet::op::ConvolutionParam>(attrs.parsed); + return CallNode::make(op, args, ConvertAttrs(nnvm_attrs, attrs)); + } else if (op->name == "BatchNorm") { + static const Op &op = Op::Get("nn.batch_norm"); + const auto &nnvm_attrs = nnvm::get<mxnet::op::BatchNormParam>(attrs.parsed); + return CallNode::make(op, args, ConvertAttrs(nnvm_attrs, attrs)); + } else if (op->name == "elemwise_add") { + static const Op &op = Op::Get("add"); + return CallNode::make(op, args, {}); + } else if (op->name == "Activation") { + static std::unordered_map<int, Op> op_map = { + {mxnet::op::activation::kReLU, Op::Get("nn.relu")}, + {mxnet::op::activation::kSigmoid, Op::Get("sigmoid")}, + {mxnet::op::activation::kTanh, Op::Get("tanh")}, + }; + const auto &nnvm_attrs = + nnvm::get<mxnet::op::ActivationParam>(attrs.parsed); + if (op_map.count(nnvm_attrs.act_type)) { + return CallNode::make(op_map[nnvm_attrs.act_type], args, {}); + } + } + LOG(INFO) << "Warning: cannot recognize NNVM operator " << op->name + << ", fallback to add"; + return CallNode::make(Op::Get("add"), args, {}, {}); +} + +} // namespace legacy_nnvm +} // namespace bridge +} // namespace v3 +} // namespace mxnet +#endif diff --git a/src/v3/src/nnvm_relay_bridge.cc b/src/v3/src/bridge/legacy_nnvm/ir.cc similarity index 67% rename from src/v3/src/nnvm_relay_bridge.cc rename to src/v3/src/bridge/legacy_nnvm/ir.cc index 298ce65..4367315 100644 --- a/src/v3/src/nnvm_relay_bridge.cc +++ b/src/v3/src/bridge/legacy_nnvm/ir.cc @@ -19,31 +19,38 @@ /*! * Copyright (c) 2019 by Contributors - * \file nnvm_relay_bridge.cc + * \file ir.cc * \author Junru Shao */ -#if MXNET_USE_TVM_OP -#ifndef MXNET_AMALGAMATION +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION #include <nnvm/graph.h> -#include <tvm/relay/expr.h> -#include <tvm/relay/op.h> -#include <tvm/node/container.h> -#include <tvm/node/node.h> + +#include "../../../include/bridge/legacy_nnvm.h" +#include "../../../include/ir.h" +#include "../../../include/op/attrs/nn.h" namespace mxnet { namespace v3 { -namespace nnvm_relay_bridge { +namespace bridge { +namespace legacy_nnvm { + +using ir::Array; +using ir::CallNode; +using ir::Expr; +using ir::Function; +using ir::FunctionNode; +using ir::LetNode; +using ir::NodeRef; +using ir::TupleGetItemNode; +using ir::TupleNode; +using ir::Var; +using ir::VarNode; -using tvm::relay::Expr; -using tvm::relay::TupleGetItemNode; -using tvm::relay::FunctionNode; -using tvm::relay::Var; -using tvm::relay::VarNode; -using tvm::relay::CallNode; -using tvm::relay::TupleNode; -using tvm::relay::LetNode; -using tvm::NodeRef; -using tvm::Array; +NNVMCapsule NNVMCapsule::make(const nnvm::NodeAttrs &attrs) { + auto node = ir::make_node<NNVMCapsuleNode>(); + node->attrs = attrs; + return NNVMCapsule(node); +} static void PrintIndexedGraph(const nnvm::Graph &g) { const auto &idx = g.indexed_graph(); @@ -58,7 +65,8 @@ static void PrintIndexedGraph(const nnvm::Graph &g) { std::string op_name = op ? op->name : "None"; if (input_nodes.count(i)) { input_cnt += 1; - op_name = (op ? op->name + " [input " : "[input ") + std::to_string(input_cnt) + "]"; + op_name = (op ? op->name + " [input " : "[input ") + + std::to_string(input_cnt) + "]"; } else { op_name = op ? op->name : "None"; } @@ -66,49 +74,49 @@ static void PrintIndexedGraph(const nnvm::Graph &g) { << ", #(input node entries) = " << idx[i].inputs.size() << std::endl; int j_cnt = 0; + for (const auto &attr : node->attrs.dict) { + std::cout << " " << attr.first << " = " << attr.second << std::endl; + } for (const nnvm::IndexedGraph::NodeEntry &j : idx[i].inputs) { std::cout << " input entry #" << ++j_cnt << ", entry_id = " << idx.entry_id(j) << ", (node_id = " << j.node_id << ", index = " << j.index - << ", version = " << j.version << ")" - << std::endl; + << ", version = " << j.version << ")" << std::endl; } for (int j_cnt = 0, n_out = node->num_outputs(); j_cnt < n_out; ++j_cnt) { uint32_t entry_id = idx.entry_id(i, j_cnt); std::cout << " output entry #" << j_cnt + 1 - << ", entry_id = " << entry_id - << std::endl; + << ", entry_id = " << entry_id << std::endl; } } - std::cout << idx.outputs().size() << " output node entries: " - << std::endl; + std::cout << idx.outputs().size() << " output node entries: " << std::endl; int j_cnt = 0; for (const nnvm::IndexedGraph::NodeEntry &j : idx.outputs()) { std::cout << " output entry #" << ++j_cnt << ", entry_id = " << idx.entry_id(j) << ", (node_id = " << j.node_id << ", index = " << j.index - << ", version = " << j.version << ")" - << std::endl; + << ", version = " << j.version << ")" << std::endl; } } -NodeRef NNVMToRelay(const nnvm::Graph &g) { +Function NNVMToRelay(const nnvm::Graph &g) { PrintIndexedGraph(g); const auto &idx = g.indexed_graph(); int n_nodes = idx.num_nodes(); // maps: node -> var std::vector<Var> node2var(n_nodes); // maps: (node, output_index) -> var - std::vector<std::vector<Var> > entries(n_nodes); + std::vector<std::vector<Var>> entries(n_nodes); // maps: node -> #outputs of the node std::vector<int> n_outputs(n_nodes); - for (int node_id = 0, input_cnt = 0, compute_cnt = 0; node_id < n_nodes; ++node_id) { + for (int node_id = 0, input_cnt = 0, compute_cnt = 0; node_id < n_nodes; + ++node_id) { const nnvm::Node *node = idx[node_id].source; int n_out = node->num_outputs(); n_outputs[node_id] = n_out; - std::string name = node->is_variable() ? - "arg_" + std::to_string(++input_cnt) : - "x_" + std::to_string(++compute_cnt); + std::string name = node->is_variable() + ? "arg_" + std::to_string(++input_cnt) + : "x_" + std::to_string(++compute_cnt); Var var = node2var[node_id] = VarNode::make(name, {}); std::vector<Var> &outputs = entries[node_id]; if (n_out == 1) { @@ -121,30 +129,30 @@ NodeRef NNVMToRelay(const nnvm::Graph &g) { } } // Create the let list - std::vector<std::pair<Var, Expr> > let_list; + std::vector<std::pair<Var, Expr>> let_list; for (int node_id = 0; node_id < n_nodes; ++node_id) { const Var &var = node2var[node_id]; const nnvm::IndexedGraph::Node &node = idx[node_id]; int n_out = n_outputs[node_id]; - if (node.source->is_variable()) { + const auto &src = node.source; + if (src->is_variable()) { CHECK_EQ(n_out, 1) << "InternalError: internal assumption violation"; continue; } // Create call_args - std::vector<Expr> call_args; + Array<Expr> call_args; for (const nnvm::IndexedGraph::NodeEntry &input : node.inputs) { - CHECK_LT((int)input.node_id, node_id) << "InternalError: IndexedGraph is not topo-sorted"; + CHECK_LT((int)input.node_id, node_id) + << "InternalError: IndexedGraph is not topo-sorted"; call_args.push_back(entries[input.node_id][input.index]); } - // TODO(@junrushao1994): map attrs // Add a CallNode - let_list.push_back({var, CallNode::make(tvm::relay::Op::Get("add"), call_args)}); + let_list.push_back({var, ConvertCall(src->op(), src->attrs, call_args)}); // Add logic for de-tuple if (n_out > 1) { for (int index = 0; index < n_out; ++index) { - let_list.push_back(std::make_pair( - entries[node_id][index], - TupleGetItemNode::make(var, index))); + let_list.push_back(std::make_pair(entries[node_id][index], + TupleGetItemNode::make(var, index))); } } } @@ -164,9 +172,14 @@ NodeRef NNVMToRelay(const nnvm::Graph &g) { for (const nnvm::IndexedGraph::NodeEntry &j : idx.outputs()) { outputs.push_back(entries[j.node_id][j.index]); } - body = TupleNode::make(std::move(outputs)); - // 2) Construct let out of let-list - for ( ; !let_list.empty(); let_list.pop_back()) { + CHECK(!outputs.empty()) << "InternalError: NNVM graph has no output"; + if (outputs.size() == 1) { + body = outputs[0]; + } else { + body = TupleNode::make(std::move(outputs)); + } + // 2) Construct the body out of let-list + for (; !let_list.empty(); let_list.pop_back()) { const std::pair<Var, Expr> &last = let_list.back(); body = LetNode::make(last.first, last.second, body); } @@ -175,8 +188,8 @@ NodeRef NNVMToRelay(const nnvm::Graph &g) { return FunctionNode::make(std::move(params), std::move(body), {}, {}, {}); } -} // namespace nnvm_relay_bridge +} // namespace legacy_nnvm +} // namespace bridge } // namespace v3 } // namespace mxnet -#endif // MXNET_AMALGAMATION -#endif // MXNET_USE_TVM_OP +#endif diff --git a/src/v3/src/op/attrs.cc b/src/v3/src/op/attrs.cc new file mode 100644 index 0000000..3396fc0 --- /dev/null +++ b/src/v3/src/op/attrs.cc @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file attrs.cc + * \author Junru Shao + */ +#if MXNET_USE_TVM_OP && !defined MXNET_AMALGAMATION +#include "../../include/ir.h" +#include "../../include/op/attrs/nn.h" + +namespace mxnet { +namespace v3 { +namespace op { +namespace attrs { +namespace { +MX_V3_REGISTER_NODE_TYPE(ConvAttrs); +MX_V3_REGISTER_NODE_TYPE(BatchNormAttrs); +} // namespace +} // namespace attrs +} // namespace op +} // namespace v3 +} // namespace mxnet +#endif diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1f90f30..7870486 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -228,7 +228,7 @@ def test_np_ldexp(): def hybrid_forward(self, F, x1, x2): return F.np.ldexp(x1, x2) - + def _np_ldexp(x1, x2): return x1 * _np.power(2.0, x2) @@ -427,6 +427,7 @@ def test_np_inner(): rtol=1e-1, atol=1e-1, dtype=dtype) [email protected]("flaky") @with_seed() @use_np def test_np_outer(): @@ -547,7 +548,7 @@ def test_np_sum(): np_out = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) - [email protected]('flaky') @with_seed() @use_np def test_np_max_min(): @@ -655,6 +656,7 @@ def test_np_max_min(): _test_np_exception(func, shape, dim) [email protected]("flaky") @with_seed() @use_np def test_np_mean(): @@ -719,6 +721,7 @@ def test_np_mean(): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) [email protected]("flaky") @with_seed() @use_np def test_np_moment(): @@ -1019,6 +1022,7 @@ def test_np_squeeze(): rtol=1e-5, atol=1e-6, use_broadcast=False) [email protected]("flaky") @with_seed() @use_np def test_np_prod(): @@ -1764,6 +1768,7 @@ def test_np_randint(): verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nrepeat=100) [email protected]("flaky") @with_seed() @use_np def test_np_minimum_maximum():
