This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 14d9cf804e [Unity][MSC][M0.3] MSCGraph Builder (#15615)
14d9cf804e is described below

commit 14d9cf804eebe64c6bebb8379228b998dac4171f
Author: Archermmt <[email protected]>
AuthorDate: Wed Aug 30 16:42:12 2023 +0800

    [Unity][MSC][M0.3] MSCGraph Builder (#15615)
    
    * add graph builder test
    
    * format fix
    
    * lint fix
    
    * lint fix
---
 python/tvm/contrib/msc/core/ir/__init__.py         |    1 +
 python/tvm/contrib/msc/core/ir/graph.py            |    4 +-
 python/tvm/contrib/msc/core/ir/translate.py        |  172 ++
 src/contrib/msc/core/ir/graph_builder.cc           |  695 +++++++
 src/contrib/msc/core/ir/graph_builder.h            |  325 ++++
 src/contrib/msc/core/transform/layout_utils.cc     |    8 +-
 src/contrib/msc/core/transform/set_expr_layout.cc  |   29 +-
 tests/lint/pylint.sh                               |    3 +
 tests/python/contrib/test_msc/test_graph_build.py  | 2037 ++++++++++++++++++++
 .../test_msc/test_transform_set_expr_layout.py     |   11 +-
 .../test_msc/test_transform_set_expr_name.py       |    9 +-
 11 files changed, 3270 insertions(+), 24 deletions(-)

diff --git a/python/tvm/contrib/msc/core/ir/__init__.py 
b/python/tvm/contrib/msc/core/ir/__init__.py
index ce23a2dd8b..81a34bedb6 100644
--- a/python/tvm/contrib/msc/core/ir/__init__.py
+++ b/python/tvm/contrib/msc/core/ir/__init__.py
@@ -17,3 +17,4 @@
 """tvm.contrib.msc.core.ir"""
 
 from .graph import *
+from .translate import *
diff --git a/python/tvm/contrib/msc/core/ir/graph.py 
b/python/tvm/contrib/msc/core/ir/graph.py
index c058f74936..0db0fe6e81 100644
--- a/python/tvm/contrib/msc/core/ir/graph.py
+++ b/python/tvm/contrib/msc/core/ir/graph.py
@@ -99,7 +99,9 @@ class MSCTensor(Object):
             The tensor description in json format.
         """
 
-        return {"name": self.alias, "shape": self.get_shape(), "dtype": 
self.dtype_name}
+        tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype": 
self.dtype_name}
+        tensor_des["layout"] = self.layout.name if self.layout else ""
+        return tensor_des
 
     @property
     def dtype_name(self) -> str:
diff --git a/python/tvm/contrib/msc/core/ir/translate.py 
b/python/tvm/contrib/msc/core/ir/translate.py
new file mode 100644
index 0000000000..46dd59a09b
--- /dev/null
+++ b/python/tvm/contrib/msc/core/ir/translate.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.ir.translate"""
+
+from typing import Dict, Optional, Tuple
+
+import tvm
+from tvm.relax.transform import BindParams
+from tvm.relax.backend.pattern_registry import get_patterns_with_prefix
+from tvm.relay.build_module import bind_params_by_name
+from tvm.contrib.msc.core import transform as msc_transform
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import utils as msc_utils
+from .graph import MSCGraph, MSCTensor
+
+
+def normalize_weights(
+    t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph
+) -> Dict[str, tvm.nd.array]:
+    """Normalize the weghts.
+
+    Parameters
+    ----------
+    t_weights: dict of <MSCTensor, tvm.nd.array>
+        The weights extracted from IRModule.
+    graph: tvm.contrib.msc.core.ir.MSCGraph
+        The translated graph.
+
+    Returns
+    -------
+    weights: dict of <string:tvm.ndarray>
+        The normalized weights.
+    """
+
+    def _to_data(ref_t, data):
+        weight_t = graph.find_tensor(ref_t.name)
+        if weight_t.ndim == 1:
+            if ref_t.ndim != weight_t.ndim:
+                return 
tvm.nd.array(data.asnumpy().reshape(weight_t.get_shape()))
+            return data
+        if ref_t.layout and weight_t.layout:
+            ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name
+            if ref_layout != weight_layout:
+                assert all(
+                    l.name in ref_layout for l in weight_layout
+                ), "layout mismatch {} compare to {}".format(ref_t, weight_t)
+                permute = [ref_layout.index(l) for l in weight_layout]
+                return tvm.nd.array(data.asnumpy().transpose(*permute))
+        return data
+
+    weights = {t.name: _to_data(t, d) for t, d in t_weights.items()}
+    return weights
+
+
+def from_relax(
+    mod: tvm.IRModule,
+    params: Optional[Dict[str, tvm.nd.array]] = None,
+    trans_config: Optional[Dict[str, str]] = None,
+    build_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+    """Change IRModule to MSCGraph.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The IRModule of relax.
+    params: dict of <string:tvm.ndarray>
+        The parameters of the IRModule.
+    trans_config: dict
+        The config for transfrorm IRModule.
+    build_config: dict
+        The config for build MSCGraph.
+
+    Returns
+    -------
+    graph: tvm.contrib.msc.core.ir.MSCGraph
+        The translated graph.
+    weights: dict of <string:tvm.ndarray>
+        The weights from the IRModule.
+    """
+
+    trans_config = trans_config or {}
+    build_config = build_config or {}
+    # TODO(tong.meng): optimize before translate?
+    if params:
+        mod = BindParams("main", params)(mod)
+    patterns = get_patterns_with_prefix("msc")
+    passes = [
+        tvm.relax.transform.FuseOpsByPattern(
+            patterns, bind_constants=False, annotate_codegen=False
+        ),
+        msc_transform.SetExprName(),
+        msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", 
True)),
+    ]
+    mod = tvm.transform.Sequential(passes)(mod)
+    graph = _ffi_api.BuildFromRelax(mod, "main", 
msc_utils.dump_dict(build_config))
+    t_weights = _ffi_api.GetRelaxWeights(mod, "main")
+    return graph, normalize_weights(t_weights, graph)
+
+
+def from_relay(
+    mod: tvm.IRModule,
+    params: Optional[Dict[str, tvm.nd.array]] = None,
+    trans_config: Optional[Dict[str, str]] = None,
+    build_config: Optional[Dict[str, str]] = None,
+    opt_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+    """Change IRModule to MSCGraph.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The IRModule of relax.
+    params: dict of <string:tvm.ndarray>
+        The parameters of the IRModule.
+    trans_config: dict
+        The config for transfrorm IRModule.
+    build_config: dict
+        The config for build MSCGraph.
+    opt_config: dict
+        The config for optimize the relay before translate.
+
+    Returns
+    -------
+    graph: tvm.contrib.msc.core.ir.MSCGraph
+        The translated graph.
+    weights: dict of <string:tvm.ndarray>
+        The weights from the IRModule.
+    """
+
+    trans_config = trans_config or {}
+    build_config = build_config or {}
+    opt_config = opt_config or {}
+    # TODO(tong.meng): optimize before translate?
+    opt_level = opt_config.get("opt_level", 0)
+    if opt_level == 0:
+        if params:
+            mod["main"] = bind_params_by_name(mod["main"], params)
+    else:
+        target = opt_config.get("target", "llvm")
+        disabled_pass = opt_config.get("disabled_pass", []) + [
+            "SimplifyInference",
+            "CanonicalizeOps",
+            "FuseOps",
+            "AlterOpLayout",
+        ]
+        with tvm.transform.PassContext(opt_level=opt_level, 
disabled_pass=disabled_pass):
+            mod, params = tvm.relay.optimize(mod, target=target, params=params)
+    patterns = tvm.relay.op.contrib.get_pattern_table("msc")
+    passes = [
+        tvm.relay.transform.InferType(),
+        tvm.relay.transform.MergeComposite(patterns),
+        msc_transform.SetExprName(as_relax=False),
+    ]
+    mod = tvm.transform.Sequential(passes)(mod)
+    graph = _ffi_api.BuildFromRelay(mod, "main", 
msc_utils.dump_dict(build_config))
+    t_weights = _ffi_api.GetRelayWeights(mod, "main")
+    return graph, normalize_weights(t_weights, graph)
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
new file mode 100644
index 0000000000..55ff4a45d4
--- /dev/null
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -0,0 +1,695 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/ir/graph_builder.cc
+ */
+
+#include "graph_builder.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) {
+  if (op->attrs.defined()) {
+    Map<String, String> attrs;
+    AttrGetter getter(&attrs);
+    const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
+    for (const auto& pair : attrs) {
+      if (attrs_.count(pair.first)) {
+        int cnt = 1;
+        String rep_key = pair.first;
+        while (attrs_.count(rep_key + "_" + std::to_string(cnt))) {
+          cnt++;
+        }
+        attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second);
+      } else {
+        attrs_.Set(pair.first, pair.second);
+      }
+    }
+  }
+}
+
+const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) {
+  // Add input nodes and record inputs;
+  Array<String> input_names, output_names;
+  for (const auto& p : func->params) {
+    AddNode(p, NullOpt, p->name_hint());
+    ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
+    input_names.push_back(expr_tensor_map_[p][0]);
+  }
+  VisitExpr(func);
+  if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
+    ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body 
" << b_node->body;
+    output_names = expr_tensor_map_[b_node->body];
+  } else {
+    LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+  }
+  // remove const nodes as weights
+  Array<MSCJoint> valid_nodes;
+  for (const auto& n : nodes_) {
+    if (!weights_.count(n->name)) {
+      n->index = valid_nodes.size();
+      valid_nodes.push_back(n);
+    }
+  }
+  const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names);
+  // set inputs and outputs alias
+  if (config_.input_aliass.size() == input_names.size()) {
+    for (size_t i = 0; i < input_names.size(); i++) {
+      graph->FindTensor(input_names[i])->alias = config_.input_aliass[i];
+    }
+  } else {
+    for (size_t i = 0; i < input_names.size(); i++) {
+      graph->FindTensor(input_names[i])->alias = 
graph->FindProducer(input_names[i])->name;
+    }
+  }
+  if (config_.output_aliass.size() == output_names.size()) {
+    for (size_t i = 0; i < output_names.size(); i++) {
+      graph->FindTensor(output_names[i])->alias = config_.output_aliass[i];
+    }
+  } else {
+    for (size_t i = 0; i < output_names.size(); i++) {
+      const auto& output = graph->FindTensor(output_names[i]);
+      if (output->alias.size() > 0) {
+        continue;
+      }
+      const auto& producer = graph->FindProducer(output_names[i]);
+      output->alias = producer->outputs.size() == 1
+                          ? producer->name
+                          : StringUtils::Replace(output_names[i], ":", "_");
+    }
+  }
+  return graph;
+}
+
+const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const 
Optional<Expr>& binding_var,
+                                          const String& name) {
+  const auto& node_name = name.size() > 0 ? name : 
SpanUtils::GetAttr(expr->span, "name");
+  const auto& master_name = SpanUtils::GetAttr(expr->span, "master");
+  String optype;
+  if (expr->IsInstance<relax::VarNode>()) {
+    optype = "input";
+  } else if (expr->IsInstance<relax::ConstantNode>()) {
+    optype = "constant";
+  } else if (expr->IsInstance<relax::ShapeExprNode>()) {
+    optype = "shape";
+  } else if (expr->IsInstance<relax::TupleGetItemNode>()) {
+    optype = "get_item";
+  } else if (expr->IsInstance<relax::TupleNode>()) {
+    optype = "tuple";
+  } else if (const auto* call_node = expr.as<relax::CallNode>()) {
+    if (const auto* op_node = call_node->op.as<OpNode>()) {
+      optype = StringUtils::Replace(op_node->name, "relax.", "");
+    } else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
+      const auto& func = 
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
+      const auto& name_opt = 
func->GetAttr<runtime::String>(relax::attr::kComposite);
+      ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+      optype = name_opt.value();
+    } else if (const auto* f_node = call_node->op.as<relax::FunctionNode>()) {
+      const auto& name_opt = 
f_node->GetAttr<runtime::String>(relax::attr::kComposite);
+      ICHECK(name_opt.defined()) << "Unexpected func without composite";
+      optype = name_opt.value();
+    } else {
+      optype = "unknown_op";
+    }
+  } else {
+    optype = "unknown_expr";
+  }
+  // Extract attributes
+  Map<String, String> attrs;
+  if (const auto* call_node = expr.as<relax::CallNode>()) {
+    if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
+      const auto& func = 
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
+      attrs = RelaxFuncAttrGetter().GetAttrs(func);
+    } else if (call_node->op->IsInstance<relax::FunctionNode>()) {
+      attrs = RelaxFuncAttrGetter().GetAttrs(call_node->op);
+    } else if (call_node->attrs.defined()) {
+      AttrGetter getter(&attrs);
+      const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
+    }
+  } else if (const auto* const_node = expr.as<relax::ConstantNode>()) {
+    if (const_node->is_scalar()) {
+      const float val = 
ExprUtils::GetScalar<float>(Downcast<relax::Constant>(expr));
+      std::stringstream stream;
+      stream << std::fixed << std::setprecision(config_.float_precision) << 
val;
+      attrs.Set("scalar", stream.str());
+    }
+  } else if (const auto* shape_node = expr.as<relax::ShapeExprNode>()) {
+    attrs.Set("shape", StringUtils::ToString(shape_node->values));
+  } else if (const auto* get_node = expr.as<relax::TupleGetItemNode>()) {
+    attrs.Set("index", std::to_string(get_node->index));
+  }
+  // Get scope
+  Array<String> scope;
+  if (optype != "input" && optype != "constant") {
+    scope = StringUtils::Split(scope_name_, ".");
+  }
+  // Build inputs and weights
+  Array<String> input_names;
+  Map<String, MSCTensor> node_weights;
+  if (const auto* call_node = expr.as<relax::CallNode>()) {
+    const auto& input_types = ExprUtils::GetInputTypes(optype, 
call_node->args.size(), true);
+    for (size_t i = 0; i < call_node->args.size(); i++) {
+      const auto& arg = call_node->args[i];
+      if (const auto* s_node = arg.as<relax::ShapeExprNode>()) {
+        attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
+        continue;
+      }
+      if (const auto* s_node = arg.as<relax::PrimValueNode>()) {
+        ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
+                                          << " should has special type, get " 
<< input_types;
+        attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
+        continue;
+      }
+      ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg;
+      if (input_types[i] != "input" && arg->IsInstance<relax::ConstantNode>()) 
{
+        const auto& t_name = expr_tensor_map_[arg][0];
+        const auto& w_name = SpanUtils::GetAttr(arg->span, "name");
+        const auto& pair = tensor_input_map_[t_name];
+        const auto& producer = Downcast<MSCJoint>(pair.first);
+        if (!weights_.count(w_name)) {
+          const auto& ref = producer->OutputAt(pair.second);
+          const auto& weight = MSCTensor(w_name, ref->dtype, 
ref->layout.name(), ref->shape);
+          weights_.Set(w_name, weight);
+        }
+        if (producer->HasAttr("scalar")) {
+          attrs.Set(input_types[i], 
producer->GetTypeAttr<std::string>("scalar"));
+        }
+        node_weights.Set(input_types[i], weights_[w_name]);
+      } else {
+        for (const auto& in_name : expr_tensor_map_[arg]) {
+          input_names.push_back(in_name);
+        }
+      }
+    }
+  } else if (const auto* tuple_node = expr.as<relax::TupleNode>()) {
+    for (const auto& f : tuple_node->fields) {
+      ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
+      for (const auto& in_name : expr_tensor_map_[f]) {
+        input_names.push_back(in_name);
+      }
+    }
+  } else if (const auto* getitem_node = expr.as<relax::TupleGetItemNode>()) {
+    ICHECK(expr_tensor_map_.count(getitem_node->tuple))
+        << "Can not find tuple " << getitem_node->tuple;
+    input_names = expr_tensor_map_[getitem_node->tuple];
+  }
+  std::vector<std::pair<BaseJoint, size_t>> inputs;
+  for (const auto& i : input_names) {
+    inputs.push_back(tensor_input_map_[i]);
+  }
+  // Build outputs
+  Array<MSCTensor> outputs;
+  const auto& layout = SpanUtils::GetAttr(expr->span, "layout");
+  const auto& sinfo = relax::GetStructInfo(expr);
+  if (const auto* t_info = sinfo.as<relax::TensorStructInfoNode>()) {
+    const auto& opt_shape = t_info->GetShape();
+    const auto& shape =
+        opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) : 
Array<Integer>();
+    const auto& output =
+        MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype, layout, 
shape);
+    outputs.push_back(output);
+  } else if (const auto* s_sinfo = sinfo.as<relax::ShapeStructInfoNode>()) {
+    Array<Integer> shape{s_sinfo->ndim};
+    const auto& output = MSCTensor(node_name + ":" + std::to_string(0),
+                                   
DataType(runtime::String2DLDataType("int32")), layout, shape);
+    outputs.push_back(output);
+  } else if (const auto* tuple_sinfo = sinfo.as<relax::TupleStructInfoNode>()) 
{
+    Array<String> layouts = StringUtils::Split(layout, ",");
+    if (layouts.size() == 0) {
+      layouts = Array<String>(tuple_sinfo->fields.size(), "");
+    }
+    ICHECK_EQ(layouts.size(), tuple_sinfo->fields.size())
+        << "Layout " << layout << " msimatch with fileds size " << 
tuple_sinfo->fields.size();
+    size_t field_size = tuple_sinfo->fields.size();
+    if (optype == "nn.batch_norm") {
+      field_size = 1;
+    }
+    for (size_t i = 0; i < field_size; i++) {
+      const auto& t_info = 
Downcast<relax::TensorStructInfo>(tuple_sinfo->fields[i]);
+      const auto& opt_shape = t_info->GetShape();
+      const auto& shape =
+          opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) : 
Array<Integer>();
+      const auto& output =
+          MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype, 
layouts[i], shape);
+      outputs.push_back(output);
+    }
+  } else {
+    LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << 
sinfo;
+  }
+  // Build node
+  const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype, 
attrs, scope, inputs,
+                              outputs, node_weights);
+  Array<String> output_names;
+  for (size_t i = 0; i < outputs.size(); i++) {
+    output_names.push_back(outputs[i]->name);
+    tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
+  }
+  nodes_.push_back(node);
+  const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
+  expr_tensor_map_.Set(ref_expr, output_names);
+  return node;
+}
+
+void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) {
+  scope_name_ = SpanUtils::GetAttr(block->span, "name");
+  RelaxExprVisitor::VisitBindingBlock(block);
+}
+
+void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
+  AddNode(GetRef<relax::Constant>(op));
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::ConstantNode* val) {
+  AddNode(GetRef<relax::Constant>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::ShapeExprNode* val) {
+  AddNode(GetRef<relax::ShapeExpr>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::CallNode* call_node) {
+  RelaxExprVisitor::VisitBinding_(binding, call_node);
+  try {
+    AddNode(GetRef<relax::Call>(call_node), binding->var);
+  } catch (runtime::InternalError& err) {
+    LOG(WARNING) << "Failed to add node from " << binding->var << " : " << 
binding->value
+                 << ", reason: " << err.message();
+    throw err;
+  }
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::TupleNode* val) {
+  RelaxExprVisitor::VisitBinding_(binding, val);
+  AddNode(GetRef<relax::Tuple>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::TupleGetItemNode* val) {
+  RelaxExprVisitor::VisitBinding_(binding, val);
+  AddNode(GetRef<relax::TupleGetItem>(val), binding->var);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::VarNode* val) {
+  RelaxExprVisitor::VisitBinding_(binding, val);
+  const auto& output = GetRef<relax::Var>(val);
+  ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output;
+  expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]);
+}
+
+void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
+                                      const relax::DataflowVarNode* val) {
+  RelaxExprVisitor::VisitBinding_(binding, val);
+  const auto& output = GetRef<relax::DataflowVar>(val);
+  ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << 
output;
+  expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]);
+}
+
+Map<MSCTensor, NDArray> RelaxWeightsExtractor::GetWeights(const 
relax::Function& func) {
+  VisitExpr(func);
+  return weights_;
+}
+
+void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) {
+  const auto& name = SpanUtils::GetAttr(op->span, "name");
+  const auto& layout = SpanUtils::GetAttr(op->span, "layout");
+  const auto& sinfo = relax::GetStructInfo(GetRef<relax::Constant>(op));
+  ICHECK(sinfo->IsInstance<relax::TensorStructInfoNode>())
+      << "Constant StrcutInfo should be TensorStructInfo";
+  const auto& t_info = Downcast<relax::TensorStructInfo>(sinfo);
+  const auto& opt_shape = t_info->GetShape();
+  const auto& shape =
+      opt_shape.defined() ? ArrayUtils::Cast<Integer>(opt_shape.value()) : 
Array<Integer>();
+  const auto& weight = MSCTensor(name, t_info->dtype, layout, shape);
+  weights_.Set(weight, op->data);
+}
+
+void RelayFuncAttrGetter::VisitExpr_(const relay::CallNode* op) {
+  RelayExprVisitor::VisitExpr_(op);
+  if (op->attrs.defined()) {
+    Map<String, String> attrs;
+    AttrGetter getter(&attrs);
+    const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
+    for (const auto& pair : attrs) {
+      if (attrs_.count(pair.first)) {
+        int cnt = 1;
+        String rep_key = pair.first;
+        while (attrs_.count(rep_key + "_" + std::to_string(cnt))) {
+          cnt++;
+        }
+        attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second);
+      } else {
+        attrs_.Set(pair.first, pair.second);
+      }
+    }
+  }
+}
+
+MSCGraph RelayGraphBuilder::Build(const relay::Function& func) {
+  // Add input nodes and record inputs;
+  Array<String> input_names, output_names;
+  for (const auto& p : func->params) {
+    AddNode(p, p->name_hint());
+    ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
+    input_names.push_back(expr_tensor_map_[p][0]);
+  }
+  VisitExpr(func);
+  ICHECK(expr_tensor_map_.count(func->body)) << "Can not find func body " << 
func->body;
+  output_names = expr_tensor_map_[func->body];
+  // remove const nodes as weights
+  Array<MSCJoint> valid_nodes;
+  for (const auto& n : nodes_) {
+    if (!weights_.count(n->name)) {
+      n->index = valid_nodes.size();
+      valid_nodes.push_back(n);
+    }
+  }
+  const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names);
+  // set inputs and outputs alias
+  if (config_.input_aliass.size() == input_names.size()) {
+    for (size_t i = 0; i < input_names.size(); i++) {
+      graph->FindTensor(input_names[i])->alias = config_.input_aliass[i];
+    }
+  } else {
+    for (size_t i = 0; i < input_names.size(); i++) {
+      graph->FindTensor(input_names[i])->alias = 
graph->FindProducer(input_names[i])->name;
+    }
+  }
+  if (config_.output_aliass.size() == output_names.size()) {
+    for (size_t i = 0; i < output_names.size(); i++) {
+      graph->FindTensor(output_names[i])->alias = config_.output_aliass[i];
+    }
+  } else {
+    for (size_t i = 0; i < output_names.size(); i++) {
+      const auto& output = graph->FindTensor(output_names[i]);
+      if (output->alias.size() > 0) {
+        continue;
+      }
+      const auto& producer = graph->FindProducer(output_names[i]);
+      output->alias = producer->outputs.size() == 1
+                          ? producer->name
+                          : StringUtils::Replace(output_names[i], ":", "_");
+    }
+  }
+  return graph;
+}
+
+MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) {
+  const auto& node_name = name.size() > 0 ? name : 
SpanUtils::GetAttr(expr->span, "name");
+  const auto& master_name = SpanUtils::GetAttr(expr->span, "master");
+  String optype;
+  if (expr->IsInstance<relay::VarNode>()) {
+    optype = "input";
+  } else if (expr->IsInstance<relay::ConstantNode>()) {
+    optype = "constant";
+  } else if (expr->IsInstance<relay::TupleGetItemNode>()) {
+    optype = "get_item";
+  } else if (expr->IsInstance<relay::TupleNode>()) {
+    optype = "tuple";
+  } else if (const auto* call_node = expr.as<relay::CallNode>()) {
+    if (const auto* op_node = call_node->op.as<OpNode>()) {
+      optype = StringUtils::Replace(op_node->name, "relay.", "");
+    } else {
+      optype = "unknown_op";
+    }
+  } else if (const auto* f_node = expr.as<relay::FunctionNode>()) {
+    const auto& name_opt = 
f_node->GetAttr<runtime::String>(relay::attr::kComposite);
+    ICHECK(name_opt.defined()) << "Unexpected func without composite";
+    optype = name_opt.value();
+  } else {
+    optype = "unknown_expr";
+  }
+  // Extract attributes
+  Map<String, String> attrs;
+  if (const auto* call_node = expr.as<relay::CallNode>()) {
+    if (call_node->attrs.defined()) {
+      AttrGetter getter(&attrs);
+      const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
+    }
+  } else if (expr->IsInstance<relay::FunctionNode>()) {
+    attrs = RelayFuncAttrGetter().GetAttrs(expr);
+  } else if (const auto* const_node = expr.as<relay::ConstantNode>()) {
+    if (const_node->is_scalar()) {
+      const float val = 
ExprUtils::GetScalar<float>(Downcast<relay::Constant>(expr));
+      std::stringstream stream;
+      stream << std::fixed << std::setprecision(config_.float_precision) << 
val;
+      attrs.Set("scalar", stream.str());
+    }
+  } else if (const auto* get_node = expr.as<relay::TupleGetItemNode>()) {
+    attrs.Set("index", std::to_string(get_node->index));
+  }
+  // Get scope
+  Array<String> scope;
+  if (optype != "input" && optype != "constant") {
+    scope.push_back("block");
+  }
+  // Build inputs and weights
+  Array<String> input_names;
+  Map<String, MSCTensor> node_weights;
+  if (const auto* call_node = expr.as<relay::CallNode>()) {
+    const auto& input_types = ExprUtils::GetInputTypes(optype, 
call_node->args.size(), false);
+    for (size_t i = 0; i < call_node->args.size(); i++) {
+      const auto& arg = call_node->args[i];
+      ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg;
+      if (input_types[i] != "input" && arg->IsInstance<relay::ConstantNode>()) 
{
+        const auto& t_name = expr_tensor_map_[arg][0];
+        const auto& w_name = SpanUtils::GetAttr(arg->span, "name");
+        const auto& pair = tensor_input_map_[t_name];
+        const auto& producer = Downcast<MSCJoint>(pair.first);
+        if (!weights_.count(w_name)) {
+          const auto& ref = producer->OutputAt(pair.second);
+          const auto& weight = MSCTensor(w_name, ref->dtype, 
ref->layout.name(), ref->shape);
+          weights_.Set(w_name, weight);
+        }
+        if (producer->HasAttr("scalar")) {
+          attrs.Set(input_types[i], 
producer->GetTypeAttr<std::string>("scalar"));
+        }
+        node_weights.Set(input_types[i], weights_[w_name]);
+      } else {
+        for (const auto& in_name : expr_tensor_map_[arg]) {
+          input_names.push_back(in_name);
+        }
+      }
+    }
+  } else if (const auto* f_node = expr.as<relay::FunctionNode>()) {
+    for (const auto& p : f_node->params) {
+      for (const auto& in_name : expr_tensor_map_[p]) {
+        input_names.push_back(in_name);
+      }
+    }
+    ICHECK(HasFuncScope()) << "Function without func scope " << 
relay::PrettyPrint(expr);
+    const auto& weight_names = func_scopes_.top().GetFuncWeights();
+    const auto& input_types =
+        ExprUtils::GetInputTypes(optype, f_node->params.size() + 
weight_names.size(), false);
+    for (size_t i = 0; i < weight_names.size(); i++) {
+      const auto& pair = tensor_input_map_[weight_names[i]];
+      const auto& producer = Downcast<MSCJoint>(pair.first);
+      if (!weights_.count(producer->name)) {
+        const auto& ref = producer->OutputAt(pair.second);
+        const auto& weight = MSCTensor(producer->name, ref->dtype, 
ref->layout.name(), ref->shape);
+        weights_.Set(producer->name, weight);
+      }
+      if (producer->HasAttr("scalar")) {
+        attrs.Set(input_types[i], 
producer->GetTypeAttr<std::string>("scalar"));
+      }
+      node_weights.Set(input_types[i + f_node->params.size()], 
weights_[producer->name]);
+    }
+  } else if (const auto* tuple_node = expr.as<relay::TupleNode>()) {
+    for (const auto& f : tuple_node->fields) {
+      ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
+      for (const auto& in_name : expr_tensor_map_[f]) {
+        input_names.push_back(in_name);
+      }
+    }
+  } else if (const auto* getitem_node = expr.as<relay::TupleGetItemNode>()) {
+    ICHECK(expr_tensor_map_.count(getitem_node->tuple))
+        << "Can not find tuple " << getitem_node->tuple;
+    input_names = expr_tensor_map_[getitem_node->tuple];
+  }
+  std::vector<std::pair<BaseJoint, size_t>> inputs;
+  for (const auto& i : input_names) {
+    inputs.push_back(tensor_input_map_[i]);
+  }
+  // Build outputs
+  Array<MSCTensor> outputs;
+  const auto& layout = SpanUtils::GetAttr(expr->span, "layout");
+  Type checked_type = expr->checked_type_;
+  if (checked_type.defined() && 
checked_type->IsInstance<relay::FuncTypeNode>()) {
+    checked_type = Downcast<FuncType>(checked_type)->ret_type;
+  }
+  if (checked_type.defined()) {
+    if (const auto* t_info = checked_type.as<relay::TensorTypeNode>()) {
+      const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+      const auto& output =
+          MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype, 
layout, shape);
+      outputs.push_back(output);
+    } else if (const auto* tuple_info = 
checked_type.as<relay::TupleTypeNode>()) {
+      Array<String> layouts = StringUtils::Split(layout, ",");
+      if (layouts.size() == 0) {
+        layouts = Array<String>(tuple_info->fields.size(), "");
+      }
+      ICHECK_EQ(layouts.size(), tuple_info->fields.size())
+          << "Layout " << layout << " msimatch with fileds size " << 
tuple_info->fields.size();
+      size_t field_size = tuple_info->fields.size();
+      if (optype == "nn.batch_norm") {
+        field_size = 1;
+      }
+      for (size_t i = 0; i < field_size; i++) {
+        const auto& t_info = 
Downcast<relay::TensorType>(tuple_info->fields[i]);
+        const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+        const auto& output =
+            MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype, 
layouts[i], shape);
+        outputs.push_back(output);
+      }
+    } else {
+      LOG(FATAL) << "Unexpected checked_type " << checked_type;
+    }
+  }
+
+  // Build node
+  const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype, 
attrs, scope, inputs,
+                              outputs, node_weights);
+  Array<String> output_names;
+  for (size_t i = 0; i < outputs.size(); i++) {
+    output_names.push_back(outputs[i]->name);
+    tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
+  }
+  nodes_.push_back(node);
+  expr_tensor_map_.Set(expr, output_names);
+  return node;
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::ConstantNode* op) {
+  const auto& node = AddNode(GetRef<relay::Constant>(op));
+  if (HasFuncScope()) {
+    func_scopes_.top().AddFuncWeight(node->OutputAt(0)->name);
+  }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::FunctionNode* op) {
+  const auto& name_opt = op->GetAttr<runtime::String>(relay::attr::kComposite);
+  if (name_opt.defined()) {
+    StartFuncScope(SpanUtils::GetAttr(op->span, "name"));
+  }
+  RelayExprVisitor::VisitExpr_(op);
+  if (HasFuncScope()) {
+    AddNode(GetRef<relay::Function>(op));
+    EndFuncScope();
+  }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::CallNode* op) {
+  if (const auto* f_node = op->op.as<relay::FunctionNode>()) {
+    const auto& name_opt = 
f_node->GetAttr<runtime::String>(relay::attr::kComposite);
+    if (name_opt.defined()) {
+      for (size_t i = 0; i < op->args.size(); i++) {
+        ICHECK(expr_tensor_map_.count(op->args[i]))
+            << "Can not find argument " << relay::PrettyPrint(op->args[i]);
+        expr_tensor_map_.Set(f_node->params[i], expr_tensor_map_[op->args[i]]);
+      }
+    }
+  }
+  RelayExprVisitor::VisitExpr_(op);
+  if (!HasFuncScope() && op->op->IsInstance<OpNode>()) {
+    try {
+      AddNode(GetRef<relay::Call>(op));
+    } catch (runtime::InternalError& err) {
+      LOG(WARNING) << "Failed to add node from " << 
relay::PrettyPrint(GetRef<relay::Call>(op))
+                   << " : " << err.message();
+      throw err;
+    }
+  }
+  if (op->op->IsInstance<relay::FunctionNode>() && 
expr_tensor_map_.count(op->op)) {
+    expr_tensor_map_.Set(GetRef<relay::Call>(op), expr_tensor_map_[op->op]);
+  }
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::TupleNode* val) {
+  RelayExprVisitor::VisitExpr_(val);
+  AddNode(GetRef<relay::Tuple>(val));
+}
+
+void RelayGraphBuilder::VisitExpr_(const relay::TupleGetItemNode* val) {
+  RelayExprVisitor::VisitExpr_(val);
+  AddNode(GetRef<relay::TupleGetItem>(val));
+}
+
+void RelayGraphBuilder::StartFuncScope(const String& name) {
+  RelayFuncScope func_scope = RelayFuncScope(name);
+  func_scopes_.push(func_scope);
+}
+void RelayGraphBuilder::EndFuncScope() {
+  ICHECK(HasFuncScope()) << "No FuncScope found";
+  func_scopes_.pop();
+}
+
+bool RelayGraphBuilder::HasFuncScope() { return func_scopes_.size() > 0; }
+
+Map<MSCTensor, NDArray> RelayWeightsExtractor::GetWeights(const 
relay::Function& func) {
+  VisitExpr(func);
+  return weights_;
+}
+
+void RelayWeightsExtractor::VisitExpr_(const relay::ConstantNode* op) {
+  const auto& name = SpanUtils::GetAttr(op->span, "name");
+  const auto& layout = SpanUtils::GetAttr(op->span, "layout");
+  const auto& t_info = op->tensor_type();
+  const auto& shape = ArrayUtils::Cast<Integer>(t_info->shape);
+  const auto& weight = MSCTensor(name, t_info->dtype, layout, shape);
+  weights_.Set(weight, op->data);
+}
+
+TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax")
+    .set_body_typed([](const IRModule& relax_module, const String& entry_name,
+                       const String& options) -> MSCGraph {
+      const auto& func = 
Downcast<relax::Function>(relax_module->Lookup(entry_name));
+      return RelaxGraphBuilder(relax_module, entry_name, options).Build(func);
+    });
+
+TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights")
+    .set_body_typed([](const IRModule& relax_module,
+                       const String& entry_name) -> Map<MSCTensor, NDArray> {
+      const auto& func = 
Downcast<relax::Function>(relax_module->Lookup(entry_name));
+      return RelaxWeightsExtractor().GetWeights(func);
+    });
+
+TVM_REGISTER_GLOBAL("msc.core.BuildFromRelay")
+    .set_body_typed([](const IRModule& relay_module, const String& entry_name,
+                       const String& options) -> MSCGraph {
+      const auto& func = 
Downcast<relay::Function>(relay_module->Lookup(entry_name));
+      return RelayGraphBuilder(relay_module, entry_name, options).Build(func);
+    });
+
+TVM_REGISTER_GLOBAL("msc.core.GetRelayWeights")
+    .set_body_typed([](const IRModule& relay_module,
+                       const String& entry_name) -> Map<MSCTensor, NDArray> {
+      const auto& func = 
Downcast<relay::Function>(relay_module->Lookup(entry_name));
+      return RelayWeightsExtractor().GetWeights(func);
+    });
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/contrib/msc/core/ir/graph_builder.h 
b/src/contrib/msc/core/ir/graph_builder.h
new file mode 100644
index 0000000000..bb7223695d
--- /dev/null
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/ir/graph_builder.h
+ * \brief Builder of MSCGraph.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
+#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
+
+#include <dmlc/json.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/tir/data_layout.h>
+
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../utils.h"
+#include "graph.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using RelaxExprVisitor = tvm::relax::ExprVisitor;
+using RelayExprVisitor = tvm::relay::ExprVisitor;
+using namespace tvm::runtime;
+
+/*!
+ * \brief Config for building MSCGraph.
+ *  Define the configuration for building MSCGraph
+ */
+struct MSCRBuildConfig {
+  bool prune_graph{false};
+  int float_precision = 6;
+  std::string sort_by;
+  std::vector<std::string> input_aliass;
+  std::vector<std::string> output_aliass;
+  std::unordered_map<std::string, std::vector<std::string>> input_types;
+
+  void LoadInputTypes(dmlc::JSONReader* reader) {
+    std::string key;
+    reader->BeginObject();
+    while (reader->NextObjectItem(&key)) {
+      std::vector<std::string> types;
+      reader->Read(&types);
+      input_types[key] = types;
+    }
+  }
+
+  void Load(dmlc::JSONReader* reader) {
+    std::string key;
+    reader->BeginObject();
+    while (reader->NextObjectItem(&key)) {
+      if (key == "prune_graph") {
+        reader->Read(&prune_graph);
+      } else if (key == "float_precision") {
+        reader->Read(&float_precision);
+      } else if (key == "sort_by") {
+        reader->Read(&sort_by);
+      } else if (key == "input_aliass") {
+        reader->Read(&input_aliass);
+      } else if (key == "output_aliass") {
+        reader->Read(&output_aliass);
+      } else if (key == "input_types") {
+        this->LoadInputTypes(reader);
+      }
+    }
+  }
+};
+
+class AttrGetter : public AttrVisitor {
+ public:
+  /*!
+   * \brief Get the attributes as Map<String, String>
+   * \param attrs the attributes.
+   */
+  explicit AttrGetter(Map<String, String>* attrs) : attrs_(attrs) {}
+
+  void Visit(const char* key, double* value) final { attrs_->Set(key, 
std::to_string(*value)); }
+
+  void Visit(const char* key, int64_t* value) final { attrs_->Set(key, 
std::to_string(*value)); }
+
+  void Visit(const char* key, uint64_t* value) final { attrs_->Set(key, 
std::to_string(*value)); }
+
+  void Visit(const char* key, int* value) final { attrs_->Set(key, 
std::to_string(*value)); }
+
+  void Visit(const char* key, bool* value) final { attrs_->Set(key, 
std::to_string(*value)); }
+
+  void Visit(const char* key, std::string* value) final { attrs_->Set(key, 
*value); }
+
+  void Visit(const char* key, DataType* value) final {
+    attrs_->Set(key, runtime::DLDataType2String(*value));
+  }
+
+  void Visit(const char* key, runtime::ObjectRef* value) final {
+    attrs_->Set(key, StringUtils::ToString(*value));
+  }
+
+  void Visit(const char* key, void** value) final {
+    LOG(FATAL) << "TypeError: void is not allowed in Attrs";
+  }
+
+  void Visit(const char* key, runtime::NDArray* value) final {
+    LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs";
+  }
+
+ private:
+  Map<String, String>* attrs_;
+};
+
+class RelaxFuncAttrGetter : public RelaxExprVisitor {
+ public:
+  /*! \brief Get the attributes as Map<String, String>*/
+  Map<String, String> GetAttrs(const Expr& expr) {
+    RelaxExprVisitor::VisitExpr(expr);
+    return attrs_;
+  }
+
+  void VisitExpr_(const relax::CallNode* op) final;
+
+ private:
+  Map<String, String> attrs_;
+};
+
+class RelaxGraphBuilder : public RelaxExprVisitor {
+ public:
+  /*!
+   * \brief The constructor of RelaxGraphBuilder
+   * \param ref_module the reference module.
+   * \param name the name of the graph.
+   * \param options the options of build the graph.
+   */
+  explicit RelaxGraphBuilder(const IRModule& ref_module, const String& name,
+                             const std::string& options = "")
+      : RelaxExprVisitor() {
+    name_ = name;
+    ref_module_ = ref_module;
+    if (options.size() > 0) {
+      std::istringstream is(options);
+      dmlc::JSONReader reader(&is);
+      reader.Read(&config_);
+    }
+  }
+
+  /*! \brief Build MSCGraph from relax function*/
+  const MSCGraph Build(const relax::Function& func);
+
+  /*! \brief Create and add MSCJoint from expr*/
+  const MSCJoint AddNode(const Expr& expr, const Optional<Expr>& binding_var = 
NullOpt,
+                         const String& name = "");
+
+  void VisitBindingBlock(const relax::BindingBlock& block) final;
+
+  void VisitExpr_(const relax::ConstantNode* op) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::ConstantNode* val) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::ShapeExprNode* val) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::CallNode* call_node) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::TupleNode* val) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding,
+                     const relax::TupleGetItemNode* val) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::VarNode* val) final;
+
+  void VisitBinding_(const relax::VarBindingNode* binding, const 
relax::DataflowVarNode* val) final;
+
+ private:
+  String name_;
+  String scope_name_;
+  IRModule ref_module_;
+  MSCRBuildConfig config_;
+  Array<MSCJoint> nodes_;
+  Map<String, MSCTensor> weights_;
+  Map<Expr, Array<String>> expr_tensor_map_;
+  std::unordered_map<String, std::pair<BaseJoint, size_t>> tensor_input_map_;
+};
+
+class RelaxWeightsExtractor : public RelaxExprVisitor {
+ public:
+  /*! \brief Visit the constant and save weights */
+  Map<MSCTensor, NDArray> GetWeights(const relax::Function& func);
+
+  void VisitExpr_(const relax::ConstantNode* op) final;
+
+ private:
+  Map<MSCTensor, NDArray> weights_;
+};
+
+class RelayFuncAttrGetter : public RelayExprVisitor {
+ public:
+  /*! \brief Get the attributes as Map<String, String>*/
+  Map<String, String> GetAttrs(const Expr& expr) {
+    RelayFuncAttrGetter::VisitExpr(expr);
+    return attrs_;
+  }
+
+  void VisitExpr_(const relay::CallNode* op) final;
+
+ private:
+  Map<String, String> attrs_;
+};
+
+/*!
+ * \brief A Scope for recording func
+ */
+class RelayFuncScope {
+ public:
+  /*! \brief The constructor */
+  explicit RelayFuncScope(const String& name) : name_(name) {}
+
+  /*! \brief Add a weight */
+  void AddFuncWeight(const String& weight) { func_weights_.push_back(weight); }
+
+  /*! \brief Get weights */
+  const Array<String> GetFuncWeights() { return func_weights_; }
+
+ private:
+  String name_;
+  Array<String> func_weights_;
+};
+
+class RelayGraphBuilder : public RelayExprVisitor {
+ public:
+  /*!
+   * \brief The constructor of RelayGraphBuilder
+   * \param ref_module the reference module.
+   * \param name the name of the graph.
+   * \param options the options of build the graph.
+   */
+  explicit RelayGraphBuilder(const IRModule& ref_module, const String& name,
+                             const std::string& options = "")
+      : RelayExprVisitor() {
+    name_ = name;
+    ref_module_ = ref_module;
+    if (options.size() > 0) {
+      std::istringstream is(options);
+      dmlc::JSONReader reader(&is);
+      reader.Read(&config_);
+    }
+    while (!func_scopes_.empty()) {
+      func_scopes_.pop();
+    }
+  }
+
+  /*! \brief Build MSCGraph from relax function*/
+  MSCGraph Build(const relay::Function& func);
+
+  /*! \brief Create and add MSCJoint from expr*/
+  MSCJoint AddNode(const Expr& expr, const String& name = "");
+
+  void VisitExpr_(const relay::ConstantNode* op) final;
+
+  void VisitExpr_(const relay::FunctionNode* op) final;
+
+  void VisitExpr_(const relay::CallNode* op) final;
+
+  void VisitExpr_(const relay::TupleNode* val) final;
+
+  void VisitExpr_(const relay::TupleGetItemNode* val) final;
+
+ protected:
+  /*! \brief Start a func scope */
+  void StartFuncScope(const String& scope);
+
+  /*! \brief End a func scope */
+  void EndFuncScope();
+
+  /*! \brief Check if has func scopes left */
+  bool HasFuncScope();
+
+ private:
+  String name_;
+  MSCRBuildConfig config_;
+  IRModule ref_module_;
+  Array<MSCJoint> nodes_;
+  Map<String, MSCTensor> weights_;
+  Map<Expr, Array<String>> expr_tensor_map_;
+  std::unordered_map<String, std::pair<BaseJoint, size_t>> tensor_input_map_;
+  std::stack<RelayFuncScope> func_scopes_;
+};
+
+class RelayWeightsExtractor : public RelayExprVisitor {
+ public:
+  /*! \brief Visit the constant and save weights*/
+  Map<MSCTensor, NDArray> GetWeights(const relay::Function& func);
+
+  void VisitExpr_(const relay::ConstantNode* op) final;
+
+ private:
+  Map<MSCTensor, NDArray> weights_;
+};
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+#endif  // TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
diff --git a/src/contrib/msc/core/transform/layout_utils.cc 
b/src/contrib/msc/core/transform/layout_utils.cc
index ffc631c6d0..3c70e1871b 100644
--- a/src/contrib/msc/core/transform/layout_utils.cc
+++ b/src/contrib/msc/core/transform/layout_utils.cc
@@ -22,6 +22,7 @@
  */
 #include "layout_utils.h"
 
