This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new aa21782828 [Unity][MSC][M0.1] Enable set name and layout for exprs 
(#15500)
aa21782828 is described below

commit aa2178282811d34df3ca3d9ee58580396fd0358d
Author: Archermmt <[email protected]>
AuthorDate: Mon Aug 14 19:39:53 2023 +0800

    [Unity][MSC][M0.1] Enable set name and layout for exprs (#15500)
    
    * add test for set expr name
    
    * roll back to M0.1
    
    * add annotation
    
    * add annotation
    
    * change test to unity
    
    * remove msg
    
    * minor fix
    
    * move test to task_python_relax
---
 CMakeLists.txt                                     |    2 +
 cmake/config.cmake                                 |    3 +
 cmake/modules/LibInfo.cmake                        |    1 +
 .../modules/contrib/MSC.cmake                      |   26 +-
 .../tvm/contrib/msc/__init__.py                    |   23 +-
 .../tvm/contrib/msc/core/__init__.py               |   23 +-
 .../tvm/contrib/msc/core/_ffi_api.py               |   23 +-
 .../tvm/contrib/msc/core/transform/__init__.py     |   24 +-
 python/tvm/contrib/msc/core/transform/pattern.py   |  626 ++++++++++
 python/tvm/contrib/msc/core/transform/transform.py |   61 +
 .../tvm/contrib/msc/core/utils/__init__.py         |   23 +-
 python/tvm/contrib/msc/core/utils/expr.py          |  105 ++
 src/contrib/msc/core/transform/layout_utils.cc     |  190 +++
 src/contrib/msc/core/transform/layout_utils.h      |  110 ++
 src/contrib/msc/core/transform/set_expr_layout.cc  | 1215 ++++++++++++++++++++
 src/contrib/msc/core/transform/set_expr_name.cc    |  348 ++++++
 src/contrib/msc/core/utils.cc                      |  314 +++++
 src/contrib/msc/core/utils.h                       |  270 +++++
 src/support/libinfo.cc                             |    1 +
 .../test_msc/test_transform_set_expr_layout.py     |   73 ++
 .../test_msc/test_transform_set_expr_name.py       |  101 ++
 tests/scripts/task_config_build_cpu.sh             |    1 +
 tests/scripts/task_config_build_gpu.sh             |    1 +
 tests/scripts/unity/task_python_relax.sh           |    3 +
 24 files changed, 3442 insertions(+), 125 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 47d57d56bd..f7c34fa22b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -121,6 +121,7 @@ tvm_option(USE_CLML "Build with CLML Codegen support" OFF)
 tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF)
 tvm_option(USE_UMA "Build with UMA support" OFF)
 tvm_option(USE_VERILATOR "Build with Verilator support" OFF)
+tvm_option(USE_MSC "Enable Multi-System Compiler" OFF)
 
 # include directories
 include_directories(${CMAKE_INCLUDE_PATH})
@@ -545,6 +546,7 @@ include(cmake/modules/contrib/TensorRT.cmake)
 include(cmake/modules/contrib/VitisAI.cmake)
 include(cmake/modules/contrib/Verilator.cmake)
 include(cmake/modules/contrib/UMA.cmake)
+include(cmake/modules/contrib/MSC.cmake)
 include(cmake/modules/Git.cmake)
 include(cmake/modules/LibInfo.cmake)
 include(cmake/modules/RustExt.cmake)
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 8a7a0f1fdd..4990e52d63 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -281,6 +281,9 @@ set(USE_VITIS_AI OFF)
 # Build Verilator codegen and runtime
 set(USE_VERILATOR OFF)
 
+# Whether to use the Multi-System Compiler
+set(USE_MSC OFF)
+
 #Whether to use CLML codegen
 set(USE_CLML OFF)
 # USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index 17a56ac439..9e1f71c729 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -125,6 +125,7 @@ function(add_lib_info src_file)
     TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}"
     TVM_INFO_USE_UMA="${USE_UMA}"
     TVM_INFO_USE_VERILATOR="${USE_VERILATOR}"
