This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 14d9cf804e [Unity][MSC][M0.3] MSCGraph Builder (#15615)
14d9cf804e is described below
commit 14d9cf804eebe64c6bebb8379228b998dac4171f
Author: Archermmt <[email protected]>
AuthorDate: Wed Aug 30 16:42:12 2023 +0800
[Unity][MSC][M0.3] MSCGraph Builder (#15615)
* add graph builder test
* format fix
* lint fix
* lint fix
---
python/tvm/contrib/msc/core/ir/__init__.py | 1 +
python/tvm/contrib/msc/core/ir/graph.py | 4 +-
python/tvm/contrib/msc/core/ir/translate.py | 172 ++
src/contrib/msc/core/ir/graph_builder.cc | 695 +++++++
src/contrib/msc/core/ir/graph_builder.h | 325 ++++
src/contrib/msc/core/transform/layout_utils.cc | 8 +-
src/contrib/msc/core/transform/set_expr_layout.cc | 29 +-
tests/lint/pylint.sh | 3 +
tests/python/contrib/test_msc/test_graph_build.py | 2037 ++++++++++++++++++++
.../test_msc/test_transform_set_expr_layout.py | 11 +-
.../test_msc/test_transform_set_expr_name.py | 9 +-
11 files changed, 3270 insertions(+), 24 deletions(-)
diff --git a/python/tvm/contrib/msc/core/ir/__init__.py
b/python/tvm/contrib/msc/core/ir/__init__.py
index ce23a2dd8b..81a34bedb6 100644
--- a/python/tvm/contrib/msc/core/ir/__init__.py
+++ b/python/tvm/contrib/msc/core/ir/__init__.py
@@ -17,3 +17,4 @@
"""tvm.contrib.msc.core.ir"""
from .graph import *
+from .translate import *
diff --git a/python/tvm/contrib/msc/core/ir/graph.py
b/python/tvm/contrib/msc/core/ir/graph.py
index c058f74936..0db0fe6e81 100644
--- a/python/tvm/contrib/msc/core/ir/graph.py
+++ b/python/tvm/contrib/msc/core/ir/graph.py
@@ -99,7 +99,9 @@ class MSCTensor(Object):
The tensor description in json format.
"""
- return {"name": self.alias, "shape": self.get_shape(), "dtype":
self.dtype_name}
+ tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype":
self.dtype_name}
+ tensor_des["layout"] = self.layout.name if self.layout else ""
+ return tensor_des
@property
def dtype_name(self) -> str:
diff --git a/python/tvm/contrib/msc/core/ir/translate.py
b/python/tvm/contrib/msc/core/ir/translate.py
new file mode 100644
index 0000000000..46dd59a09b
--- /dev/null
+++ b/python/tvm/contrib/msc/core/ir/translate.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.ir.translate"""
+
+from typing import Dict, Optional, Tuple
+
+import tvm
+from tvm.relax.transform import BindParams
+from tvm.relax.backend.pattern_registry import get_patterns_with_prefix
+from tvm.relay.build_module import bind_params_by_name
+from tvm.contrib.msc.core import transform as msc_transform
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import utils as msc_utils
+from .graph import MSCGraph, MSCTensor
+
+
+def normalize_weights(
+ t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph
+) -> Dict[str, tvm.nd.array]:
+ """Normalize the weghts.
+
+ Parameters
+ ----------
+ t_weights: dict of <MSCTensor, tvm.nd.array>
+ The weights extracted from IRModule.
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+
+ Returns
+ -------
+ weights: dict of <string:tvm.ndarray>
+ The normalized weights.
+ """
+
+ def _to_data(ref_t, data):
+ weight_t = graph.find_tensor(ref_t.name)
+ if weight_t.ndim == 1:
+ if ref_t.ndim != weight_t.ndim:
+ return
tvm.nd.array(data.asnumpy().reshape(weight_t.get_shape()))
+ return data
+ if ref_t.layout and weight_t.layout:
+ ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name
+ if ref_layout != weight_layout:
+ assert all(
+ l.name in ref_layout for l in weight_layout
+ ), "layout mismatch {} compare to {}".format(ref_t, weight_t)
+ permute = [ref_layout.index(l) for l in weight_layout]
+ return tvm.nd.array(data.asnumpy().transpose(*permute))
+ return data
+
+ weights = {t.name: _to_data(t, d) for t, d in t_weights.items()}
+ return weights
+
+
+def from_relax(
+ mod: tvm.IRModule,
+ params: Optional[Dict[str, tvm.nd.array]] = None,
+ trans_config: Optional[Dict[str, str]] = None,
+ build_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ """Change IRModule to MSCGraph.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ trans_config: dict
+ The config for transfrorm IRModule.
+ build_config: dict
+ The config for build MSCGraph.
+
+ Returns
+ -------
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+ weights: dict of <string:tvm.ndarray>
+ The weights from the IRModule.
+ """
+
+ trans_config = trans_config or {}
+ build_config = build_config or {}
+ # TODO(tong.meng): optimize before translate?
+ if params:
+ mod = BindParams("main", params)(mod)
+ patterns = get_patterns_with_prefix("msc")
+ passes = [
+ tvm.relax.transform.FuseOpsByPattern(
+ patterns, bind_constants=False, annotate_codegen=False
+ ),
+ msc_transform.SetExprName(),
+ msc_transform.SetExprLayout(trans_config.get("allow_layout_missing",
True)),
+ ]
+ mod = tvm.transform.Sequential(passes)(mod)
+ graph = _ffi_api.BuildFromRelax(mod, "main",
msc_utils.dump_dict(build_config))
+ t_weights = _ffi_api.GetRelaxWeights(mod, "main")
+ return graph, normalize_weights(t_weights, graph)
+
+
+def from_relay(
+ mod: tvm.IRModule,
+ params: Optional[Dict[str, tvm.nd.array]] = None,
+ trans_config: Optional[Dict[str, str]] = None,
+ build_config: Optional[Dict[str, str]] = None,
+ opt_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ """Change IRModule to MSCGraph.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ trans_config: dict
+ The config for transfrorm IRModule.
+ build_config: dict
+ The config for build MSCGraph.
+ opt_config: dict
+ The config for optimize the relay before translate.
+
+ Returns
+ -------
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+ weights: dict of <string:tvm.ndarray>
+ The weights from the IRModule.
+ """
+
+ trans_config = trans_config or {}
+ build_config = build_config or {}
+ opt_config = opt_config or {}
+ # TODO(tong.meng): optimize before translate?
+ opt_level = opt_config.get("opt_level", 0)
+ if opt_level == 0:
+ if params:
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ else:
+ target = opt_config.get("target", "llvm")
+ disabled_pass = opt_config.get("disabled_pass", []) + [
+ "SimplifyInference",
+ "CanonicalizeOps",
+ "FuseOps",
+ "AlterOpLayout",
+ ]
+ with tvm.transform.PassContext(opt_level=opt_level,
disabled_pass=disabled_pass):
+ mod, params = tvm.relay.optimize(mod, target=target, params=params)
+ patterns = tvm.relay.op.contrib.get_pattern_table("msc")
+ passes = [
+ tvm.relay.transform.InferType(),
+ tvm.relay.transform.MergeComposite(patterns),
+ msc_transform.SetExprName(as_relax=False),
+ ]
+ mod = tvm.transform.Sequential(passes)(mod)
+ graph = _ffi_api.BuildFromRelay(mod, "main",
msc_utils.dump_dict(build_config))
+ t_weights = _ffi_api.GetRelayWeights(mod, "main")
+ return graph, normalize_weights(t_weights, graph)
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
new file mode 100644
index 0000000000..55ff4a45d4
--- /dev/null
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -0,0 +1,695 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/ir/graph_builder.cc
+ */
+
+#include "graph_builder.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) {
+ if (op->attrs.defined()) {
+ Map<String, String> attrs;
+ AttrGetter getter(&attrs);
+ const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
+ for (const auto& pair : attrs) {
+ if (attrs_.count(pair.first)) {
+ int cnt = 1;
+ String rep_key = pair.first;
+ while (attrs_.count(rep_key + "_" + std::to_string(cnt))) {
+ cnt++;
+ }
+ attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second);
+ } else {
+ attrs_.Set(pair.first, pair.second);
+ }
+ }
+ }
+}
+
+const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) {
+ // Add input nodes and record inputs;
+ Array<String> input_names, output_names;
+ for (const auto& p : func->params) {
+ AddNode(p, NullOpt, p->name_hint());
+ ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
+ input_names.push_back(expr_tensor_map_[p][0]);
+ }
+ VisitExpr(func);
+ if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
+ ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body
" << b_node->body;
+ output_names = expr_tensor_map_[b_node->body];
+ } else {
+ LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+ }
+ // remove const nodes as weights
+ Array<MSCJoint> valid_nodes;
+ for (const auto& n : nodes_) {
+ if (!weights_.count(n->name)) {
+ n->index = valid_nodes.size();
+ valid_nodes.push_back(n);
+ }
+ }
+ const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names);
+ // set inputs and outputs alias
+ if (config_.input_aliass.size() == input_names.size()) {
+ for (size_t i = 0; i < input_names.size(); i++) {
+ graph->FindTensor(input_names[i])->alias = config_.input_aliass[i];
+ }
+ } else {
+ for (size_t i = 0; i < input_names.size(); i++) {
+ graph->FindTensor(input_names[i])->alias =
graph->FindProducer(input_names[i])->name;
+ }
+ }
+ if (config_.output_aliass.size() == output_names.size()) {
+ for (size_t i = 0; i < output_names.size(); i++) {
+ graph->FindTensor(output_names[i])->alias = config_.output_aliass[i];
+ }
+ } else {
+ for (size_t i = 0; i < output_names.size(); i++) {
+ const auto& output = graph->FindTensor(output_names[i]);
+ if (output->alias.size() > 0) {
+ continue;
+ }
+ const auto& producer = graph->FindProducer(output_names[i]);
+ output->alias = producer->outputs.size() == 1
+ ? producer->name
+ : StringUtils::Replace(output_names[i], ":", "_");
+ }
+ }
+ return graph;
+}
+
+const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const
Optional<Expr>& binding_var,
+ const String& name) {
+ const auto& node_name = name.size() > 0 ? name :
SpanUtils::GetAttr(expr->span, "name");
+ const auto& master_name = SpanUtils::GetAttr(expr->span, "master");
+ String optype;
+ if (expr->IsInstance<relax::VarNode>()) {
+ optype = "input";
+ } else if (expr->IsInstance<relax::ConstantNode>()) {
+ optype = "constant";
+ } else if (expr->IsInstance<relax::ShapeExprNode>()) {
+ optype = "shape";
+ } else if (expr->IsInstance<relax::TupleGetItemNode>()) {
+ optype = "get_item";
+ } else if (expr->IsInstance<relax::TupleNode>()) {
+ optype = "tuple";
+ } else if (const auto* call_node = expr.as<relax::CallNode>()) {
+ if (const auto* op_node = call_node->op.as<OpNode>()) {
+ optype = StringUtils::Replace(op_node->name, "relax.", "");
+ } else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
+ const auto& func =
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
+ const auto& name_opt =
func->GetAttr<runtime::String>(relax::attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+ optype = name_opt.value();
+ } else if (const auto* f_node = call_node->op.as<relax::FunctionNode>()) {
+ const auto& name_opt =
f_node->GetAttr<runtime::String>(relax::attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected func without composite";
+ optype = name_opt.value();
+ } else {
+ optype = "unknown_op";
+ }
+ } else {
+ optype = "unknown_expr";
+ }
+ // Extract attributes
+ Map<String, String> attrs;
+ if (const auto* call_node = expr.as<relax::CallNode>()) {
+ if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
+ const auto& func =
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
+ attrs = RelaxFuncAttrGetter().GetAttrs(func);
+ } else if (call_node->op->IsInstance<relax::FunctionNode>()) {
+ attrs = RelaxFuncAttrGetter().GetAttrs(call_node->op);
+ } else if (call_node->attrs.defined()) {
+ AttrGetter getter(&attrs);
+ const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
+ }
+ } else if (const auto* const_node = expr.as<relax::ConstantNode>()) {
+ if (const_node->is_scalar()) {
+ const float val =
ExprUtils::GetScalar<float>(Downcast<relax::Constant>(expr));
+ std::stringstream stream;
+ stream << std::fixed << std::setprecision(config_.float_precision) <<
val;
+ attrs.Set("scalar", stream.str());
+ }
+ } else if (const auto* shape_node = expr.as<relax::ShapeExprNode>()) {
+ attrs.Set("shape", StringUtils::ToString(shape_node->values));
+ } else if (const auto* get_node = expr.as<relax::TupleGetItemNode>()) {
+ attrs.Set("index", std::to_string(get_node->index));
+ }
+ // Get scope
+ Array<String> scope;
+ if (optype != "input" && optype != "constant") {
+ scope = StringUtils::Split(scope_name_, ".");
+ }
+ // Build inputs and weights
+ Array<String> input_names;
+ Map<String, MSCTensor> node_weights;
+ if (const auto* call_node = expr.as<relax::CallNode>()) {
+ const auto& input_types = ExprUtils::GetInputTypes(optype,
call_node->args.size(), true);
+ for (size_t i = 0; i < call_node->args.size(); i++) {
+ const auto& arg = call_node->args[i];
+ if (const auto* s_node = arg.as<relax::ShapeExprNode>()) {
+ attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
+ continue;
+ }
+ if (const auto* s_node = arg.as<relax::PrimValueNode>()) {
+ ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
+ << " should has special type, get "
<< input_types;
+ attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
+ continue;
+ }
+ ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg;
+ if (input_types[i] != "input" && arg->IsInstance<relax::ConstantNode>())
{
+ const auto& t_name = expr_tensor_map_[arg][0];
+ const auto& w_name = SpanUtils::GetAttr(arg->span, "name");
+ const auto& pair = tensor_input_map_[t_name];
+ const auto& producer = Downcast<MSCJoint>(pair.first);
+ if (!weights_.count(w_name)) {
+ const auto& ref = producer->OutputAt(pair.second);
+ const auto& weight = MSCTensor(w_name, ref->dtype,
ref->layout.name(), ref->shape);
+ weights_.Set(w_name, weight);
+ }
+ if (producer->HasAttr("scalar")) {
+ attrs.Set(input_types[i],
producer->GetTypeAttr<std::string>("scalar"));
+ }
+ node_weights.Set(input_types[i], weights_[w_name]);
+ } else {
+ for (const auto& in_name : expr_tensor_map_[arg]) {
+ input_names.push_back(in_name);
+ }
+ }
+ }
+ } else if (const auto* tuple_node = expr.as<relax::TupleNode>()) {
+ for (const auto& f : tuple_node->fields) {
+ ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
+ for (const auto& in_name : expr_tensor_map_[f]) {
+ input_names.push_back(in_name);
+ }
+ }
+ } else if (const auto* getitem_node = expr.as<relax::TupleGetItemNode>()) {
+ ICHECK(expr_tensor_map_.count(getitem_node->tuple))
+ << "Can not find tuple " << getitem_node->tuple;
+ input_names = expr_tensor_map_[getitem_node->tuple];
+ }
+ std::vector<std::pair<BaseJoint, size_t>> inputs;
+ for (const auto& i : input_names) {
+ inputs.push_back(tensor_input_map_[i]);
+ }
+ // Build outputs
+ Array<MSCTensor> outputs;
+ const auto& layout = SpanUtils::GetAttr(expr->span, "layout");
+ const auto& sinfo = relax::GetStructInfo(expr);
+ if (const auto* t_info = sinfo.as<relax::TensorStructInfoNode>()) {
+ const auto& opt_shape = t_info->GetShape();
+ const auto& shape =
+ opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) :
Array<Integer>();
+ const auto& output =
+ MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype, layout,
shape);
+ outputs.push_back(output);
+ } else if (const auto* s_sinfo = sinfo.as<relax::ShapeStructInfoNode>()) {
+ Array<Integer> shape{s_sinfo->ndim};
+ const auto& output = MSCTensor(node_name + ":" + std::to_string(0),
+
DataType(runtime::String2DLDataType("int32")), layout, shape);
+ outputs.push_back(output);
+ } else if (const auto* tuple_sinfo = sinfo.as<relax::TupleStructInfoNode>())
{
+ Array<String> layouts = StringUtils::Split(layout, ",");
+ if (layouts.size() == 0) {
+ layouts = Array<String>(tuple_sinfo->fields.size(), "");
+ }
+ ICHECK_EQ(layouts.size(), tuple_sinfo->fields.size())
+ << "Layout " << layout << " msimatch with fileds size " <<
tuple_sinfo->fields.size();
+ size_t field_size = tuple_sinfo->fields.size();
+ if (optype == "nn.batch_norm") {
+ field_size = 1;
+ }
+ for (size_t i = 0; i < field_size; i++) {
+ const auto& t_info =
Downcast<relax::TensorStructInfo>(tuple_sinfo->fields[i]);
+ const auto& opt_shape = t_info->GetShape();
+ const auto& shape =
+ opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) :
Array<Integer>();
+ const auto& output =
+ MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype,
layouts[i], shape);
+ outputs.push_back(output);
+ }
+ } else {
+ LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" <<
sinfo;
+ }
+ // Build node
+ const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype,
attrs, scope, inputs,
+ outputs, node_weights);
+ Array<String> output_names;
+ for (size_t i = 0; i < outputs.size(); i++) {
+ output_names.push_back(outputs[i]->name);
+ tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
+ }
+ nodes_.push_back(node);
+ const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
+ expr_tensor_map_.Set(ref_expr, output_names);
+ return node;
+}
+
+void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) {
+ scope_name_ = SpanUtils::GetAttr(block->span, "name");
+ RelaxExprVisitor::VisitBindingBlock(block);
+}
+
+void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
+ AddNode(GetRef<relax::Constant>(op));
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::ConstantNode* val) {
+ AddNode(GetRef<relax::Constant>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::ShapeExprNode* val) {
+ AddNode(GetRef<relax::ShapeExpr>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::CallNode* call_node) {
+ RelaxExprVisitor::VisitBinding_(binding, call_node);
+ try {
+ AddNode(GetRef<relax::Call>(call_node), binding->var);
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to add node from " << binding->var << " : " <<
binding->value
+ << ", reason: " << err.message();
+ throw err;
+ }
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::TupleNode* val) {
+ RelaxExprVisitor::VisitBinding_(binding, val);
+ AddNode(GetRef<relax::Tuple>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::TupleGetItemNode* val) {
+ RelaxExprVisitor::VisitBinding_(binding, val);
+ AddNode(GetRef<relax::TupleGetItem>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::VarNode* val) {
+ RelaxExprVisitor::VisitBinding_(binding, val);
+ const auto& output = GetRef<relax::Var>(val);
+ ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output;
+ expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::DataflowVarNode* val) {
+ RelaxExprVisitor::VisitBinding_(binding, val);
+ const auto& output = GetRef<relax::DataflowVar>(val);
+ ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " <<
output;
+ expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]);
+}
+
+Map<MSCTensor, NDArray> RelaxWeightsExtractor::GetWeights(const
relax::Function& func) {
+ VisitExpr(func);
+ return weights_;
+}
+
+void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) {
+ const auto& name = SpanUtils::GetAttr(op->span, "name");
+ const auto& layout = SpanUtils::GetAttr(op->span, "layout");
+ const auto& sinfo = relax::GetStructInfo(GetRef<relax::Constant>(op));
+ ICHECK(sinfo->IsInstance<relax::TensorStructInfoNode>())
+ << "Constant StrcutInfo should be TensorStructInfo";
+ const auto& t_info = Downcast<relax::TensorStructInfo>(sinfo);
+ const auto& opt_shape = t_info->GetShape();
+ const auto& shape =
+ opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) :
Array<Integer>();
+ const auto& weight = MSCTensor(name, t_info->dtype, layout, shape);
+ weights_.Set(weight, op->data);
+}
+
+void RelayFuncAttrGetter::VisitExpr_(const relay::CallNode* op) {
+ RelayExprVisitor::VisitExpr_(op);
+ if (op->attrs.defined()) {
+ Map<String, String> attrs;
+ AttrGetter getter(&attrs);
+ const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
+ for (const auto& pair : attrs) {
+ if (attrs_.count(pair.first)) {
+ int cnt = 1;
+ String rep_key = pair.first;
+ while (attrs_.count(rep_key + "_" + std::to_string(cnt))) {
+ cnt++;
+ }
+ attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second);
+ } else {
+ attrs_.Set(pair.first, pair.second);
+ }
+ }
+ }
+}
+
+MSCGraph RelayGraphBuilder::Build(const relay::Function& func) {
+ // Add input nodes and record inputs;
+ Array<String> input_names, output_names;
+ for (const auto& p : func->params) {
+ AddNode(p, p->name_hint());
+ ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
+ input_names.push_back(expr_tensor_map_[p][0]);
+ }
+ VisitExpr(func);
+ ICHECK(expr_tensor_map_.count(func->body)) << "Can not find func body " <<
func->body;
+ output_names = expr_tensor_map_[func->body];
+ // remove const nodes as weights
+ Array<MSCJoint> valid_nodes;
+ for (const auto& n : nodes_) {
+ if (!weights_.count(n->name)) {
+ n->index = valid_nodes.size();
+ valid_nodes.push_back(n);
+ }
+ }
+ const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names);
+ // set inputs and outputs alias
+ if (config_.input_aliass.size() == input_names.size()) {
+ for (size_t i = 0; i < input_names.size(); i++) {
+ graph->FindTensor(input_names[i])->alias = config_.input_aliass[i];
+ }
+ } else {
+ for (size_t i = 0; i < input_names.size(); i++) {
+ graph->FindTensor(input_names[i])->alias =
graph->FindProducer(input_names[i])->name;
+ }
+ }
+ if (config_.output_aliass.size() == output_names.size()) {
+ for (size_t i = 0; i < output_names.size(); i++) {
+ graph->FindTensor(output_names[i])->alias = config_.output_aliass[i];
+ }
+ } else {
+ for (size_t i = 0; i < output_names.size(); i++) {
+ const auto& output = graph->FindTensor(output_names[i]);
+ if (output->alias.size() > 0) {
+ continue;
+ }
+ const auto& producer = graph->FindProducer(output_names[i]);
+ output->alias = producer->outputs.size() == 1
+ ? producer->name
+ : StringUtils::Replace(output_names[i], ":", "_");
+ }
+ }
+ return graph;
+}
+
+MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) {
+ const auto& node_name = name.size() > 0 ? name :
SpanUtils::GetAttr(expr->span, "name");
+ const auto& master_name = SpanUtils::GetAttr(expr->span, "master");
+ String optype;
+ if (expr->IsInstance<relay::VarNode>()) {
+ optype = "input";
+ } else if (expr->IsInstance<relay::ConstantNode>()) {
+ optype = "constant";
+ } else if (expr->IsInstance<relay::TupleGetItemNode>()) {
+ optype = "get_item";
+ } else if (expr->IsInstance<relay::TupleNode>()) {
+ optype = "tuple";
+ } else if (const auto* call_node = expr.as<relay::CallNode>()) {
+ if (const auto* op_node = call_node->op.as<OpNode>()) {
+ optype = StringUtils::Replace(op_node->name, "relay.", "");
+ } else {
+ optype = "unknown_op";
+ }
+ } else if (const auto* f_node = expr.as<relay::FunctionNode>()) {
+ const auto& name_opt =
f_node->GetAttr<runtime::String>(relay::attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected func without composite";
+ optype = name_opt.value();
+ } else {
+ optype = "unknown_expr";
+ }
+ // Extract attributes
+ Map<String, String> attrs;
+ if (const auto* call_node = expr.as<relay::CallNode>()) {
+ if (call_node->attrs.defined()) {
+ AttrGetter getter(&attrs);
+ const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
+ }
+ } else if (expr->IsInstance<relay::FunctionNode>()) {
+ attrs = RelayFuncAttrGetter().GetAttrs(expr);
+ } else if (const auto* const_node = expr.as<relay::ConstantNode>()) {
+ if (const_node->is_scalar()) {
+ const float val =
ExprUtils::GetScalar<float>(Downcast<relay::Constant>(expr));
+ std::stringstream stream;
+ stream << std::fixed << std::setprecision(config_.float_precision) <<
val;
+ attrs.Set("scalar", stream.str());
+ }
+ } else if (const auto* get_node = expr.as<relay::TupleGetItemNode>()) {
+ attrs.Set("index", std::to_string(get_node->index));
+ }
+ // Get scope
+ Array<String> scope;
+ if (optype != "input" && optype != "constant") {
+ scope.push_back("block");
+ }
+ // Build inputs and weights
+ Array<String> input_names;
+ Map<String, MSCTensor> node_weights;
+ if (const auto* call_node = expr.as<relay::CallNode>()) {
+ const auto& input_types = ExprUtils::GetInputTypes(optype,
call_node->args.size(), false);
+ for (size_t i = 0; i < call_node->args.size(); i++) {
+ const auto& arg = call_node->args[i];
+ ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg;
+ if (input_types[i] != "input" && arg->IsInstance<relay::ConstantNode>())
{
+ const auto& t_name = expr_tensor_map_[arg][0];
+ const auto& w_name = SpanUtils::GetAttr(arg->span, "name");
+ const auto& pair = tensor_input_map_[t_name];
+ const auto& producer = Downcast<MSCJoint>(pair.first);
+ if (!weights_.count(w_name)) {
+ const auto& ref = producer->OutputAt(pair.second);
+ const auto& weight = MSCTensor(w_name, ref->dtype,
ref->layout.name(), ref->shape);
+ weights_.Set(w_name, weight);
+ }
+ if (producer->HasAttr("scalar")) {
+ attrs.Set(input_types[i],
producer->GetTypeAttr<std::string>("scalar"));
+ }
+ node_weights.Set(input_types[i], weights_[w_name]);
+ } else {
+ for (const auto& in_name : expr_tensor_map_[arg]) {
+ input_names.push_back(in_name);
+ }
+ }
+ }
+ } else if (const auto* f_node = expr.as<relay::FunctionNode>()) {
+ for (const auto& p : f_node->params) {
+ for (const auto& in_name : expr_tensor_map_[p]) {
+ input_names.push_back(in_name);
+ }
+ }
+ ICHECK(HasFuncScope()) << "Function without func scope " <<
relay::PrettyPrint(expr);
+ const auto& weight_names = func_scopes_.top().GetFuncWeights();
+ const auto& input_types =
+ ExprUtils::GetInputTypes(optype, f_node->params.size() +
weight_names.size(), false);
+ for (size_t i = 0; i < weight_names.size(); i++) {
+ const auto& pair = tensor_input_map_[weight_names[i]];
+ const auto& producer = Downcast<MSCJoint>(pair.first);
+ if (!weights_.count(producer->name)) {
+ const auto& ref = producer->OutputAt(pair.second);
+ const auto& weight = MSCTensor(producer->name, ref->dtype,
ref->layout.name(), ref->shape);
+ weights_.Set(producer->name, weight);
+ }
+ if (producer->HasAttr("scalar")) {
+ attrs.Set(input_types[i],
producer->GetTypeAttr<std::string>("scalar"));
+ }
+ node_weights.Set(input_types[i + f_node->params.size()],
weights_[producer->name]);
+ }
+ } else if (const auto* tuple_node = expr.as<relay::TupleNode>()) {
+ for (const auto& f : tuple_node->fields) {
+ ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
+ for (const auto& in_name : expr_tensor_map_[f]) {
+ input_names.push_back(in_name);
+ }
+ }
+ } else if (const auto* getitem_node = expr.as<relay::TupleGetItemNode>()) {
+ ICHECK(expr_tensor_map_.count(getitem_node->tuple))
+ << "Can not find tuple " << getitem_node->tuple;
+ input_names = expr_tensor_map_[getitem_node->tuple];
+ }
+ std::vector<std::pair<BaseJoint, size_t>> inputs;
+ for (const auto& i : input_names) {
+ inputs.push_back(tensor_input_map_[i]);
+ }
+ // Build outputs
+ Array<MSCTensor> outputs;
+ const auto& layout = SpanUtils::GetAttr(expr->span, "layout");
+ Type checked_type = expr->checked_type_;
+ if (checked_type.defined() &&
checked_type->IsInstance<relay::FuncTypeNode>()) {
+ checked_type = Downcast<FuncType>(checked_type)->ret_type;
+ }
+ if (checked_type.defined()) {
+ if (const auto* t_info = checked_type.as<relay::TensorTypeNode>()) {
+ const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+ const auto& output =
+ MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype,
layout, shape);
+ outputs.push_back(output);
+ } else if (const auto* tuple_info =
checked_type.as<relay::TupleTypeNode>()) {
+ Array<String> layouts = StringUtils::Split(layout, ",");
+ if (layouts.size() == 0) {
+ layouts = Array<String>(tuple_info->fields.size(), "");
+ }
+ ICHECK_EQ(layouts.size(), tuple_info->fields.size())
+ << "Layout " << layout << " msimatch with fileds size " <<
tuple_info->fields.size();
+ size_t field_size = tuple_info->fields.size();
+ if (optype == "nn.batch_norm") {
+ field_size = 1;
+ }
+ for (size_t i = 0; i < field_size; i++) {
+ const auto& t_info =
Downcast<relay::TensorType>(tuple_info->fields[i]);
+ const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+ const auto& output =
+ MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype,
layouts[i], shape);
+ outputs.push_back(output);
+ }
+ } else {
+ LOG(FATAL) << "Unexpected checked_type " << checked_type;
+ }
+ }
+
+ // Build node
+ const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype,
attrs, scope, inputs,
+ outputs, node_weights);
+ Array<String> output_names;
+ for (size_t i = 0; i < outputs.size(); i++) {
+ output_names.push_back(outputs[i]->name);
+ tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
+ }
+ nodes_.push_back(node);
+ expr_tensor_map_.Set(expr, output_names);
+ return node;
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::ConstantNode* op) {
+ const auto& node = AddNode(GetRef<relay::Constant>(op));
+ if (HasFuncScope()) {
+ func_scopes_.top().AddFuncWeight(node->OutputAt(0)->name);
+ }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::FunctionNode* op) {
+ const auto& name_opt = op->GetAttr<runtime::String>(relay::attr::kComposite);
+ if (name_opt.defined()) {
+ StartFuncScope(SpanUtils::GetAttr(op->span, "name"));
+ }
+ RelayExprVisitor::VisitExpr_(op);
+ if (HasFuncScope()) {
+ AddNode(GetRef<relay::Function>(op));
+ EndFuncScope();
+ }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::CallNode* op) {
+ if (const auto* f_node = op->op.as<relay::FunctionNode>()) {
+ const auto& name_opt =
f_node->GetAttr<runtime::String>(relay::attr::kComposite);
+ if (name_opt.defined()) {
+ for (size_t i = 0; i < op->args.size(); i++) {
+ ICHECK(expr_tensor_map_.count(op->args[i]))
+ << "Can not find argument " << relay::PrettyPrint(op->args[i]);
+ expr_tensor_map_.Set(f_node->params[i], expr_tensor_map_[op->args[i]]);
+ }
+ }
+ }
+ RelayExprVisitor::VisitExpr_(op);
+ if (!HasFuncScope() && op->op->IsInstance<OpNode>()) {
+ try {
+ AddNode(GetRef<relay::Call>(op));
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to add node from " <<
relay::PrettyPrint(GetRef<relay::Call>(op))
+ << " : " << err.message();
+ throw err;
+ }
+ }
+ if (op->op->IsInstance<relay::FunctionNode>() &&
expr_tensor_map_.count(op->op)) {
+ expr_tensor_map_.Set(GetRef<relay::Call>(op), expr_tensor_map_[op->op]);
+ }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::TupleNode* val) {
+ RelayExprVisitor::VisitExpr_(val);
+ AddNode(GetRef<relay::Tuple>(val));
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::TupleGetItemNode* val) {
+ RelayExprVisitor::VisitExpr_(val);
+ AddNode(GetRef<relay::TupleGetItem>(val));
+}
+
+void RelayGraphBuilder::StartFuncScope(const String& name) {
+ RelayFuncScope func_scope = RelayFuncScope(name);
+ func_scopes_.push(func_scope);
+}
+void RelayGraphBuilder::EndFuncScope() {
+ ICHECK(HasFuncScope()) << "No FuncScope found";
+ func_scopes_.pop();
+}
+
+bool RelayGraphBuilder::HasFuncScope() { return func_scopes_.size() > 0; }
+
+Map<MSCTensor, NDArray> RelayWeightsExtractor::GetWeights(const
relay::Function& func) {
+ VisitExpr(func);
+ return weights_;
+}
+
+void RelayWeightsExtractor::VisitExpr_(const relay::ConstantNode* op) {
+ const auto& name = SpanUtils::GetAttr(op->span, "name");
+ const auto& layout = SpanUtils::GetAttr(op->span, "layout");
+ const auto& t_info = op->tensor_type();
+ const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+ const auto& weight = MSCTensor(name, t_info->dtype, layout, shape);
+ weights_.Set(weight, op->data);
+}
+
+TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax")
+ .set_body_typed([](const IRModule& relax_module, const String& entry_name,
+ const String& options) -> MSCGraph {
+ const auto& func =
Downcast<relax::Function>(relax_module->Lookup(entry_name));
+ return RelaxGraphBuilder(relax_module, entry_name, options).Build(func);
+ });
+
+TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights")
+ .set_body_typed([](const IRModule& relax_module,
+ const String& entry_name) -> Map<MSCTensor, NDArray> {
+ const auto& func =
Downcast<relax::Function>(relax_module->Lookup(entry_name));
+ return RelaxWeightsExtractor().GetWeights(func);
+ });
+
+TVM_REGISTER_GLOBAL("msc.core.BuildFromRelay")
+ .set_body_typed([](const IRModule& relay_module, const String& entry_name,
+ const String& options) -> MSCGraph {
+ const auto& func =
Downcast<relay::Function>(relay_module->Lookup(entry_name));
+ return RelayGraphBuilder(relay_module, entry_name, options).Build(func);
+ });
+
+TVM_REGISTER_GLOBAL("msc.core.GetRelayWeights")
+ .set_body_typed([](const IRModule& relay_module,
+ const String& entry_name) -> Map<MSCTensor, NDArray> {
+ const auto& func =
Downcast<relay::Function>(relay_module->Lookup(entry_name));
+ return RelayWeightsExtractor().GetWeights(func);
+ });
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/msc/core/ir/graph_builder.h
b/src/contrib/msc/core/ir/graph_builder.h
new file mode 100644
index 0000000000..bb7223695d
--- /dev/null
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/ir/graph_builder.h
+ * \brief Builder of MSCGraph.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
+#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
+
+#include <dmlc/json.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/tir/data_layout.h>
+
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../utils.h"
+#include "graph.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using RelaxExprVisitor = tvm::relax::ExprVisitor;
+using RelayExprVisitor = tvm::relay::ExprVisitor;
+using namespace tvm::runtime;
+
+/*!
+ * \brief Config for building MSCGraph.
+ * Define the configuration for building MSCGraph
+ */
+struct MSCRBuildConfig {
+ bool prune_graph{false};
+ int float_precision = 6;
+ std::string sort_by;
+ std::vector<std::string> input_aliass;
+ std::vector<std::string> output_aliass;
+ std::unordered_map<std::string, std::vector<std::string>> input_types;
+
+ void LoadInputTypes(dmlc::JSONReader* reader) {
+ std::string key;
+ reader->BeginObject();
+ while (reader->NextObjectItem(&key)) {
+ std::vector<std::string> types;
+ reader->Read(&types);
+ input_types[key] = types;
+ }
+ }
+
+ void Load(dmlc::JSONReader* reader) {
+ std::string key;
+ reader->BeginObject();
+ while (reader->NextObjectItem(&key)) {
+ if (key == "prune_graph") {
+ reader->Read(&prune_graph);
+ } else if (key == "float_precision") {
+ reader->Read(&float_precision);
+ } else if (key == "sort_by") {
+ reader->Read(&sort_by);
+ } else if (key == "input_aliass") {
+ reader->Read(&input_aliass);
+ } else if (key == "output_aliass") {
+ reader->Read(&output_aliass);
+ } else if (key == "input_types") {
+ this->LoadInputTypes(reader);
+ }
+ }
+ }
+};
+
+class AttrGetter : public AttrVisitor {
+ public:
+ /*!
+ * \brief Get the attributes as Map<String, String>
+ * \param attrs the attributes.
+ */
+ explicit AttrGetter(Map<String, String>* attrs) : attrs_(attrs) {}
+
+ void Visit(const char* key, double* value) final { attrs_->Set(key,
std::to_string(*value)); }
+
+ void Visit(const char* key, int64_t* value) final { attrs_->Set(key,
std::to_string(*value)); }
+
+ void Visit(const char* key, uint64_t* value) final { attrs_->Set(key,
std::to_string(*value)); }
+
+ void Visit(const char* key, int* value) final { attrs_->Set(key,
std::to_string(*value)); }
+
+ void Visit(const char* key, bool* value) final { attrs_->Set(key,
std::to_string(*value)); }
+
+ void Visit(const char* key, std::string* value) final { attrs_->Set(key,
*value); }
+
+ void Visit(const char* key, DataType* value) final {
+ attrs_->Set(key, runtime::DLDataType2String(*value));
+ }
+
+ void Visit(const char* key, runtime::ObjectRef* value) final {
+ attrs_->Set(key, StringUtils::ToString(*value));
+ }
+
+ void Visit(const char* key, void** value) final {
+ LOG(FATAL) << "TypeError: void is not allowed in Attrs";
+ }
+
+ void Visit(const char* key, runtime::NDArray* value) final {
+ LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs";
+ }
+
+ private:
+ Map<String, String>* attrs_;
+};
+
+class RelaxFuncAttrGetter : public RelaxExprVisitor {
+ public:
+ /*! \brief Get the attributes as Map<String, String>*/
+ Map<String, String> GetAttrs(const Expr& expr) {
+ RelaxExprVisitor::VisitExpr(expr);
+ return attrs_;
+ }
+
+ void VisitExpr_(const relax::CallNode* op) final;
+
+ private:
+ Map<String, String> attrs_;
+};
+
+class RelaxGraphBuilder : public RelaxExprVisitor {
+ public:
+ /*!
+ * \brief The constructor of RelaxGraphBuilder
+ * \param ref_module the reference module.
+ * \param name the name of the graph.
+ * \param options the options of build the graph.
+ */
+ explicit RelaxGraphBuilder(const IRModule& ref_module, const String& name,
+ const std::string& options = "")
+ : RelaxExprVisitor() {
+ name_ = name;
+ ref_module_ = ref_module;
+ if (options.size() > 0) {
+ std::istringstream is(options);
+ dmlc::JSONReader reader(&is);
+ reader.Read(&config_);
+ }
+ }
+
+ /*! \brief Build MSCGraph from relax function*/
+ const MSCGraph Build(const relax::Function& func);
+
+ /*! \brief Create and add MSCJoint from expr*/
+ const MSCJoint AddNode(const Expr& expr, const Optional<Expr>& binding_var =
NullOpt,
+ const String& name = "");
+
+ void VisitBindingBlock(const relax::BindingBlock& block) final;
+
+ void VisitExpr_(const relax::ConstantNode* op) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::ConstantNode* val) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::ShapeExprNode* val) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::CallNode* call_node) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::TupleNode* val) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding,
+ const relax::TupleGetItemNode* val) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::VarNode* val) final;
+
+ void VisitBinding_(const relax::VarBindingNode* binding, const
relax::DataflowVarNode* val) final;
+
+ private:
+ String name_;
+ String scope_name_;
+ IRModule ref_module_;
+ MSCRBuildConfig config_;
+ Array<MSCJoint> nodes_;
+ Map<String, MSCTensor> weights_;
+ Map<Expr, Array<String>> expr_tensor_map_;
+ std::unordered_map<String, std::pair<BaseJoint, size_t>> tensor_input_map_;
+};
+
+class RelaxWeightsExtractor : public RelaxExprVisitor {
+ public:
+ /*! \brief Visit the constant and save weights */
+ Map<MSCTensor, NDArray> GetWeights(const relax::Function& func);
+
+ void VisitExpr_(const relax::ConstantNode* op) final;
+
+ private:
+ Map<MSCTensor, NDArray> weights_;
+};
+
+class RelayFuncAttrGetter : public RelayExprVisitor {
+ public:
+ /*! \brief Get the attributes as Map<String, String>*/
+ Map<String, String> GetAttrs(const Expr& expr) {
+ RelayFuncAttrGetter::VisitExpr(expr);
+ return attrs_;
+ }
+
+ void VisitExpr_(const relay::CallNode* op) final;
+
+ private:
+ Map<String, String> attrs_;
+};
+
+/*!
+ * \brief A Scope for recording func
+ */
+class RelayFuncScope {
+ public:
+ /*! \brief The constructor */
+ explicit RelayFuncScope(const String& name) : name_(name) {}
+
+ /*! \brief Add a weight */
+ void AddFuncWeight(const String& weight) { func_weights_.push_back(weight); }
+
+ /*! \brief Get weights */
+ const Array<String> GetFuncWeights() { return func_weights_; }
+
+ private:
+ String name_;
+ Array<String> func_weights_;
+};
+
+class RelayGraphBuilder : public RelayExprVisitor {
+ public:
+ /*!
+ * \brief The constructor of RelayGraphBuilder
+ * \param ref_module the reference module.
+ * \param name the name of the graph.
+ * \param options the options of build the graph.
+ */
+ explicit RelayGraphBuilder(const IRModule& ref_module, const String& name,
+ const std::string& options = "")
+ : RelayExprVisitor() {
+ name_ = name;
+ ref_module_ = ref_module;
+ if (options.size() > 0) {
+ std::istringstream is(options);
+ dmlc::JSONReader reader(&is);
+ reader.Read(&config_);
+ }
+ while (!func_scopes_.empty()) {
+ func_scopes_.pop();
+ }
+ }
+
+ /*! \brief Build MSCGraph from relax function*/
+ MSCGraph Build(const relay::Function& func);
+
+ /*! \brief Create and add MSCJoint from expr*/
+ MSCJoint AddNode(const Expr& expr, const String& name = "");
+
+ void VisitExpr_(const relay::ConstantNode* op) final;
+
+ void VisitExpr_(const relay::FunctionNode* op) final;
+
+ void VisitExpr_(const relay::CallNode* op) final;
+
+ void VisitExpr_(const relay::TupleNode* val) final;
+
+ void VisitExpr_(const relay::TupleGetItemNode* val) final;
+
+ protected:
+ /*! \brief Start a func scope */
+ void StartFuncScope(const String& scope);
+
+ /*! \brief End a func scope */
+ void EndFuncScope();
+
+ /*! \brief Check if has func scopes left */
+ bool HasFuncScope();
+
+ private:
+ String name_;
+ MSCRBuildConfig config_;
+ IRModule ref_module_;
+ Array<MSCJoint> nodes_;
+ Map<String, MSCTensor> weights_;
+ Map<Expr, Array<String>> expr_tensor_map_;
+ std::unordered_map<String, std::pair<BaseJoint, size_t>> tensor_input_map_;
+ std::stack<RelayFuncScope> func_scopes_;
+};
+
+class RelayWeightsExtractor : public RelayExprVisitor {
+ public:
+ /*! \brief Visit the constant and save weights*/
+ Map<MSCTensor, NDArray> GetWeights(const relay::Function& func);
+
+ void VisitExpr_(const relay::ConstantNode* op) final;
+
+ private:
+ Map<MSCTensor, NDArray> weights_;
+};
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+#endif // TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
diff --git a/src/contrib/msc/core/transform/layout_utils.cc
b/src/contrib/msc/core/transform/layout_utils.cc
index ffc631c6d0..3c70e1871b 100644
--- a/src/contrib/msc/core/transform/layout_utils.cc
+++ b/src/contrib/msc/core/transform/layout_utils.cc
@@ -22,6 +22,7 @@
*/
#include "layout_utils.h"
+#include <algorithm>
#include <set>
#include <string>
@@ -118,12 +119,15 @@ const LayoutDecision LayoutUtils::ExpandLayout(const
LayoutDecision& src_layout,
if (!src_layout->layout.defined()) {
return src_layout;
}
+ // sort expand axes
+ std::vector<size_t> axes = expand_axes;
+ std::sort(std::begin(axes), std::end(axes));
std::string new_layout = src_layout.name();
ICHECK_EQ(new_layout.size(), src_layout->layout.ndim())
<< "Only support normal layout, get " << src_layout->layout;
std::vector<std::string> priority_dims{"N", "C", "H", "W", "D", "G", "T"};
- size_t left_size = expand_axes.size();
- for (const auto& a : expand_axes) {
+ size_t left_size = axes.size();
+ for (const auto& a : axes) {
std::string target = "U";
if (new_layout.find("H") && !new_layout.find("W")) {
target = "W";
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc
b/src/contrib/msc/core/transform/set_expr_layout.cc
index 5915bef9e1..a94c846b3f 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -1118,28 +1118,27 @@ class LayoutInfer : public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const TupleNode* val)
final {
ExprVisitor::VisitBinding_(binding, val);
- std::vector<NLayout> input_layout;
- for (const auto& field : val->fields) {
- if (binding->var->IsInstance<DataflowVarNode>()) {
- // Df var: Use the current realized layout to group the tuple;
- input_layout.push_back(GetNLayout(var_layout_map_, field));
- } else {
- // Global var: Use the initial layout to group the tuple;
- input_layout.push_back(InitialNLayout(field));
- }
- }
if (IsNestedTensor(binding->var)) {
- var_layout_map_[binding->var] = input_layout;
+ Array<NLayout> input_layouts;
+ for (const auto& field : val->fields) {
+ input_layouts.push_back(InferLayoutDecision(field, var_layout_map_));
+ }
+ var_layout_map_[binding->var] = input_layouts;
+ if (LayoutUtils::SetLayout(GetRef<Tuple>(val), NLayout(input_layouts))) {
+ infered_ = true;
+ }
}
RecordExpr(binding->var, GetRef<Tuple>(val));
}
void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
val) final {
ExprVisitor::VisitBinding_(binding, val);
- NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
- ? GetNLayout(var_layout_map_, val->tuple)
- : InitialNLayout(val->tuple);
- var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+ const auto& out_layout =
+ InferLayoutDecisionAt(GetRef<TupleGetItem>(val)->tuple,
var_layout_map_, val->index);
+ var_layout_map_[binding->var] = out_layout;
+ if (LayoutUtils::SetLayout(GetRef<TupleGetItem>(val), out_layout)) {
+ infered_ = true;
+ }
RecordExpr(binding->var, GetRef<TupleGetItem>(val));
}
diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh
index ac93a6f15d..7fb6af30a1 100755
--- a/tests/lint/pylint.sh
+++ b/tests/lint/pylint.sh
@@ -45,3 +45,6 @@ python3 -m pylint tests/python/frontend/oneflow/*.py
--rcfile="$(dirname "$0")"/
python3 -m pylint tests/python/frontend/tensorflow/test_forward.py
--rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/frontend/pytorch/test_forward.py
--rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/frontend/tflite/test_forward.py
--rcfile="$(dirname "$0")"/pylintrc
+
+# tests/python/contrib/test_msc tests
+python3 -m pylint tests/python/contrib/test_msc/*.py --rcfile="$(dirname
"$0")"/pylintrc
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
new file mode 100644
index 0000000000..6e410e584c
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -0,0 +1,2037 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test graph builder && graph. """
+
+import torch
+from torch import fx
+from torch.nn import Module
+
+import tvm.testing
+from tvm.relax.frontend.torch import from_fx
+from tvm.contrib.msc.core.ir import translate
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+def verify_model(torch_model, input_info, expected):
+ graph_model = fx.symbolic_trace(torch_model)
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+ graph, _ = translate.from_relax(mod)
+ inspect = graph.inspect()
+ assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with
expected {}".format(
+ inspect, expected
+ )
+
+
+def test_conv1d():
+ """test graph builder for conv1d"""
+
+ class Conv1D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
+
+ def forward(self, data):
+ return self.conv(data)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32",
"layout": "NCW"}],
+ "outputs": [
+ {"name": "msc.conv1d_bias", "shape": [1, 6, 4], "dtype":
"float32", "layout": "NCW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1},
+ }
+
+ class Conv1D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+
+ def forward(self, data):
+ return self.conv(data)
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32",
"layout": "NCW"}],
+ "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32",
"layout": "NCW"}],
+ "nodes": {"total": 2, "input": 1, "nn.conv1d": 1},
+ }
+
+ input_info = [([1, 3, 10], "float32")]
+ verify_model(Conv1D1(), input_info, expected1)
+ verify_model(Conv1D2(), input_info, expected2)
+
+
+def test_conv2d():
+ """test graph builder for conv2d"""
+
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, data):
+ return self.conv(data)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {
+ "name": "msc.conv2d_bias",
+ "shape": [1, 6, 4, 4],
+ "dtype": "float32",
+ "layout": "NCHW",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
+ }
+
+ class Conv2D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+ def forward(self, data):
+ return self.conv(data)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "conv2d", "shape": [1, 6, 4, 4], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.conv2d": 1},
+ }
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Conv2D1(), input_info, expected1)
+ verify_model(Conv2D2(), input_info, expected2)
+
+
+def test_linear():
+ """test graph builder for linear"""
+
+ class Dense1(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=True)
+
+ def forward(self, data):
+ return self.linear(data)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {
+ "name": "msc.linear_bias",
+ "shape": [1, 3, 10, 7],
+ "dtype": "float32",
+ "layout": "NCHW",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.linear_bias": 1},
+ }
+
+ class Dense2(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=False)
+
+ def forward(self, data):
+ return self.linear(data)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "msc.linear", "shape": [1, 3, 10, 7], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.linear": 1},
+ }
+
+ class MatMul1(Module):
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+ expected3 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout":
"NC"},
+ {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout":
"IO"},
+ ],
+ "outputs": [{"name": "matmul", "shape": [10, 10], "dtype": "float32",
"layout": "NC"}],
+ "nodes": {"total": 3, "input": 2, "matmul": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Dense1(), input_info, expected1)
+ verify_model(Dense2(), input_info, expected2)
+ verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")],
expected3)
+
+
+def test_bmm():
+ """test graph builder for bmm"""
+
+ class BMM(Module):
+ def forward(self, x, y):
+ return torch.bmm(x, y)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [4, 128, 256], "dtype": "float32",
"layout": "NCD"},
+ {"name": "inp_1", "shape": [4, 256, 512], "dtype": "float32",
"layout": "NIO"},
+ ],
+ "outputs": [
+ {"name": "matmul", "shape": [4, 128, 512], "dtype": "float32",
"layout": "NCD"}
+ ],
+ "nodes": {"total": 3, "input": 2, "matmul": 1},
+ }
+
+ input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")]
+ verify_model(BMM(), input_info, expected)
+
+
+def test_baddbmm():
+ """test graph builder for baddbmm"""
+
+ class BAddBMM1(Module):
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32",
"layout": "NCD"},
+ {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32",
"layout": "NCD"},
+ {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32",
"layout": "NIO"},
+ ],
+ "outputs": [{"name": "add", "shape": [4, 128, 512], "dtype":
"float32", "layout": "NCD"}],
+ "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
+ }
+
+ class BAddBMM2(Module):
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32",
"layout": ""},
+ {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32",
"layout": "NCD"},
+ {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32",
"layout": "NIO"},
+ ],
+ "outputs": [
+ {"name": "multiply", "shape": [4, 128, 512], "dtype": "float32",
"layout": "NCD"}
+ ],
+ "nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1,
"multiply": 1},
+ }
+
+ input_info = [
+ ((4, 128, 512), "float32"),
+ ((4, 128, 256), "float32"),
+ ((4, 256, 512), "float32"),
+ ]
+ verify_model(BAddBMM1(), input_info, expected1)
+ verify_model(BAddBMM2(), input_info, expected2)
+
+
+def test_relu():
+ """test graph builder for relu"""
+
+ class ReLU(Module):
+ def __init__(self):
+ super().__init__()
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, data):
+ return self.relu(data)
+
+ class ReLU1(Module):
+ def forward(self, data):
+ return torch.nn.functional.relu(data)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [{"name": "relu", "shape": [10, 10], "dtype": "float32",
"layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "nn.relu": 1},
+ }
+
+ input_info = [([10, 10], "float32")]
+ verify_model(ReLU(), input_info, expected)
+ verify_model(ReLU1(), input_info, expected)
+
+
+def test_relu6():
+ """test graph builder for relu6"""
+
+ class ReLU6(Module):
+ def __init__(self):
+ super().__init__()
+ self.relu6 = torch.nn.ReLU6()
+
+ def forward(self, data):
+ return self.relu6(data)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "clip", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "clip": 1},
+ }
+ input_info = [([10, 10], "float32")]
+ verify_model(ReLU6(), input_info, expected)
+
+
+def test_maxpool2d():
+ """test graph builder for maxpool2d"""
+
+ class MaxPool2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "max_pool2d", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+ }
+
+ class MaxPool2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "max_pool2d", "shape": [1, 3, 4, 4], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+ }
+
+ class MaxPool2d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2,
stride=2)
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected3 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "max_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(MaxPool2d(), input_info, expected1)
+ verify_model(MaxPool2d2(), input_info, expected2)
+ verify_model(MaxPool2d3(), input_info, expected3)
+
+
+def test_avgpool2d():
+ """test graph builder for avgpool2d"""
+
+ class AvgPool2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "avg_pool2d", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
+ }
+
+ class AvgPool2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2,
padding=2, ceil_mode=True)
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "avg_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(AvgPool2d(), input_info, expected1)
+ verify_model(AvgPool2d2(), input_info, expected2)
+
+
+def test_adaptive_avgpool2d():
+ """test graph builder for adaptive_avgpool2d"""
+
+ class AdaptiveAvgPool2d0(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+ def forward(self, data):
+ return self.pool(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {
+ "name": "adaptive_avg_pool2d",
+ "shape": [1, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "NCHW",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(AdaptiveAvgPool2d0(), input_info, expected)
+
+
+def test_flatten():
+ """test graph builder for flatten"""
+
+ class Flatten(Module):
+ def __init__(self):
+ super().__init__()
+ self.f = torch.nn.Flatten(2, -1)
+
+ def forward(self, data):
+ return self.f(data)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "reshape", "shape": [1, 3, 100], "dtype":
"float32", "layout": ""}],
+ "nodes": {"total": 2, "input": 1, "reshape": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Flatten(), input_info, expected)
+ verify_model(torch.nn.Flatten(2, -1), input_info, expected)
+
+
+def test_batchnorm2d():
+ """test graph builder for batchnorm2d"""
+
+ class BatchNorm2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.batchnorm = torch.nn.BatchNorm2d(3)
+
+ def forward(self, data):
+ return self.batchnorm(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {
+ "name": "batch_norm.0",
+ "shape": [1, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "NCHW",
+ }
+ ],
+ "nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(BatchNorm2d(), input_info, expected)
+
+
+def test_embedding():
+ """test graph builder for embedding"""
+
+ class Embedding(Module):
+ def __init__(self):
+ super().__init__()
+ self.embedding = torch.nn.Embedding(10, 3)
+
+ def forward(self, data):
+ return self.embedding(data)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [4], "dtype": "int64", "layout":
"A"}],
+ "outputs": [{"name": "msc.embedding", "shape": [4, 3], "dtype":
"float32", "layout": "NA"}],
+ "nodes": {"total": 2, "input": 1, "msc.embedding": 1},
+ }
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [4, 5], "dtype": "int64",
"layout": "AB"}],
+ "outputs": [
+ {"name": "msc.embedding", "shape": [4, 5, 3], "dtype": "float32",
"layout": "CNB"}
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.embedding": 1},
+ }
+
+ verify_model(Embedding(), [([4], "int64")], expected1)
+ verify_model(Embedding(), [([4, 5], "int64")], expected2)
+
+
+def test_dropout():
+ """test graph builder for dropout"""
+
+ class Dropout1(Module):
+ def __init__(self):
+ super().__init__()
+ self.dropout = torch.nn.Dropout(0.5)
+
+ def forward(self, data):
+ return self.dropout(data)
+
+ class Dropout2(Module):
+ def forward(self, data):
+ return torch.dropout(data, 0.5, train=True)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "nodes": {"total": 1, "input": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Dropout1(), input_info, expected)
+ verify_model(Dropout2(), input_info, expected)
+
+
+def test_layernorm():
+ """test graph builder for layernorm"""
+
+ class LayerNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.layernorm = torch.nn.LayerNorm((10, 10))
+
+ def forward(self, data):
+ return self.layernorm(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(LayerNorm(), input_info, expected)
+
+
+def test_functional_layernorm():
+ """test graph builder for functional_layernorm"""
+
+ class LayerNorm(Module):
+ def __init__(self, shape):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(shape))
+ self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+ def forward(self, data):
+ return torch.nn.functional.layer_norm(
+ data, self.weight.shape, self.weight, self.bias, 1e-5
+ )
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(LayerNorm((10, 10)), input_info, expected)
+
+
+def test_cross_entropy():
+ """test graph builder for cross_entropy"""
+
+ class CrossEntropy1(Module):
+ def __init__(self):
+ super().__init__()
+ self.loss = torch.nn.CrossEntropyLoss()
+
+ def forward(self, logits, targets):
+ return self.loss(logits, targets)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+ ],
+ "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss":
1},
+ }
+
+ class CrossEntropy2(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones((2,)))
+ self.loss = torch.nn.CrossEntropyLoss(weight=self.weight)
+
+ def forward(self, logits, targets):
+ return self.loss(logits, targets)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+ ],
+ "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1,
"nn.nll_loss": 1},
+ }
+
+ class CrossEntropy3(Module):
+ def __init__(self):
+ super().__init__()
+ self.loss = torch.nn.CrossEntropyLoss(ignore_index=1,
reduction="sum")
+
+ def forward(self, logits, targets):
+ return self.loss(logits, targets)
+
+ expected3 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+ ],
+ "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss":
1},
+ }
+
+ input_info = [([3, 2], "float32"), ([3], "int32")]
+ verify_model(CrossEntropy1(), input_info, expected1)
+ verify_model(CrossEntropy2(), input_info, expected2)
+ verify_model(CrossEntropy3(), input_info, expected3)
+
+
+def test_functional_cross_entropy():
+ """test graph builder for functional_cross_entropy"""
+
+ class CrossEntropy(Module):
+ def forward(self, logits, targets):
+ return torch.nn.functional.cross_entropy(logits, targets)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [3, 10], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+ ],
+ "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss":
1},
+ }
+
+ input_info = [([3, 10], "float32"), ([3], "int32")]
+ verify_model(CrossEntropy(), input_info, expected)
+
+
+def test_silu():
+ """test graph builder for silu"""
+
+ class SiLU(Module):
+ def __init__(self):
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+
+ def forward(self, data):
+ return self.silu(data)
+
+ class SiLU2(Module):
+ def forward(self, data):
+ return torch.nn.functional.silu(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "silu", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.silu": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(SiLU(), input_info, expected)
+ verify_model(SiLU2(), input_info, expected)
+
+
+def test_groupnorm():
+ """test graph builder for groupnorm"""
+
+ class GroupNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.groupnorm = torch.nn.GroupNorm(3, 3)
+
+ def forward(self, data):
+ return self.groupnorm(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "group_norm", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.group_norm": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(GroupNorm(), input_info, expected)
+
+
+def test_softmax():
+ """test graph builder for softmax"""
+
+ class Softmax(Module):
+ def __init__(self):
+ super().__init__()
+ self.softmax = torch.nn.Softmax(dim=1)
+
+ def forward(self, data):
+ return self.softmax(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "softmax", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.softmax": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Softmax(), input_info, expected)
+
+
+def test_binary():
+ """test graph builder for binary"""
+
+ input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+ input_info2 = [([1, 3, 10, 10], "float32")]
+
+ # Add
+ class Add1(Module):
+ def forward(self, lhs, rhs):
+ return lhs + rhs
+
+ expected_add1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "ABCD"}],
+ "nodes": {"total": 3, "input": 2, "add": 1},
+ }
+
+ class Add2(Module):
+ def forward(self, lhs):
+ return lhs + 1.0
+
+ expected_add2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "ABCD"}],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "add": 1},
+ }
+
+ verify_model(Add1(), input_info1, expected_add1)
+ verify_model(Add2(), input_info2, expected_add2)
+
+ # Sub
+ class Sub1(Module):
+ def forward(self, lhs, rhs):
+ return lhs - rhs
+
+ expected_sub1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 2, "subtract": 1},
+ }
+
+ class Sub2(Module):
+ def forward(self, lhs):
+ return lhs - 1.0
+
+ expected_sub2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1},
+ }
+
+ verify_model(Sub1(), input_info1, expected_sub1)
+ verify_model(Sub2(), input_info2, expected_sub2)
+
+ # Mul
+ class Mul1(Module):
+ def forward(self, lhs, rhs):
+ return lhs * rhs
+
+ expected_mul1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 2, "multiply": 1},
+ }
+
+ class Mul2(Module):
+ def forward(self, lhs):
+ return lhs * 1.0
+
+ expected_mul2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1},
+ }
+
+ verify_model(Mul1(), input_info1, expected_mul1)
+ verify_model(Mul2(), input_info2, expected_mul2)
+
+ # True div
+ class TrueDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs / rhs
+
+ expected_div1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 2, "divide": 1},
+ }
+
+ class TrueDiv2(Module):
+ def forward(self, lhs):
+ return lhs / 1.0
+
+ expected_div2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1},
+ }
+
+ verify_model(TrueDiv1(), input_info1, expected_div1)
+ verify_model(TrueDiv2(), input_info2, expected_div2)
+
+ # Floor div
+ class FloorDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs // rhs
+
+ expected_floordiv1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {
+ "name": "floor_divide",
+ "shape": [1, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "ABCD",
+ }
+ ],
+ "nodes": {"total": 3, "input": 2, "floor_divide": 1},
+ }
+
+ class FloorDiv2(Module):
+ def forward(self, lhs):
+ return lhs // 1.0
+
+ expected_floordiv2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {
+ "name": "floor_divide",
+ "shape": [1, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "ABCD",
+ }
+ ],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1},
+ }
+
+ verify_model(FloorDiv1(), input_info1, expected_floordiv1)
+ verify_model(FloorDiv2(), input_info2, expected_floordiv2)
+
+ # Power
+ class Power1(Module):
+ def forward(self, lhs, rhs):
+ return lhs**rhs
+
+ expected_power1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 2, "power": 1},
+ }
+
+ class Power2(Module):
+ def forward(self, lhs):
+ return lhs**1.0
+
+ expected_power2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "power": 1},
+ }
+
+ verify_model(Power1(), input_info1, expected_power1)
+ verify_model(Power2(), input_info2, expected_power2)
+
+ # LT
+ class LT1(Module):
+ def forward(self, lhs, rhs):
+ return lhs < rhs
+
+ expected_lt1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool",
"layout": "ABCD"}],
+ "nodes": {"total": 3, "input": 2, "less": 1},
+ }
+
+ class LT2(Module):
+ def forward(self, lhs):
+ return lhs < 1.0
+
+ expected_lt2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool",
"layout": "ABCD"}],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "less": 1},
+ }
+
+ verify_model(LT1(), input_info1, expected_lt1)
+ verify_model(LT2(), input_info2, expected_lt2)
+
+
+def test_size():
+ """test graph builder for size"""
+
+ class Size(Module):
+ def forward(self, data):
+ return data.size()
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "shape", "shape": [4], "dtype": "int32",
"layout": "O"}],
+ "nodes": {"total": 2, "input": 1, "shape": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Size(), input_info, expected)
+
+
+def test_squeeze():
+ """test graph builder for squeeze"""
+
+ class Squeeze1(Module):
+ def forward(self, data):
+ return data.squeeze(1)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype":
"float32", "layout": "ANBC"}],
+ "outputs": [{"name": "squeeze", "shape": [3, 4, 1], "dtype":
"float32", "layout": "ABC"}],
+ "nodes": {"total": 2, "input": 1, "squeeze": 1},
+ }
+
+ class Squeeze2(Module):
+ def forward(self, data):
+ return data.squeeze()
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype":
"float32", "layout": "ANBC"}],
+ "outputs": [{"name": "squeeze", "shape": [3, 4], "dtype": "float32",
"layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "squeeze": 1},
+ }
+
+ input_info = [([3, 1, 4, 1], "float32")]
+ verify_model(Squeeze1(), input_info, expected1)
+ verify_model(Squeeze2(), input_info, expected2)
+
+
+def test_unsqueeze():
+ """test graph builder for unsqueeze"""
+
+ class Unsqueeze1(Module):
+ def forward(self, data):
+ return data.unsqueeze(1)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ACDE"}
+ ],
+ "outputs": [
+ {
+ "name": "expand_dims",
+ "shape": [1, 1, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "ABCDE",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "expand_dims": 1},
+ }
+
+ class Unsqueeze2(Module):
+ def forward(self, data):
+ return data.unsqueeze(-1)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCE"}
+ ],
+ "outputs": [
+ {
+ "name": "expand_dims",
+ "shape": [1, 3, 10, 10, 1],
+ "dtype": "float32",
+ "layout": "ABCDE",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "expand_dims": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Unsqueeze1(), input_info, expected1)
+ verify_model(Unsqueeze2(), input_info, expected2)
+
+
+def test_getattr():
+ """test graph builder for getattr"""
+
+ class GetAttr1(Module):
+ def forward(self, data):
+ return data.shape
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "shape", "shape": [4], "dtype": "int32",
"layout": "O"}],
+ "nodes": {"total": 2, "input": 1, "shape": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(GetAttr1(), input_info, expected)
+
+
+def test_getitem():
+ """test graph builder for getitem"""
+
+ class Slice1(Module):
+ def forward(self, x):
+ return x[0, 1::2, :, :3]
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "reshape", "shape": [1, 1, 10, 3], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
+ }
+
+ class Slice2(Module):
+ def forward(self, x):
+ return x[:, None, None, :, None]
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [8, 16], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [
+ {"name": "reshape", "shape": [8, 1, 1, 16, 1], "dtype": "float32",
"layout": "ANCHB"}
+ ],
+ "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
+ }
+
+ verify_model(Slice1(), [([1, 3, 10, 10], "float32")], expected1)
+ verify_model(Slice2(), [([8, 16], "float32")], expected2)
+
+
+def test_unary():
+ """test graph builder for unary"""
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ # sin
+ class Sin(Module):
+ def forward(self, data):
+ return torch.sin(data)
+
+ expected_sin = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [{"name": "sin", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "ABCD"}],
+ "nodes": {"total": 2, "input": 1, "sin": 1},
+ }
+
+ verify_model(Sin(), input_info, expected_sin)
+
+ # cos
+ class Cos(Module):
+ def forward(self, data):
+ return torch.cos(data)
+
+ expected_cos = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [{"name": "cos", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "ABCD"}],
+ "nodes": {"total": 2, "input": 1, "cos": 1},
+ }
+
+ verify_model(Cos(), input_info, expected_cos)
+
+ # exp
+ class Exp(Module):
+ def forward(self, data):
+ return torch.exp(data)
+
+ expected_exp = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [{"name": "exp", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": "ABCD"}],
+ "nodes": {"total": 2, "input": 1, "exp": 1},
+ }
+
+ verify_model(Exp(), input_info, expected_exp)
+
+ # sqrt
+ class Sqrt(Module):
+ def forward(self, data):
+ return torch.sqrt(data)
+
+ expected_sqrt = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "sqrt", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "sqrt": 1},
+ }
+
+ verify_model(Sqrt(), input_info, expected_sqrt)
+
+ # sigmoid
+ class Sigmoid(Module):
+ def forward(self, data):
+ return torch.sigmoid(data)
+
+ expected_sigmoid = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "sigmoid", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "sigmoid": 1},
+ }
+
+ verify_model(Sigmoid(), input_info, expected_sigmoid)
+
+ # round
+ class Round(Module):
+ def forward(self, data):
+ return torch.round(data)
+
+ expected_round = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "round", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "round": 1},
+ }
+
+ verify_model(Round(), input_info, expected_round)
+
+
+def test_gelu():
+ """test graph builder for gelu"""
+
+ class Gelu(Module):
+ def forward(self, data):
+ return torch.nn.functional.gelu(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "gelu", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "nn.gelu": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Gelu(), input_info, expected)
+
+
+def test_tanh():
+ """test graph builder for tanh"""
+
+ class Tanh(Module):
+ def forward(self, data):
+ return torch.tanh(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "tanh", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "tanh": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Tanh(), input_info, expected)
+
+
+def test_clamp():
+ """test graph builder for clamp"""
+
+ class Clamp(Module):
+ def forward(self, data):
+ return torch.clamp(data, min=0.1, max=0.5)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "clip", "shape": [1, 3, 10, 10], "dtype":
"float32", "layout": ""}],
+ "nodes": {"total": 2, "input": 1, "clip": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Clamp(), input_info, expected)
+
+
+def test_interpolate():
+ """test graph builder for interpolate"""
+
+ class Interpolate(Module):
+ def forward(self, data):
+ return torch.nn.functional.interpolate(data, (5, 5))
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "resize2d", "shape": [1, 3, 5, 5], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "image.resize2d": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Interpolate(), input_info, expected)
+
+
+def test_addmm():
+ """test graph builder for addmm"""
+
+ class Addmm(Module):
+ def forward(self, x_1, x_2, x_3):
+ return torch.addmm(x_1, x_2, x_3)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout":
"NC"},
+ {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout":
"NC"},
+ {"name": "inp_2", "shape": [10, 10], "dtype": "float32", "layout":
"IO"},
+ ],
+ "outputs": [{"name": "add", "shape": [10, 10], "dtype": "float32",
"layout": "NC"}],
+ "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
+ }
+
+ input_info = [
+ ([10, 10], "float32"),
+ ([10, 10], "float32"),
+ ([10, 10], "float32"),
+ ]
+ verify_model(Addmm(), input_info, expected)
+
+
+def test_split():
+ """test graph builder for split"""
+
+ class Split(Module):
+ def forward(self, data):
+ return torch.split(data, 1, dim=1)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "nodes": {"total": 2, "input": 1, "split": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Split(), input_info, expected)
+
+
+def test_cumsum():
+ """test graph builder for cumsum"""
+
+ class Cumsum(Module):
+ def forward(self, data):
+ return torch.cumsum(data, dim=1, dtype=torch.int32)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "cumsum", "shape": [1, 2, 3, 4], "dtype":
"int32", "layout": ""}],
+ "nodes": {"total": 2, "input": 1, "cumsum": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Cumsum(), input_info, expected)
+
+
+def test_chunk():
+ """test graph builder for chunk"""
+
+ class Chunk(Module):
+ def forward(self, data):
+ return torch.chunk(data, 3, dim=1)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "outputs": [
+ {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "nodes": {"total": 2, "input": 1, "split": 1},
+ }
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Chunk(), input_info, expected)
+
+
+def test_inplace_fill():
+ """test graph builder for inplace_fill"""
+
+ class InplaceFill(Module):
+ def forward(self, data):
+ data.fill_(1.5)
+ return data
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "full", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1},
+ }
+
+ verify_model(InplaceFill(), [([10, 10], "float32")], expected)
+
+
+def test_arange():
+ """test graph builder for arange"""
+
+ class Arange(Module):
+ def forward(self):
+ return torch.arange(0, 20, dtype=torch.int32)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "const", "shape": [20], "dtype": "int32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "constant": 1},
+ }
+
+ verify_model(Arange(), [([10, 10], "float32")], expected)
+
+
+def test_empty():
+ """test graph builder for empty"""
+
+ class Empty(Module):
+ def forward(self):
+ return torch.empty((10, 10), dtype=torch.float32)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "constant": 1},
+ }
+
+ verify_model(Empty(), [([10, 10], "float32")], expected)
+
+
+def test_tensor():
+ """test graph builder for tensor"""
+
+ class Empty1(Module):
+ def forward(self):
+ return torch.tensor(3, dtype=torch.float32)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "const", "shape": [], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "constant": 1},
+ }
+
+ class Empty2(Module):
+ def forward(self):
+ return torch.tensor(3)
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "const", "shape": [], "dtype": "int64", "layout":
""}],
+ "nodes": {"total": 2, "input": 1, "constant": 1},
+ }
+
+ verify_model(Empty1(), [([10, 10], "float32")], expected1)
+ verify_model(Empty2(), [([10, 10], "float32")], expected2)
+
+
+def test_tril():
+ """test graph builder for tril"""
+
+ class Tril(Module):
+ def forward(self, data):
+ return torch.tril(data, 1)
+
+ class InplaceTril(Module):
+ def forward(self, data):
+ data.tril_(1)
+ return data
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "tril", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "tril": 1},
+ }
+
+ input_info = [([10, 10], "float32")]
+ verify_model(Tril(), input_info, expected)
+ verify_model(InplaceTril(), input_info, expected)
+
+
+def test_triu():
+ """test graph builder for triu"""
+
+ class Triu(Module):
+ def forward(self, data):
+ return torch.triu(data, 1)
+
+ class InplaceTriu(Module):
+ def forward(self, data):
+ data.triu_(1)
+ return data
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "triu", "shape": [10, 10], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "triu": 1},
+ }
+
+ input_info = [([10, 10], "float32")]
+ verify_model(Triu(), input_info, expected)
+ verify_model(InplaceTriu(), input_info, expected)
+
+
+def test_new_ones():
+ """test graph builder for new_ones"""
+
+ class NewOnes(Module):
+ def forward(self, x):
+ return x.new_ones(1, 2, 3)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "full", "shape": [1, 2, 3], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1},
+ }
+
+ input_info = [([1, 2, 3], "float32")]
+ verify_model(NewOnes(), input_info, expected)
+
+
+def test_expand():
+ """test graph builder for expand"""
+
+ class Expand(Module):
+ def forward(self, x):
+ return x.expand(4, 2, 3, 4)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": ""}],
+ "outputs": [
+ {"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype":
"float32", "layout": ""}
+ ],
+ "nodes": {"total": 2, "input": 1, "broadcast_to": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Expand(), input_info, expected)
+
+
+def test_reduce():
+ """test graph builder for reduce"""
+
+ # sum
+ class Sum(Module):
+ def forward(self, x):
+ return torch.sum(x, (2, 1))
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ANCB"}],
+ "outputs": [{"name": "sum", "shape": [1, 4], "dtype": "float32",
"layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "sum": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Sum(), input_info, expected)
+
+
+def test_datatype():
+ """test graph builder for datatype"""
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ # float
+ class ToFloat(Module):
+ def forward(self, x):
+ return x.float()
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ABCD"}],
+ "outputs": [
+ {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ verify_model(ToFloat(), input_info, expected1)
+
+ # half
+ class ToHalf(Module):
+ def forward(self, x):
+ return x.half()
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ABCD"}],
+ "outputs": [
+ {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float16",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ verify_model(ToHalf(), input_info, expected2)
+
+ # type
+ class Type(Module):
+ def forward(self, x):
+ return x.type(torch.float32)
+
+ expected3 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ABCD"}],
+ "outputs": [
+ {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ # type
+ class TypeFromAttr(Module):
+ def forward(self, x):
+ return x.type(x.getattr("dtype"))
+
+ expected4 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ABCD"}],
+ "outputs": [
+ {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ # astype
+ class AsType(Module):
+ def forward(self, x):
+ return x.astype(torch.float32)
+
+ expected5 = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ABCD"}],
+ "outputs": [
+ {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32",
"layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ verify_model(Type(), input_info, expected3)
+ verify_model(TypeFromAttr(), input_info, expected4)
+ verify_model(AsType(), input_info, expected5)
+
+
+def test_permute():
+ """test graph builder for permute"""
+
+ class Permute(Module):
+ def forward(self, x):
+ return x.permute(0, 3, 2, 1)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ADCB"}],
+ "outputs": [
+ {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype":
"float32", "layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "permute_dims": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Permute(), input_info, expected)
+
+
+def test_reshape():
+ """test graph builder for reshape"""
+
+ class Reshape(Module):
+ def forward(self, x):
+ return x.reshape(2, 12)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "reshape": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Reshape(), input_info, expected)
+
+
+def test_transpose():
+ """test graph builder for transpose"""
+
+ class Transpose(Module):
+ def forward(self, x):
+ return x.transpose(1, 3)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": "ADCB"}],
+ "outputs": [
+ {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype":
"float32", "layout": "ABCD"}
+ ],
+ "nodes": {"total": 2, "input": 1, "permute_dims": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(Transpose(), input_info, expected)
+
+
+def test_view():
+ """test graph builder for view"""
+
+ class View(Module):
+ def forward(self, x):
+ return x.view(2, 12)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype":
"float32", "layout": ""}],
+ "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "reshape": 1},
+ }
+
+ input_info = [([1, 2, 3, 4], "float32")]
+ verify_model(View(), input_info, expected)
+
+
+def test_keep_params():
+ """test graph builder for keep_params"""
+
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, data):
+ return self.conv(data)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {
+ "name": "msc.conv2d_bias",
+ "shape": [1, 6, 4, 4],
+ "dtype": "float32",
+ "layout": "NCHW",
+ }
+ ],
+ "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
+ }
+
+ verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")], expected)
+
+
+def test_unwrap_unit_return_tuple():
+ """test graph builder for unwrap_unit_return_tuple"""
+
+ class Identity(Module):
+ def forward(self, x):
+ return (x,)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "tuple", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "tuple": 1},
+ }
+
+ verify_model(Identity(), [([256, 256], "float32")], expected)
+
+
+def test_no_bind_return_tuple():
+ """test graph builder for no_bind_return_tuple"""
+
+ class Identity(Module):
+ def forward(self, x, y):
+ return (x, y)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""},
+ {"name": "inp_1", "shape": [256, 256], "dtype": "float32",
"layout": ""},
+ ],
+ "outputs": [
+ {"name": "tuple_0", "shape": [256, 256], "dtype": "float32",
"layout": ""},
+ {"name": "tuple_1", "shape": [256, 256], "dtype": "float32",
"layout": ""},
+ ],
+ "nodes": {"total": 3, "input": 2, "tuple": 1},
+ }
+
+ input_info = [([256, 256], "float32"), ([256, 256], "float32")]
+ verify_model(Identity(), input_info, expected)
+
+
+def test_argmax():
+ """test graph builder for argmax"""
+
+ class Argmax1(Module):
+ def forward(self, data):
+ return torch.argmax(data, dim=-1)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "argmax", "shape": [256], "dtype": "int64",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "argmax": 1},
+ }
+
+ class Argmax2(Module):
+ def forward(self, data):
+ return torch.argmax(data, dim=-1, keepdim=True)
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "argmax", "shape": [256, 1], "dtype": "int64",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "argmax": 1},
+ }
+
+ verify_model(Argmax1(), [([256, 256], "float32")], expected1)
+ verify_model(Argmax2(), [([256, 256], "float32")], expected2)
+
+
+def test_argmin():
+ """test graph builder for argmin"""
+
+ class Argmin1(Module):
+ def forward(self, data):
+ return torch.argmin(data)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "argmin", "shape": [], "dtype": "int64",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "argmin": 1},
+ }
+
+ class Argmin2(Module):
+ def forward(self, data):
+ return torch.argmin(data, keepdim=True)
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64",
"layout": ""}],
+ "nodes": {"total": 2, "input": 1, "argmin": 1},
+ }
+
+ verify_model(Argmin1(), [([256, 256], "float32")], expected1)
+ verify_model(Argmin2(), [([256, 256], "float32")], expected2)
+
+
+def test_to():
+ """test graph builder for to"""
+
+ class To1(Module):
+ def forward(self, data):
+ return data.to(torch.float16)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [{"name": "astype", "shape": [256, 256], "dtype":
"float16", "layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "astype": 1},
+ }
+
+ class To2(Module):
+ def forward(self, data):
+ return data.to("cpu")
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "outputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": ""}],
+ "nodes": {"total": 1, "input": 1},
+ }
+
+ verify_model(To1(), [([256, 256], "float32")], expected1)
+ verify_model(To2(), [([256, 256], "float32")], expected2)
+
+
+def test_mean():
+ """test graph builder for mean"""
+
+ class Mean(Module):
+ def forward(self, data):
+ return data.mean(-1)
+
+ expected1 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AN"}],
+ "outputs": [{"name": "mean", "shape": [256], "dtype": "float32",
"layout": "A"}],
+ "nodes": {"total": 2, "input": 1, "mean": 1},
+ }
+
+ class MeanKeepDim(Module):
+ def forward(self, data):
+ return data.mean(-1, keepdim=True)
+
+ expected2 = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [{"name": "mean", "shape": [256, 1], "dtype": "float32",
"layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "mean": 1},
+ }
+
+ verify_model(Mean(), [([256, 256], "float32")], expected1)
+ verify_model(MeanKeepDim(), [([256, 256], "float32")], expected2)
+
+
+def test_rsqrt():
+ """test graph builder for rsqrt"""
+
+ class Rsqrt(Module):
+ def forward(self, data):
+ return torch.rsqrt(data)
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [{"name": "rsqrt", "shape": [256, 256], "dtype": "float32",
"layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "rsqrt": 1},
+ }
+
+ verify_model(Rsqrt(), [([256, 256], "float32")], expected)
+
+
+def test_neg():
+ """test graph builder for neg"""
+
+ class Neg(Module):
+ def forward(self, data):
+ return -data
+
+ expected = {
+ "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AB"}],
+ "outputs": [{"name": "negative", "shape": [256, 256], "dtype":
"float32", "layout": "AB"}],
+ "nodes": {"total": 2, "input": 1, "negative": 1},
+ }
+
+ verify_model(Neg(), [([256, 256], "float32")], expected)
+
+
+def test_max():
+ """test graph builder for max"""
+
+ class Max(Module):
+ def forward(self, x, y):
+ return torch.max(x, y)
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [256, 256], "dtype": "float32",
"layout": "AB"},
+ {"name": "inp_1", "shape": [256, 256], "dtype": "float32",
"layout": "AB"},
+ ],
+ "outputs": [{"name": "maximum", "shape": [256, 256], "dtype":
"float32", "layout": "AB"}],
+ "nodes": {"total": 3, "input": 2, "maximum": 1},
+ }
+
+ verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")],
expected)
+
+
+def test_attention():
+ """test graph builder for attention"""
+
+ # pylint: disable=import-outside-toplevel
+ import torch.nn.functional as F
+
+ class Attention1(Module):
+ def forward(self, q_data, k_data, v_data):
+ return F.scaled_dot_product_attention(q_data, k_data, v_data)
+
+ class Attention2(Module):
+ def forward(self, q_data, k_data, v_data):
+ return F.scaled_dot_product_attention(q_data, k_data, v_data,
is_causal=True)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ ],
+ "outputs": [
+ {
+ "name": "msc.attention",
+ "shape": [32, 128, 8, 64],
+ "dtype": "float32",
+ "layout": "ABCD",
+ }
+ ],
+ "nodes": {"total": 4, "input": 3, "msc.attention": 1},
+ }
+
+ input_info = [
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ]
+ verify_model(Attention1(), input_info, expected1)
+ verify_model(Attention2(), input_info, expected1)
+
+ class Attention3(Module):
+ def forward(self, q_data, k_data, v_data, mask):
+ return F.scaled_dot_product_attention(q_data, k_data, v_data, mask)
+
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32",
"layout": "ACBD"},
+ {"name": "inp_3", "shape": [32, 8, 128, 128], "dtype": "float32",
"layout": "ABCD"},
+ ],
+ "outputs": [
+ {
+ "name": "msc.attention",
+ "shape": [32, 128, 8, 64],
+ "dtype": "float32",
+ "layout": "ABCD",
+ }
+ ],
+ "nodes": {"total": 5, "input": 4, "msc.attention": 1},
+ }
+
+ verify_model(
+ Attention3(),
+ [
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 128], "float32"),
+ ],
+ expected2,
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
index 4717437d76..34cd3a3214 100644
--- a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
@@ -15,11 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-import tvm.testing
-from tvm.relay import testing
-from tvm.relay.expr_functor import ExprVisitor
-from tvm.relay.build_module import bind_params_by_name
+""" Test SetExprLayout Pass. """
+import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.relax import PyExprVisitor
@@ -52,11 +50,14 @@ class RelaxChecker(PyExprVisitor):
def test_relax():
+ """Test SetExprLayout for relax"""
+
+ # pylint: disable=import-outside-toplevel
try:
import torch
import torchvision
from torch import fx
- except:
+ except: # pylint: disable=bare-except
print("please install pytorch python package")
return
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py
b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
index 0c174ff7bd..426860145c 100644
--- a/tests/python/contrib/test_msc/test_transform_set_expr_name.py
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+""" Test SetExprName Pass. """
+
import tvm.testing
from tvm.relay import testing
from tvm.relay.expr_functor import ExprVisitor
@@ -73,6 +75,8 @@ class RelaxChecker(PyExprVisitor):
def test_relay():
+ """Test SetExprName for relay"""
+
mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1,
dtype="float32")
mod["main"] = bind_params_by_name(mod["main"], params)
mod = msc_transform.SetExprName(as_relax=False)(mod)
@@ -80,11 +84,14 @@ def test_relay():
def test_relax():
+ """Test SetExprName for relax"""
+
+ # pylint: disable=import-outside-toplevel
try:
import torch
import torchvision
from torch import fx
- except:
+ except: # pylint: disable=bare-except
print("please install pytorch python package")
return