+#include <algorithm>
 #include <set>
 #include <string>
 
@@ -118,12 +119,15 @@ const LayoutDecision LayoutUtils::ExpandLayout(const 
LayoutDecision& src_layout,
   if (!src_layout->layout.defined()) {
     return src_layout;
   }
+  // sort expand axes
+  std::vector<size_t> axes = expand_axes;
+  std::sort(std::begin(axes), std::end(axes));
   std::string new_layout = src_layout.name();
   ICHECK_EQ(new_layout.size(), src_layout->layout.ndim())
       << "Only support normal layout, get " << src_layout->layout;
   std::vector<std::string> priority_dims{"N", "C", "H", "W", "D", "G", "T"};
-  size_t left_size = expand_axes.size();
-  for (const auto& a : expand_axes) {
+  size_t left_size = axes.size();
+  for (const auto& a : axes) {
     std::string target = "U";
     if (new_layout.find("H") && !new_layout.find("W")) {
       target = "W";
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc 
b/src/contrib/msc/core/transform/set_expr_layout.cc
index 5915bef9e1..a94c846b3f 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -1118,28 +1118,27 @@ class LayoutInfer : public ExprVisitor {
 
   void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) 
final {
     ExprVisitor::VisitBinding_(binding, val);
-    std::vector<NLayout> input_layout;
-    for (const auto& field : val->fields) {
-      if (binding->var->IsInstance<DataflowVarNode>()) {
-        // Df var: Use the current realized layout to group the tuple;
-        input_layout.push_back(GetNLayout(var_layout_map_, field));
-      } else {
-        // Global var: Use the initial layout to group the tuple;
-        input_layout.push_back(InitialNLayout(field));
-      }
-    }
     if (IsNestedTensor(binding->var)) {
-      var_layout_map_[binding->var] = input_layout;
+      Array<NLayout> input_layouts;
+      for (const auto& field : val->fields) {
+        input_layouts.push_back(InferLayoutDecision(field, var_layout_map_));
+      }
+      var_layout_map_[binding->var] = input_layouts;
+      if (LayoutUtils::SetLayout(GetRef<Tuple>(val), NLayout(input_layouts))) {
+        infered_ = true;
+      }
     }
     RecordExpr(binding->var, GetRef<Tuple>(val));
   }
 
   void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
val) final {
     ExprVisitor::VisitBinding_(binding, val);
-    NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
-                               ? GetNLayout(var_layout_map_, val->tuple)
-                               : InitialNLayout(val->tuple);
-    var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+    const auto& out_layout =
+        InferLayoutDecisionAt(GetRef<TupleGetItem>(val)->tuple, 
var_layout_map_, val->index);
+    var_layout_map_[binding->var] = out_layout;
+    if (LayoutUtils::SetLayout(GetRef<TupleGetItem>(val), out_layout)) {
+      infered_ = true;
+    }
     RecordExpr(binding->var, GetRef<TupleGetItem>(val));
   }
 
diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh
index ac93a6f15d..7fb6af30a1 100755
--- a/tests/lint/pylint.sh
+++ b/tests/lint/pylint.sh
@@ -45,3 +45,6 @@ python3 -m pylint tests/python/frontend/oneflow/*.py 
--rcfile="$(dirname "$0")"/
 python3 -m pylint tests/python/frontend/tensorflow/test_forward.py 
--rcfile="$(dirname "$0")"/pylintrc
 python3 -m pylint tests/python/frontend/pytorch/test_forward.py 
--rcfile="$(dirname "$0")"/pylintrc
 python3 -m pylint tests/python/frontend/tflite/test_forward.py 
--rcfile="$(dirname "$0")"/pylintrc
+
+# tests/python/contrib/test_msc tests
+python3 -m pylint tests/python/contrib/test_msc/*.py --rcfile="$(dirname 
"$0")"/pylintrc
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
new file mode 100644
index 0000000000..6e410e584c
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -0,0 +1,2037 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test graph builder && graph. """
+
+import torch
+from torch import fx
+from torch.nn import Module
+
+import tvm.testing
+from tvm.relax.frontend.torch import from_fx
+from tvm.contrib.msc.core.ir import translate
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+def verify_model(torch_model, input_info, expected):
+    graph_model = fx.symbolic_trace(torch_model)
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+    graph, _ = translate.from_relax(mod)
+    inspect = graph.inspect()
+    assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with 
expected {}".format(
+        inspect, expected
+    )
+
+
+def test_conv1d():
+    """test graph builder for conv1d"""
+
+    class Conv1D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
+
+        def forward(self, data):
+            return self.conv(data)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", 
"layout": "NCW"}],
+        "outputs": [
+            {"name": "msc.conv1d_bias", "shape": [1, 6, 4], "dtype": 
"float32", "layout": "NCW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1},
+    }
+
+    class Conv1D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+
+        def forward(self, data):
+            return self.conv(data)
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", 
"layout": "NCW"}],
+        "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32", 
"layout": "NCW"}],
+        "nodes": {"total": 2, "input": 1, "nn.conv1d": 1},
+    }
+
+    input_info = [([1, 3, 10], "float32")]
+    verify_model(Conv1D1(), input_info, expected1)
+    verify_model(Conv1D2(), input_info, expected2)
+
+
+def test_conv2d():
+    """test graph builder for conv2d"""
+
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, data):
+            return self.conv(data)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {
+                "name": "msc.conv2d_bias",
+                "shape": [1, 6, 4, 4],
+                "dtype": "float32",
+                "layout": "NCHW",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
+    }
+
+    class Conv2D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+        def forward(self, data):
+            return self.conv(data)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "conv2d", "shape": [1, 6, 4, 4], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.conv2d": 1},
+    }
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Conv2D1(), input_info, expected1)
+    verify_model(Conv2D2(), input_info, expected2)
+
+
+def test_linear():
+    """test graph builder for linear"""
+
+    class Dense1(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=True)
+
+        def forward(self, data):
+            return self.linear(data)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {
+                "name": "msc.linear_bias",
+                "shape": [1, 3, 10, 7],
+                "dtype": "float32",
+                "layout": "NCHW",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.linear_bias": 1},
+    }
+
+    class Dense2(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=False)
+
+        def forward(self, data):
+            return self.linear(data)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "msc.linear", "shape": [1, 3, 10, 7], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.linear": 1},
+    }
+
+    class MatMul1(Module):
+        def forward(self, x, y):
+            return torch.matmul(x, y)
+
+    expected3 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": 
"NC"},
+            {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": 
"IO"},
+        ],
+        "outputs": [{"name": "matmul", "shape": [10, 10], "dtype": "float32", 
"layout": "NC"}],
+        "nodes": {"total": 3, "input": 2, "matmul": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Dense1(), input_info, expected1)
+    verify_model(Dense2(), input_info, expected2)
+    verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")], 
expected3)
+
+
+def test_bmm():
+    """test graph builder for bmm"""
+
+    class BMM(Module):
+        def forward(self, x, y):
+            return torch.bmm(x, y)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [4, 128, 256], "dtype": "float32", 
"layout": "NCD"},
+            {"name": "inp_1", "shape": [4, 256, 512], "dtype": "float32", 
"layout": "NIO"},
+        ],
+        "outputs": [
+            {"name": "matmul", "shape": [4, 128, 512], "dtype": "float32", 
"layout": "NCD"}
+        ],
+        "nodes": {"total": 3, "input": 2, "matmul": 1},
+    }
+
+    input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")]
+    verify_model(BMM(), input_info, expected)
+
+
+def test_baddbmm():
+    """test graph builder for baddbmm"""
+
+    class BAddBMM1(Module):
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", 
"layout": "NCD"},
+            {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", 
"layout": "NCD"},
+            {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", 
"layout": "NIO"},
+        ],
+        "outputs": [{"name": "add", "shape": [4, 128, 512], "dtype": 
"float32", "layout": "NCD"}],
+        "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
+    }
+
+    class BAddBMM2(Module):
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", 
"layout": ""},
+            {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", 
"layout": "NCD"},
+            {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", 
"layout": "NIO"},
+        ],
+        "outputs": [
+            {"name": "multiply", "shape": [4, 128, 512], "dtype": "float32", 
"layout": "NCD"}
+        ],
+        "nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1, 
"multiply": 1},
+    }
+
+    input_info = [
+        ((4, 128, 512), "float32"),
+        ((4, 128, 256), "float32"),
+        ((4, 256, 512), "float32"),
+    ]
+    verify_model(BAddBMM1(), input_info, expected1)
+    verify_model(BAddBMM2(), input_info, expected2)
+
+
+def test_relu():
+    """test graph builder for relu"""
+
+    class ReLU(Module):
+        def __init__(self):
+            super().__init__()
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, data):
+            return self.relu(data)
+
+    class ReLU1(Module):
+        def forward(self, data):
+            return torch.nn.functional.relu(data)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [{"name": "relu", "shape": [10, 10], "dtype": "float32", 
"layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "nn.relu": 1},
+    }
+
+    input_info = [([10, 10], "float32")]
+    verify_model(ReLU(), input_info, expected)
+    verify_model(ReLU1(), input_info, expected)
+
+
+def test_relu6():
+    """test graph builder for relu6"""
+
+    class ReLU6(Module):
+        def __init__(self):
+            super().__init__()
+            self.relu6 = torch.nn.ReLU6()
+
+        def forward(self, data):
+            return self.relu6(data)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "clip", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "clip": 1},
+    }
+    input_info = [([10, 10], "float32")]
+    verify_model(ReLU6(), input_info, expected)
+
+
+def test_maxpool2d():
+    """test graph builder for maxpool2d"""
+
+    class MaxPool2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "max_pool2d", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+    }
+
+    class MaxPool2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "max_pool2d", "shape": [1, 3, 4, 4], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+    }
+
+    class MaxPool2d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, 
stride=2)
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected3 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "max_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(MaxPool2d(), input_info, expected1)
+    verify_model(MaxPool2d2(), input_info, expected2)
+    verify_model(MaxPool2d3(), input_info, expected3)
+
+
+def test_avgpool2d():
+    """test graph builder for avgpool2d"""
+
+    class AvgPool2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "avg_pool2d", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
+    }
+
+    class AvgPool2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, 
padding=2, ceil_mode=True)
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "avg_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(AvgPool2d(), input_info, expected1)
+    verify_model(AvgPool2d2(), input_info, expected2)
+
+
+def test_adaptive_avgpool2d():
+    """test graph builder for adaptive_avgpool2d"""
+
+    class AdaptiveAvgPool2d0(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+        def forward(self, data):
+            return self.pool(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {
+                "name": "adaptive_avg_pool2d",
+                "shape": [1, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "NCHW",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(AdaptiveAvgPool2d0(), input_info, expected)
+
+
+def test_flatten():
+    """test graph builder for flatten"""
+
+    class Flatten(Module):
+        def __init__(self):
+            super().__init__()
+            self.f = torch.nn.Flatten(2, -1)
+
+        def forward(self, data):
+            return self.f(data)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "reshape", "shape": [1, 3, 100], "dtype": 
"float32", "layout": ""}],
+        "nodes": {"total": 2, "input": 1, "reshape": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Flatten(), input_info, expected)
+    verify_model(torch.nn.Flatten(2, -1), input_info, expected)
+
+
+def test_batchnorm2d():
+    """test graph builder for batchnorm2d"""
+
+    class BatchNorm2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.batchnorm = torch.nn.BatchNorm2d(3)
+
+        def forward(self, data):
+            return self.batchnorm(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {
+                "name": "batch_norm.0",
+                "shape": [1, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "NCHW",
+            }
+        ],
+        "nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(BatchNorm2d(), input_info, expected)
+
+
+def test_embedding():
+    """test graph builder for embedding"""
+
+    class Embedding(Module):
+        def __init__(self):
+            super().__init__()
+            self.embedding = torch.nn.Embedding(10, 3)
+
+        def forward(self, data):
+            return self.embedding(data)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [4], "dtype": "int64", "layout": 
"A"}],
+        "outputs": [{"name": "msc.embedding", "shape": [4, 3], "dtype": 
"float32", "layout": "NA"}],
+        "nodes": {"total": 2, "input": 1, "msc.embedding": 1},
+    }
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [4, 5], "dtype": "int64", 
"layout": "AB"}],
+        "outputs": [
+            {"name": "msc.embedding", "shape": [4, 5, 3], "dtype": "float32", 
"layout": "CNB"}
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.embedding": 1},
+    }
+
+    verify_model(Embedding(), [([4], "int64")], expected1)
+    verify_model(Embedding(), [([4, 5], "int64")], expected2)
+
+
+def test_dropout():
+    """test graph builder for dropout"""
+
+    class Dropout1(Module):
+        def __init__(self):
+            super().__init__()
+            self.dropout = torch.nn.Dropout(0.5)
+
+        def forward(self, data):
+            return self.dropout(data)
+
+    class Dropout2(Module):
+        def forward(self, data):
+            return torch.dropout(data, 0.5, train=True)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "nodes": {"total": 1, "input": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Dropout1(), input_info, expected)
+    verify_model(Dropout2(), input_info, expected)
+
+
+def test_layernorm():
+    """test graph builder for layernorm"""
+
+    class LayerNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.layernorm = torch.nn.LayerNorm((10, 10))
+
+        def forward(self, data):
+            return self.layernorm(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(LayerNorm(), input_info, expected)
+
+
+def test_functional_layernorm():
+    """test graph builder for functional_layernorm"""
+
+    class LayerNorm(Module):
+        def __init__(self, shape):
+            super().__init__()
+            self.weight = torch.nn.Parameter(torch.ones(shape))
+            self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+        def forward(self, data):
+            return torch.nn.functional.layer_norm(
+                data, self.weight.shape, self.weight, self.bias, 1e-5
+            )
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(LayerNorm((10, 10)), input_info, expected)
+
+
+def test_cross_entropy():
+    """test graph builder for cross_entropy"""
+
+    class CrossEntropy1(Module):
+        def __init__(self):
+            super().__init__()
+            self.loss = torch.nn.CrossEntropyLoss()
+
+        def forward(self, logits, targets):
+            return self.loss(logits, targets)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+        ],
+        "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 
1},
+    }
+
+    class CrossEntropy2(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.nn.Parameter(torch.ones((2,)))
+            self.loss = torch.nn.CrossEntropyLoss(weight=self.weight)
+
+        def forward(self, logits, targets):
+            return self.loss(logits, targets)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+        ],
+        "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, 
"nn.nll_loss": 1},
+    }
+
+    class CrossEntropy3(Module):
+        def __init__(self):
+            super().__init__()
+            self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, 
reduction="sum")
+
+        def forward(self, logits, targets):
+            return self.loss(logits, targets)
+
+    expected3 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+        ],
+        "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 
1},
+    }
+
+    input_info = [([3, 2], "float32"), ([3], "int32")]
+    verify_model(CrossEntropy1(), input_info, expected1)
+    verify_model(CrossEntropy2(), input_info, expected2)
+    verify_model(CrossEntropy3(), input_info, expected3)
+
+
+def test_functional_cross_entropy():
+    """test graph builder for functional_cross_entropy"""
+
+    class CrossEntropy(Module):
+        def forward(self, logits, targets):
+            return torch.nn.functional.cross_entropy(logits, targets)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [3, 10], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""},
+        ],
+        "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 
1},
+    }
+
+    input_info = [([3, 10], "float32"), ([3], "int32")]
+    verify_model(CrossEntropy(), input_info, expected)
+
+
+def test_silu():
+    """test graph builder for silu"""
+
+    class SiLU(Module):
+        def __init__(self):
+            super().__init__()
+            self.silu = torch.nn.SiLU()
+
+        def forward(self, data):
+            return self.silu(data)
+
+    class SiLU2(Module):
+        def forward(self, data):
+            return torch.nn.functional.silu(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "silu", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.silu": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(SiLU(), input_info, expected)
+    verify_model(SiLU2(), input_info, expected)
+
+
+def test_groupnorm():
+    """test graph builder for groupnorm"""
+
+    class GroupNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.groupnorm = torch.nn.GroupNorm(3, 3)
+
+        def forward(self, data):
+            return self.groupnorm(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "group_norm", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "NCHW"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.group_norm": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(GroupNorm(), input_info, expected)
+
+
+def test_softmax():
+    """test graph builder for softmax"""
+
+    class Softmax(Module):
+        def __init__(self):
+            super().__init__()
+            self.softmax = torch.nn.Softmax(dim=1)
+
+        def forward(self, data):
+            return self.softmax(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "softmax", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.softmax": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Softmax(), input_info, expected)
+
+
+def test_binary():
+    """test graph builder for binary"""
+
+    input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+    input_info2 = [([1, 3, 10, 10], "float32")]
+
+    # Add
+    class Add1(Module):
+        def forward(self, lhs, rhs):
+            return lhs + rhs
+
+    expected_add1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "ABCD"}],
+        "nodes": {"total": 3, "input": 2, "add": 1},
+    }
+
+    class Add2(Module):
+        def forward(self, lhs):
+            return lhs + 1.0
+
+    expected_add2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "ABCD"}],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "add": 1},
+    }
+
+    verify_model(Add1(), input_info1, expected_add1)
+    verify_model(Add2(), input_info2, expected_add2)
+
+    # Sub
+    class Sub1(Module):
+        def forward(self, lhs, rhs):
+            return lhs - rhs
+
+    expected_sub1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 2, "subtract": 1},
+    }
+
+    class Sub2(Module):
+        def forward(self, lhs):
+            return lhs - 1.0
+
+    expected_sub2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1},
+    }
+
+    verify_model(Sub1(), input_info1, expected_sub1)
+    verify_model(Sub2(), input_info2, expected_sub2)
+
+    # Mul
+    class Mul1(Module):
+        def forward(self, lhs, rhs):
+            return lhs * rhs
+
+    expected_mul1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 2, "multiply": 1},
+    }
+
+    class Mul2(Module):
+        def forward(self, lhs):
+            return lhs * 1.0
+
+    expected_mul2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1},
+    }
+
+    verify_model(Mul1(), input_info1, expected_mul1)
+    verify_model(Mul2(), input_info2, expected_mul2)
+
+    # True div
+    class TrueDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs / rhs
+
+    expected_div1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 2, "divide": 1},
+    }
+
+    class TrueDiv2(Module):
+        def forward(self, lhs):
+            return lhs / 1.0
+
+    expected_div2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1},
+    }
+
+    verify_model(TrueDiv1(), input_info1, expected_div1)
+    verify_model(TrueDiv2(), input_info2, expected_div2)
+
+    # Floor div
+    class FloorDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs // rhs
+
+    expected_floordiv1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {
+                "name": "floor_divide",
+                "shape": [1, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "ABCD",
+            }
+        ],
+        "nodes": {"total": 3, "input": 2, "floor_divide": 1},
+    }
+
+    class FloorDiv2(Module):
+        def forward(self, lhs):
+            return lhs // 1.0
+
+    expected_floordiv2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {
+                "name": "floor_divide",
+                "shape": [1, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "ABCD",
+            }
+        ],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1},
+    }
+
+    verify_model(FloorDiv1(), input_info1, expected_floordiv1)
+    verify_model(FloorDiv2(), input_info2, expected_floordiv2)
+
+    # Power
+    class Power1(Module):
+        def forward(self, lhs, rhs):
+            return lhs**rhs
+
+    expected_power1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 2, "power": 1},
+    }
+
+    class Power2(Module):
+        def forward(self, lhs):
+            return lhs**1.0
+
+    expected_power2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "power": 1},
+    }
+
+    verify_model(Power1(), input_info1, expected_power1)
+    verify_model(Power2(), input_info2, expected_power2)
+
+    # LT
+    class LT1(Module):
+        def forward(self, lhs, rhs):
+            return lhs < rhs
+
+    expected_lt1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", 
"layout": "ABCD"}],
+        "nodes": {"total": 3, "input": 2, "less": 1},
+    }
+
+    class LT2(Module):
+        def forward(self, lhs):
+            return lhs < 1.0
+
+    expected_lt2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", 
"layout": "ABCD"}],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "less": 1},
+    }
+
+    verify_model(LT1(), input_info1, expected_lt1)
+    verify_model(LT2(), input_info2, expected_lt2)
+
+
+def test_size():
+    """test graph builder for size"""
+
+    class Size(Module):
+        def forward(self, data):
+            return data.size()
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", 
"layout": "O"}],
+        "nodes": {"total": 2, "input": 1, "shape": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Size(), input_info, expected)
+
+
+def test_squeeze():
+    """test graph builder for squeeze"""
+
+    class Squeeze1(Module):
+        def forward(self, data):
+            return data.squeeze(1)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": 
"float32", "layout": "ANBC"}],
+        "outputs": [{"name": "squeeze", "shape": [3, 4, 1], "dtype": 
"float32", "layout": "ABC"}],
+        "nodes": {"total": 2, "input": 1, "squeeze": 1},
+    }
+
+    class Squeeze2(Module):
+        def forward(self, data):
+            return data.squeeze()
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": 
"float32", "layout": "ANBC"}],
+        "outputs": [{"name": "squeeze", "shape": [3, 4], "dtype": "float32", 
"layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "squeeze": 1},
+    }
+
+    input_info = [([3, 1, 4, 1], "float32")]
+    verify_model(Squeeze1(), input_info, expected1)
+    verify_model(Squeeze2(), input_info, expected2)
+
+
+def test_unsqueeze():
+    """test graph builder for unsqueeze"""
+
+    class Unsqueeze1(Module):
+        def forward(self, data):
+            return data.unsqueeze(1)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ACDE"}
+        ],
+        "outputs": [
+            {
+                "name": "expand_dims",
+                "shape": [1, 1, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "ABCDE",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "expand_dims": 1},
+    }
+
+    class Unsqueeze2(Module):
+        def forward(self, data):
+            return data.unsqueeze(-1)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCE"}
+        ],
+        "outputs": [
+            {
+                "name": "expand_dims",
+                "shape": [1, 3, 10, 10, 1],
+                "dtype": "float32",
+                "layout": "ABCDE",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "expand_dims": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Unsqueeze1(), input_info, expected1)
+    verify_model(Unsqueeze2(), input_info, expected2)
+
+
+def test_getattr():
+    """test graph builder for getattr"""
+
+    class GetAttr1(Module):
+        def forward(self, data):
+            return data.shape
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", 
"layout": "O"}],
+        "nodes": {"total": 2, "input": 1, "shape": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(GetAttr1(), input_info, expected)
+
+
+def test_getitem():
+    """test graph builder for getitem"""
+
+    class Slice1(Module):
+        def forward(self, x):
+            return x[0, 1::2, :, :3]
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "reshape", "shape": [1, 1, 10, 3], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
+    }
+
+    class Slice2(Module):
+        def forward(self, x):
+            return x[:, None, None, :, None]
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [8, 16], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [
+            {"name": "reshape", "shape": [8, 1, 1, 16, 1], "dtype": "float32", 
"layout": "ANCHB"}
+        ],
+        "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
+    }
+
+    verify_model(Slice1(), [([1, 3, 10, 10], "float32")], expected1)
+    verify_model(Slice2(), [([8, 16], "float32")], expected2)
+
+
+def test_unary():
+    """test graph builder for unary"""
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    # sin
+    class Sin(Module):
+        def forward(self, data):
+            return torch.sin(data)
+
+    expected_sin = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [{"name": "sin", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "ABCD"}],
+        "nodes": {"total": 2, "input": 1, "sin": 1},
+    }
+
+    verify_model(Sin(), input_info, expected_sin)
+
+    # cos
+    class Cos(Module):
+        def forward(self, data):
+            return torch.cos(data)
+
+    expected_cos = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [{"name": "cos", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "ABCD"}],
+        "nodes": {"total": 2, "input": 1, "cos": 1},
+    }
+
+    verify_model(Cos(), input_info, expected_cos)
+
+    # exp
+    class Exp(Module):
+        def forward(self, data):
+            return torch.exp(data)
+
+    expected_exp = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [{"name": "exp", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": "ABCD"}],
+        "nodes": {"total": 2, "input": 1, "exp": 1},
+    }
+
+    verify_model(Exp(), input_info, expected_exp)
+
+    # sqrt
+    class Sqrt(Module):
+        def forward(self, data):
+            return torch.sqrt(data)
+
+    expected_sqrt = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "sqrt", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "sqrt": 1},
+    }
+
+    verify_model(Sqrt(), input_info, expected_sqrt)
+
+    # sigmoid
+    class Sigmoid(Module):
+        def forward(self, data):
+            return torch.sigmoid(data)
+
+    expected_sigmoid = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "sigmoid", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "sigmoid": 1},
+    }
+
+    verify_model(Sigmoid(), input_info, expected_sigmoid)
+
+    # round
+    class Round(Module):
+        def forward(self, data):
+            return torch.round(data)
+
+    expected_round = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "round", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "round": 1},
+    }
+
+    verify_model(Round(), input_info, expected_round)
+
+
+def test_gelu():
+    """test graph builder for gelu"""
+
+    class Gelu(Module):
+        def forward(self, data):
+            return torch.nn.functional.gelu(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "gelu", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "nn.gelu": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Gelu(), input_info, expected)
+
+
+def test_tanh():
+    """test graph builder for tanh"""
+
+    class Tanh(Module):
+        def forward(self, data):
+            return torch.tanh(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "tanh", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "tanh": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Tanh(), input_info, expected)
+
+
+def test_clamp():
+    """test graph builder for clamp"""
+
+    class Clamp(Module):
+        def forward(self, data):
+            return torch.clamp(data, min=0.1, max=0.5)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "clip", "shape": [1, 3, 10, 10], "dtype": 
"float32", "layout": ""}],
+        "nodes": {"total": 2, "input": 1, "clip": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Clamp(), input_info, expected)
+
+
+def test_interpolate():
+    """test graph builder for interpolate"""
+
+    class Interpolate(Module):
+        def forward(self, data):
+            return torch.nn.functional.interpolate(data, (5, 5))
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "resize2d", "shape": [1, 3, 5, 5], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "image.resize2d": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Interpolate(), input_info, expected)
+
+
+def test_addmm():
+    """test graph builder for addmm"""
+
+    class Addmm(Module):
+        def forward(self, x_1, x_2, x_3):
+            return torch.addmm(x_1, x_2, x_3)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": 
"NC"},
+            {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": 
"NC"},
+            {"name": "inp_2", "shape": [10, 10], "dtype": "float32", "layout": 
"IO"},
+        ],
+        "outputs": [{"name": "add", "shape": [10, 10], "dtype": "float32", 
"layout": "NC"}],
+        "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
+    }
+
+    input_info = [
+        ([10, 10], "float32"),
+        ([10, 10], "float32"),
+        ([10, 10], "float32"),
+    ]
+    verify_model(Addmm(), input_info, expected)
+
+
+def test_split():
+    """test graph builder for split"""
+
+    class Split(Module):
+        def forward(self, data):
+            return torch.split(data, 1, dim=1)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "nodes": {"total": 2, "input": 1, "split": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Split(), input_info, expected)
+
+
+def test_cumsum():
+    """test graph builder for cumsum"""
+
+    class Cumsum(Module):
+        def forward(self, data):
+            return torch.cumsum(data, dim=1, dtype=torch.int32)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "cumsum", "shape": [1, 2, 3, 4], "dtype": 
"int32", "layout": ""}],
+        "nodes": {"total": 2, "input": 1, "cumsum": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Cumsum(), input_info, expected)
+
+
+def test_chunk():
+    """test graph builder for chunk"""
+
+    class Chunk(Module):
+        def forward(self, data):
+            return torch.chunk(data, 3, dim=1)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "outputs": [
+            {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+            {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "nodes": {"total": 2, "input": 1, "split": 1},
+    }
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Chunk(), input_info, expected)
+
+
+def test_inplace_fill():
+    """test graph builder for inplace_fill"""
+
+    class InplaceFill(Module):
+        def forward(self, data):
+            data.fill_(1.5)
+            return data
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "full", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1},
+    }
+
+    verify_model(InplaceFill(), [([10, 10], "float32")], expected)
+
+
+def test_arange():
+    """test graph builder for arange"""
+
+    class Arange(Module):
+        def forward(self):
+            return torch.arange(0, 20, dtype=torch.int32)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "const", "shape": [20], "dtype": "int32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "constant": 1},
+    }
+
+    verify_model(Arange(), [([10, 10], "float32")], expected)
+
+
+def test_empty():
+    """test graph builder for empty"""
+
+    class Empty(Module):
+        def forward(self):
+            return torch.empty((10, 10), dtype=torch.float32)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "constant": 1},
+    }
+
+    verify_model(Empty(), [([10, 10], "float32")], expected)
+
+
+def test_tensor():
+    """test graph builder for tensor"""
+
+    class Empty1(Module):
+        def forward(self):
+            return torch.tensor(3, dtype=torch.float32)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "const", "shape": [], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "constant": 1},
+    }
+
+    class Empty2(Module):
+        def forward(self):
+            return torch.tensor(3)
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "const", "shape": [], "dtype": "int64", "layout": 
""}],
+        "nodes": {"total": 2, "input": 1, "constant": 1},
+    }
+
+    verify_model(Empty1(), [([10, 10], "float32")], expected1)
+    verify_model(Empty2(), [([10, 10], "float32")], expected2)
+
+
+def test_tril():
+    """test graph builder for tril"""
+
+    class Tril(Module):
+        def forward(self, data):
+            return torch.tril(data, 1)
+
+    class InplaceTril(Module):
+        def forward(self, data):
+            data.tril_(1)
+            return data
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "tril", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "tril": 1},
+    }
+
+    input_info = [([10, 10], "float32")]
+    verify_model(Tril(), input_info, expected)
+    verify_model(InplaceTril(), input_info, expected)
+
+
+def test_triu():
+    """test graph builder for triu"""
+
+    class Triu(Module):
+        def forward(self, data):
+            return torch.triu(data, 1)
+
+    class InplaceTriu(Module):
+        def forward(self, data):
+            data.triu_(1)
+            return data
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "triu", "shape": [10, 10], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "triu": 1},
+    }
+
+    input_info = [([10, 10], "float32")]
+    verify_model(Triu(), input_info, expected)
+    verify_model(InplaceTriu(), input_info, expected)
+
+
+def test_new_ones():
+    """test graph builder for new_ones"""
+
+    class NewOnes(Module):
+        def forward(self, x):
+            return x.new_ones(1, 2, 3)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "full", "shape": [1, 2, 3], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1},
+    }
+
+    input_info = [([1, 2, 3], "float32")]
+    verify_model(NewOnes(), input_info, expected)
+
+
+def test_expand():
+    """test graph builder for expand"""
+
+    class Expand(Module):
+        def forward(self, x):
+            return x.expand(4, 2, 3, 4)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": ""}],
+        "outputs": [
+            {"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype": 
"float32", "layout": ""}
+        ],
+        "nodes": {"total": 2, "input": 1, "broadcast_to": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Expand(), input_info, expected)
+
+
+def test_reduce():
+    """test graph builder for reduce"""
+
+    # sum
+    class Sum(Module):
+        def forward(self, x):
+            return torch.sum(x, (2, 1))
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ANCB"}],
+        "outputs": [{"name": "sum", "shape": [1, 4], "dtype": "float32", 
"layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "sum": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Sum(), input_info, expected)
+
+
+def test_datatype():
+    """test graph builder for datatype"""
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    # float
+    class ToFloat(Module):
+        def forward(self, x):
+            return x.float()
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ABCD"}],
+        "outputs": [
+            {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    verify_model(ToFloat(), input_info, expected1)
+
+    # half
+    class ToHalf(Module):
+        def forward(self, x):
+            return x.half()
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ABCD"}],
+        "outputs": [
+            {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float16", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    verify_model(ToHalf(), input_info, expected2)
+
+    # type
+    class Type(Module):
+        def forward(self, x):
+            return x.type(torch.float32)
+
+    expected3 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ABCD"}],
+        "outputs": [
+            {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    # type
+    class TypeFromAttr(Module):
+        def forward(self, x):
+            return x.type(x.getattr("dtype"))
+
+    expected4 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ABCD"}],
+        "outputs": [
+            {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    # astype
+    class AsType(Module):
+        def forward(self, x):
+            return x.astype(torch.float32)
+
+    expected5 = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ABCD"}],
+        "outputs": [
+            {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", 
"layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    verify_model(Type(), input_info, expected3)
+    verify_model(TypeFromAttr(), input_info, expected4)
+    verify_model(AsType(), input_info, expected5)
+
+
+def test_permute():
+    """test graph builder for permute"""
+
+    class Permute(Module):
+        def forward(self, x):
+            return x.permute(0, 3, 2, 1)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ADCB"}],
+        "outputs": [
+            {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": 
"float32", "layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "permute_dims": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Permute(), input_info, expected)
+
+
+def test_reshape():
+    """test graph builder for reshape"""
+
+    class Reshape(Module):
+        def forward(self, x):
+            return x.reshape(2, 12)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "reshape": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Reshape(), input_info, expected)
+
+
+def test_transpose():
+    """test graph builder for transpose"""
+
+    class Transpose(Module):
+        def forward(self, x):
+            return x.transpose(1, 3)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": "ADCB"}],
+        "outputs": [
+            {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": 
"float32", "layout": "ABCD"}
+        ],
+        "nodes": {"total": 2, "input": 1, "permute_dims": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(Transpose(), input_info, expected)
+
+
+def test_view():
+    """test graph builder for view"""
+
+    class View(Module):
+        def forward(self, x):
+            return x.view(2, 12)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": 
"float32", "layout": ""}],
+        "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "reshape": 1},
+    }
+
+    input_info = [([1, 2, 3, 4], "float32")]
+    verify_model(View(), input_info, expected)
+
+
+def test_keep_params():
+    """test graph builder for keep_params"""
+
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, data):
+            return self.conv(data)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {
+                "name": "msc.conv2d_bias",
+                "shape": [1, 6, 4, 4],
+                "dtype": "float32",
+                "layout": "NCHW",
+            }
+        ],
+        "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
+    }
+
+    verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")], expected)
+
+
+def test_unwrap_unit_return_tuple():
+    """test graph builder for unwrap_unit_return_tuple"""
+
+    class Identity(Module):
+        def forward(self, x):
+            return (x,)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "tuple", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "tuple": 1},
+    }
+
+    verify_model(Identity(), [([256, 256], "float32")], expected)
+
+
+def test_no_bind_return_tuple():
+    """test graph builder for no_bind_return_tuple"""
+
+    class Identity(Module):
+        def forward(self, x, y):
+            return (x, y)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""},
+            {"name": "inp_1", "shape": [256, 256], "dtype": "float32", 
"layout": ""},
+        ],
+        "outputs": [
+            {"name": "tuple_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""},
+            {"name": "tuple_1", "shape": [256, 256], "dtype": "float32", 
"layout": ""},
+        ],
+        "nodes": {"total": 3, "input": 2, "tuple": 1},
+    }
+
+    input_info = [([256, 256], "float32"), ([256, 256], "float32")]
+    verify_model(Identity(), input_info, expected)
+
+
+def test_argmax():
+    """test graph builder for argmax"""
+
+    class Argmax1(Module):
+        def forward(self, data):
+            return torch.argmax(data, dim=-1)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "argmax", "shape": [256], "dtype": "int64", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "argmax": 1},
+    }
+
+    class Argmax2(Module):
+        def forward(self, data):
+            return torch.argmax(data, dim=-1, keepdim=True)
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "argmax", "shape": [256, 1], "dtype": "int64", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "argmax": 1},
+    }
+
+    verify_model(Argmax1(), [([256, 256], "float32")], expected1)
+    verify_model(Argmax2(), [([256, 256], "float32")], expected2)
+
+
+def test_argmin():
+    """test graph builder for argmin"""
+
+    class Argmin1(Module):
+        def forward(self, data):
+            return torch.argmin(data)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "argmin": 1},
+    }
+
+    class Argmin2(Module):
+        def forward(self, data):
+            return torch.argmin(data, keepdim=True)
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64", 
"layout": ""}],
+        "nodes": {"total": 2, "input": 1, "argmin": 1},
+    }
+
+    verify_model(Argmin1(), [([256, 256], "float32")], expected1)
+    verify_model(Argmin2(), [([256, 256], "float32")], expected2)
+
+
+def test_to():
+    """test graph builder for to"""
+
+    class To1(Module):
+        def forward(self, data):
+            return data.to(torch.float16)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [{"name": "astype", "shape": [256, 256], "dtype": 
"float16", "layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "astype": 1},
+    }
+
+    class To2(Module):
+        def forward(self, data):
+            return data.to("cpu")
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "outputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": ""}],
+        "nodes": {"total": 1, "input": 1},
+    }
+
+    verify_model(To1(), [([256, 256], "float32")], expected1)
+    verify_model(To2(), [([256, 256], "float32")], expected2)
+
+
+def test_mean():
+    """test graph builder for mean"""
+
+    class Mean(Module):
+        def forward(self, data):
+            return data.mean(-1)
+
+    expected1 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AN"}],
+        "outputs": [{"name": "mean", "shape": [256], "dtype": "float32", 
"layout": "A"}],
+        "nodes": {"total": 2, "input": 1, "mean": 1},
+    }
+
+    class MeanKeepDim(Module):
+        def forward(self, data):
+            return data.mean(-1, keepdim=True)
+
+    expected2 = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [{"name": "mean", "shape": [256, 1], "dtype": "float32", 
"layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "mean": 1},
+    }
+
+    verify_model(Mean(), [([256, 256], "float32")], expected1)
+    verify_model(MeanKeepDim(), [([256, 256], "float32")], expected2)
+
+
+def test_rsqrt():
+    """test graph builder for rsqrt"""
+
+    class Rsqrt(Module):
+        def forward(self, data):
+            return torch.rsqrt(data)
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [{"name": "rsqrt", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "rsqrt": 1},
+    }
+
+    verify_model(Rsqrt(), [([256, 256], "float32")], expected)
+
+
+def test_neg():
+    """test graph builder for neg"""
+
+    class Neg(Module):
+        def forward(self, data):
+            return -data
+
+    expected = {
+        "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"}],
+        "outputs": [{"name": "negative", "shape": [256, 256], "dtype": 
"float32", "layout": "AB"}],
+        "nodes": {"total": 2, "input": 1, "negative": 1},
+    }
+
+    verify_model(Neg(), [([256, 256], "float32")], expected)
+
+
+def test_max():
+    """test graph builder for max"""
+
+    class Max(Module):
+        def forward(self, x, y):
+            return torch.max(x, y)
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"},
+            {"name": "inp_1", "shape": [256, 256], "dtype": "float32", 
"layout": "AB"},
+        ],
+        "outputs": [{"name": "maximum", "shape": [256, 256], "dtype": 
"float32", "layout": "AB"}],
+        "nodes": {"total": 3, "input": 2, "maximum": 1},
+    }
+
+    verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], 
expected)
+
+
+def test_attention():
+    """test graph builder for attention"""
+
+    # pylint: disable=import-outside-toplevel
+    import torch.nn.functional as F
+
+    class Attention1(Module):
+        def forward(self, q_data, k_data, v_data):
+            return F.scaled_dot_product_attention(q_data, k_data, v_data)
+
+    class Attention2(Module):
+        def forward(self, q_data, k_data, v_data):
+            return F.scaled_dot_product_attention(q_data, k_data, v_data, 
is_causal=True)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+            {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+            {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+        ],
+        "outputs": [
+            {
+                "name": "msc.attention",
+                "shape": [32, 128, 8, 64],
+                "dtype": "float32",
+                "layout": "ABCD",
+            }
+        ],
+        "nodes": {"total": 4, "input": 3, "msc.attention": 1},
+    }
+
+    input_info = [
+        ([32, 8, 128, 64], "float32"),
+        ([32, 8, 128, 64], "float32"),
+        ([32, 8, 128, 64], "float32"),
+    ]
+    verify_model(Attention1(), input_info, expected1)
+    verify_model(Attention2(), input_info, expected1)
+
+    class Attention3(Module):
+        def forward(self, q_data, k_data, v_data, mask):
+            return F.scaled_dot_product_attention(q_data, k_data, v_data, mask)
+
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+            {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+            {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", 
"layout": "ACBD"},
+            {"name": "inp_3", "shape": [32, 8, 128, 128], "dtype": "float32", 
"layout": "ABCD"},
+        ],
+        "outputs": [
+            {
+                "name": "msc.attention",
+                "shape": [32, 128, 8, 64],
+                "dtype": "float32",
+                "layout": "ABCD",
+            }
+        ],
+        "nodes": {"total": 5, "input": 4, "msc.attention": 1},
+    }
+
+    verify_model(
+        Attention3(),
+        [
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 128], "float32"),
+        ],
+        expected2,
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py 
b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
index 4717437d76..34cd3a3214 100644
--- a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
@@ -15,11 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import tvm.testing
-from tvm.relay import testing
-from tvm.relay.expr_functor import ExprVisitor
-from tvm.relay.build_module import bind_params_by_name
+""" Test SetExprLayout Pass. """
 
+import tvm.testing
 from tvm.relax.frontend.torch import from_fx
 from tvm.relax import PyExprVisitor
 
@@ -52,11 +50,14 @@ class RelaxChecker(PyExprVisitor):
 
 
 def test_relax():
+    """Test SetExprLayout for relax"""
+
+    # pylint: disable=import-outside-toplevel
     try:
         import torch
         import torchvision
         from torch import fx
-    except:
+    except:  # pylint: disable=bare-except
         print("please install pytorch python package")
         return
 
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py 
b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
index 0c174ff7bd..426860145c 100644
--- a/tests/python/contrib/test_msc/test_transform_set_expr_name.py
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+""" Test SetExprName Pass. """
+
 import tvm.testing
 from tvm.relay import testing
 from tvm.relay.expr_functor import ExprVisitor
@@ -73,6 +75,8 @@ class RelaxChecker(PyExprVisitor):
 
 
 def test_relay():
+    """Test SetExprName for relay"""
+
     mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, 
dtype="float32")
     mod["main"] = bind_params_by_name(mod["main"], params)
     mod = msc_transform.SetExprName(as_relax=False)(mod)
@@ -80,11 +84,14 @@ def test_relay():
 
 
 def test_relax():
+    """Test SetExprName for relax"""
+
+    # pylint: disable=import-outside-toplevel
     try:
         import torch
         import torchvision
         from torch import fx
-    except:
+    except:  # pylint: disable=bare-except
         print("please install pytorch python package")
         return
 

Reply via email to