+    TVM_INFO_USE_MSC="${USE_MSC}"
     TVM_INFO_USE_CCACHE="${USE_CCACHE}"
     TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
   )
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/cmake/modules/contrib/MSC.cmake
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to cmake/modules/contrib/MSC.cmake
index b6b70ab457..45ce776a08
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/cmake/modules/contrib/MSC.cmake
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,23 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -euxo pipefail
+if(USE_MSC)
+    tvm_file_glob(GLOB_RECURSE MSC_CORE_SOURCE "src/contrib/msc/*.cc")
+    list(APPEND COMPILER_SRCS ${MSC_CORE_SOURCE})
 
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
+    tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE 
"src/runtime/contrib/msc/*.cc")
+    list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE})
 
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+    message(STATUS "Build with MSC support...")
+endif()
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/python/tvm/contrib/msc/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/__init__.py
index b6b70ab457..a2813b4a2d
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,24 +14,4 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+"""tvm.contrib.msc"""
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/python/tvm/contrib/msc/core/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/__init__.py
index b6b70ab457..6d1a7c68c8
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,24 +14,4 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+"""tvm.contrib.msc.core"""
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/python/tvm/contrib/msc/core/_ffi_api.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/_ffi_api.py
index b6b70ab457..c0b0e21267
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/_ffi_api.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,24 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""tvm.contrib.msc.core._ffi_api"""
 
-set -euxo pipefail
+import tvm._ffi
 
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+tvm._ffi._init_api("msc.core", __name__)
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/python/tvm/contrib/msc/core/transform/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/transform/__init__.py
index b6b70ab457..ec74597803
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/transform/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,24 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""tvm.contrib.msc.core.transform"""
 
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+from .pattern import *
+from .transform import *
diff --git a/python/tvm/contrib/msc/core/transform/pattern.py 
b/python/tvm/contrib/msc/core/transform/pattern.py
new file mode 100644
index 0000000000..76e9651c60
--- /dev/null
+++ b/python/tvm/contrib/msc/core/transform/pattern.py
@@ -0,0 +1,626 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.transform.pattern"""
+
+from typing import Mapping, Tuple
+
+import tvm
+from tvm.relax.dpl import pattern as relax_pattern
+from tvm.relay import dataflow_pattern as relay_pattern
+
+from tvm.relax.transform import PatternCheckContext
+from tvm.relax.backend.pattern_registry import register_patterns
+from tvm.relay.op.contrib.register import register_pattern_table
+
+
+def make_relax_conv_bias_pattern(
+    op_name: str,
+) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]:
+    """A simple utility to create patterns for an conv fused with bias.
+
+    Parameters
+    ----------
+    op_name: str
+        The name of a Relax op, such as "relax.nn.conv2d"
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a conv_bias operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    data = relax_pattern.wildcard()
+    weight = relax_pattern.is_const()
+    conv = relax_pattern.is_op(op_name)(data, weight)
+    bias = relax_pattern.is_const()
+    shape = relax_pattern.wildcard()
+    reshape = relax_pattern.is_op("relax.reshape")(bias, shape)
+    out = relax_pattern.is_op("relax.add")(conv, reshape)
+    annotations = {"bias": bias, "reshape": reshape}
+    return out, annotations
+
+
+def _check_relax_conv_bias(context: PatternCheckContext) -> bool:
+    """Check if conv_bias fuse pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    bias = context.annotated_expr["bias"]
+    reshape = context.annotated_expr["reshape"]
+    non_one_dims = len([i for i in reshape.struct_info.shape.values if i > 1])
+    return non_one_dims <= 1 and bias.struct_info.ndim == 1
+
+
+def make_relax_linear_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for linear.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a linear operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    data = relax_pattern.wildcard()
+    weight = relax_pattern.is_const()
+    permute = relax_pattern.is_op("relax.permute_dims")(weight)
+    out = relax_pattern.is_op("relax.matmul")(data, permute)
+    annotations = {"weight": weight, "permute": permute}
+    return out, annotations
+
+
+def _check_relax_linear(context: PatternCheckContext) -> bool:
+    """Check if linear pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    weight = context.annotated_expr["weight"]
+    permute = context.annotated_expr["permute"]
+    return weight.struct_info.ndim == 2 and not permute.attrs["axes"]
+
+
+def make_relax_linear_bias_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for linear with bias.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a linear_bias operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+
+    """
+
+    linear, annotations = make_relax_linear_pattern()
+    bias = relax_pattern.is_const()
+    out = relax_pattern.is_op("relax.add")(linear, bias)
+    annotations.update({"bias": bias, "out": out})
+    return out, annotations
+
+
+def _check_relax_linear_bias(context: PatternCheckContext) -> bool:
+    """Check if linear_bias pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    if not _check_relax_linear(context):
+        return False
+    bias = context.annotated_expr["bias"]
+    return bias.struct_info.ndim == 1
+
+
+def make_relax_embedding_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for embedding.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a embedding operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    weight = relax_pattern.is_const()
+    data = relax_pattern.wildcard()
+    astype = relax_pattern.is_op("relax.astype")(data)
+    out = relax_pattern.is_op("relax.take")(weight, astype)
+    annotations = {"weight": weight, "astype": astype}
+    return out, annotations
+
+
+def _check_relax_embedding(context: PatternCheckContext) -> bool:
+    """Check if 1d embedding pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    weight = context.annotated_expr["weight"]
+    astype = context.annotated_expr["astype"]
+    return (
+        astype.attrs["dtype"] == "int32"
+        and weight.struct_info.ndim == 2
+        and weight.struct_info.dtype == "float32"
+    )
+
+
+def make_relax_reshape_embedding_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for reshaped embedding.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a reshape_embedding operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    weight = relax_pattern.is_const()
+    data = relax_pattern.wildcard()
+    astype = relax_pattern.is_op("relax.astype")(data)
+    reduce_shape = relax_pattern.wildcard()
+    reduce_in = relax_pattern.is_op("relax.reshape")(astype, reduce_shape)
+    take = relax_pattern.is_op("relax.take")(weight, reduce_in)
+    expand_shape = relax_pattern.wildcard()
+    out = relax_pattern.is_op("relax.reshape")(take, expand_shape)
+    annotations = {"weight": weight, "astype": astype, "reduce_in": reduce_in}
+    return out, annotations
+
+
+def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool:
+    """Check if reshape embedding pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    weight = context.annotated_expr["weight"]
+    if weight.struct_info.ndim != 2 or weight.struct_info.dtype != "float32":
+        return False
+    astype = context.annotated_expr["astype"]
+    reduce_in = context.annotated_expr["reduce_in"]
+    if astype.attrs["dtype"] != "int32" or reduce_in.struct_info.ndim != 1:
+        return False
+    return True
+
+
+def make_relax_attention_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for attention.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a attention operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    weight_q = relax_pattern.wildcard()
+    weight_k = relax_pattern.wildcard()
+    weight_v = relax_pattern.wildcard()
+    q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
+    k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
+    v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
+    out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
+    annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans}
+    return out, annotations
+
+
+def _check_relax_attention(context: PatternCheckContext) -> bool:
+    """Check if attention pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    return True
+
+
+def make_relax_mask_attention_pattern() -> Tuple[
+    relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+    """A simple utility to create patterns for mask_attention.
+
+    Returns
+    -------
+    out: tvm.relax.dpl.pattern.DFPattern
+        The resulting pattern describing a mask_attention operation.
+
+    annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+
+    weight_q = relax_pattern.wildcard()
+    weight_k = relax_pattern.wildcard()
+    weight_v = relax_pattern.wildcard()
+    mask = relax_pattern.wildcard()
+    q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
+    k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
+    v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
+    out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, 
v_trans, mask)
+    annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans}
+    return out, annotations
+
+
+def _check_relax_mask_attention(context: PatternCheckContext) -> bool:
+    """Check if mask_attention pattern is correct.
+
+    Returns
+    -------
+    pass: bool
+        Whether the pattern is correct.
+    """
+
+    return True
+
+
+# TODO(tong.meng): support patterns after optimize
+register_patterns(
+    [
+        (
+            "msc.conv1d_bias",
+            *make_relax_conv_bias_pattern(
+                "relax.nn.conv1d",
+            ),
+            _check_relax_conv_bias,
+        ),
+        (
+            "msc.conv2d_bias",
+            *make_relax_conv_bias_pattern(
+                "relax.nn.conv2d",
+            ),
+            _check_relax_conv_bias,
+        ),
+        (
+            "msc.linear",
+            *make_relax_linear_pattern(),
+            _check_relax_linear,
+        ),
+        (
+            "msc.linear_bias",
+            *make_relax_linear_bias_pattern(),
+            _check_relax_linear_bias,
+        ),
+        (
+            "msc.embedding",
+            *make_relax_embedding_pattern(),
+            _check_relax_embedding,
+        ),
+        (
+            "msc.embedding",
+            *make_relax_reshape_embedding_pattern(),
+            _check_relax_reshape_embedding,
+        ),
+        (
+            "msc.attention",
+            *make_relax_attention_pattern(),
+            _check_relax_attention,
+        ),
+        (
+            "msc.attention",
+            *make_relax_mask_attention_pattern(),
+            _check_relax_mask_attention,
+        ),
+    ]
+)
+
+
+# TODO(tong.meng): support patterns after optimize
+@register_pattern_table("msc")
+def pattern_table():
+    """Returns list of triples describing the name, dataflow pattern and 
predicate for all
+    the MSC-supported operators."""
+
+    def make_relay_conv_bias_pattern(
+        op_name: str, optimized: bool = False
+    ) -> relay_pattern.DFPattern:
+        """A simple utility to create patterns for an operation fused with 
bias.
+
+        Parameters
+        ----------
+        op_name: str
+            The name of a Relay op, such as "relay.nn.conv2d"
+        optimized: bool
+            Whether the relay is optimized
+
+        Returns
+        -------
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The resulting pattern describing a conv_bias operation
+        """
+
+        data = relay_pattern.wildcard()
+        weight = relay_pattern.is_constant()
+        bias = relay_pattern.is_constant()
+        conv = relay_pattern.is_op(op_name)(data, weight)
+        if optimized:
+            out = relay_pattern.is_op("add")(conv, bias)
+        else:
+            out = relay_pattern.is_op("nn.bias_add")(conv, bias)
+        return out
+
+    def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool:
+        """Check if conv_bias fuse pattern is correct.
+
+        Returns
+        -------
+        pass: bool
+            Whether the pattern is correct.
+        """
+
+        if call.op.name == "nn.bias_add":
+            bias = call.args[1]
+            return len(bias.checked_type.shape) == 1
+        if call.op.name == "add":
+            return True
+        return False
+
+    def make_relay_linear_pattern(optimized: bool = False) -> 
relay_pattern.DFPattern:
+        """A simple utility to create patterns for linear.
+
+        Parameters
+        ----------
+        optimized: bool
+            Whether the relay is optimized
+
+        Returns
+        -------
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The resulting pattern describing a linear operation
+        """
+
+        if optimized:
+            data = relay_pattern.wildcard()
+            weight = relay_pattern.is_constant()
+            broadcast_data = relay_pattern.is_op("broadcast_to")(data)
+            reshape_data = relay_pattern.is_op("reshape")(broadcast_data)
+            batch_matmul = 
relay_pattern.is_op("nn.batch_matmul")(reshape_data, weight)
+            reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+            return relay_pattern.is_op("squeeze")(reshape_out)
+        data = relay_pattern.wildcard()
+        weight = relay_pattern.is_constant()
+        trans_weight = relay_pattern.is_op("transpose")(weight)
+        broadcast_data = relay_pattern.is_op("broadcast_to")(data)
+        broadcast_weight = relay_pattern.is_op("broadcast_to")(trans_weight)
+        reshape_data = relay_pattern.is_op("reshape")(broadcast_data)
+        reshape_weight = relay_pattern.is_op("reshape")(broadcast_weight)
+        batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data, 
reshape_weight)
+        reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+        return relay_pattern.is_op("squeeze")(reshape_out)
+
+    def _check_relay_linear(call: tvm.relay.Expr) -> bool:
+        """Check if linear pattern is correct.
+
+        Returns
+        -------
+        pass: bool
+            Whether the pattern is correct.
+        """
+
+        return True
+
+    def make_relay_linear_bias_pattern(optimized: bool = False) -> 
relay_pattern.DFPattern:
+        """A simple utility to create patterns for linear_bias.
+
+        Parameters
+        ----------
+        optimized: bool
+            Whether the relay is optimized
+
+        Returns
+        -------
+        pattern: DFPattern
+            The resulting pattern describing a linear_bias operation
+        """
+
+        bias = relay_pattern.is_constant()
+        linear = make_relay_linear_pattern(optimized)
+        if optimized:
+            out = relay_pattern.is_op("add")(linear, bias)
+        else:
+            out = relay_pattern.is_op("nn.bias_add")(linear, bias)
+        return out
+
+    def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool:
+        """Check if linear_bias pattern is correct."""
+        return True
+
+    def make_relay_matmul_pattern(dim: int = 2, optimized: bool = False) -> 
relay_pattern.DFPattern:
+        """A simple utility to create patterns for matmul.
+
+        Parameters
+        ----------
+        optimized: bool
+            Whether the relay is optimized
+
+        Returns
+        -------
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The resulting pattern describing a matmul operation
+        """
+
+        if dim == 2:
+            a = relay_pattern.wildcard()
+            b = relay_pattern.wildcard()
+            trans_b = relay_pattern.is_op("transpose")(b)
+            dense = relay_pattern.is_op("nn.dense")(a, trans_b)
+            return dense | relay_pattern.is_op("squeeze")(dense)
+        elif dim == 3:
+            a = relay_pattern.wildcard()
+            b = relay_pattern.wildcard()
+            broadcast_a = relay_pattern.is_op("broadcast_to")(a)
+            broadcast_b = relay_pattern.is_op("broadcast_to")(b)
+            reshape_a = relay_pattern.is_op("reshape")(broadcast_a)
+            reshape_b = relay_pattern.is_op("reshape")(broadcast_b)
+            batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_a, 
reshape_b)
+            reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+            return relay_pattern.is_op("squeeze")(reshape_out)
+        else:
+            raise Exception("matmul pattern only support dim 2 and 3")
+
+    def _check_relay_matmul(call: tvm.relay.Expr) -> bool:
+        """Check if matmul pattern is correct.
+
+        Returns
+        -------
+        pass: bool
+            Whether the pattern is correct.
+        """
+
+        last_call = call.args[0] if call.op.name == "squeeze" else call
+        if last_call.op.name == "nn.dense":
+            trans_b = last_call.args[1]
+            b = trans_b.args[0]
+            if len(b.checked_type.shape) != 2:
+                return False
+            return trans_b.attrs["axes"] is None or 
list(trans_b.attrs["axes"]) == [1, 0]
+        return True
+
+    def make_relay_embedding_pattern(optimized: bool = False) -> 
relay_pattern.DFPattern:
+        """A simple utility to create patterns for 1d embedding.
+
+        Returns
+        -------
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The resulting pattern describing a embedding operation
+        """
+
+        weight = relay_pattern.is_constant()
+        data = relay_pattern.wildcard()
+        astype = relay_pattern.is_op("cast")(data)
+        return relay_pattern.is_op("take")(weight, astype)
+
+    def _check_relay_embedding(call: tvm.relay.Expr) -> bool:
+        """Check if embedding pattern is correct.
+
+        Returns
+        -------
+        pass: bool
+            Whether the pattern is correct.
+        """
+
+        weight = call.args[0]
+        cast = call.args[1]
+        return (
+            cast.attrs["dtype"] == "int32"
+            and len(weight.checked_type.shape) == 2
+            and weight.checked_type.dtype == "float32"
+        )
+
+    def make_relay_gelu_pattern(optimized: bool = False) -> 
relay_pattern.DFPattern:
+        """A simple utility to create patterns for gelu.
+
+        Returns
+        -------
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The resulting pattern describing a gelu operation.
+        """
+
+        data = relay_pattern.wildcard()
+        factor_1 = relay_pattern.is_constant()
+        mul_1 = relay_pattern.is_op("multiply")(data, factor_1)
+        erf = relay_pattern.is_op("erf")(mul_1)
+        factor_2 = relay_pattern.is_constant()
+        mul_2 = relay_pattern.is_op("multiply")(erf, factor_2)
+        factor_3 = relay_pattern.is_constant()
+        add = relay_pattern.is_op("add")(factor_3, mul_2)
+        return relay_pattern.is_op("multiply")(data, add)
+
+    def _check_relay_gelu(call: tvm.relay.Expr) -> bool:
+        """Check if gelu pattern is correct.
+
+        Returns
+        -------
+        pass: bool
+            Whether the pattern is correct.
+        """
+
+        return True
+
+    return [
+        ("msc.conv1d_bias", make_relay_conv_bias_pattern("nn.conv1d"), 
_check_relay_conv_bias),
+        (
+            "msc.conv1d_bias",
+            make_relay_conv_bias_pattern("nn.conv1d", True),
+            _check_relay_conv_bias,
+        ),
+        ("msc.conv2d_bias", make_relay_conv_bias_pattern("nn.conv2d"), 
_check_relay_conv_bias),
+        (
+            "msc.conv2d_bias",
+            make_relay_conv_bias_pattern("nn.conv2d", True),
+            _check_relay_conv_bias,
+        ),
+        ("msc.linear_bias", make_relay_linear_bias_pattern(), 
_check_relay_linear_bias),
+        ("msc.linear", make_relay_linear_pattern(), _check_relay_linear),
+        ("msc.linear", make_relay_linear_pattern(True), _check_relay_linear),
+        ("msc.matmul", make_relay_matmul_pattern(dim=2), _check_relay_matmul),
+        ("msc.matmul", make_relay_matmul_pattern(dim=3), _check_relay_matmul),
+        ("msc.embedding", make_relay_embedding_pattern(), 
_check_relay_embedding),
+        ("msc.gelu", make_relay_gelu_pattern(), _check_relay_gelu),
+    ]
diff --git a/python/tvm/contrib/msc/core/transform/transform.py 
b/python/tvm/contrib/msc/core/transform/transform.py
new file mode 100644
index 0000000000..355922d6de
--- /dev/null
+++ b/python/tvm/contrib/msc/core/transform/transform.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""tvm.contrib.msc.core.transform.transform"""
+
+import tvm
+from tvm.relax.transform import _ffi_api as relax_api
+from tvm.relay.transform import _ffi_api as relay_api
+
+
+def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass:
+    """Set name for the call and constant in IRModule.
+
+    Parameters
+    ----------
+    as_relax: bool
+        Whether set names for relax, otherwise for relay.
+    entry_name: str
+        The entry name
+
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+
+    if as_relax:
+        return relax_api.SetRelaxExprName(entry_name)  # type: ignore
+    return relay_api.SetRelayExprName(entry_name)  # type: ignore
+
+
+def SetExprLayout(allow_missing=True, entry_name="main") -> 
tvm.ir.transform.Pass:
+    """Set layout for the var and constant in IRModule.
+
+    Parameters
+    ----------
+    allow_missing: bool
+        Whether allow missing layouts.
+    entry_name: str
+        The entry name
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+
+    return relax_api.SetExprLayout(allow_missing, entry_name)  # type: ignore
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/python/tvm/contrib/msc/core/utils/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/utils/__init__.py
index b6b70ab457..65f9e1b326
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/utils/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,24 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""tvm.contrib.msc.core.utils"""
 
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+from .expr import *
diff --git a/python/tvm/contrib/msc/core/utils/expr.py 
b/python/tvm/contrib/msc/core/utils/expr.py
new file mode 100644
index 0000000000..ad459e7832
--- /dev/null
+++ b/python/tvm/contrib/msc/core/utils/expr.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.utils.expr"""
+
+import tvm
+from tvm import relax
+from tvm.relax import PyExprVisitor
+from tvm.contrib.msc.core import _ffi_api
+
+
+def get_span_attrs(mod: tvm.IRModule) -> dict:
+    """Extract the span attributes from relax.Function.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The IRModule of relax.
+
+    Returns
+    -------
+    attrs: dict
+    """
+
+    @relax.expr_functor.visitor
+    class SpanVisitor(PyExprVisitor):
+        """Visitor for get attributes in span"""
+
+        def extract(self, expr: relax.Expr) -> dict:
+            self._span_info = {}
+            if isinstance(expr, relax.Expr):
+                self.visit_expr(expr)
+            elif isinstance(expr, relax.BindingBlock):
+                self.visit_binding_block(expr)
+            return self._span_info
+
+        def _update_attrs(self, expr: relax.Expr, name: str = "") -> None:
+            if not expr.span:
+                return
+            name = name or _ffi_api.SpanGetAttr(expr.span, "name")
+            if not name:
+                return
+            self._span_info[name] = _ffi_api.SpanGetAttrs(expr.span)
+
+        def visit_var_binding_(self, binding: relax.VarBinding) -> None:
+            super().visit_var_binding_(binding)
+            self._update_attrs(binding.value, binding.var.name_hint)
+
+        def visit_constant_(self, op: relax.Constant) -> None:
+            super().visit_constant_(op)
+            self._update_attrs(op)
+
+        def visit_var_(self, op: relax.Var) -> None:
+            super().visit_var_(op)
+            self._update_attrs(op, op.name_hint)
+
+    return {v.name_hint: SpanVisitor().extract(mod[v]) for v in mod.functions}
+
+
+def msc_script(mod: tvm.IRModule, script: str = "") -> str:
+    """Add span attrs after lines.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The IRModule of relax.
+    script: string
+        The script to be replaced
+
+    Returns
+    -------
+    script: string
+        The replaced script
+    """
+
+    script = script or str(mod)
+    attrs = get_span_attrs(mod)
+    cur_attr, lines = {}, []
+    for line in script.split("\n"):
+        if line.strip().startswith("def "):
+            func_name = line.strip().split("def ")[1].split("(")[0]
+            cur_attr = attrs.get(func_name, {})
+        if ": " in line:
+            v_name = line.strip().split(": ")[0]
+            if v_name in cur_attr:
+                line += (
+                    " # "
+                    + ", ".join(["{}={}".format(k, v) for k, v in 
cur_attr[v_name].items()])
+                    + " #"
+                )
+        lines.append(line)
+    return "\n".join(lines)
diff --git a/src/contrib/msc/core/transform/layout_utils.cc 
b/src/contrib/msc/core/transform/layout_utils.cc
new file mode 100644
index 0000000000..ffc631c6d0
--- /dev/null
+++ b/src/contrib/msc/core/transform/layout_utils.cc
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/layout_utils.cc
+ */
+#include "layout_utils.h"
+
+#include <set>
+#include <string>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+bool LayoutUtils::LayoutInfered(const Expr& expr) {
+  const String& layout = SpanUtils::GetAttr(expr->span, "layout");
+  return layout.size() > 0;
+}
+
+bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) {
+  const String& saved_layout = SpanUtils::GetAttr(expr->span, "layout");
+  const auto& sinfo = GetStructInfo(expr);
+  if (sinfo.as<TensorStructInfoNode>() || sinfo.as<ShapeStructInfoNode>()) {
+    ICHECK(layout.IsLeaf()) << "Expr has tensor struct, but find nested layout 
" << expr;
+    const auto& l_layout = layout.LeafValue()->layout;
+    if (!l_layout.defined()) {
+      return false;
+    }
+    if (saved_layout == l_layout.name()) {
+      return false;
+    }
+    expr->span = SpanUtils::SetAttr(expr->span, "layout", l_layout.name());
+  } else if (sinfo.as<TupleStructInfoNode>()) {
+    ICHECK(!layout.IsLeaf()) << "Expr has tupple struct, but find non-nested 
layout " << expr;
+    String layout_str;
+    Array<NLayout> nested_layouts = layout.NestedArray();
+    for (size_t i = 0; i < nested_layouts.size(); i++) {
+      ICHECK(nested_layouts[i].IsLeaf())
+          << "Expr input[" << i << "] has tensor struct, but find nested 
layout " << expr;
+      const auto& l_layout = nested_layouts[i].LeafValue()->layout;
+      if (!l_layout.defined()) {
+        return false;
+      }
+      layout_str = layout_str + l_layout.name() + (i < nested_layouts.size() - 
1 ? "," : "");
+    }
+    if (saved_layout == layout_str) {
+      return false;
+    }
+    expr->span = SpanUtils::SetAttr(expr->span, "layout", layout_str);
+  }
+  return true;
+}
+
+const NLayout LayoutUtils::GetNLayout(const Expr& expr) {
+  if (!LayoutInfered(expr)) {
+    return LayoutDecision("");
+  }
+  auto sinfo = GetStructInfo(expr);
+  if (sinfo.as<TensorStructInfoNode>()) {
+    return LayoutDecision(SpanUtils::GetAttr(expr->span, "layout"));
+  }
+  if (sinfo.as<TupleStructInfoNode>()) {
+    String layout_str = SpanUtils::GetAttr(expr->span, "layout");
+    std::vector<NLayout> output_layout;
+    for (const auto& l : StringUtils::Split(layout_str, ",")) {
+      output_layout.push_back(LayoutDecision(l));
+    }
+    return NLayout(output_layout);
+  }
+  return LayoutDecision("");
+}
+
+const LayoutDecision LayoutUtils::GetLayoutDecision(const Expr& expr) {
+  NLayout nlayout = GetNLayout(expr);
+  ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr;
+  return nlayout.LeafValue();
+}
+
+bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) {
+  bool find = false;
+  auto fvisit = [&](const LayoutDecision& layout) {
+    find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim()));
+  };
+  ForEachLeaf<LayoutDecision>(nlayout, fvisit);
+  return find;
+}
+
+bool LayoutUtils::HasUnknownDimTensor(const Array<Expr>& args) {
+  for (const auto& arg : args) {
+    if (IsNestedTensor(arg)) {
+      if (HasUnknownDimTensor(GetNLayout(arg))) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& 
src_layout,
+                                               const std::vector<size_t>& 
expand_axes) {
+  if (!src_layout->layout.defined()) {
+    return src_layout;
+  }
+  std::string new_layout = src_layout.name();
+  ICHECK_EQ(new_layout.size(), src_layout->layout.ndim())
+      << "Only support normal layout, get " << src_layout->layout;
+  std::vector<std::string> priority_dims{"N", "C", "H", "W", "D", "G", "T"};
+  size_t left_size = expand_axes.size();
+  for (const auto& a : expand_axes) {
+    std::string target = "U";
+    if (new_layout.find("H") && !new_layout.find("W")) {
+      target = "W";
+    } else if (new_layout.find("W") && !new_layout.find("H")) {
+      target = "H";
+    } else if (left_size == 1 && new_layout.find("C") && 
!new_layout.find("D")) {
+      target = "D";
+    } else if (left_size == 1 && new_layout.find("D") && 
!new_layout.find("C")) {
+      target = "C";
+    } else {
+      for (const auto& p : priority_dims) {
+        int pos = new_layout.find(p);
+        if (pos < 0) {
+          target = p;
+          break;
+        }
+      }
+    }
+    new_layout = new_layout.insert(a, target);
+    left_size--;
+  }
+  return LayoutDecision(new_layout);
+}
+
+const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& 
src_layout,
+                                               const std::vector<size_t>& 
reduce_axes) {
+  if (!src_layout->layout.defined()) {
+    return src_layout;
+  }
+  std::set<size_t> reduce_axes_set;
+  for (const auto& a : reduce_axes) {
+    reduce_axes_set.insert(a);
+  }
+  std::string new_layout = "";
+  for (size_t i = 0; i < src_layout->layout.ndim(); i++) {
+    if (reduce_axes_set.count(i)) {
+      continue;
+    }
+    new_layout += src_layout->layout[i].name();
+  }
+  return LayoutDecision(new_layout);
+}
+
+const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& 
src_layout,
+                                                const Array<Integer>& axes) {
+  String layout_str;
+  for (const auto& a : axes) {
+    layout_str = layout_str + src_layout->layout[a->value].name();
+  }
+  return LayoutDecision(layout_str);
+}
+
+const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& 
src_layout,
+                                                const std::vector<size_t>& 
axes) {
+  String layout_str;
+  for (const auto& a : axes) {
+    layout_str = layout_str + src_layout->layout[a].name();
+  }
+  return LayoutDecision(layout_str);
+}
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/contrib/msc/core/transform/layout_utils.h 
b/src/contrib/msc/core/transform/layout_utils.h
new file mode 100644
index 0000000000..b9de832838
--- /dev/null
+++ b/src/contrib/msc/core/transform/layout_utils.h
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/layout_utils.h
+ * \brief Common utilities for layout.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+
+#include <vector>
+
+#include "../../../../relax/transform/infer_layout_utils.h"
+#include "../../../../relax/transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using namespace tvm::relax;
+
+/*!
+ * \brief Utils for Layout.
+ */
+class LayoutUtils {
+ public:
+  /*!
+   * \brief Check if the layout is infered.
+   * \return Whether the layout is infered.
+   */
+  TVM_DLL static bool LayoutInfered(const Expr& expr);
+
+  /*!
+   * \brief Set the layout to span
+   * \return Whether the layout is setted.
+   */
+  TVM_DLL static bool SetLayout(const Expr& expr, const NLayout& layout);
+
+  /*!
+   * \brief Get the layout from span
+   * \return The NLayout.
+   */
+  TVM_DLL static const NLayout GetNLayout(const Expr& expr);
+
+  /*!
+   * \brief Get the layout desion from span
+   * \return The LayoutDecision.
+   */
+  TVM_DLL static const LayoutDecision GetLayoutDecision(const Expr& expr);
+
+  /*!
+   * \brief Check if the layout has unknown dim tensor.
+   * \return Whether the layout has unknown dim tensor.
+   */
+  TVM_DLL static bool HasUnknownDimTensor(const NLayout& nlayout);
+
+  /*!
+   * \brief Check if the args has unknown dim tensor.
+   * \return Whether the args has unknown dim tensor.
+   */
+  TVM_DLL static bool HasUnknownDimTensor(const Array<Expr>& args);
+
+  /*!
+   * \brief Insert axes to the Layout
+   * \return The new layout.
+   */
+  TVM_DLL static const LayoutDecision ExpandLayout(const LayoutDecision& 
src_layout,
+                                                   const std::vector<size_t>& 
expand_axes);
+
+  /*!
+   * \brief Delete axes from the Layout
+   * \return The new layout.
+   */
+  TVM_DLL static const LayoutDecision ReduceLayout(const LayoutDecision& 
src_layout,
+                                                   const std::vector<size_t>& 
reduce_axes);
+  /*!
+   * \brief Permute axes from the Layout
+   * \return The new layout.
+   */
+  TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& 
src_layout,
+                                                    const Array<Integer>& 
axes);
+  TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& 
src_layout,
+                                                    const std::vector<size_t>& 
axes);
+};
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+#endif  // TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc 
b/src/contrib/msc/core/transform/set_expr_layout.cc
new file mode 100644
index 0000000000..5915bef9e1
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -0,0 +1,1215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_expr_layout.cc
+ * \brief Pass for setting layout for expr and constant.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../utils.h"
+#include "layout_utils.h"
+
+namespace tvm {
+namespace relax {
+
+using namespace tvm::contrib::msc;
+
+NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) {
+  if (expr->IsInstance<VarNode>() && 
var_layout_map.count(Downcast<Var>(expr))) {
+    return GetNLayout(var_layout_map, expr);
+  }
+  return LayoutUtils::GetNLayout(expr);
+}
+
+LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap& 
var_layout_map) {
+  const auto& nlayout = InferNLayout(expr, var_layout_map);
+  ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr;
+  return nlayout.LeafValue();
+}
+
+LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap& 
var_layout_map,
+                                     size_t index = 0) {
+  const auto& nlayouts = InferNLayout(expr, var_layout_map);
+  const auto& nlayout = nlayouts.NestedArray()[0];
+  ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr;
+  return nlayout.LeafValue();
+}
+
+std::tuple<int64_t, int64_t> AccumulateMatch(const std::vector<int64_t>& 
in_shape,
+                                             const std::vector<int64_t>& 
out_shape, size_t in_start,
+                                             size_t out_start) {
+  // find input position in_pos and output position out_pos
+  // cumsum(in_shape[in_start:in_ops])==cumsum(out_shape[out_start:out_pos])
+  int64_t in_pos = -1;
+  int64_t out_pos = -1;
+  int64_t in_accumulate = 1;
+  int64_t out_accumulate = 1;
+  for (size_t i = in_start; i < in_shape.size(); i++) {
+    in_accumulate *= in_shape[i];
+    out_accumulate = 1;
+    for (size_t j = out_start; j < out_shape.size(); j++) {
+      out_accumulate *= out_shape[j];
+      if (in_accumulate == out_accumulate) {
+        in_pos = i;
+        out_pos = j;
+        break;
+      } else if (out_accumulate > in_accumulate) {
+        break;
+      }
+    }
+    if (in_pos >= 0) {
+      break;
+    }
+  }
+  // append tailed 1s
+  if (in_pos >= 0) {
+    int64_t in_size = static_cast<int64_t>(in_shape.size());
+    int64_t out_size = static_cast<int64_t>(out_shape.size());
+    while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) {
+      in_pos++;
+    }
+    while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) {
+      out_pos++;
+    }
+  }
+  return std::make_tuple(in_pos, out_pos);
+}
+
+std::vector<size_t> InferReduceAxes(const Array<PrimExpr>& input_shape,
+                                    const Array<PrimExpr>& output_shape) {
+  std::vector<size_t> reduce_axes, out_axes;
+  std::vector<int64_t> in_shape, out_shape;
+  for (const auto& s : input_shape) {
+    in_shape.push_back(Downcast<Integer>(s)->value);
+  }
+  for (const auto& s : output_shape) {
+    out_shape.push_back(Downcast<Integer>(s)->value);
+  }
+  size_t start = 0;
+  while (start < in_shape.size() && out_axes.size() < out_shape.size()) {
+    if (in_shape[start] == out_shape[out_axes.size()]) {
+      out_axes.push_back(start);
+      start++;
+    } else {
+      int64_t in_pos, out_pos;
+      size_t out_start = out_axes.size();
+      std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, 
out_start);
+      if (in_pos == -1) {
+        return std::vector<size_t>();
+      }
+      for (size_t i = out_start; i < static_cast<size_t>(out_pos) + 1; i++) {
+        out_axes.push_back(i + 1);
+      }
+      start = in_pos + 1;
+    }
+  }
+  if (out_axes.size() != out_shape.size()) {
+    return std::vector<size_t>();
+  }
+  std::set<size_t> out_axes_set;
+  for (const auto& a : out_axes) {
+    out_axes_set.insert(a);
+  }
+  for (size_t i = 0; i < in_shape.size(); i++) {
+    if (!out_axes_set.count(i)) {
+      reduce_axes.push_back(i);
+    }
+  }
+  return reduce_axes;
+}
+
+std::vector<size_t> InferExpandAxes(const Array<PrimExpr>& input_shape,
+                                    const Array<PrimExpr>& output_shape) {
+  std::vector<size_t> expand_axes;
+  std::vector<int64_t> in_shape, out_shape;
+  for (const auto& s : input_shape) {
+    in_shape.push_back(Downcast<Integer>(s)->value);
+  }
+  for (const auto& s : output_shape) {
+    out_shape.push_back(Downcast<Integer>(s)->value);
+  }
+  size_t start = 0;
+  while (start < in_shape.size() && expand_axes.size() + in_shape.size() < 
out_shape.size()) {
+    if (in_shape[start] == out_shape[start + expand_axes.size()]) {
+      start++;
+    } else {
+      int64_t in_pos, out_pos;
+      size_t out_start = start + expand_axes.size();
+      std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, 
out_start);
+      if (in_pos == -1) {
+        return std::vector<size_t>();
+      }
+      size_t expand_size = out_pos - in_pos - expand_axes.size();
+      for (size_t i = 0; i < expand_size; i++) {
+        expand_axes.push_back(out_start + i);
+      }
+      start = in_pos + 1;
+    }
+  }
+  if (expand_axes.size() + in_shape.size() != out_shape.size()) {
+    return std::vector<size_t>();
+  }
+  return expand_axes;
+}
+
+// Forward and Backward infer
+InferLayoutOutput MSCInferLayoutConv(const Call& call,
+                                     const Map<String, Array<String>>& 
desired_layouts,
+                                     const VarLayoutMap& var_layout_map) {
+  LayoutDecision data_layout, kernel_layout, out_layout;
+  const String& op_name = Downcast<Op>(call->op)->name;
+  if (op_name == "relax.nn.conv1d") {
+    const auto* attrs = call->attrs.as<Conv1DAttrs>();
+    data_layout = LayoutDecision(attrs->data_layout);
+    kernel_layout = LayoutDecision(attrs->kernel_layout);
+    out_layout = LayoutDecision(attrs->out_layout);
+  } else if (op_name == "relax.nn.conv2d") {
+    const auto* attrs = call->attrs.as<Conv2DAttrs>();
+    data_layout = LayoutDecision(attrs->data_layout);
+    kernel_layout = LayoutDecision(attrs->kernel_layout);
+    out_layout = LayoutDecision(attrs->out_layout);
+  }
+  return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, 
Attrs());
+}
+
+InferLayoutOutput MSCInferLayoutPool2d(const Call& call,
+                                       const Map<String, Array<String>>& 
desired_layouts,
+                                       const VarLayoutMap& var_layout_map) {
+  LayoutDecision layout, out_layout;
+  const String& op_name = Downcast<Op>(call->op)->name;
+  if (op_name == "relax.nn.adaptive_avg_pool2d") {
+    const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>();
+    layout = LayoutDecision(attrs->layout);
+    out_layout = LayoutDecision(attrs->out_layout);
+  } else {
+    const auto* attrs = call->attrs.as<Pool2DAttrs>();
+    layout = LayoutDecision(attrs->layout);
+    out_layout = LayoutDecision(attrs->out_layout);
+  }
+  return InferLayoutOutput({layout}, {out_layout}, Attrs());
+}
+
+// Forward Infer
+InferLayoutOutput ForwardInferLayoutCommon(const Call& call,
+                                           const Map<String, Array<String>>& 
desired_layouts,
+                                           const VarLayoutMap& var_layout_map) 
{
+  Array<NLayout> input_layouts;
+  LayoutDecision layout_hint;
+  for (const auto& arg : call->args) {
+    const auto& in_layout = InferLayoutDecision(arg, var_layout_map);
+    if (in_layout->layout.defined()) {
+      layout_hint = in_layout;
+    }
+    input_layouts.push_back(in_layout);
+  }
+  if (!layout_hint.defined()) {
+    return InferLayoutOutput();
+  }
+  std::vector<NLayout> output_layouts;
+  const auto& sinfo = GetStructInfo(call);
+  if (sinfo->IsInstance<TensorStructInfoNode>()) {
+    output_layouts.push_back(layout_hint);
+  } else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
+    for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) {
+      output_layouts.push_back(layout_hint);
+    }
+  } else {
+    return InferLayoutOutput();
+  }
+  return InferLayoutOutput(input_layouts, {output_layouts}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutBinary(const Call& call,
+                                           const Map<String, Array<String>>& 
desired_layouts,
+                                           const VarLayoutMap& var_layout_map) 
{
+  const auto& output = ForwardInferLayoutCommon(call, desired_layouts, 
var_layout_map);
+  if (!output.defined()) {
+    return output;
+  }
+  std::vector<NLayout> input_layouts;
+  for (size_t i = 0; i < call->args.size(); i++) {
+    const auto& sinfo = GetStructInfo(call->args[i]);
+    if (const auto* t_info = sinfo.as<TensorStructInfoNode>()) {
+      if (t_info->ndim == 0) {
+        input_layouts.push_back(LayoutDecision(""));
+      } else {
+        input_layouts.push_back(output->input_layouts[i]);
+      }
+    } else {
+      LOG(FATAL) << "Binary input should be tensor, get " << 
sinfo->GetTypeKey();
+    }
+  }
+  return InferLayoutOutput(input_layouts, output->output_layouts, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutInplace(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map);
+}
+
+InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call,
+                                              const Map<String, 
Array<String>>& desired_layouts,
+                                              const VarLayoutMap& 
var_layout_map) {
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision in_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!in_layout->layout.defined()) {
+    if (input_shape.size() == 4) {
+      in_layout = LayoutDecision("NCHW");
+    } else if (input_shape.size() == 3) {
+      in_layout = LayoutDecision("NCD");
+    }
+  }
+  LayoutDecision g_layout = LayoutDecision("O");
+  return InferLayoutOutput({in_layout, g_layout, g_layout, g_layout, g_layout},
+                           {{in_layout, g_layout, g_layout}}, Attrs());
+}
+
+InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call,
+                                                const Map<String, 
Array<String>>& desired_layouts,
+                                                const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+  std::vector<size_t> expand_axes;
+  for (const auto& s : attrs->axis) {
+    expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+  }
+  LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, 
expand_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutNormalize(const Call& call,
+                                              const Map<String, 
Array<String>>& desired_layouts,
+                                              const VarLayoutMap& 
var_layout_map) {
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision in_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!in_layout->layout.defined()) {
+    if (input_shape.size() == 4) {
+      in_layout = LayoutDecision("NCHW");
+    } else if (input_shape.size() == 3) {
+      in_layout = LayoutDecision("NCD");
+    }
+  }
+  LayoutDecision g_layout = LayoutDecision("O");
+  return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout}, 
Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutMatmul(const Call& call,
+                                           const Map<String, Array<String>>& 
desired_layouts,
+                                           const VarLayoutMap& var_layout_map) 
{
+  Array<PrimExpr> empty;
+  const auto& a_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  const auto& b_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[1]))->GetShape().value_or(empty);
+
+  if (a_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision a_layout = InferLayoutDecision(call->args[0], var_layout_map);
+  if (!a_layout->layout.defined()) {
+    if (a_shape.size() == 4) {
+      a_layout = LayoutDecision("NCHW");
+    } else if (a_shape.size() == 3) {
+      a_layout = LayoutDecision("NCD");
+    } else if (a_shape.size() == 2) {
+      a_layout = LayoutDecision("NC");
+    }
+  }
+  size_t start = a_layout->layout.ndim() - b_shape.size();
+  String pre_layout;
+  for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) {
+    pre_layout = pre_layout + a_layout->layout[i].name();
+  }
+  LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
+  return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutPermute(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  std::vector<size_t> permute_axes;
+  const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+  if (!attrs->axes.defined()) {
+    for (size_t i = input_layout->layout.ndim(); i > 0; i--) {
+      permute_axes.push_back(i - 1);
+    }
+  } else {
+    for (const auto& a : attrs->axes.value()) {
+      permute_axes.push_back(a->value);
+    }
+  }
+  LayoutDecision output_layout = LayoutUtils::PermuteLayout(input_layout, 
permute_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call,
+                                               const Map<String, 
Array<String>>& desired_layouts,
+                                               const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+  if (attrs->keepdims) {
+    return InferLayoutOutput({input_layout}, {input_layout}, Attrs());
+  }
+  if (!attrs->axis.defined()) {
+    return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs());
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  std::vector<size_t> axes;
+  for (const auto& s : attrs->axis.value()) {
+    axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+  }
+  LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutReshape(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  const auto& output_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call))->GetShape().value_or(empty);
+  if (input_shape.size() == 0 || output_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision output_layout;
+  if (input_shape.size() == output_shape.size()) {
+    output_layout = input_layout;
+  } else if (input_shape.size() > output_shape.size()) {
+    const auto& reduce_axes = InferReduceAxes(input_shape, output_shape);
+    if (reduce_axes.size() == 0) {
+      return InferLayoutOutput();
+    }
+    output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes);
+  } else {
+    const auto& expand_axes = InferExpandAxes(input_shape, output_shape);
+    if (expand_axes.size() == 0) {
+      return InferLayoutOutput();
+    }
+    output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes);
+  }
+  return InferLayoutOutput({input_layout, LayoutDecision("O")}, 
{output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[0], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<SqueezeAttrs>();
+  std::vector<size_t> reduce_axes;
+  if (attrs->axis.defined()) {
+    for (const auto& s : attrs->axis.value()) {
+      size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size());
+      if (Downcast<Integer>(input_shape[v_index])->value == 1) {
+        reduce_axes.push_back(v_index);
+      }
+    }
+  } else {
+    for (size_t i = 0; i < input_shape.size(); i++) {
+      if (Downcast<Integer>(input_shape[i])->value == 1) {
+        reduce_axes.push_back(i);
+      }
+    }
+  }
+  LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, 
reduce_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutTake(const Call& call,
+                                         const Map<String, Array<String>>& 
desired_layouts,
+                                         const VarLayoutMap& var_layout_map) {
+  LayoutDecision input_layout = InferLayoutDecision(call->args[1], 
var_layout_map);
+  if (!input_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, 
std::vector<size_t>{0});
+  return InferLayoutOutput({LayoutDecision("WE"), input_layout}, 
{output_layout}, Attrs());
+}
+
+TVM_REGISTER_OP("relax.nn.conv1d")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.conv2d")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.max_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
MSCInferLayoutPool2d);
+// reduce axis ops
+TVM_REGISTER_OP("relax.argmax")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.argmin")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.max")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.min")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.mean")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.sum")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.prod")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.std")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReduceAxis);
+// binary ops
+TVM_REGISTER_OP("relax.add")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.divide")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.floor_divide")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.multiply")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.power")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.subtract")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.equal")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater_equal")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less_equal")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.not_equal")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.maximum")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.minimum")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_and")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_or")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_xor")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_and")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_or")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_xor")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBinary);
+// math ops
+TVM_REGISTER_OP("relax.expand_dims")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForkwardInferLayoutExpandDims);
+TVM_REGISTER_OP("relax.matmul")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutMatmul);
+TVM_REGISTER_OP("relax.permute_dims")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutPermute);
+TVM_REGISTER_OP("relax.reshape")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutReshape);
+TVM_REGISTER_OP("relax.squeeze")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutSqueeze);
+TVM_REGISTER_OP("relax.take")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutTake);
+// nn ops
+TVM_REGISTER_OP("relax.nn.batch_norm")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutBatchNorm);
+TVM_REGISTER_OP("relax.nn.group_norm")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutNormalize);
+TVM_REGISTER_OP("relax.nn.layer_norm")
+    .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", 
ForwardInferLayoutNormalize);
+
+// Backward Infer
+InferLayoutOutput BackwardInferLayoutCommon(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  NLayout output_layout = InferNLayout(call, var_layout_map);
+  LayoutDecision layout_hint;
+  if (output_layout.IsLeaf()) {
+    layout_hint = output_layout.LeafValue();
+  } else {
+    for (const auto& l : output_layout.NestedArray()) {
+      if (l.IsLeaf() && l.LeafValue()->layout.defined()) {
+        layout_hint = l.LeafValue();
+      }
+    }
+  }
+  if (!layout_hint->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<NLayout> input_layouts;
+  for (const auto& arg : call->args) {
+    const auto& saved_layout = InferLayoutDecision(arg, var_layout_map);
+    if (saved_layout->layout.defined()) {
+      input_layouts.push_back(saved_layout);
+    } else {
+      input_layouts.push_back(layout_hint);
+    }
+  }
+  return InferLayoutOutput(input_layouts, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutBinary(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  const auto& output = BackwardInferLayoutCommon(call, desired_layouts, 
var_layout_map);
+  if (!output.defined()) {
+    return output;
+  }
+  std::vector<NLayout> input_layouts;
+  for (size_t i = 0; i < call->args.size(); i++) {
+    const auto& sinfo = GetStructInfo(call->args[i]);
+    if (const auto* t_info = sinfo.as<TensorStructInfoNode>()) {
+      if (t_info->ndim == 0) {
+        input_layouts.push_back(LayoutDecision(""));
+      } else {
+        input_layouts.push_back(output->input_layouts[i]);
+      }
+    } else {
+      LOG(FATAL) << "Binary input should be tensor, get " << 
sinfo->GetTypeKey();
+    }
+  }
+  return InferLayoutOutput(input_layouts, output->output_layouts, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutInplace(const Call& call,
+                                             const Map<String, Array<String>>& 
desired_layouts,
+                                             const VarLayoutMap& 
var_layout_map) {
+  return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map);
+}
+
+InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call,
+                                               const Map<String, 
Array<String>>& desired_layouts,
+                                               const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 
0);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision g_layout = LayoutDecision("O");
+  return InferLayoutOutput({output_layout, g_layout, g_layout, g_layout, 
g_layout},
+                           {{output_layout, g_layout, g_layout}}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call,
+                                                const Map<String, 
Array<String>>& desired_layouts,
+                                                const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+  std::vector<size_t> expand_axes;
+  for (const auto& s : attrs->axis) {
+    expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+  }
+  LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, 
expand_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutNormalize(const Call& call,
+                                               const Map<String, 
Array<String>>& desired_layouts,
+                                               const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 
0);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision g_layout = LayoutDecision("O");
+  return InferLayoutOutput({output_layout, g_layout, g_layout}, 
{output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutMatmul(const Call& call,
+                                            const Map<String, Array<String>>& 
desired_layouts,
+                                            const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& b_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[1]))->GetShape().value_or(empty);
+  if (b_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  size_t start = output_layout->layout.ndim() - b_shape.size();
+  String pre_layout;
+  for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) {
+    pre_layout = pre_layout + output_layout->layout[i].name();
+  }
+  LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
+  return InferLayoutOutput({output_layout, b_layout}, {output_layout}, 
Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutPermute(const Call& call,
+                                             const Map<String, Array<String>>& 
desired_layouts,
+                                             const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  std::vector<size_t> permute_axes;
+  const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+  if (!attrs->axes.defined()) {
+    for (size_t i = output_layout->layout.ndim(); i > 0; i--) {
+      permute_axes.push_back(i - 1);
+    }
+  } else {
+    std::vector<int> attr_axes;
+    for (const auto& s : attrs->axes.value()) {
+      attr_axes.push_back(s->value);
+    }
+    for (size_t i = 0; i < output_layout->layout.ndim(); i++) {
+      int pos = ArrayUtils::IndexOf(attr_axes, static_cast<int>(i));
+      if (pos >= 0) {
+        permute_axes.push_back(pos);
+      } else {
+        permute_axes.push_back(i);
+      }
+    }
+  }
+  LayoutDecision input_layout = LayoutUtils::PermuteLayout(output_layout, 
permute_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call,
+                                                const Map<String, 
Array<String>>& desired_layouts,
+                                                const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+  if (attrs->keepdims) {
+    return InferLayoutOutput({output_layout}, {output_layout}, Attrs());
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  std::vector<size_t> axes;
+  for (const auto& s : attrs->axis.value()) {
+    axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+  }
+  LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutReshape(const Call& call,
+                                             const Map<String, Array<String>>& 
desired_layouts,
+                                             const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  const auto& output_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call))->GetShape().value_or(empty);
+  if (input_shape.size() == 0 || output_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision input_layout;
+  if (input_shape.size() == output_shape.size()) {
+    input_layout = output_layout;
+  } else if (input_shape.size() > output_shape.size()) {
+    const auto& reduce_axes = InferReduceAxes(input_shape, output_shape);
+    if (reduce_axes.size() == 0) {
+      return InferLayoutOutput();
+    }
+    input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes);
+  } else {
+    const auto& expand_axes = InferExpandAxes(input_shape, output_shape);
+    if (expand_axes.size() == 0) {
+      return InferLayoutOutput();
+    }
+    input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes);
+  }
+  return InferLayoutOutput({input_layout, LayoutDecision("O")}, 
{output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call,
+                                             const Map<String, Array<String>>& 
desired_layouts,
+                                             const VarLayoutMap& 
var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  Array<PrimExpr> empty;
+  const auto& input_shape =
+      
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+  if (input_shape.size() == 0) {
+    return InferLayoutOutput();
+  }
+  const auto* attrs = call->attrs.as<SqueezeAttrs>();
+  std::vector<size_t> reduce_axes;
+  if (attrs->axis.defined()) {
+    for (const auto& s : attrs->axis.value()) {
+      size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size());
+      if (Downcast<Integer>(input_shape[v_index])->value == 1) {
+        reduce_axes.push_back(v_index);
+      }
+    }
+  } else {
+    for (size_t i = 0; i < input_shape.size(); i++) {
+      if (Downcast<Integer>(input_shape[i])->value == 1) {
+        reduce_axes.push_back(i);
+      }
+    }
+  }
+  LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, 
reduce_axes);
+  return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutTake(const Call& call,
+                                          const Map<String, Array<String>>& 
desired_layouts,
+                                          const VarLayoutMap& var_layout_map) {
+  LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+  if (!output_layout->layout.defined()) {
+    return InferLayoutOutput();
+  }
+  LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, 
std::vector<size_t>{0});
+  return InferLayoutOutput({LayoutDecision("WE"), input_layout}, 
{output_layout}, Attrs());
+}
+
+TVM_REGISTER_OP("relax.nn.conv1d")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.conv2d")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.max_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
MSCInferLayoutPool2d);
+// reduce axis ops
+TVM_REGISTER_OP("relax.argmax")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.argmin")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.max")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.min")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.mean")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.sum")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.prod")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.std")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReduceAxis);
+// binary ops
+TVM_REGISTER_OP("relax.add")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.divide")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.floor_divide")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.multiply")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.power")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.subtract")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.equal")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater_equal")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less_equal")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.not_equal")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.maximum")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.minimum")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_and")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_or")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_xor")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_and")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_or")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_xor")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBinary);
+// math ops
+TVM_REGISTER_OP("relax.expand_dims")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutExpandDims);
+TVM_REGISTER_OP("relax.matmul")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutMatmul);
+TVM_REGISTER_OP("relax.permute_dims")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutPermute);
+TVM_REGISTER_OP("relax.reshape")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutReshape);
+TVM_REGISTER_OP("relax.squeeze")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutSqueeze);
+TVM_REGISTER_OP("relax.take")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutTake);
+// nn ops
+TVM_REGISTER_OP("relax.nn.batch_norm")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutBatchNorm);
+TVM_REGISTER_OP("relax.nn.group_norm")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutNormalize);
+TVM_REGISTER_OP("relax.nn.layer_norm")
+    .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout", 
BackwardInferLayoutNormalize);
+
+class LayoutInfer : public ExprVisitor {
+ public:
+  explicit LayoutInfer(const IRModule& ref_module) : ref_module_(ref_module) { 
Reset(); }
+
+  void Reset() {
+    infered_ = false;
+    var_map_.clear();
+    ordered_exprs_.clear();
+  }
+
+  void RecordExpr(const Var& var, const Expr& expr) {
+    var_map_.Set(var, expr);
+    ordered_exprs_.push_back(expr);
+  }
+
+  Expr Infer(const Expr& expr) {
+    Reset();
+    ForwardInfer(expr);
+    BackwardInfer();
+    return expr;
+  }
+
+  void ForwardInfer(const Expr& expr) { ExprVisitor::VisitExpr(expr); }
+
+  void BackwardInfer() {
+    for (size_t e_idx = ordered_exprs_.size(); e_idx > 0; e_idx--) {
+      const Expr& expr = ordered_exprs_[e_idx - 1];
+      if (expr->IsInstance<TupleNode>()) {
+        continue;
+      }
+      if (expr->IsInstance<TupleGetItemNode>()) {
+        continue;
+      }
+      if (!expr->IsInstance<CallNode>()) {
+        continue;
+      }
+      const Call& call = Downcast<Call>(expr);
+      size_t infered_num = 0;
+      for (const auto& arg : call->args) {
+        if (arg->IsInstance<VarNode>() && var_map_.count(Downcast<Var>(arg))) {
+          if (LayoutUtils::LayoutInfered(var_map_[Downcast<Var>(arg)]) > 0) {
+            infered_num++;
+          }
+        } else if (LayoutUtils::LayoutInfered(arg)) {
+          infered_num++;
+        }
+      }
+      if (call->args.size() == 0 || infered_num == call->args.size() ||
+          !call->op->IsInstance<OpNode>() || 
LayoutUtils::HasUnknownDimTensor(call->args)) {
+        continue;
+      }
+      const OpNode* op_node = call->op.as<OpNode>();
+      if (op_node == nullptr) {
+        continue;
+      }
+      // Infer by op_node
+      Op op = Downcast<Op>(GetRef<Op>(op_node));
+      InferLayoutOutput infered_layout;
+      const auto msc_infer_map = 
Op::GetAttrMap<FRelaxInferLayout>("FMSCBackwardInferLayout");
+      try {
+        if (msc_infer_map.count(op)) {
+          FRelaxInferLayout f = msc_infer_map[op];
+          infered_layout = f(call, Map<String, Array<String>>(), 
var_layout_map_);
+        } else {
+          infered_layout =
+              BackwardInferLayoutCommon(call, Map<String, Array<String>>(), 
var_layout_map_);
+        }
+      } catch (runtime::InternalError& err) {
+        LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << 
err.message();
+        infered_layout = InferLayoutOutput();
+      }
+      try {
+        if (infered_layout.defined()) {
+          SetInputLayouts(infered_layout->input_layouts, call);
+        }
+      } catch (runtime::InternalError& err) {
+        LOG(WARNING) << "Failed to backward set inputs layout for " << call << 
" : "
+                     << err.message();
+      }
+    }
+  }
+
+  void SetInputLayouts(const Array<NLayout>& input_layouts, const Call& call) {
+    if (input_layouts.size() == call->args.size()) {
+      for (size_t i = 0; i < input_layouts.size(); i++) {
+        if (call->args[i]->IsInstance<VarNode>()) {
+          const auto& var = Downcast<Var>(call->args[i]);
+          var_layout_map_[var] = input_layouts[i];
+          if (var_map_.count(var)) {
+            if (LayoutUtils::SetLayout(var_map_[var], input_layouts[i])) {
+              infered_ = true;
+            }
+          } else if (LayoutUtils::SetLayout(var, input_layouts[i])) {
+            infered_ = true;
+          }
+        } else if (LayoutUtils::SetLayout(call->args[i], input_layouts[i])) {
+          infered_ = true;
+        }
+      }
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
+    ExprVisitor::VisitBinding_(binding, call_node);
+    const auto& call = GetRef<Call>(call_node);
+    if (const auto* v_node = call->op.as<GlobalVarNode>()) {
+      // infer global func and set var layouts
+      const auto& func = 
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+      Infer(func);
+      for (size_t i = 0; i < func->params.size(); i++) {
+        if (var_layout_map_.count(func->params[i]) &&
+            LayoutUtils::SetLayout(call->args[i], 
var_layout_map_[func->params[i]])) {
+          infered_ = true;
+        }
+      }
+      if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
+        var_layout_map_[binding->var] = GetNLayout(var_layout_map_, 
b_node->body);
+        if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) {
+          infered_ = true;
+        }
+      } else {
+        LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+      }
+    } else {
+      // infer call
+      bool infer_outputs = true;
+      RecordExpr(binding->var, call);
+      if (LayoutUtils::LayoutInfered(call)) {
+        infer_outputs = false;
+      }
+      if (call->args.size() == 0 || !call->op->IsInstance<OpNode>() ||
+          LayoutUtils::HasUnknownDimTensor(call->args)) {
+        infer_outputs = false;
+      }
+      const OpNode* op_node = call->op.as<OpNode>();
+      if (op_node == nullptr) {
+        infer_outputs = false;
+      }
+      if (infer_outputs) {
+        // infer layouts
+        Op op = Downcast<Op>(GetRef<Op>(op_node));
+        InferLayoutOutput infered_layout;
+        const auto msc_infer_map = 
Op::GetAttrMap<FRelaxInferLayout>("FMSCForwardInferLayout");
+        const auto relax_infer_map = 
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
+        bool set_inputs = true;
+        try {
+          if (msc_infer_map.count(op)) {
+            FRelaxInferLayout f = msc_infer_map[op];
+            infered_layout = f(call, Map<String, Array<String>>(), 
var_layout_map_);
+          } else if (!relax_infer_map.count(op)) {
+            infered_layout =
+                ForwardInferLayoutCommon(call, Map<String, Array<String>>(), 
var_layout_map_);
+          }
+          if (relax_infer_map.count(op) && !infered_layout.defined()) {
+            FRelaxInferLayout f = relax_infer_map[op];
+            infered_layout = f(call, Map<String, Array<String>>(), 
var_layout_map_);
+            set_inputs = false;
+          }
+        } catch (runtime::InternalError& err) {
+          LOG(WARNING) << "Failed to forward infer layout for " << 
binding->var << " : "
+                       << binding->value << ", reason: " << err.message();
+          infered_layout = InferLayoutOutput();
+        }
+        if (infered_layout.defined() && infered_layout->output_layouts.size() 
== 1) {
+          try {
+            var_layout_map_[binding->var] = infered_layout->output_layouts[0];
+            if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) {
+              infered_ = true;
+            }
+          } catch (runtime::InternalError& err) {
+            LOG(WARNING) << "Failed to forward set output layout for " << 
binding->var << " : "
+                         << binding->value << ", reason: " << err.message();
+          }
+        }
+        if (set_inputs && infered_layout.defined()) {
+          try {
+            SetInputLayouts(infered_layout->input_layouts, call);
+          } catch (runtime::InternalError& err) {
+            LOG(WARNING) << "Failed to forward set inputs layout for " << call 
<< " : "
+                         << err.message();
+          }
+        }
+      }
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) 
final {
+    ExprVisitor::VisitBinding_(binding, val);
+    std::vector<NLayout> input_layout;
+    for (const auto& field : val->fields) {
+      if (binding->var->IsInstance<DataflowVarNode>()) {
+        // Df var: Use the current realized layout to group the tuple;
+        input_layout.push_back(GetNLayout(var_layout_map_, field));
+      } else {
+        // Global var: Use the initial layout to group the tuple;
+        input_layout.push_back(InitialNLayout(field));
+      }
+    }
+    if (IsNestedTensor(binding->var)) {
+      var_layout_map_[binding->var] = input_layout;
+    }
+    RecordExpr(binding->var, GetRef<Tuple>(val));
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
val) final {
+    ExprVisitor::VisitBinding_(binding, val);
+    NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
+                               ? GetNLayout(var_layout_map_, val->tuple)
+                               : InitialNLayout(val->tuple);
+    var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+    RecordExpr(binding->var, GetRef<TupleGetItem>(val));
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) 
final {
+    ExprVisitor::VisitBinding_(binding, val);
+    const NLayout& out_layout = LayoutDecision("O");
+    var_layout_map_[binding->var] = out_layout;
+    if (LayoutUtils::SetLayout(GetRef<ShapeExpr>(val), out_layout)) {
+      infered_ = true;
+    }
+  }
+
+  bool infered() { return infered_; }
+
+ private:
+  IRModule ref_module_;
+  bool infered_;
+  Map<Var, Expr> var_map_;
+  Array<Expr> ordered_exprs_;
+  std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual> 
var_layout_map_;
+};  // class LayoutInfer
+
+class LayoutChecker : public ExprVisitor {
+ public:
+  LayoutChecker() { missing_num_ = 0; }
+
+  void Check(const Expr& expr) {
+    ExprVisitor::VisitExpr(expr);
+    ICHECK_EQ(missing_num_, 0) << "Some layout is missing";
+  }
+
+  void VisitExpr_(const CallNode* call) final {
+    ExprVisitor::VisitExpr_(call);
+    if (!LayoutUtils::LayoutInfered(GetRef<Call>(call))) {
+      missing_num_++;
+    }
+  }
+
+  void VisitExpr_(const ConstantNode* cn) final {
+    ExprVisitor::VisitExpr_(cn);
+    if (!LayoutUtils::LayoutInfered(GetRef<Constant>(cn))) {
+      missing_num_++;
+    }
+  }
+
+ private:
+  size_t missing_num_;
+};  // class LayoutChecker
+
+void SetExprLayout(const IRModule& ref_module, const Expr& func, bool 
allow_missing) {
+  auto layout_infer = LayoutInfer(ref_module);
+  auto new_func = layout_infer.Infer(func);
+  if (!allow_missing) {
+    LayoutChecker().Check(new_func);
+  }
+}
+
+namespace transform {
+
+Pass SetExprLayout(bool allow_missing, const String& entry_name) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule m,
+                                                                            
PassContext pc) {
+    relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing);
+    return m;
+  };
+  return CreateModulePass(pass_func, 0, "SetExprLayout", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc 
b/src/contrib/msc/core/transform/set_expr_name.cc
new file mode 100644
index 0000000000..5b39a5a7ac
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_expr_name.cc
+ * \brief Pass for setting name for call and constant.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../utils.h"
+
+namespace tvm {
+using namespace tvm::contrib::msc;
+
+namespace relax {
+
+/*!
+ * \brief Name setter for Relax
+ */
+class RelaxExprNameSetter : public ExprVisitor {
+ public:
+  explicit RelaxExprNameSetter(const IRModule& ref_module) : 
ref_module_(ref_module) {}
+
+  void VisitBindingBlock(const BindingBlock& block) final {
+    String block_name = SpanUtils::GetAttr(block->span, "name");
+    if (block_name.size() == 0) {
+      block_name = "block";
+    }
+    if (setted_blocks_.count(block_name)) {
+      int cnt = 1;
+      while (setted_blocks_.count(block_name + "_" + std::to_string(cnt))) {
+        cnt++;
+      }
+      block_name = block_name + "_" + std::to_string(cnt);
+    }
+    setted_blocks_.insert(block_name);
+    block_stack_.push_back(block_name);
+    const String& unique_name = StringUtils::Join(block_stack_, ".");
+    block->span = SpanUtils::SetAttr(block->span, "name", unique_name);
+    ExprVisitor::VisitBindingBlock(block);
+    block_stack_.pop_back();
+  }
+
+  void VisitExpr_(const ConstantNode* val) {
+    ExprVisitor::VisitExpr_(val);
+    const String& unique_name = GetUniqueName(GetRef<Constant>(val), "const");
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    expr_names_.Set(GetRef<Constant>(val), unique_name);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) {
+    ExprVisitor::VisitBinding_(binding, val);
+    const String& unique_name = GetUniqueName(GetRef<Constant>(val), "const");
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    expr_names_.Set(binding->var, unique_name);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) {
+    ExprVisitor::VisitBinding_(binding, val);
+    const String& unique_name = GetUniqueName(GetRef<ShapeExpr>(val), "shape");
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    expr_names_.Set(binding->var, unique_name);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) {
+    ExprVisitor::VisitBinding_(binding, val);
+    const String& unique_name = GetUniqueName(GetRef<Tuple>(val), "tuple");
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    expr_names_.Set(binding->var, unique_name);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
val) {
+    ExprVisitor::VisitBinding_(binding, val);
+    ICHECK(expr_names_.count(val->tuple)) << "Can not find tuple of " << 
GetRef<TupleGetItem>(val);
+    const String& unique_name = expr_names_[val->tuple] + "." + 
std::to_string(val->index);
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    expr_names_.Set(binding->var, unique_name);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* val) {
+    ExprVisitor::VisitBinding_(binding, val);
+    String name_hint, optype;
+    if (const auto* op_node = val->op.as<OpNode>()) {
+      const std::string& op_name = op_node->name;
+      int rpos = op_name.rfind(".");
+      name_hint = op_name.substr(rpos + 1);
+      optype = StringUtils::Replace(op_node->name, "relax.", "");
+    } else if (const auto* v_node = val->op.as<GlobalVarNode>()) {
+      const auto& func = 
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+      ExprVisitor::VisitExpr(func);
+      const auto& name_opt = func->GetAttr<runtime::String>(attr::kComposite);
+      ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+      name_hint = name_opt.value();
+      optype = name_hint;
+    }
+    // set name
+    const String& unique_name = GetUniqueName(GetRef<Expr>(val), name_hint);
+    if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+      val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+    }
+    // set constant consumer && master
+    Array<String> input_types;
+    try {
+      input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true);
+    } catch (runtime::InternalError& err) {
+      LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(val) << " 
: " << err.message();
+      throw err;
+    }
+    for (size_t i = 0; i < input_types.size(); i++) {
+      if (input_types[i] == "input") {
+        continue;
+      }
+      if (const auto* c_node = val->args[i].as<ConstantNode>()) {
+        const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
+        if (constant_consumers_.count(const_name)) {
+          val->span = SpanUtils::SetAttr(val->span, "master", 
constant_consumers_[const_name]);
+        } else {
+          constant_consumers_.Set(const_name, unique_name);
+        }
+      }
+    }
+    expr_names_.Set(binding->var, unique_name);
+  }
+
+ private:
+  const String GetUniqueName(const Expr& expr, const String& name_hint) {
+    String expr_name = SpanUtils::GetAttr(expr->span, "name");
+    if (expr_name.size() == 0) {
+      expr_name = name_hint;
+    }
+    if (!setted_names_.count(expr_name)) {
+      setted_names_.Set(expr_name, expr);
+      return expr_name;
+    }
+    if (setted_names_[expr_name] == expr) {
+      return expr_name;
+    }
+    int cnt = 1;
+    while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) &&
+           setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) {
+      cnt++;
+    }
+    expr_name = expr_name + "_" + std::to_string(cnt);
+    if (!setted_names_.count(expr_name)) {
+      setted_names_.Set(expr_name, expr);
+    }
+    return expr_name;
+  }
+
+  Map<String, Expr> setted_names_;
+  Map<String, String> constant_consumers_;
+  std::set<String> setted_blocks_;
+  Array<String> block_stack_;
+  Map<Expr, String> expr_names_;
+  IRModule ref_module_;
+};  // class ExprNameSetter
+
+void SetRelaxExprName(const IRModule& ref_module, const Expr& e) {
+  RelaxExprNameSetter(ref_module).VisitExpr(e);
+}
+
+namespace transform {
+
+Pass SetRelaxExprName(const String& entry_name) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule m,
+                                                                            
PassContext pc) {
+    relax::SetRelaxExprName(m, m->Lookup(entry_name));
+    return m;
+  };
+  return CreateModulePass(pass_func, 0, "SetRelaxExprName", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName);
+
+}  // namespace transform
+}  // namespace relax
+
+namespace relay {
+
+/*!
+ * \brief Name setter for Relay
+ */
+class RelayExprNameSetter : public ExprVisitor {
+ public:
+  explicit RelayExprNameSetter(const IRModule& ref_module) : 
ref_module_(ref_module) {}
+
+  void VisitExpr_(const ConstantNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    const String& unique_name = GetUniqueName(GetRef<Constant>(op), "const");
+    if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+      op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+    }
+  }
+
+  void VisitExpr_(const TupleNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    const String& unique_name = GetUniqueName(GetRef<Tuple>(op), "tuple");
+    if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+      op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+    }
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, "name");
+    const String& unique_name = tuple_name + "." + std::to_string(op->index);
+    if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+      op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+    }
+  }
+
+  void VisitExpr_(const FunctionNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    const auto& name_opt = op->GetAttr<runtime::String>(attr::kComposite);
+    const String& name_hint = name_opt.defined() ? name_opt.value() : "func";
+    const String& unique_name = GetUniqueName(GetRef<Function>(op), name_hint);
+    if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+      op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+    }
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    String name_hint, optype;
+    if (const auto* op_node = op->op.as<OpNode>()) {
+      const std::string& op_name = op_node->name;
+      int rpos = op_name.rfind(".");
+      name_hint = op_name.substr(rpos + 1);
+      optype = StringUtils::Replace(op_node->name, "relay.", "");
+    } else if (const auto* v_node = op->op.as<GlobalVarNode>()) {
+      const auto& func = 
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+      ExprVisitor::VisitExpr(func);
+      const auto& name_opt = func->GetAttr<runtime::String>(attr::kComposite);
+      ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+      optype = name_opt.value();
+      name_hint = optype;
+    }
+    // set name
+    const String& unique_name = GetUniqueName(GetRef<Expr>(op), name_hint);
+    if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+      op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+    }
+    // set constant consumer && master
+    Array<String> input_types;
+    try {
+      input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false);
+    } catch (runtime::InternalError& err) {
+      LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(op) << " 
: " << err.message();
+      throw err;
+    }
+    for (size_t i = 0; i < input_types.size(); i++) {
+      if (input_types[i] == "input") {
+        continue;
+      }
+      if (const auto* c_node = op->args[i].as<ConstantNode>()) {
+        const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
+        if (constant_consumers_.count(const_name)) {
+          op->span = SpanUtils::SetAttr(op->span, "master", 
constant_consumers_[const_name]);
+        } else {
+          constant_consumers_.Set(const_name, unique_name);
+        }
+      }
+    }
+  }
+
+ private:
+  const String GetUniqueName(const Expr& expr, const String& name_hint) {
+    String expr_name = SpanUtils::GetAttr(expr->span, "name");
+    if (expr_name.size() == 0) {
+      expr_name = name_hint;
+    }
+    if (!setted_names_.count(expr_name)) {
+      setted_names_.Set(expr_name, expr);
+      return expr_name;
+    }
+    if (setted_names_[expr_name] == expr) {
+      return expr_name;
+    }
+    int cnt = 1;
+    while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) &&
+           setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) {
+      cnt++;
+    }
+    expr_name = expr_name + "_" + std::to_string(cnt);
+    if (!setted_names_.count(expr_name)) {
+      setted_names_.Set(expr_name, expr);
+    }
+    return expr_name;
+  }
+
+  Map<String, Expr> setted_names_;
+  Map<String, String> constant_consumers_;
+  IRModule ref_module_;
+};  // class ExprNameSetter
+
+void SetRelayExprName(const IRModule& ref_module, const Expr& e) {
+  RelayExprNameSetter(ref_module).VisitExpr(e);
+}
+
+namespace transform {
+
+Pass SetRelayExprName(const String& entry_name) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule m,
+                                                                            
PassContext pc) {
+    relay::SetRelayExprName(m, m->Lookup(entry_name));
+    return m;
+  };
+  return CreateModulePass(pass_func, 0, "SetRelayExprName", {});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.SetRelayExprName").set_body_typed(SetRelayExprName);
+
+}  // namespace transform
+}  // namespace relay
+
+}  // namespace tvm
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
new file mode 100644
index 0000000000..7ecff876f2
--- /dev/null
+++ b/src/contrib/msc/core/utils.cc
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/utils.cc
+ */
+
+#include "utils.h"
+
+#include <string>
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+size_t CommonUtils::GetIndex(int index, size_t max_size) {
+  size_t v_index;
+  if (index < 0) {
+    v_index = index + max_size;
+  } else {
+    v_index = index;
+  }
+  ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " << 
max_size;
+  return v_index;
+}
+
+const Array<String> StringUtils::Split(const String& src_string, const String& 
sep) {
+  Array<String> sub_strings;
+  if (src_string.size() == 0) {
+    return sub_strings;
+  }
+  std::string src_cstring = src_string;
+  const std::string& csep = sep;
+  int pos = src_cstring.find(csep);
+  while (pos >= 0) {
+    if (pos > 0) {
+      sub_strings.push_back(src_cstring.substr(0, pos));
+    }
+    src_cstring = src_cstring.substr(pos + csep.size());
+    pos = src_cstring.find(csep);
+  }
+  if (src_cstring.size() > 0) {
+    sub_strings.push_back(src_cstring);
+  }
+  return sub_strings;
+}
+
+const String StringUtils::Join(const Array<String>& sub_strings, const String& 
joint) {
+  String join_str = "";
+  for (size_t i = 0; i < sub_strings.size(); i++) {
+    join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : 
joint);
+  }
+  return join_str;
+}
+
+const String StringUtils::Replace(const String& src_string, const String& 
old_str,
+                                  const String& new_str) {
+  String new_string;
+  const auto& sub_strings = Split(src_string, old_str);
+  for (size_t i = 0; i < sub_strings.size(); i++) {
+    new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ? 
"" : new_str);
+  }
+  return new_string;
+}
+
+const std::tuple<String, String> StringUtils::SplitOnce(const String& 
src_string, const String& sep,
+                                                        bool from_left) {
+  if (src_string.size() == 0) {
+    return std::make_tuple(String(), String());
+  }
+  std::string src_cstring = src_string;
+  const std::string& csep = sep;
+  int pos = from_left ? src_cstring.find(csep) : src_cstring.rfind(csep);
+  if (pos >= 0) {
+    return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos 
+ csep.size()));
+  }
+  return std::make_tuple(src_string, String());
+}
+
+const Array<String> StringUtils::GetClosures(const String& src_string, const 
String& left,
+                                             const String& right) {
+  Array<String> tokens;
+  if (src_string.size() == 0) {
+    return tokens;
+  }
+  String token = "start";
+  String left_str = src_string;
+  while (token.size() > 0) {
+    std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left);
+    if (left_str.size() > 0) {
+      std::tie(token, left_str) = StringUtils::SplitOnce(left_str, right);
+    } else {
+      token = "";
+    }
+    if (token.size() > 0) {
+      tokens.push_back(token);
+    }
+  }
+  return tokens;
+}
+
+const String StringUtils::GetClosureOnce(const String& src_string, const 
String& left,
+                                         const String& right, bool from_left) {
+  if (src_string.size() == 0) {
+    return "";
+  }
+  String val = std::get<1>(SplitOnce(src_string, left, from_left));
+  if (val.size() > 0) {
+    val = std::get<0>(StringUtils::SplitOnce(val, right, from_left));
+  }
+  return val;
+}
+
+const String StringUtils::ToString(const runtime::ObjectRef& obj) {
+  String obj_string;
+  if (!obj.defined()) {
+    obj_string = "";
+  } else if (obj.as<StringObj>()) {
+    obj_string = Downcast<String>(obj);
+  } else if (const auto* n = obj.as<IntImmNode>()) {
+    obj_string = std::to_string(n->value);
+  } else if (const auto* n = obj.as<FloatImmNode>()) {
+    obj_string = std::to_string(n->value);
+  } else if (const auto* n = obj.as<ArrayNode>()) {
+    for (size_t i = 0; i < n->size(); i++) {
+      obj_string = obj_string + ToString((*n)[i]);
+      if (n->size() == 1 || i < n->size() - 1) {
+        obj_string = obj_string + ",";
+      }
+    }
+  } else {
+    std::ostringstream obj_des;
+    obj_des << obj;
+    obj_string = obj_des.str();
+  }
+  return obj_string;
+}
+
+bool StringUtils::CompareArrays(const Array<String>& left, const 
Array<String>& right, int size) {
+  if (left.size() == right.size() && left.size() == 0) {
+    return true;
+  }
+  if (size == -1 && left.size() != right.size()) {
+    return false;
+  }
+  if (left.size() == 0 || right.size() == 0) {
+    return false;
+  }
+  size = left.size();
+  ICHECK_GT(size, 0) << "Positive size should be given, get " << size;
+  if (size > static_cast<int>(left.size()) || size > 
static_cast<int>(right.size())) {
+    return false;
+  }
+  for (size_t i = 0; i < static_cast<size_t>(size); i++) {
+    if (left[i] != right[i]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+const Span SpanUtils::SetAttr(const Span& span, const String& key, const 
String& value) {
+  if (value.size() == 0) {
+    return span;
+  }
+  String new_source;
+  Array<String> tokens{"<" + key + ">", "</" + key + ">"};
+  if (span.defined() && span->source_name.defined()) {
+    const String& source_str = span->source_name->name;
+    String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0]));
+    String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1]));
+    if (left.size() > 0) {
+      new_source = left + tokens[0] + value + tokens[1] + right;
+    } else {
+      new_source = source_str + tokens[0] + value + tokens[1];
+    }
+  } else {
+    new_source = tokens[0] + value + tokens[1];
+  }
+  if (span.defined()) {
+    return Span(SourceName::Get(new_source), span->line, span->end_line, 
span->column,
+                span->end_column);
+  }
+  return Span(SourceName::Get(new_source), 0, 0, 0, 0);
+}
+
+const String SpanUtils::GetAttr(const Span& span, const String& key) {
+  if (span.defined() && span->source_name.defined()) {
+    Array<String> tokens{"<" + key + ">", "</" + key + ">"};
+    return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], 
tokens[1]);
+  }
+  return "";
+}
+
+const Map<String, String> SpanUtils::GetAttrs(const Span& span) {
+  Map<String, String> attrs;
+  for (const auto& key : StringUtils::GetClosures(span->source_name->name, 
"</", ">")) {
+    attrs.Set(key, GetAttr(span, key));
+  }
+  return attrs;
+}
+
+const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t 
inputs_num,
+                                             bool as_relax) {
+  Array<String> input_types;
+  if (as_relax && (optype == "broadcast_to" || optype == "reshape")) {
+    input_types.push_back("input");
+    input_types.push_back("shape");
+  } else if (optype == "clip" && as_relax) {
+    input_types.push_back("input");
+    input_types.push_back("min");
+    input_types.push_back("max");
+  } else if (optype == "full" && as_relax) {
+    input_types.push_back("shape");
+    input_types.push_back("input");
+  } else if (optype == "trilu") {
+    input_types.push_back("input");
+    input_types.push_back("k");
+  } else if (optype == "image.resize2d" && as_relax) {
+    input_types.push_back("input");
+    input_types.push_back("size");
+  } else if (optype == "nn.conv1d" || optype == "nn.conv2d" || optype == 
"nn.conv3d") {
+    input_types.push_back("input");
+    input_types.push_back("weight");
+  } else if (optype == "nn.batch_norm") {
+    input_types.push_back("input");
+    input_types.push_back("gamma");
+    input_types.push_back("beta");
+    input_types.push_back("mean");
+    input_types.push_back("var");
+  } else if (optype == "nn.layer_norm" || optype == "nn.group_norm") {
+    input_types.push_back("input");
+    input_types.push_back("gamma");
+    input_types.push_back("beta");
+  } else if (optype == "msc.linear") {
+    if (as_relax) {
+      input_types.push_back("weight");
+      input_types.push_back("input");
+    } else {
+      input_types.push_back("input");
+      input_types.push_back("weight");
+    }
+  } else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") {
+    input_types.push_back("input");
+    input_types.push_back("weight");
+    input_types.push_back("bias");
+    if (as_relax) {
+      input_types.push_back("expand_bias");
+    }
+  } else if (optype == "msc.linear_bias") {
+    if (as_relax) {
+      input_types.push_back("weight");
+      input_types.push_back("input");
+    } else {
+      input_types.push_back("input");
+      input_types.push_back("weight");
+    }
+    input_types.push_back("bias");
+  } else if (optype == "msc.embedding" && inputs_num == 2) {
+    input_types.push_back("input");
+    input_types.push_back("weight");
+  } else if (optype == "msc.embedding" && inputs_num == 4) {
+    input_types.push_back("input");
+    input_types.push_back("reduce_in");
+    input_types.push_back("weight");
+    input_types.push_back("expand_out");
+  } else if (optype == "msc.gelu") {
+    input_types.push_back("input");
+    input_types.push_back("factor_1");
+    input_types.push_back("factor_2");
+    input_types.push_back("factor_3");
+  } else {
+    for (size_t i = 0; i < inputs_num; i++) {
+      input_types.push_back("input");
+    }
+  }
+  ICHECK_EQ(input_types.size(), inputs_num)
+      << "Optype " << optype << " get input types " << input_types << " and 
inputs_num "
+      << inputs_num << " mismatch";
+  return input_types;
+}
+
+const Array<String> ExprUtils::GetInputTypes(const RelaxCall& call) {
+  const String& optype = StringUtils::Replace(Downcast<Op>(call->op)->name, 
"relax.", "");
+  return GetInputTypes(optype, call->args.size(), true);
+}
+
+const Array<String> ExprUtils::GetInputTypes(const RelayCall& call) {
+  const String& optype = StringUtils::Replace(Downcast<Op>(call->op)->name, 
"relay.", "");
+  return GetInputTypes(optype, call->args.size(), false);
+}
+
+TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr);
+
+TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs);
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
new file mode 100644
index 0000000000..9da4ce3346
--- /dev/null
+++ b/src/contrib/msc/core/utils.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/utils.h
+ * \brief Common utilities for msc.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relay/expr.h>
+
+#include <tuple>
+#include <vector>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using RelaxCall = tvm::relax::Call;
+using RelayCall = tvm::relay::Call;
+
+class CommonUtils {
+ public:
+  /*!
+   * \brief Check if the index is in range.
+   * \return The valid index.
+   */
+  TVM_DLL static size_t GetIndex(int index, size_t max_size);
+};
+
+/*!
+ * \brief Utils for String.
+ */
+class StringUtils {
+ public:
+  /*!
+   * \brief Split the String into sub Strings.
+   * \return The SubStrings.
+   */
+  TVM_DLL static const Array<String> Split(const String& src_string, const 
String& sep);
+
+  /*!
+   * \brief Join the SubStrings into String.
+   * \return The String.
+   */
+  TVM_DLL static const String Join(const Array<String>& sub_strings, const 
String& joint);
+
+  /*!
+   * \brief Replace the substring old to new in String.
+   * \return The replaced String.
+   */
+  TVM_DLL static const String Replace(const String& src_string, const String& 
old_str,
+                                      const String& new_str);
+
+  /*!
+   * \brief Split the String into two sub Strings, only split by the frist seq.
+   * \return The SubStrings.
+   */
+  TVM_DLL static const std::tuple<String, String> SplitOnce(const String& 
src_string,
+                                                            const String& sep,
+                                                            bool from_left = 
true);
+
+  /*!
+   * \brief Get the tokens between left and right.
+   * \return The Tokens.
+   */
+  TVM_DLL static const Array<String> GetClosures(const String& src_string, 
const String& left,
+                                                 const String& right);
+
+  /*!
+   * \brief Get the first token between left and right.
+   * \return The Token.
+   */
+  TVM_DLL static const String GetClosureOnce(const String& src_string, const 
String& left,
+                                             const String& right, bool 
from_left = true);
+
+  /*!
+   * \brief Change Object to String.
+   * \return The String.
+   */
+  TVM_DLL static const String ToString(const runtime::ObjectRef& obj);
+
+  /*!
+   * \brief Compare String arrays.
+   * \return Whether two array are same.
+   */
+  TVM_DLL static bool CompareArrays(const Array<String>& left, const 
Array<String>& right,
+                                    int size = -1);
+};
+
+/*!
+ * \brief Utils for Array.
+ */
+class ArrayUtils {
+ public:
+  /*!
+   * \brief Replace the element old to new in Array.
+   * \return The replaced Array.
+   */
+  template <typename T>
+  TVM_DLL static const Array<T> Replace(const Array<T>& src_array, const T& 
old_ele,
+                                        const T& new_ele) {
+    Array<T> new_array;
+    for (const auto& a : src_array) {
+      if (a == old_ele) {
+        new_array.push_back(new_ele);
+      } else {
+        new_array.push_back(a);
+      }
+    }
+    return new_array;
+  }
+
+  /*!
+   * \brief Find the index of element.
+   * \return The index, -1 if not found.
+   */
+  template <typename T>
+  TVM_DLL static int IndexOf(const std::vector<T>& array, const T& ele) {
+    for (size_t i = 0; i < array.size(); i++) {
+      if (array[i] == ele) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  /*!
+   * \brief Downcast elements in the array.
+   * \return The downcasted array
+   */
+  template <typename T>
+  TVM_DLL static const Array<T> Cast(const Array<PrimExpr>& src_array) {
+    Array<T> new_array;
+    for (const auto& s : src_array) {
+      new_array.push_back(Downcast<T>(s));
+    }
+    return new_array;
+  }
+};
+
+/*!
+ * \brief Utils for Span.
+ */
+class SpanUtils {
+ public:
+  /*!
+   * \brief Set <key>value</key> to the Span.
+   * \return The new Span.
+   */
+  TVM_DLL static const Span SetAttr(const Span& span, const String& key, const 
String& value);
+
+  /*!
+   * \brief Get the value in <key>value</key> from the Span.
+   * \return The value String.
+   */
+  TVM_DLL static const String GetAttr(const Span& span, const String& key);
+
+  /*!
+   * \brief Get all the key:value in format <key>value</key> from the Span.
+   * \return The Attrs Map.
+   */
+  TVM_DLL static const Map<String, String> GetAttrs(const Span& span);
+};
+
+/*!
+ * \brief Utils for Expr.
+ */
+class ExprUtils {
+ public:
+  /*!
+   * \brief Get the input types of call.
+   * \return The input types.
+   */
+  TVM_DLL static const Array<String> GetInputTypes(const String& optype, 
size_t inputs_num,
+                                                   bool as_relax);
+
+  /*!
+   * \brief Get the input types of call.
+   * \return The input types.
+   */
+  TVM_DLL static const Array<String> GetInputTypes(const RelaxCall& call);
+
+  /*!
+   * \brief Get the input types of call.
+   * \return The input types.
+   */
+  TVM_DLL static const Array<String> GetInputTypes(const RelayCall& call);
+
+  /*!
+   * \brief Get the scalar value of ndarray.
+   * \return The scalar value.
+   */
+  template <typename T>
+  TVM_DLL static const T GetScalar(const runtime::NDArray& array, size_t i = 
0) {
+    if (array->dtype.code == kDLInt) {
+      if (array->dtype.bits == 8) {
+        return T(reinterpret_cast<int8_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 16) {
+        return T(reinterpret_cast<int16_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 32) {
+        return T(reinterpret_cast<int32_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 64) {
+        return T(reinterpret_cast<int64_t*>(array->data)[i]);
+      }
+    } else if (array->dtype.code == kDLUInt) {
+      if (array->dtype.bits == 1) {  // bool
+        return T(reinterpret_cast<uint8_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 8) {
+        return T(reinterpret_cast<uint8_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 16) {
+        return T(reinterpret_cast<uint16_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 32) {
+        return T(reinterpret_cast<uint32_t*>(array->data)[i]);
+      } else if (array->dtype.bits == 64) {
+        return T(reinterpret_cast<uint64_t*>(array->data)[i]);
+      }
+    } else if (array->dtype.code == kDLFloat) {
+      if (array->dtype.bits == 32) {
+        return T(reinterpret_cast<float*>(array->data)[i]);
+      } else if (array->dtype.bits == 64) {
+        return T(reinterpret_cast<double*>(array->data)[i]);
+      }
+    }
+    LOG(FATAL) << "Failed to get scalar from array " << array;
+  }
+
+  /*!
+   * \brief Get the scalar value of relax constant.
+   * \return The scalar value.
+   */
+  template <typename T>
+  TVM_DLL static const T GetScalar(const relax::Constant& constant, size_t i = 
0) {
+    return GetScalar<T>(constant->data, i);
+  }
+
+  /*!
+   * \brief Get the scalar value of relay constant.
+   * \return The scalar value.
+   */
+  template <typename T>
+  TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 
0) {
+    return GetScalar<T>(constant->data, i);
+  }
+};
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+#endif  // TVM_CONTRIB_MSC_CORE_UTILS_H_
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 23091fffd2..3f028ba656 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -342,6 +342,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
       {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR},
       {"USE_UMA", TVM_INFO_USE_UMA},
       {"USE_VERILATOR", TVM_INFO_USE_VERILATOR},
+      {"USE_MSC", TVM_INFO_USE_MSC},
       {"USE_CCACHE", TVM_INFO_USE_CCACHE},
       {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT},
   };
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py 
b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
new file mode 100644
index 0000000000..4717437d76
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm.relay import testing
+from tvm.relay.expr_functor import ExprVisitor
+from tvm.relay.build_module import bind_params_by_name
+
+from tvm.relax.frontend.torch import from_fx
+from tvm.relax import PyExprVisitor
+
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import transform as msc_transform
+
+
+class RelaxChecker(PyExprVisitor):
+    """Check if name as span attribute is setted."""
+
+    def check(self, expr):
+        self._missing_exprs = []
+        if isinstance(expr, tvm.relax.Expr):
+            self.visit_expr(expr)
+        elif isinstance(expr, tvm.relax.BindingBlock):
+            self.visit_binding_block(expr)
+        assert len(self._missing_exprs) == 0, "Missing {} 
layouts".format(len(self._missing_exprs))
+
+    def visit_var_binding_(self, binding) -> None:
+        super().visit_var_binding_(binding)
+        layout = _ffi_api.SpanGetAttr(binding.value.span, "layout")
+        if not layout:
+            self._missing_exprs.append(binding.value)
+
+    def visit_constant_(self, op) -> None:
+        super().visit_constant_(op)
+        layout = _ffi_api.SpanGetAttr(op.span, "layout")
+        if not layout:
+            self._missing_exprs.append(op)
+
+
+def test_relax():
+    try:
+        import torch
+        import torchvision
+        from torch import fx
+    except:
+        print("please install pytorch python package")
+        return
+
+    torch_model = torchvision.models.resnet50()
+    graph_model = fx.symbolic_trace(torch_model)
+    input_info = [([1, 3, 224, 224], "float32")]
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+    mod = msc_transform.SetExprLayout()(mod)
+    RelaxChecker().check(mod)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py 
b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
new file mode 100644
index 0000000000..0c174ff7bd
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm.relay import testing
+from tvm.relay.expr_functor import ExprVisitor
+from tvm.relay.build_module import bind_params_by_name
+
+from tvm.relax.frontend.torch import from_fx
+from tvm.relax import PyExprVisitor
+
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import transform as msc_transform
+
+
+class RelayChecker(ExprVisitor):
+    """Check if name as span attribute is setted."""
+
+    def check(self, expr):
+        self._missing_exprs = []
+        super().visit(expr)
+        assert len(self._missing_exprs) == 0, "Missing {} 
names".format(len(self._missing_exprs))
+
+    def visit_constant(self, expr):
+        super().visit_constant(expr)
+        name = _ffi_api.SpanGetAttr(expr.span, "name")
+        if not name:
+            self._missing_exprs.append(expr)
+
+    def visit_call(self, expr):
+        super().visit_call(expr)
+        name = _ffi_api.SpanGetAttr(expr.span, "name")
+        if not name:
+            self._missing_exprs.append(expr)
+
+
+class RelaxChecker(PyExprVisitor):
+    """Check if name as span attribute is setted."""
+
+    def check(self, expr):
+        self._missing_exprs = []
+        if isinstance(expr, tvm.relax.Expr):
+            self.visit_expr(expr)
+        elif isinstance(expr, tvm.relax.BindingBlock):
+            self.visit_binding_block(expr)
+        assert len(self._missing_exprs) == 0, "Missing {} 
names".format(len(self._missing_exprs))
+
+    def visit_var_binding_(self, binding) -> None:
+        super().visit_var_binding_(binding)
+        name = _ffi_api.SpanGetAttr(binding.value.span, "name")
+        if not name:
+            self._missing_exprs.append(binding.value)
+
+    def visit_constant_(self, op) -> None:
+        super().visit_constant_(op)
+        name = _ffi_api.SpanGetAttr(op.span, "name")
+        if not name:
+            self._missing_exprs.append(op)
+
+
+def test_relay():
+    mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, 
dtype="float32")
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = msc_transform.SetExprName(as_relax=False)(mod)
+    RelayChecker().check(mod["main"])
+
+
+def test_relax():
+    try:
+        import torch
+        import torchvision
+        from torch import fx
+    except:
+        print("please install pytorch python package")
+        return
+
+    torch_model = torchvision.models.resnet50()
+    graph_model = fx.symbolic_trace(torch_model)
+    input_info = [([1, 3, 224, 224], "float32")]
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+    mod = msc_transform.SetExprName()(mod)
+    RelaxChecker().check(mod)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/scripts/task_config_build_cpu.sh 
b/tests/scripts/task_config_build_cpu.sh
index 9eda0d74d4..0d6c0e2cae 100755
--- a/tests/scripts/task_config_build_cpu.sh
+++ b/tests/scripts/task_config_build_cpu.sh
@@ -57,3 +57,4 @@ echo set\(USE_CCACHE OFF\) >> config.cmake
 echo set\(USE_ETHOSU OFF\) >> config.cmake
 echo set\(USE_UMA ON\) >> config.cmake
 echo set\(SUMMARIZE ON\) >> config.cmake
+echo set\(USE_MSC ON\) >> config.cmake
diff --git a/tests/scripts/task_config_build_gpu.sh 
b/tests/scripts/task_config_build_gpu.sh
index 8929ae5041..37ab0a87f1 100755
--- a/tests/scripts/task_config_build_gpu.sh
+++ b/tests/scripts/task_config_build_gpu.sh
@@ -53,3 +53,4 @@ echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
 echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake
 echo set\(USE_CUTLASS ON\) >> config.cmake
 echo set\(USE_CMSISNN ON\) >> config.cmake
+echo set\(USE_MSC ON\) >> config.cmake
diff --git a/tests/scripts/unity/task_python_relax.sh 
b/tests/scripts/unity/task_python_relax.sh
index b6b70ab457..121ba1389a 100755
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/tests/scripts/unity/task_python_relax.sh
@@ -36,3 +36,6 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest 
tests/python/dlight
 # python3 ./apps/relax_examples/mlp.py
 # python3 ./apps/relax_examples/nn_module.py
 # python3 ./apps/relax_examples/resnet.py
+
+# Test for MSC
+pytest tests/python/contrib/test_msc


Reply via email to