This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new aa21782828 [Unity][MSC][M0.1] Enable set name and layout for exprs
(#15500)
aa21782828 is described below
commit aa2178282811d34df3ca3d9ee58580396fd0358d
Author: Archermmt <[email protected]>
AuthorDate: Mon Aug 14 19:39:53 2023 +0800
[Unity][MSC][M0.1] Enable set name and layout for exprs (#15500)
* add test for set expr name
* roll back to M0.1
* add annotation
* add annotation
* change test to unity
* remove msg
* minor fix
* move test to task_python_relax
---
CMakeLists.txt | 2 +
cmake/config.cmake | 3 +
cmake/modules/LibInfo.cmake | 1 +
.../modules/contrib/MSC.cmake | 26 +-
.../tvm/contrib/msc/__init__.py | 23 +-
.../tvm/contrib/msc/core/__init__.py | 23 +-
.../tvm/contrib/msc/core/_ffi_api.py | 23 +-
.../tvm/contrib/msc/core/transform/__init__.py | 24 +-
python/tvm/contrib/msc/core/transform/pattern.py | 626 ++++++++++
python/tvm/contrib/msc/core/transform/transform.py | 61 +
.../tvm/contrib/msc/core/utils/__init__.py | 23 +-
python/tvm/contrib/msc/core/utils/expr.py | 105 ++
src/contrib/msc/core/transform/layout_utils.cc | 190 +++
src/contrib/msc/core/transform/layout_utils.h | 110 ++
src/contrib/msc/core/transform/set_expr_layout.cc | 1215 ++++++++++++++++++++
src/contrib/msc/core/transform/set_expr_name.cc | 348 ++++++
src/contrib/msc/core/utils.cc | 314 +++++
src/contrib/msc/core/utils.h | 270 +++++
src/support/libinfo.cc | 1 +
.../test_msc/test_transform_set_expr_layout.py | 73 ++
.../test_msc/test_transform_set_expr_name.py | 101 ++
tests/scripts/task_config_build_cpu.sh | 1 +
tests/scripts/task_config_build_gpu.sh | 1 +
tests/scripts/unity/task_python_relax.sh | 3 +
24 files changed, 3442 insertions(+), 125 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 47d57d56bd..f7c34fa22b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -121,6 +121,7 @@ tvm_option(USE_CLML "Build with CLML Codegen support" OFF)
tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF)
tvm_option(USE_UMA "Build with UMA support" OFF)
tvm_option(USE_VERILATOR "Build with Verilator support" OFF)
+tvm_option(USE_MSC "Enable Multi-System Compiler" OFF)
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
@@ -545,6 +546,7 @@ include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/VitisAI.cmake)
include(cmake/modules/contrib/Verilator.cmake)
include(cmake/modules/contrib/UMA.cmake)
+include(cmake/modules/contrib/MSC.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/RustExt.cmake)
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 8a7a0f1fdd..4990e52d63 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -281,6 +281,9 @@ set(USE_VITIS_AI OFF)
# Build Verilator codegen and runtime
set(USE_VERILATOR OFF)
+# Whether to use the Multi-System Compiler
+set(USE_MSC OFF)
+
#Whether to use CLML codegen
set(USE_CLML OFF)
# USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index 17a56ac439..9e1f71c729 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -125,6 +125,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}"
TVM_INFO_USE_UMA="${USE_UMA}"
TVM_INFO_USE_VERILATOR="${USE_VERILATOR}"
+ TVM_INFO_USE_MSC="${USE_MSC}"
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
)
diff --git a/tests/scripts/unity/task_python_relax.sh
b/cmake/modules/contrib/MSC.cmake
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to cmake/modules/contrib/MSC.cmake
index b6b70ab457..45ce776a08
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/cmake/modules/contrib/MSC.cmake
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,23 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-set -euxo pipefail
+if(USE_MSC)
+ tvm_file_glob(GLOB_RECURSE MSC_CORE_SOURCE "src/contrib/msc/*.cc")
+ list(APPEND COMPILER_SRCS ${MSC_CORE_SOURCE})
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
+ tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE
"src/runtime/contrib/msc/*.cc")
+ list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE})
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+ message(STATUS "Build with MSC support...")
+endif()
diff --git a/tests/scripts/unity/task_python_relax.sh
b/python/tvm/contrib/msc/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/__init__.py
index b6b70ab457..a2813b4a2d
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,24 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+"""tvm.contrib.msc"""
diff --git a/tests/scripts/unity/task_python_relax.sh
b/python/tvm/contrib/msc/core/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/__init__.py
index b6b70ab457..6d1a7c68c8
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,24 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+"""tvm.contrib.msc.core"""
diff --git a/tests/scripts/unity/task_python_relax.sh
b/python/tvm/contrib/msc/core/_ffi_api.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/_ffi_api.py
index b6b70ab457..c0b0e21267
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/_ffi_api.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,24 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""tvm.contrib.msc.core._ffi_api"""
-set -euxo pipefail
+import tvm._ffi
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+tvm._ffi._init_api("msc.core", __name__)
diff --git a/tests/scripts/unity/task_python_relax.sh
b/python/tvm/contrib/msc/core/transform/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/transform/__init__.py
index b6b70ab457..ec74597803
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/transform/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,24 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""tvm.contrib.msc.core.transform"""
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+from .pattern import *
+from .transform import *
diff --git a/python/tvm/contrib/msc/core/transform/pattern.py
b/python/tvm/contrib/msc/core/transform/pattern.py
new file mode 100644
index 0000000000..76e9651c60
--- /dev/null
+++ b/python/tvm/contrib/msc/core/transform/pattern.py
@@ -0,0 +1,626 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.transform.pattern"""
+
+from typing import Mapping, Tuple
+
+import tvm
+from tvm.relax.dpl import pattern as relax_pattern
+from tvm.relay import dataflow_pattern as relay_pattern
+
+from tvm.relax.transform import PatternCheckContext
+from tvm.relax.backend.pattern_registry import register_patterns
+from tvm.relay.op.contrib.register import register_pattern_table
+
+
+def make_relax_conv_bias_pattern(
+ op_name: str,
+) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]:
+ """A simple utility to create patterns for an conv fused with bias.
+
+ Parameters
+ ----------
+ op_name: str
+ The name of a Relax op, such as "relax.nn.conv2d"
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a conv_bias operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ data = relax_pattern.wildcard()
+ weight = relax_pattern.is_const()
+ conv = relax_pattern.is_op(op_name)(data, weight)
+ bias = relax_pattern.is_const()
+ shape = relax_pattern.wildcard()
+ reshape = relax_pattern.is_op("relax.reshape")(bias, shape)
+ out = relax_pattern.is_op("relax.add")(conv, reshape)
+ annotations = {"bias": bias, "reshape": reshape}
+ return out, annotations
+
+
+def _check_relax_conv_bias(context: PatternCheckContext) -> bool:
+ """Check if conv_bias fuse pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ bias = context.annotated_expr["bias"]
+ reshape = context.annotated_expr["reshape"]
+ non_one_dims = len([i for i in reshape.struct_info.shape.values if i > 1])
+ return non_one_dims <= 1 and bias.struct_info.ndim == 1
+
+
+def make_relax_linear_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for linear.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a linear operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ data = relax_pattern.wildcard()
+ weight = relax_pattern.is_const()
+ permute = relax_pattern.is_op("relax.permute_dims")(weight)
+ out = relax_pattern.is_op("relax.matmul")(data, permute)
+ annotations = {"weight": weight, "permute": permute}
+ return out, annotations
+
+
+def _check_relax_linear(context: PatternCheckContext) -> bool:
+ """Check if linear pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ weight = context.annotated_expr["weight"]
+ permute = context.annotated_expr["permute"]
+ return weight.struct_info.ndim == 2 and not permute.attrs["axes"]
+
+
+def make_relax_linear_bias_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for linear with bias.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a linear_bias operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+
+ """
+
+ linear, annotations = make_relax_linear_pattern()
+ bias = relax_pattern.is_const()
+ out = relax_pattern.is_op("relax.add")(linear, bias)
+ annotations.update({"bias": bias, "out": out})
+ return out, annotations
+
+
+def _check_relax_linear_bias(context: PatternCheckContext) -> bool:
+ """Check if linear_bias pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ if not _check_relax_linear(context):
+ return False
+ bias = context.annotated_expr["bias"]
+ return bias.struct_info.ndim == 1
+
+
+def make_relax_embedding_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for embedding.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a embedding operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ weight = relax_pattern.is_const()
+ data = relax_pattern.wildcard()
+ astype = relax_pattern.is_op("relax.astype")(data)
+ out = relax_pattern.is_op("relax.take")(weight, astype)
+ annotations = {"weight": weight, "astype": astype}
+ return out, annotations
+
+
+def _check_relax_embedding(context: PatternCheckContext) -> bool:
+ """Check if 1d embedding pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ weight = context.annotated_expr["weight"]
+ astype = context.annotated_expr["astype"]
+ return (
+ astype.attrs["dtype"] == "int32"
+ and weight.struct_info.ndim == 2
+ and weight.struct_info.dtype == "float32"
+ )
+
+
+def make_relax_reshape_embedding_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for reshaped embedding.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a reshape_embedding operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ weight = relax_pattern.is_const()
+ data = relax_pattern.wildcard()
+ astype = relax_pattern.is_op("relax.astype")(data)
+ reduce_shape = relax_pattern.wildcard()
+ reduce_in = relax_pattern.is_op("relax.reshape")(astype, reduce_shape)
+ take = relax_pattern.is_op("relax.take")(weight, reduce_in)
+ expand_shape = relax_pattern.wildcard()
+ out = relax_pattern.is_op("relax.reshape")(take, expand_shape)
+ annotations = {"weight": weight, "astype": astype, "reduce_in": reduce_in}
+ return out, annotations
+
+
+def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool:
+ """Check if reshape embedding pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ weight = context.annotated_expr["weight"]
+ if weight.struct_info.ndim != 2 or weight.struct_info.dtype != "float32":
+ return False
+ astype = context.annotated_expr["astype"]
+ reduce_in = context.annotated_expr["reduce_in"]
+ if astype.attrs["dtype"] != "int32" or reduce_in.struct_info.ndim != 1:
+ return False
+ return True
+
+
+def make_relax_attention_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for attention.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a attention operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ weight_q = relax_pattern.wildcard()
+ weight_k = relax_pattern.wildcard()
+ weight_v = relax_pattern.wildcard()
+ q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
+ k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
+ v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
+ out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
+ annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans}
+ return out, annotations
+
+
+def _check_relax_attention(context: PatternCheckContext) -> bool:
+ """Check if attention pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ return True
+
+
+def make_relax_mask_attention_pattern() -> Tuple[
+ relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]
+]:
+ """A simple utility to create patterns for mask_attention.
+
+ Returns
+ -------
+ out: tvm.relax.dpl.pattern.DFPattern
+ The resulting pattern describing a mask_attention operation.
+
+ annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
+ """
+
+ weight_q = relax_pattern.wildcard()
+ weight_k = relax_pattern.wildcard()
+ weight_v = relax_pattern.wildcard()
+ mask = relax_pattern.wildcard()
+ q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
+ k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
+ v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
+ out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans,
v_trans, mask)
+ annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans}
+ return out, annotations
+
+
+def _check_relax_mask_attention(context: PatternCheckContext) -> bool:
+ """Check if mask_attention pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ return True
+
+
+# TODO(tong.meng): support patterns after optimize
+register_patterns(
+ [
+ (
+ "msc.conv1d_bias",
+ *make_relax_conv_bias_pattern(
+ "relax.nn.conv1d",
+ ),
+ _check_relax_conv_bias,
+ ),
+ (
+ "msc.conv2d_bias",
+ *make_relax_conv_bias_pattern(
+ "relax.nn.conv2d",
+ ),
+ _check_relax_conv_bias,
+ ),
+ (
+ "msc.linear",
+ *make_relax_linear_pattern(),
+ _check_relax_linear,
+ ),
+ (
+ "msc.linear_bias",
+ *make_relax_linear_bias_pattern(),
+ _check_relax_linear_bias,
+ ),
+ (
+ "msc.embedding",
+ *make_relax_embedding_pattern(),
+ _check_relax_embedding,
+ ),
+ (
+ "msc.embedding",
+ *make_relax_reshape_embedding_pattern(),
+ _check_relax_reshape_embedding,
+ ),
+ (
+ "msc.attention",
+ *make_relax_attention_pattern(),
+ _check_relax_attention,
+ ),
+ (
+ "msc.attention",
+ *make_relax_mask_attention_pattern(),
+ _check_relax_mask_attention,
+ ),
+ ]
+)
+
+
+# TODO(tong.meng): support patterns after optimize
+@register_pattern_table("msc")
+def pattern_table():
+ """Returns list of triples describing the name, dataflow pattern and
predicate for all
+ the MSC-supported operators."""
+
+ def make_relay_conv_bias_pattern(
+ op_name: str, optimized: bool = False
+ ) -> relay_pattern.DFPattern:
+ """A simple utility to create patterns for an operation fused with
bias.
+
+ Parameters
+ ----------
+ op_name: str
+ The name of a Relay op, such as "relay.nn.conv2d"
+ optimized: bool
+ Whether the relay is optimized
+
+ Returns
+ -------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern describing a conv_bias operation
+ """
+
+ data = relay_pattern.wildcard()
+ weight = relay_pattern.is_constant()
+ bias = relay_pattern.is_constant()
+ conv = relay_pattern.is_op(op_name)(data, weight)
+ if optimized:
+ out = relay_pattern.is_op("add")(conv, bias)
+ else:
+ out = relay_pattern.is_op("nn.bias_add")(conv, bias)
+ return out
+
+ def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool:
+ """Check if conv_bias fuse pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ if call.op.name == "nn.bias_add":
+ bias = call.args[1]
+ return len(bias.checked_type.shape) == 1
+ if call.op.name == "add":
+ return True
+ return False
+
+ def make_relay_linear_pattern(optimized: bool = False) ->
relay_pattern.DFPattern:
+ """A simple utility to create patterns for linear.
+
+ Parameters
+ ----------
+ optimized: bool
+ Whether the relay is optimized
+
+ Returns
+ -------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern describing a linear operation
+ """
+
+ if optimized:
+ data = relay_pattern.wildcard()
+ weight = relay_pattern.is_constant()
+ broadcast_data = relay_pattern.is_op("broadcast_to")(data)
+ reshape_data = relay_pattern.is_op("reshape")(broadcast_data)
+ batch_matmul =
relay_pattern.is_op("nn.batch_matmul")(reshape_data, weight)
+ reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+ return relay_pattern.is_op("squeeze")(reshape_out)
+ data = relay_pattern.wildcard()
+ weight = relay_pattern.is_constant()
+ trans_weight = relay_pattern.is_op("transpose")(weight)
+ broadcast_data = relay_pattern.is_op("broadcast_to")(data)
+ broadcast_weight = relay_pattern.is_op("broadcast_to")(trans_weight)
+ reshape_data = relay_pattern.is_op("reshape")(broadcast_data)
+ reshape_weight = relay_pattern.is_op("reshape")(broadcast_weight)
+ batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data,
reshape_weight)
+ reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+ return relay_pattern.is_op("squeeze")(reshape_out)
+
+ def _check_relay_linear(call: tvm.relay.Expr) -> bool:
+ """Check if linear pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ return True
+
+ def make_relay_linear_bias_pattern(optimized: bool = False) ->
relay_pattern.DFPattern:
+ """A simple utility to create patterns for linear_bias.
+
+ Parameters
+ ----------
+ optimized: bool
+ Whether the relay is optimized
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing a linear_bias operation
+ """
+
+ bias = relay_pattern.is_constant()
+ linear = make_relay_linear_pattern(optimized)
+ if optimized:
+ out = relay_pattern.is_op("add")(linear, bias)
+ else:
+ out = relay_pattern.is_op("nn.bias_add")(linear, bias)
+ return out
+
+ def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool:
+ """Check if linear_bias pattern is correct."""
+ return True
+
+ def make_relay_matmul_pattern(dim: int = 2, optimized: bool = False) ->
relay_pattern.DFPattern:
+ """A simple utility to create patterns for matmul.
+
+ Parameters
+ ----------
+ optimized: bool
+ Whether the relay is optimized
+
+ Returns
+ -------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern describing a matmul operation
+ """
+
+ if dim == 2:
+ a = relay_pattern.wildcard()
+ b = relay_pattern.wildcard()
+ trans_b = relay_pattern.is_op("transpose")(b)
+ dense = relay_pattern.is_op("nn.dense")(a, trans_b)
+ return dense | relay_pattern.is_op("squeeze")(dense)
+ elif dim == 3:
+ a = relay_pattern.wildcard()
+ b = relay_pattern.wildcard()
+ broadcast_a = relay_pattern.is_op("broadcast_to")(a)
+ broadcast_b = relay_pattern.is_op("broadcast_to")(b)
+ reshape_a = relay_pattern.is_op("reshape")(broadcast_a)
+ reshape_b = relay_pattern.is_op("reshape")(broadcast_b)
+ batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_a,
reshape_b)
+ reshape_out = relay_pattern.is_op("reshape")(batch_matmul)
+ return relay_pattern.is_op("squeeze")(reshape_out)
+ else:
+ raise Exception("matmul pattern only support dim 2 and 3")
+
+ def _check_relay_matmul(call: tvm.relay.Expr) -> bool:
+ """Check if matmul pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ last_call = call.args[0] if call.op.name == "squeeze" else call
+ if last_call.op.name == "nn.dense":
+ trans_b = last_call.args[1]
+ b = trans_b.args[0]
+ if len(b.checked_type.shape) != 2:
+ return False
+ return trans_b.attrs["axes"] is None or
list(trans_b.attrs["axes"]) == [1, 0]
+ return True
+
+ def make_relay_embedding_pattern(optimized: bool = False) ->
relay_pattern.DFPattern:
+ """A simple utility to create patterns for 1d embedding.
+
+ Returns
+ -------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern describing a embedding operation
+ """
+
+ weight = relay_pattern.is_constant()
+ data = relay_pattern.wildcard()
+ astype = relay_pattern.is_op("cast")(data)
+ return relay_pattern.is_op("take")(weight, astype)
+
+ def _check_relay_embedding(call: tvm.relay.Expr) -> bool:
+ """Check if embedding pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ weight = call.args[0]
+ cast = call.args[1]
+ return (
+ cast.attrs["dtype"] == "int32"
+ and len(weight.checked_type.shape) == 2
+ and weight.checked_type.dtype == "float32"
+ )
+
+ def make_relay_gelu_pattern(optimized: bool = False) ->
relay_pattern.DFPattern:
+ """A simple utility to create patterns for gelu.
+
+ Returns
+ -------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern describing a gelu operation.
+ """
+
+ data = relay_pattern.wildcard()
+ factor_1 = relay_pattern.is_constant()
+ mul_1 = relay_pattern.is_op("multiply")(data, factor_1)
+ erf = relay_pattern.is_op("erf")(mul_1)
+ factor_2 = relay_pattern.is_constant()
+ mul_2 = relay_pattern.is_op("multiply")(erf, factor_2)
+ factor_3 = relay_pattern.is_constant()
+ add = relay_pattern.is_op("add")(factor_3, mul_2)
+ return relay_pattern.is_op("multiply")(data, add)
+
+ def _check_relay_gelu(call: tvm.relay.Expr) -> bool:
+ """Check if gelu pattern is correct.
+
+ Returns
+ -------
+ pass: bool
+ Whether the pattern is correct.
+ """
+
+ return True
+
+ return [
+ ("msc.conv1d_bias", make_relay_conv_bias_pattern("nn.conv1d"),
_check_relay_conv_bias),
+ (
+ "msc.conv1d_bias",
+ make_relay_conv_bias_pattern("nn.conv1d", True),
+ _check_relay_conv_bias,
+ ),
+ ("msc.conv2d_bias", make_relay_conv_bias_pattern("nn.conv2d"),
_check_relay_conv_bias),
+ (
+ "msc.conv2d_bias",
+ make_relay_conv_bias_pattern("nn.conv2d", True),
+ _check_relay_conv_bias,
+ ),
+ ("msc.linear_bias", make_relay_linear_bias_pattern(),
_check_relay_linear_bias),
+ ("msc.linear", make_relay_linear_pattern(), _check_relay_linear),
+ ("msc.linear", make_relay_linear_pattern(True), _check_relay_linear),
+ ("msc.matmul", make_relay_matmul_pattern(dim=2), _check_relay_matmul),
+ ("msc.matmul", make_relay_matmul_pattern(dim=3), _check_relay_matmul),
+ ("msc.embedding", make_relay_embedding_pattern(),
_check_relay_embedding),
+ ("msc.gelu", make_relay_gelu_pattern(), _check_relay_gelu),
+ ]
diff --git a/python/tvm/contrib/msc/core/transform/transform.py
b/python/tvm/contrib/msc/core/transform/transform.py
new file mode 100644
index 0000000000..355922d6de
--- /dev/null
+++ b/python/tvm/contrib/msc/core/transform/transform.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""tvm.contrib.msc.core.transform.transform"""
+
+import tvm
+from tvm.relax.transform import _ffi_api as relax_api
+from tvm.relay.transform import _ffi_api as relay_api
+
+
+def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass:
+ """Set name for the call and constant in IRModule.
+
+ Parameters
+ ----------
+ as_relax: bool
+ Whether set names for relax, otherwise for relay.
+ entry_name: str
+ The entry name
+
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+
+ if as_relax:
+ return relax_api.SetRelaxExprName(entry_name) # type: ignore
+ return relay_api.SetRelayExprName(entry_name) # type: ignore
+
+
+def SetExprLayout(allow_missing=True, entry_name="main") ->
tvm.ir.transform.Pass:
+ """Set layout for the var and constant in IRModule.
+
+ Parameters
+ ----------
+ allow_missing: bool
+ Whether allow missing layouts.
+ entry_name: str
+ The entry name
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+
+ return relax_api.SetExprLayout(allow_missing, entry_name) # type: ignore
diff --git a/tests/scripts/unity/task_python_relax.sh
b/python/tvm/contrib/msc/core/utils/__init__.py
old mode 100755
new mode 100644
similarity index 55%
copy from tests/scripts/unity/task_python_relax.sh
copy to python/tvm/contrib/msc/core/utils/__init__.py
index b6b70ab457..65f9e1b326
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/python/tvm/contrib/msc/core/utils/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,24 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""tvm.contrib.msc.core.utils"""
-set -euxo pipefail
-
-source tests/scripts/setup-pytest-env.sh
-export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
-export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
-
-# to avoid CI CPU thread throttling.
-export TVM_BIND_THREADS=0
-export TVM_NUM_THREADS=2
-
-make cython3
-
-# Run Relax tests
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax
-TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight
-
-# Run Relax examples
-# python3 ./apps/relax_examples/mlp.py
-# python3 ./apps/relax_examples/nn_module.py
-# python3 ./apps/relax_examples/resnet.py
+from .expr import *
diff --git a/python/tvm/contrib/msc/core/utils/expr.py
b/python/tvm/contrib/msc/core/utils/expr.py
new file mode 100644
index 0000000000..ad459e7832
--- /dev/null
+++ b/python/tvm/contrib/msc/core/utils/expr.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.utils.expr"""
+
+import tvm
+from tvm import relax
+from tvm.relax import PyExprVisitor
+from tvm.contrib.msc.core import _ffi_api
+
+
+def get_span_attrs(mod: tvm.IRModule) -> dict:
+ """Extract the span attributes from relax.Function.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+
+ Returns
+ -------
+ attrs: dict
+ """
+
+ @relax.expr_functor.visitor
+ class SpanVisitor(PyExprVisitor):
+ """Visitor for get attributes in span"""
+
+ def extract(self, expr: relax.Expr) -> dict:
+ self._span_info = {}
+ if isinstance(expr, relax.Expr):
+ self.visit_expr(expr)
+ elif isinstance(expr, relax.BindingBlock):
+ self.visit_binding_block(expr)
+ return self._span_info
+
+ def _update_attrs(self, expr: relax.Expr, name: str = "") -> None:
+ if not expr.span:
+ return
+ name = name or _ffi_api.SpanGetAttr(expr.span, "name")
+ if not name:
+ return
+ self._span_info[name] = _ffi_api.SpanGetAttrs(expr.span)
+
+ def visit_var_binding_(self, binding: relax.VarBinding) -> None:
+ super().visit_var_binding_(binding)
+ self._update_attrs(binding.value, binding.var.name_hint)
+
+ def visit_constant_(self, op: relax.Constant) -> None:
+ super().visit_constant_(op)
+ self._update_attrs(op)
+
+ def visit_var_(self, op: relax.Var) -> None:
+ super().visit_var_(op)
+ self._update_attrs(op, op.name_hint)
+
+ return {v.name_hint: SpanVisitor().extract(mod[v]) for v in mod.functions}
+
+
+def msc_script(mod: tvm.IRModule, script: str = "") -> str:
+ """Add span attrs after lines.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+ script: string
+ The script to be replaced
+
+ Returns
+ -------
+ script: string
+ The replaced script
+ """
+
+ script = script or str(mod)
+ attrs = get_span_attrs(mod)
+ cur_attr, lines = {}, []
+ for line in script.split("\n"):
+ if line.strip().startswith("def "):
+ func_name = line.strip().split("def ")[1].split("(")[0]
+ cur_attr = attrs.get(func_name, {})
+ if ": " in line:
+ v_name = line.strip().split(": ")[0]
+ if v_name in cur_attr:
+ line += (
+ " # "
+ + ", ".join(["{}={}".format(k, v) for k, v in
cur_attr[v_name].items()])
+ + " #"
+ )
+ lines.append(line)
+ return "\n".join(lines)
diff --git a/src/contrib/msc/core/transform/layout_utils.cc
b/src/contrib/msc/core/transform/layout_utils.cc
new file mode 100644
index 0000000000..ffc631c6d0
--- /dev/null
+++ b/src/contrib/msc/core/transform/layout_utils.cc
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/layout_utils.cc
+ */
+#include "layout_utils.h"
+
+#include <set>
+#include <string>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+bool LayoutUtils::LayoutInfered(const Expr& expr) {
+ const String& layout = SpanUtils::GetAttr(expr->span, "layout");
+ return layout.size() > 0;
+}
+
+bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) {
+ const String& saved_layout = SpanUtils::GetAttr(expr->span, "layout");
+ const auto& sinfo = GetStructInfo(expr);
+ if (sinfo.as<TensorStructInfoNode>() || sinfo.as<ShapeStructInfoNode>()) {
+ ICHECK(layout.IsLeaf()) << "Expr has tensor struct, but find nested layout
" << expr;
+ const auto& l_layout = layout.LeafValue()->layout;
+ if (!l_layout.defined()) {
+ return false;
+ }
+ if (saved_layout == l_layout.name()) {
+ return false;
+ }
+ expr->span = SpanUtils::SetAttr(expr->span, "layout", l_layout.name());
+ } else if (sinfo.as<TupleStructInfoNode>()) {
+ ICHECK(!layout.IsLeaf()) << "Expr has tupple struct, but find non-nested
layout " << expr;
+ String layout_str;
+ Array<NLayout> nested_layouts = layout.NestedArray();
+ for (size_t i = 0; i < nested_layouts.size(); i++) {
+ ICHECK(nested_layouts[i].IsLeaf())
+ << "Expr input[" << i << "] has tensor struct, but find nested
layout " << expr;
+ const auto& l_layout = nested_layouts[i].LeafValue()->layout;
+ if (!l_layout.defined()) {
+ return false;
+ }
+ layout_str = layout_str + l_layout.name() + (i < nested_layouts.size() -
1 ? "," : "");
+ }
+ if (saved_layout == layout_str) {
+ return false;
+ }
+ expr->span = SpanUtils::SetAttr(expr->span, "layout", layout_str);
+ }
+ return true;
+}
+
+const NLayout LayoutUtils::GetNLayout(const Expr& expr) {
+ if (!LayoutInfered(expr)) {
+ return LayoutDecision("");
+ }
+ auto sinfo = GetStructInfo(expr);
+ if (sinfo.as<TensorStructInfoNode>()) {
+ return LayoutDecision(SpanUtils::GetAttr(expr->span, "layout"));
+ }
+ if (sinfo.as<TupleStructInfoNode>()) {
+ String layout_str = SpanUtils::GetAttr(expr->span, "layout");
+ std::vector<NLayout> output_layout;
+ for (const auto& l : StringUtils::Split(layout_str, ",")) {
+ output_layout.push_back(LayoutDecision(l));
+ }
+ return NLayout(output_layout);
+ }
+ return LayoutDecision("");
+}
+
+const LayoutDecision LayoutUtils::GetLayoutDecision(const Expr& expr) {
+ NLayout nlayout = GetNLayout(expr);
+ ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr;
+ return nlayout.LeafValue();
+}
+
+bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) {
+ bool find = false;
+ auto fvisit = [&](const LayoutDecision& layout) {
+ find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim()));
+ };
+ ForEachLeaf<LayoutDecision>(nlayout, fvisit);
+ return find;
+}
+
+bool LayoutUtils::HasUnknownDimTensor(const Array<Expr>& args) {
+ for (const auto& arg : args) {
+ if (IsNestedTensor(arg)) {
+ if (HasUnknownDimTensor(GetNLayout(arg))) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
expand_axes) {
+ if (!src_layout->layout.defined()) {
+ return src_layout;
+ }
+ std::string new_layout = src_layout.name();
+ ICHECK_EQ(new_layout.size(), src_layout->layout.ndim())
+ << "Only support normal layout, get " << src_layout->layout;
+ std::vector<std::string> priority_dims{"N", "C", "H", "W", "D", "G", "T"};
+ size_t left_size = expand_axes.size();
+ for (const auto& a : expand_axes) {
+ std::string target = "U";
+ if (new_layout.find("H") && !new_layout.find("W")) {
+ target = "W";
+ } else if (new_layout.find("W") && !new_layout.find("H")) {
+ target = "H";
+ } else if (left_size == 1 && new_layout.find("C") &&
!new_layout.find("D")) {
+ target = "D";
+ } else if (left_size == 1 && new_layout.find("D") &&
!new_layout.find("C")) {
+ target = "C";
+ } else {
+ for (const auto& p : priority_dims) {
+ int pos = new_layout.find(p);
+ if (pos < 0) {
+ target = p;
+ break;
+ }
+ }
+ }
+ new_layout = new_layout.insert(a, target);
+ left_size--;
+ }
+ return LayoutDecision(new_layout);
+}
+
+const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
reduce_axes) {
+ if (!src_layout->layout.defined()) {
+ return src_layout;
+ }
+ std::set<size_t> reduce_axes_set;
+ for (const auto& a : reduce_axes) {
+ reduce_axes_set.insert(a);
+ }
+ std::string new_layout = "";
+ for (size_t i = 0; i < src_layout->layout.ndim(); i++) {
+ if (reduce_axes_set.count(i)) {
+ continue;
+ }
+ new_layout += src_layout->layout[i].name();
+ }
+ return LayoutDecision(new_layout);
+}
+
+const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision&
src_layout,
+ const Array<Integer>& axes) {
+ String layout_str;
+ for (const auto& a : axes) {
+ layout_str = layout_str + src_layout->layout[a->value].name();
+ }
+ return LayoutDecision(layout_str);
+}
+
+const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
axes) {
+ String layout_str;
+ for (const auto& a : axes) {
+ layout_str = layout_str + src_layout->layout[a].name();
+ }
+ return LayoutDecision(layout_str);
+}
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/msc/core/transform/layout_utils.h
b/src/contrib/msc/core/transform/layout_utils.h
new file mode 100644
index 0000000000..b9de832838
--- /dev/null
+++ b/src/contrib/msc/core/transform/layout_utils.h
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/layout_utils.h
+ * \brief Common utilities for layout.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+
+#include <vector>
+
+#include "../../../../relax/transform/infer_layout_utils.h"
+#include "../../../../relax/transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using namespace tvm::relax;
+
+/*!
+ * \brief Utils for Layout.
+ */
+class LayoutUtils {
+ public:
+ /*!
+ * \brief Check if the layout is infered.
+ * \return Whether the layout is infered.
+ */
+ TVM_DLL static bool LayoutInfered(const Expr& expr);
+
+ /*!
+ * \brief Set the layout to span
+ * \return Whether the layout is setted.
+ */
+ TVM_DLL static bool SetLayout(const Expr& expr, const NLayout& layout);
+
+ /*!
+ * \brief Get the layout from span
+ * \return The NLayout.
+ */
+ TVM_DLL static const NLayout GetNLayout(const Expr& expr);
+
+ /*!
+ * \brief Get the layout desion from span
+ * \return The LayoutDecision.
+ */
+ TVM_DLL static const LayoutDecision GetLayoutDecision(const Expr& expr);
+
+ /*!
+ * \brief Check if the layout has unknown dim tensor.
+ * \return Whether the layout has unknown dim tensor.
+ */
+ TVM_DLL static bool HasUnknownDimTensor(const NLayout& nlayout);
+
+ /*!
+ * \brief Check if the args has unknown dim tensor.
+ * \return Whether the args has unknown dim tensor.
+ */
+ TVM_DLL static bool HasUnknownDimTensor(const Array<Expr>& args);
+
+ /*!
+ * \brief Insert axes to the Layout
+ * \return The new layout.
+ */
+ TVM_DLL static const LayoutDecision ExpandLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
expand_axes);
+
+ /*!
+ * \brief Delete axes from the Layout
+ * \return The new layout.
+ */
+ TVM_DLL static const LayoutDecision ReduceLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
reduce_axes);
+ /*!
+ * \brief Permute axes from the Layout
+ * \return The new layout.
+ */
+ TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision&
src_layout,
+ const Array<Integer>&
axes);
+ TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision&
src_layout,
+ const std::vector<size_t>&
axes);
+};
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc
b/src/contrib/msc/core/transform/set_expr_layout.cc
new file mode 100644
index 0000000000..5915bef9e1
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -0,0 +1,1215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_expr_layout.cc
+ * \brief Pass for setting layout for expr and constant.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../utils.h"
+#include "layout_utils.h"
+
+namespace tvm {
+namespace relax {
+
+using namespace tvm::contrib::msc;
+
+NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) {
+ if (expr->IsInstance<VarNode>() &&
var_layout_map.count(Downcast<Var>(expr))) {
+ return GetNLayout(var_layout_map, expr);
+ }
+ return LayoutUtils::GetNLayout(expr);
+}
+
+LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap&
var_layout_map) {
+ const auto& nlayout = InferNLayout(expr, var_layout_map);
+ ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr;
+ return nlayout.LeafValue();
+}
+
+LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap&
var_layout_map,
+ size_t index = 0) {
+ const auto& nlayouts = InferNLayout(expr, var_layout_map);
+ const auto& nlayout = nlayouts.NestedArray()[0];
+ ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr;
+ return nlayout.LeafValue();
+}
+
+std::tuple<int64_t, int64_t> AccumulateMatch(const std::vector<int64_t>&
in_shape,
+ const std::vector<int64_t>&
out_shape, size_t in_start,
+ size_t out_start) {
+ // find input position in_pos and output position out_pos
+ // cumsum(in_shape[in_start:in_ops])==cumsum(out_shape[out_start:out_pos])
+ int64_t in_pos = -1;
+ int64_t out_pos = -1;
+ int64_t in_accumulate = 1;
+ int64_t out_accumulate = 1;
+ for (size_t i = in_start; i < in_shape.size(); i++) {
+ in_accumulate *= in_shape[i];
+ out_accumulate = 1;
+ for (size_t j = out_start; j < out_shape.size(); j++) {
+ out_accumulate *= out_shape[j];
+ if (in_accumulate == out_accumulate) {
+ in_pos = i;
+ out_pos = j;
+ break;
+ } else if (out_accumulate > in_accumulate) {
+ break;
+ }
+ }
+ if (in_pos >= 0) {
+ break;
+ }
+ }
+ // append tailed 1s
+ if (in_pos >= 0) {
+ int64_t in_size = static_cast<int64_t>(in_shape.size());
+ int64_t out_size = static_cast<int64_t>(out_shape.size());
+ while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) {
+ in_pos++;
+ }
+ while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) {
+ out_pos++;
+ }
+ }
+ return std::make_tuple(in_pos, out_pos);
+}
+
+std::vector<size_t> InferReduceAxes(const Array<PrimExpr>& input_shape,
+ const Array<PrimExpr>& output_shape) {
+ std::vector<size_t> reduce_axes, out_axes;
+ std::vector<int64_t> in_shape, out_shape;
+ for (const auto& s : input_shape) {
+ in_shape.push_back(Downcast<Integer>(s)->value);
+ }
+ for (const auto& s : output_shape) {
+ out_shape.push_back(Downcast<Integer>(s)->value);
+ }
+ size_t start = 0;
+ while (start < in_shape.size() && out_axes.size() < out_shape.size()) {
+ if (in_shape[start] == out_shape[out_axes.size()]) {
+ out_axes.push_back(start);
+ start++;
+ } else {
+ int64_t in_pos, out_pos;
+ size_t out_start = out_axes.size();
+ std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start,
out_start);
+ if (in_pos == -1) {
+ return std::vector<size_t>();
+ }
+ for (size_t i = out_start; i < static_cast<size_t>(out_pos) + 1; i++) {
+ out_axes.push_back(i + 1);
+ }
+ start = in_pos + 1;
+ }
+ }
+ if (out_axes.size() != out_shape.size()) {
+ return std::vector<size_t>();
+ }
+ std::set<size_t> out_axes_set;
+ for (const auto& a : out_axes) {
+ out_axes_set.insert(a);
+ }
+ for (size_t i = 0; i < in_shape.size(); i++) {
+ if (!out_axes_set.count(i)) {
+ reduce_axes.push_back(i);
+ }
+ }
+ return reduce_axes;
+}
+
+std::vector<size_t> InferExpandAxes(const Array<PrimExpr>& input_shape,
+ const Array<PrimExpr>& output_shape) {
+ std::vector<size_t> expand_axes;
+ std::vector<int64_t> in_shape, out_shape;
+ for (const auto& s : input_shape) {
+ in_shape.push_back(Downcast<Integer>(s)->value);
+ }
+ for (const auto& s : output_shape) {
+ out_shape.push_back(Downcast<Integer>(s)->value);
+ }
+ size_t start = 0;
+ while (start < in_shape.size() && expand_axes.size() + in_shape.size() <
out_shape.size()) {
+ if (in_shape[start] == out_shape[start + expand_axes.size()]) {
+ start++;
+ } else {
+ int64_t in_pos, out_pos;
+ size_t out_start = start + expand_axes.size();
+ std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start,
out_start);
+ if (in_pos == -1) {
+ return std::vector<size_t>();
+ }
+ size_t expand_size = out_pos - in_pos - expand_axes.size();
+ for (size_t i = 0; i < expand_size; i++) {
+ expand_axes.push_back(out_start + i);
+ }
+ start = in_pos + 1;
+ }
+ }
+ if (expand_axes.size() + in_shape.size() != out_shape.size()) {
+ return std::vector<size_t>();
+ }
+ return expand_axes;
+}
+
+// Forward and Backward infer
+InferLayoutOutput MSCInferLayoutConv(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ LayoutDecision data_layout, kernel_layout, out_layout;
+ const String& op_name = Downcast<Op>(call->op)->name;
+ if (op_name == "relax.nn.conv1d") {
+ const auto* attrs = call->attrs.as<Conv1DAttrs>();
+ data_layout = LayoutDecision(attrs->data_layout);
+ kernel_layout = LayoutDecision(attrs->kernel_layout);
+ out_layout = LayoutDecision(attrs->out_layout);
+ } else if (op_name == "relax.nn.conv2d") {
+ const auto* attrs = call->attrs.as<Conv2DAttrs>();
+ data_layout = LayoutDecision(attrs->data_layout);
+ kernel_layout = LayoutDecision(attrs->kernel_layout);
+ out_layout = LayoutDecision(attrs->out_layout);
+ }
+ return InferLayoutOutput({data_layout, kernel_layout}, {out_layout},
Attrs());
+}
+
+InferLayoutOutput MSCInferLayoutPool2d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ LayoutDecision layout, out_layout;
+ const String& op_name = Downcast<Op>(call->op)->name;
+ if (op_name == "relax.nn.adaptive_avg_pool2d") {
+ const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>();
+ layout = LayoutDecision(attrs->layout);
+ out_layout = LayoutDecision(attrs->out_layout);
+ } else {
+ const auto* attrs = call->attrs.as<Pool2DAttrs>();
+ layout = LayoutDecision(attrs->layout);
+ out_layout = LayoutDecision(attrs->out_layout);
+ }
+ return InferLayoutOutput({layout}, {out_layout}, Attrs());
+}
+
+// Forward Infer
+InferLayoutOutput ForwardInferLayoutCommon(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map)
{
+ Array<NLayout> input_layouts;
+ LayoutDecision layout_hint;
+ for (const auto& arg : call->args) {
+ const auto& in_layout = InferLayoutDecision(arg, var_layout_map);
+ if (in_layout->layout.defined()) {
+ layout_hint = in_layout;
+ }
+ input_layouts.push_back(in_layout);
+ }
+ if (!layout_hint.defined()) {
+ return InferLayoutOutput();
+ }
+ std::vector<NLayout> output_layouts;
+ const auto& sinfo = GetStructInfo(call);
+ if (sinfo->IsInstance<TensorStructInfoNode>()) {
+ output_layouts.push_back(layout_hint);
+ } else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
+ for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) {
+ output_layouts.push_back(layout_hint);
+ }
+ } else {
+ return InferLayoutOutput();
+ }
+ return InferLayoutOutput(input_layouts, {output_layouts}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutBinary(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map)
{
+ const auto& output = ForwardInferLayoutCommon(call, desired_layouts,
var_layout_map);
+ if (!output.defined()) {
+ return output;
+ }
+ std::vector<NLayout> input_layouts;
+ for (size_t i = 0; i < call->args.size(); i++) {
+ const auto& sinfo = GetStructInfo(call->args[i]);
+ if (const auto* t_info = sinfo.as<TensorStructInfoNode>()) {
+ if (t_info->ndim == 0) {
+ input_layouts.push_back(LayoutDecision(""));
+ } else {
+ input_layouts.push_back(output->input_layouts[i]);
+ }
+ } else {
+ LOG(FATAL) << "Binary input should be tensor, get " <<
sinfo->GetTypeKey();
+ }
+ }
+ return InferLayoutOutput(input_layouts, output->output_layouts, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutInplace(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map);
+}
+
+InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision in_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!in_layout->layout.defined()) {
+ if (input_shape.size() == 4) {
+ in_layout = LayoutDecision("NCHW");
+ } else if (input_shape.size() == 3) {
+ in_layout = LayoutDecision("NCD");
+ }
+ }
+ LayoutDecision g_layout = LayoutDecision("O");
+ return InferLayoutOutput({in_layout, g_layout, g_layout, g_layout, g_layout},
+ {{in_layout, g_layout, g_layout}}, Attrs());
+}
+
+InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+ std::vector<size_t> expand_axes;
+ for (const auto& s : attrs->axis) {
+ expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+ }
+ LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout,
expand_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutNormalize(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision in_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!in_layout->layout.defined()) {
+ if (input_shape.size() == 4) {
+ in_layout = LayoutDecision("NCHW");
+ } else if (input_shape.size() == 3) {
+ in_layout = LayoutDecision("NCD");
+ }
+ }
+ LayoutDecision g_layout = LayoutDecision("O");
+ return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout},
Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutMatmul(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map)
{
+ Array<PrimExpr> empty;
+ const auto& a_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ const auto& b_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[1]))->GetShape().value_or(empty);
+
+ if (a_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision a_layout = InferLayoutDecision(call->args[0], var_layout_map);
+ if (!a_layout->layout.defined()) {
+ if (a_shape.size() == 4) {
+ a_layout = LayoutDecision("NCHW");
+ } else if (a_shape.size() == 3) {
+ a_layout = LayoutDecision("NCD");
+ } else if (a_shape.size() == 2) {
+ a_layout = LayoutDecision("NC");
+ }
+ }
+ size_t start = a_layout->layout.ndim() - b_shape.size();
+ String pre_layout;
+ for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) {
+ pre_layout = pre_layout + a_layout->layout[i].name();
+ }
+ LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
+ return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutPermute(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ std::vector<size_t> permute_axes;
+ const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+ if (!attrs->axes.defined()) {
+ for (size_t i = input_layout->layout.ndim(); i > 0; i--) {
+ permute_axes.push_back(i - 1);
+ }
+ } else {
+ for (const auto& a : attrs->axes.value()) {
+ permute_axes.push_back(a->value);
+ }
+ }
+ LayoutDecision output_layout = LayoutUtils::PermuteLayout(input_layout,
permute_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<StatisticalAttrs>();
+ if (attrs->keepdims) {
+ return InferLayoutOutput({input_layout}, {input_layout}, Attrs());
+ }
+ if (!attrs->axis.defined()) {
+ return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs());
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ std::vector<size_t> axes;
+ for (const auto& s : attrs->axis.value()) {
+ axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+ }
+ LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutReshape(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ const auto& output_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call))->GetShape().value_or(empty);
+ if (input_shape.size() == 0 || output_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision output_layout;
+ if (input_shape.size() == output_shape.size()) {
+ output_layout = input_layout;
+ } else if (input_shape.size() > output_shape.size()) {
+ const auto& reduce_axes = InferReduceAxes(input_shape, output_shape);
+ if (reduce_axes.size() == 0) {
+ return InferLayoutOutput();
+ }
+ output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes);
+ } else {
+ const auto& expand_axes = InferExpandAxes(input_shape, output_shape);
+ if (expand_axes.size() == 0) {
+ return InferLayoutOutput();
+ }
+ output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes);
+ }
+ return InferLayoutOutput({input_layout, LayoutDecision("O")},
{output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[0],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<SqueezeAttrs>();
+ std::vector<size_t> reduce_axes;
+ if (attrs->axis.defined()) {
+ for (const auto& s : attrs->axis.value()) {
+ size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size());
+ if (Downcast<Integer>(input_shape[v_index])->value == 1) {
+ reduce_axes.push_back(v_index);
+ }
+ }
+ } else {
+ for (size_t i = 0; i < input_shape.size(); i++) {
+ if (Downcast<Integer>(input_shape[i])->value == 1) {
+ reduce_axes.push_back(i);
+ }
+ }
+ }
+ LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout,
reduce_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput ForwardInferLayoutTake(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ LayoutDecision input_layout = InferLayoutDecision(call->args[1],
var_layout_map);
+ if (!input_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout,
std::vector<size_t>{0});
+ return InferLayoutOutput({LayoutDecision("WE"), input_layout},
{output_layout}, Attrs());
+}
+
+TVM_REGISTER_OP("relax.nn.conv1d")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.conv2d")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout", MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.max_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
MSCInferLayoutPool2d);
+// reduce axis ops
+TVM_REGISTER_OP("relax.argmax")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.argmin")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.max")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.min")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.mean")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.sum")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.prod")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.std")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReduceAxis);
+// binary ops
+TVM_REGISTER_OP("relax.add")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.divide")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.floor_divide")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.multiply")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.power")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.subtract")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.equal")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater_equal")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less_equal")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.not_equal")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.maximum")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.minimum")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_and")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_or")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_xor")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_and")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_or")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_xor")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBinary);
+// math ops
+TVM_REGISTER_OP("relax.expand_dims")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForkwardInferLayoutExpandDims);
+TVM_REGISTER_OP("relax.matmul")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutMatmul);
+TVM_REGISTER_OP("relax.permute_dims")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutPermute);
+TVM_REGISTER_OP("relax.reshape")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutReshape);
+TVM_REGISTER_OP("relax.squeeze")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutSqueeze);
+TVM_REGISTER_OP("relax.take")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutTake);
+// nn ops
+TVM_REGISTER_OP("relax.nn.batch_norm")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutBatchNorm);
+TVM_REGISTER_OP("relax.nn.group_norm")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutNormalize);
+TVM_REGISTER_OP("relax.nn.layer_norm")
+ .set_attr<FRelaxInferLayout>("FMSCForwardInferLayout",
ForwardInferLayoutNormalize);
+
+// Backward Infer
+InferLayoutOutput BackwardInferLayoutCommon(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ NLayout output_layout = InferNLayout(call, var_layout_map);
+ LayoutDecision layout_hint;
+ if (output_layout.IsLeaf()) {
+ layout_hint = output_layout.LeafValue();
+ } else {
+ for (const auto& l : output_layout.NestedArray()) {
+ if (l.IsLeaf() && l.LeafValue()->layout.defined()) {
+ layout_hint = l.LeafValue();
+ }
+ }
+ }
+ if (!layout_hint->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<NLayout> input_layouts;
+ for (const auto& arg : call->args) {
+ const auto& saved_layout = InferLayoutDecision(arg, var_layout_map);
+ if (saved_layout->layout.defined()) {
+ input_layouts.push_back(saved_layout);
+ } else {
+ input_layouts.push_back(layout_hint);
+ }
+ }
+ return InferLayoutOutput(input_layouts, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutBinary(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ const auto& output = BackwardInferLayoutCommon(call, desired_layouts,
var_layout_map);
+ if (!output.defined()) {
+ return output;
+ }
+ std::vector<NLayout> input_layouts;
+ for (size_t i = 0; i < call->args.size(); i++) {
+ const auto& sinfo = GetStructInfo(call->args[i]);
+ if (const auto* t_info = sinfo.as<TensorStructInfoNode>()) {
+ if (t_info->ndim == 0) {
+ input_layouts.push_back(LayoutDecision(""));
+ } else {
+ input_layouts.push_back(output->input_layouts[i]);
+ }
+ } else {
+ LOG(FATAL) << "Binary input should be tensor, get " <<
sinfo->GetTypeKey();
+ }
+ }
+ return InferLayoutOutput(input_layouts, output->output_layouts, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutInplace(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map);
+}
+
+InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map,
0);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision g_layout = LayoutDecision("O");
+ return InferLayoutOutput({output_layout, g_layout, g_layout, g_layout,
g_layout},
+ {{output_layout, g_layout, g_layout}}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+ std::vector<size_t> expand_axes;
+ for (const auto& s : attrs->axis) {
+ expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+ }
+ LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout,
expand_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutNormalize(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map,
0);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision g_layout = LayoutDecision("O");
+ return InferLayoutOutput({output_layout, g_layout, g_layout},
{output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutMatmul(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& b_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[1]))->GetShape().value_or(empty);
+ if (b_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ size_t start = output_layout->layout.ndim() - b_shape.size();
+ String pre_layout;
+ for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) {
+ pre_layout = pre_layout + output_layout->layout[i].name();
+ }
+ LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
+ return InferLayoutOutput({output_layout, b_layout}, {output_layout},
Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutPermute(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ std::vector<size_t> permute_axes;
+ const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+ if (!attrs->axes.defined()) {
+ for (size_t i = output_layout->layout.ndim(); i > 0; i--) {
+ permute_axes.push_back(i - 1);
+ }
+ } else {
+ std::vector<int> attr_axes;
+ for (const auto& s : attrs->axes.value()) {
+ attr_axes.push_back(s->value);
+ }
+ for (size_t i = 0; i < output_layout->layout.ndim(); i++) {
+ int pos = ArrayUtils::IndexOf(attr_axes, static_cast<int>(i));
+ if (pos >= 0) {
+ permute_axes.push_back(pos);
+ } else {
+ permute_axes.push_back(i);
+ }
+ }
+ }
+ LayoutDecision input_layout = LayoutUtils::PermuteLayout(output_layout,
permute_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<StatisticalAttrs>();
+ if (attrs->keepdims) {
+ return InferLayoutOutput({output_layout}, {output_layout}, Attrs());
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ std::vector<size_t> axes;
+ for (const auto& s : attrs->axis.value()) {
+ axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size()));
+ }
+ LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutReshape(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ const auto& output_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call))->GetShape().value_or(empty);
+ if (input_shape.size() == 0 || output_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision input_layout;
+ if (input_shape.size() == output_shape.size()) {
+ input_layout = output_layout;
+ } else if (input_shape.size() > output_shape.size()) {
+ const auto& reduce_axes = InferReduceAxes(input_shape, output_shape);
+ if (reduce_axes.size() == 0) {
+ return InferLayoutOutput();
+ }
+ input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes);
+ } else {
+ const auto& expand_axes = InferExpandAxes(input_shape, output_shape);
+ if (expand_axes.size() == 0) {
+ return InferLayoutOutput();
+ }
+ input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes);
+ }
+ return InferLayoutOutput({input_layout, LayoutDecision("O")},
{output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ Array<PrimExpr> empty;
+ const auto& input_shape =
+
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->GetShape().value_or(empty);
+ if (input_shape.size() == 0) {
+ return InferLayoutOutput();
+ }
+ const auto* attrs = call->attrs.as<SqueezeAttrs>();
+ std::vector<size_t> reduce_axes;
+ if (attrs->axis.defined()) {
+ for (const auto& s : attrs->axis.value()) {
+ size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size());
+ if (Downcast<Integer>(input_shape[v_index])->value == 1) {
+ reduce_axes.push_back(v_index);
+ }
+ }
+ } else {
+ for (size_t i = 0; i < input_shape.size(); i++) {
+ if (Downcast<Integer>(input_shape[i])->value == 1) {
+ reduce_axes.push_back(i);
+ }
+ }
+ }
+ LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout,
reduce_axes);
+ return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
+}
+
+InferLayoutOutput BackwardInferLayoutTake(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map);
+ if (!output_layout->layout.defined()) {
+ return InferLayoutOutput();
+ }
+ LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout,
std::vector<size_t>{0});
+ return InferLayoutOutput({LayoutDecision("WE"), input_layout},
{output_layout}, Attrs());
+}
+
+TVM_REGISTER_OP("relax.nn.conv1d")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.conv2d")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
MSCInferLayoutConv);
+TVM_REGISTER_OP("relax.nn.max_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
MSCInferLayoutPool2d);
+TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
MSCInferLayoutPool2d);
+// reduce axis ops
+TVM_REGISTER_OP("relax.argmax")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.argmin")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.max")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.min")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.mean")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.sum")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.prod")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+TVM_REGISTER_OP("relax.std")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReduceAxis);
+// binary ops
+TVM_REGISTER_OP("relax.add")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.divide")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.floor_divide")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.multiply")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.power")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.subtract")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.equal")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.greater_equal")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.less_equal")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.not_equal")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.maximum")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.minimum")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_and")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_or")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.logical_xor")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_and")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_or")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+TVM_REGISTER_OP("relax.bitwise_xor")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBinary);
+// math ops
+TVM_REGISTER_OP("relax.expand_dims")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutExpandDims);
+TVM_REGISTER_OP("relax.matmul")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutMatmul);
+TVM_REGISTER_OP("relax.permute_dims")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutPermute);
+TVM_REGISTER_OP("relax.reshape")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutReshape);
+TVM_REGISTER_OP("relax.squeeze")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutSqueeze);
+TVM_REGISTER_OP("relax.take")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutTake);
+// nn ops
+TVM_REGISTER_OP("relax.nn.batch_norm")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutBatchNorm);
+TVM_REGISTER_OP("relax.nn.group_norm")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutNormalize);
+TVM_REGISTER_OP("relax.nn.layer_norm")
+ .set_attr<FRelaxInferLayout>("FMSCBackwardInferLayout",
BackwardInferLayoutNormalize);
+
+class LayoutInfer : public ExprVisitor {
+ public:
+ explicit LayoutInfer(const IRModule& ref_module) : ref_module_(ref_module) {
Reset(); }
+
+ void Reset() {
+ infered_ = false;
+ var_map_.clear();
+ ordered_exprs_.clear();
+ }
+
+ void RecordExpr(const Var& var, const Expr& expr) {
+ var_map_.Set(var, expr);
+ ordered_exprs_.push_back(expr);
+ }
+
+ Expr Infer(const Expr& expr) {
+ Reset();
+ ForwardInfer(expr);
+ BackwardInfer();
+ return expr;
+ }
+
+ void ForwardInfer(const Expr& expr) { ExprVisitor::VisitExpr(expr); }
+
+ void BackwardInfer() {
+ for (size_t e_idx = ordered_exprs_.size(); e_idx > 0; e_idx--) {
+ const Expr& expr = ordered_exprs_[e_idx - 1];
+ if (expr->IsInstance<TupleNode>()) {
+ continue;
+ }
+ if (expr->IsInstance<TupleGetItemNode>()) {
+ continue;
+ }
+ if (!expr->IsInstance<CallNode>()) {
+ continue;
+ }
+ const Call& call = Downcast<Call>(expr);
+ size_t infered_num = 0;
+ for (const auto& arg : call->args) {
+ if (arg->IsInstance<VarNode>() && var_map_.count(Downcast<Var>(arg))) {
+ if (LayoutUtils::LayoutInfered(var_map_[Downcast<Var>(arg)]) > 0) {
+ infered_num++;
+ }
+ } else if (LayoutUtils::LayoutInfered(arg)) {
+ infered_num++;
+ }
+ }
+ if (call->args.size() == 0 || infered_num == call->args.size() ||
+ !call->op->IsInstance<OpNode>() ||
LayoutUtils::HasUnknownDimTensor(call->args)) {
+ continue;
+ }
+ const OpNode* op_node = call->op.as<OpNode>();
+ if (op_node == nullptr) {
+ continue;
+ }
+ // Infer by op_node
+ Op op = Downcast<Op>(GetRef<Op>(op_node));
+ InferLayoutOutput infered_layout;
+ const auto msc_infer_map =
Op::GetAttrMap<FRelaxInferLayout>("FMSCBackwardInferLayout");
+ try {
+ if (msc_infer_map.count(op)) {
+ FRelaxInferLayout f = msc_infer_map[op];
+ infered_layout = f(call, Map<String, Array<String>>(),
var_layout_map_);
+ } else {
+ infered_layout =
+ BackwardInferLayoutCommon(call, Map<String, Array<String>>(),
var_layout_map_);
+ }
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to backward infer layout " << expr << " : " <<
err.message();
+ infered_layout = InferLayoutOutput();
+ }
+ try {
+ if (infered_layout.defined()) {
+ SetInputLayouts(infered_layout->input_layouts, call);
+ }
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to backward set inputs layout for " << call <<
" : "
+ << err.message();
+ }
+ }
+ }
+
+ void SetInputLayouts(const Array<NLayout>& input_layouts, const Call& call) {
+ if (input_layouts.size() == call->args.size()) {
+ for (size_t i = 0; i < input_layouts.size(); i++) {
+ if (call->args[i]->IsInstance<VarNode>()) {
+ const auto& var = Downcast<Var>(call->args[i]);
+ var_layout_map_[var] = input_layouts[i];
+ if (var_map_.count(var)) {
+ if (LayoutUtils::SetLayout(var_map_[var], input_layouts[i])) {
+ infered_ = true;
+ }
+ } else if (LayoutUtils::SetLayout(var, input_layouts[i])) {
+ infered_ = true;
+ }
+ } else if (LayoutUtils::SetLayout(call->args[i], input_layouts[i])) {
+ infered_ = true;
+ }
+ }
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
+ ExprVisitor::VisitBinding_(binding, call_node);
+ const auto& call = GetRef<Call>(call_node);
+ if (const auto* v_node = call->op.as<GlobalVarNode>()) {
+ // infer global func and set var layouts
+ const auto& func =
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+ Infer(func);
+ for (size_t i = 0; i < func->params.size(); i++) {
+ if (var_layout_map_.count(func->params[i]) &&
+ LayoutUtils::SetLayout(call->args[i],
var_layout_map_[func->params[i]])) {
+ infered_ = true;
+ }
+ }
+ if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
+ var_layout_map_[binding->var] = GetNLayout(var_layout_map_,
b_node->body);
+ if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) {
+ infered_ = true;
+ }
+ } else {
+ LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+ }
+ } else {
+ // infer call
+ bool infer_outputs = true;
+ RecordExpr(binding->var, call);
+ if (LayoutUtils::LayoutInfered(call)) {
+ infer_outputs = false;
+ }
+ if (call->args.size() == 0 || !call->op->IsInstance<OpNode>() ||
+ LayoutUtils::HasUnknownDimTensor(call->args)) {
+ infer_outputs = false;
+ }
+ const OpNode* op_node = call->op.as<OpNode>();
+ if (op_node == nullptr) {
+ infer_outputs = false;
+ }
+ if (infer_outputs) {
+ // infer layouts
+ Op op = Downcast<Op>(GetRef<Op>(op_node));
+ InferLayoutOutput infered_layout;
+ const auto msc_infer_map =
Op::GetAttrMap<FRelaxInferLayout>("FMSCForwardInferLayout");
+ const auto relax_infer_map =
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
+ bool set_inputs = true;
+ try {
+ if (msc_infer_map.count(op)) {
+ FRelaxInferLayout f = msc_infer_map[op];
+ infered_layout = f(call, Map<String, Array<String>>(),
var_layout_map_);
+ } else if (!relax_infer_map.count(op)) {
+ infered_layout =
+ ForwardInferLayoutCommon(call, Map<String, Array<String>>(),
var_layout_map_);
+ }
+ if (relax_infer_map.count(op) && !infered_layout.defined()) {
+ FRelaxInferLayout f = relax_infer_map[op];
+ infered_layout = f(call, Map<String, Array<String>>(),
var_layout_map_);
+ set_inputs = false;
+ }
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to forward infer layout for " <<
binding->var << " : "
+ << binding->value << ", reason: " << err.message();
+ infered_layout = InferLayoutOutput();
+ }
+ if (infered_layout.defined() && infered_layout->output_layouts.size()
== 1) {
+ try {
+ var_layout_map_[binding->var] = infered_layout->output_layouts[0];
+ if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) {
+ infered_ = true;
+ }
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to forward set output layout for " <<
binding->var << " : "
+ << binding->value << ", reason: " << err.message();
+ }
+ }
+ if (set_inputs && infered_layout.defined()) {
+ try {
+ SetInputLayouts(infered_layout->input_layouts, call);
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to forward set inputs layout for " << call
<< " : "
+ << err.message();
+ }
+ }
+ }
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* val)
final {
+ ExprVisitor::VisitBinding_(binding, val);
+ std::vector<NLayout> input_layout;
+ for (const auto& field : val->fields) {
+ if (binding->var->IsInstance<DataflowVarNode>()) {
+ // Df var: Use the current realized layout to group the tuple;
+ input_layout.push_back(GetNLayout(var_layout_map_, field));
+ } else {
+ // Global var: Use the initial layout to group the tuple;
+ input_layout.push_back(InitialNLayout(field));
+ }
+ }
+ if (IsNestedTensor(binding->var)) {
+ var_layout_map_[binding->var] = input_layout;
+ }
+ RecordExpr(binding->var, GetRef<Tuple>(val));
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
val) final {
+ ExprVisitor::VisitBinding_(binding, val);
+ NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
+ ? GetNLayout(var_layout_map_, val->tuple)
+ : InitialNLayout(val->tuple);
+ var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+ RecordExpr(binding->var, GetRef<TupleGetItem>(val));
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val)
final {
+ ExprVisitor::VisitBinding_(binding, val);
+ const NLayout& out_layout = LayoutDecision("O");
+ var_layout_map_[binding->var] = out_layout;
+ if (LayoutUtils::SetLayout(GetRef<ShapeExpr>(val), out_layout)) {
+ infered_ = true;
+ }
+ }
+
+ bool infered() { return infered_; }
+
+ private:
+ IRModule ref_module_;
+ bool infered_;
+ Map<Var, Expr> var_map_;
+ Array<Expr> ordered_exprs_;
+ std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual>
var_layout_map_;
+}; // class LayoutInfer
+
+class LayoutChecker : public ExprVisitor {
+ public:
+ LayoutChecker() { missing_num_ = 0; }
+
+ void Check(const Expr& expr) {
+ ExprVisitor::VisitExpr(expr);
+ ICHECK_EQ(missing_num_, 0) << "Some layout is missing";
+ }
+
+ void VisitExpr_(const CallNode* call) final {
+ ExprVisitor::VisitExpr_(call);
+ if (!LayoutUtils::LayoutInfered(GetRef<Call>(call))) {
+ missing_num_++;
+ }
+ }
+
+ void VisitExpr_(const ConstantNode* cn) final {
+ ExprVisitor::VisitExpr_(cn);
+ if (!LayoutUtils::LayoutInfered(GetRef<Constant>(cn))) {
+ missing_num_++;
+ }
+ }
+
+ private:
+ size_t missing_num_;
+}; // class LayoutChecker
+
+void SetExprLayout(const IRModule& ref_module, const Expr& func, bool
allow_missing) {
+ auto layout_infer = LayoutInfer(ref_module);
+ auto new_func = layout_infer.Infer(func);
+ if (!allow_missing) {
+ LayoutChecker().Check(new_func);
+ }
+}
+
+namespace transform {
+
+Pass SetExprLayout(bool allow_missing, const String& entry_name) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m,
+
PassContext pc) {
+ relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing);
+ return m;
+ };
+ return CreateModulePass(pass_func, 0, "SetExprLayout", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc
b/src/contrib/msc/core/transform/set_expr_name.cc
new file mode 100644
index 0000000000..5b39a5a7ac
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_expr_name.cc
+ * \brief Pass for setting name for call and constant.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../utils.h"
+
+namespace tvm {
+using namespace tvm::contrib::msc;
+
+namespace relax {
+
+/*!
+ * \brief Name setter for Relax
+ */
+class RelaxExprNameSetter : public ExprVisitor {
+ public:
+ explicit RelaxExprNameSetter(const IRModule& ref_module) :
ref_module_(ref_module) {}
+
+ void VisitBindingBlock(const BindingBlock& block) final {
+ String block_name = SpanUtils::GetAttr(block->span, "name");
+ if (block_name.size() == 0) {
+ block_name = "block";
+ }
+ if (setted_blocks_.count(block_name)) {
+ int cnt = 1;
+ while (setted_blocks_.count(block_name + "_" + std::to_string(cnt))) {
+ cnt++;
+ }
+ block_name = block_name + "_" + std::to_string(cnt);
+ }
+ setted_blocks_.insert(block_name);
+ block_stack_.push_back(block_name);
+ const String& unique_name = StringUtils::Join(block_stack_, ".");
+ block->span = SpanUtils::SetAttr(block->span, "name", unique_name);
+ ExprVisitor::VisitBindingBlock(block);
+ block_stack_.pop_back();
+ }
+
+ void VisitExpr_(const ConstantNode* val) {
+ ExprVisitor::VisitExpr_(val);
+ const String& unique_name = GetUniqueName(GetRef<Constant>(val), "const");
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ expr_names_.Set(GetRef<Constant>(val), unique_name);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) {
+ ExprVisitor::VisitBinding_(binding, val);
+ const String& unique_name = GetUniqueName(GetRef<Constant>(val), "const");
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ expr_names_.Set(binding->var, unique_name);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) {
+ ExprVisitor::VisitBinding_(binding, val);
+ const String& unique_name = GetUniqueName(GetRef<ShapeExpr>(val), "shape");
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ expr_names_.Set(binding->var, unique_name);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) {
+ ExprVisitor::VisitBinding_(binding, val);
+ const String& unique_name = GetUniqueName(GetRef<Tuple>(val), "tuple");
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ expr_names_.Set(binding->var, unique_name);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
val) {
+ ExprVisitor::VisitBinding_(binding, val);
+ ICHECK(expr_names_.count(val->tuple)) << "Can not find tuple of " <<
GetRef<TupleGetItem>(val);
+ const String& unique_name = expr_names_[val->tuple] + "." +
std::to_string(val->index);
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ expr_names_.Set(binding->var, unique_name);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* val) {
+ ExprVisitor::VisitBinding_(binding, val);
+ String name_hint, optype;
+ if (const auto* op_node = val->op.as<OpNode>()) {
+ const std::string& op_name = op_node->name;
+ int rpos = op_name.rfind(".");
+ name_hint = op_name.substr(rpos + 1);
+ optype = StringUtils::Replace(op_node->name, "relax.", "");
+ } else if (const auto* v_node = val->op.as<GlobalVarNode>()) {
+ const auto& func =
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+ ExprVisitor::VisitExpr(func);
+ const auto& name_opt = func->GetAttr<runtime::String>(attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+ name_hint = name_opt.value();
+ optype = name_hint;
+ }
+ // set name
+ const String& unique_name = GetUniqueName(GetRef<Expr>(val), name_hint);
+ if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
+ val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
+ }
+ // set constant consumer && master
+ Array<String> input_types;
+ try {
+ input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true);
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(val) << "
: " << err.message();
+ throw err;
+ }
+ for (size_t i = 0; i < input_types.size(); i++) {
+ if (input_types[i] == "input") {
+ continue;
+ }
+ if (const auto* c_node = val->args[i].as<ConstantNode>()) {
+ const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
+ if (constant_consumers_.count(const_name)) {
+ val->span = SpanUtils::SetAttr(val->span, "master",
constant_consumers_[const_name]);
+ } else {
+ constant_consumers_.Set(const_name, unique_name);
+ }
+ }
+ }
+ expr_names_.Set(binding->var, unique_name);
+ }
+
+ private:
+ const String GetUniqueName(const Expr& expr, const String& name_hint) {
+ String expr_name = SpanUtils::GetAttr(expr->span, "name");
+ if (expr_name.size() == 0) {
+ expr_name = name_hint;
+ }
+ if (!setted_names_.count(expr_name)) {
+ setted_names_.Set(expr_name, expr);
+ return expr_name;
+ }
+ if (setted_names_[expr_name] == expr) {
+ return expr_name;
+ }
+ int cnt = 1;
+ while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) &&
+ setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) {
+ cnt++;
+ }
+ expr_name = expr_name + "_" + std::to_string(cnt);
+ if (!setted_names_.count(expr_name)) {
+ setted_names_.Set(expr_name, expr);
+ }
+ return expr_name;
+ }
+
+ Map<String, Expr> setted_names_;
+ Map<String, String> constant_consumers_;
+ std::set<String> setted_blocks_;
+ Array<String> block_stack_;
+ Map<Expr, String> expr_names_;
+ IRModule ref_module_;
+}; // class ExprNameSetter
+
+void SetRelaxExprName(const IRModule& ref_module, const Expr& e) {
+ RelaxExprNameSetter(ref_module).VisitExpr(e);
+}
+
+namespace transform {
+
+Pass SetRelaxExprName(const String& entry_name) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m,
+
PassContext pc) {
+ relax::SetRelaxExprName(m, m->Lookup(entry_name));
+ return m;
+ };
+ return CreateModulePass(pass_func, 0, "SetRelaxExprName", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName);
+
+} // namespace transform
+} // namespace relax
+
+namespace relay {
+
+/*!
+ * \brief Name setter for Relay
+ */
+class RelayExprNameSetter : public ExprVisitor {
+ public:
+ explicit RelayExprNameSetter(const IRModule& ref_module) :
ref_module_(ref_module) {}
+
+ void VisitExpr_(const ConstantNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ const String& unique_name = GetUniqueName(GetRef<Constant>(op), "const");
+ if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+ op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+ }
+ }
+
+ void VisitExpr_(const TupleNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ const String& unique_name = GetUniqueName(GetRef<Tuple>(op), "tuple");
+ if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+ op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+ }
+ }
+
+ void VisitExpr_(const TupleGetItemNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, "name");
+ const String& unique_name = tuple_name + "." + std::to_string(op->index);
+ if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+ op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+ }
+ }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ const auto& name_opt = op->GetAttr<runtime::String>(attr::kComposite);
+ const String& name_hint = name_opt.defined() ? name_opt.value() : "func";
+ const String& unique_name = GetUniqueName(GetRef<Function>(op), name_hint);
+ if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+ op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+ }
+ }
+
+ void VisitExpr_(const CallNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ String name_hint, optype;
+ if (const auto* op_node = op->op.as<OpNode>()) {
+ const std::string& op_name = op_node->name;
+ int rpos = op_name.rfind(".");
+ name_hint = op_name.substr(rpos + 1);
+ optype = StringUtils::Replace(op_node->name, "relay.", "");
+ } else if (const auto* v_node = op->op.as<GlobalVarNode>()) {
+ const auto& func =
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
+ ExprVisitor::VisitExpr(func);
+ const auto& name_opt = func->GetAttr<runtime::String>(attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected global func without composite";
+ optype = name_opt.value();
+ name_hint = optype;
+ }
+ // set name
+ const String& unique_name = GetUniqueName(GetRef<Expr>(op), name_hint);
+ if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
+ op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
+ }
+ // set constant consumer && master
+ Array<String> input_types;
+ try {
+ input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false);
+ } catch (runtime::InternalError& err) {
+ LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(op) << "
: " << err.message();
+ throw err;
+ }
+ for (size_t i = 0; i < input_types.size(); i++) {
+ if (input_types[i] == "input") {
+ continue;
+ }
+ if (const auto* c_node = op->args[i].as<ConstantNode>()) {
+ const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
+ if (constant_consumers_.count(const_name)) {
+ op->span = SpanUtils::SetAttr(op->span, "master",
constant_consumers_[const_name]);
+ } else {
+ constant_consumers_.Set(const_name, unique_name);
+ }
+ }
+ }
+ }
+
+ private:
+ const String GetUniqueName(const Expr& expr, const String& name_hint) {
+ String expr_name = SpanUtils::GetAttr(expr->span, "name");
+ if (expr_name.size() == 0) {
+ expr_name = name_hint;
+ }
+ if (!setted_names_.count(expr_name)) {
+ setted_names_.Set(expr_name, expr);
+ return expr_name;
+ }
+ if (setted_names_[expr_name] == expr) {
+ return expr_name;
+ }
+ int cnt = 1;
+ while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) &&
+ setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) {
+ cnt++;
+ }
+ expr_name = expr_name + "_" + std::to_string(cnt);
+ if (!setted_names_.count(expr_name)) {
+ setted_names_.Set(expr_name, expr);
+ }
+ return expr_name;
+ }
+
+ Map<String, Expr> setted_names_;
+ Map<String, String> constant_consumers_;
+ IRModule ref_module_;
+}; // class ExprNameSetter
+
+void SetRelayExprName(const IRModule& ref_module, const Expr& e) {
+ RelayExprNameSetter(ref_module).VisitExpr(e);
+}
+
+namespace transform {
+
+Pass SetRelayExprName(const String& entry_name) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m,
+
PassContext pc) {
+ relay::SetRelayExprName(m, m->Lookup(entry_name));
+ return m;
+ };
+ return CreateModulePass(pass_func, 0, "SetRelayExprName", {});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.SetRelayExprName").set_body_typed(SetRelayExprName);
+
+} // namespace transform
+} // namespace relay
+
+} // namespace tvm
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
new file mode 100644
index 0000000000..7ecff876f2
--- /dev/null
+++ b/src/contrib/msc/core/utils.cc
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/utils.cc
+ */
+
+#include "utils.h"
+
+#include <string>
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+size_t CommonUtils::GetIndex(int index, size_t max_size) {
+ size_t v_index;
+ if (index < 0) {
+ v_index = index + max_size;
+ } else {
+ v_index = index;
+ }
+ ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " <<
max_size;
+ return v_index;
+}
+
+const Array<String> StringUtils::Split(const String& src_string, const String&
sep) {
+ Array<String> sub_strings;
+ if (src_string.size() == 0) {
+ return sub_strings;
+ }
+ std::string src_cstring = src_string;
+ const std::string& csep = sep;
+ int pos = src_cstring.find(csep);
+ while (pos >= 0) {
+ if (pos > 0) {
+ sub_strings.push_back(src_cstring.substr(0, pos));
+ }
+ src_cstring = src_cstring.substr(pos + csep.size());
+ pos = src_cstring.find(csep);
+ }
+ if (src_cstring.size() > 0) {
+ sub_strings.push_back(src_cstring);
+ }
+ return sub_strings;
+}
+
+const String StringUtils::Join(const Array<String>& sub_strings, const String&
joint) {
+ String join_str = "";
+ for (size_t i = 0; i < sub_strings.size(); i++) {
+ join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" :
joint);
+ }
+ return join_str;
+}
+
+const String StringUtils::Replace(const String& src_string, const String&
old_str,
+ const String& new_str) {
+ String new_string;
+ const auto& sub_strings = Split(src_string, old_str);
+ for (size_t i = 0; i < sub_strings.size(); i++) {
+ new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ?
"" : new_str);
+ }
+ return new_string;
+}
+
+const std::tuple<String, String> StringUtils::SplitOnce(const String&
src_string, const String& sep,
+ bool from_left) {
+ if (src_string.size() == 0) {
+ return std::make_tuple(String(), String());
+ }
+ std::string src_cstring = src_string;
+ const std::string& csep = sep;
+ int pos = from_left ? src_cstring.find(csep) : src_cstring.rfind(csep);
+ if (pos >= 0) {
+ return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos
+ csep.size()));
+ }
+ return std::make_tuple(src_string, String());
+}
+
+const Array<String> StringUtils::GetClosures(const String& src_string, const
String& left,
+ const String& right) {
+ Array<String> tokens;
+ if (src_string.size() == 0) {
+ return tokens;
+ }
+ String token = "start";
+ String left_str = src_string;
+ while (token.size() > 0) {
+ std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left);
+ if (left_str.size() > 0) {
+ std::tie(token, left_str) = StringUtils::SplitOnce(left_str, right);
+ } else {
+ token = "";
+ }
+ if (token.size() > 0) {
+ tokens.push_back(token);
+ }
+ }
+ return tokens;
+}
+
+const String StringUtils::GetClosureOnce(const String& src_string, const
String& left,
+ const String& right, bool from_left) {
+ if (src_string.size() == 0) {
+ return "";
+ }
+ String val = std::get<1>(SplitOnce(src_string, left, from_left));
+ if (val.size() > 0) {
+ val = std::get<0>(StringUtils::SplitOnce(val, right, from_left));
+ }
+ return val;
+}
+
+const String StringUtils::ToString(const runtime::ObjectRef& obj) {
+ String obj_string;
+ if (!obj.defined()) {
+ obj_string = "";
+ } else if (obj.as<StringObj>()) {
+ obj_string = Downcast<String>(obj);
+ } else if (const auto* n = obj.as<IntImmNode>()) {
+ obj_string = std::to_string(n->value);
+ } else if (const auto* n = obj.as<FloatImmNode>()) {
+ obj_string = std::to_string(n->value);
+ } else if (const auto* n = obj.as<ArrayNode>()) {
+ for (size_t i = 0; i < n->size(); i++) {
+ obj_string = obj_string + ToString((*n)[i]);
+ if (n->size() == 1 || i < n->size() - 1) {
+ obj_string = obj_string + ",";
+ }
+ }
+ } else {
+ std::ostringstream obj_des;
+ obj_des << obj;
+ obj_string = obj_des.str();
+ }
+ return obj_string;
+}
+
+bool StringUtils::CompareArrays(const Array<String>& left, const
Array<String>& right, int size) {
+ if (left.size() == right.size() && left.size() == 0) {
+ return true;
+ }
+ if (size == -1 && left.size() != right.size()) {
+ return false;
+ }
+ if (left.size() == 0 || right.size() == 0) {
+ return false;
+ }
+ size = left.size();
+ ICHECK_GT(size, 0) << "Positive size should be given, get " << size;
+ if (size > static_cast<int>(left.size()) || size >
static_cast<int>(right.size())) {
+ return false;
+ }
+ for (size_t i = 0; i < static_cast<size_t>(size); i++) {
+ if (left[i] != right[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+const Span SpanUtils::SetAttr(const Span& span, const String& key, const
String& value) {
+ if (value.size() == 0) {
+ return span;
+ }
+ String new_source;
+ Array<String> tokens{"<" + key + ">", "</" + key + ">"};
+ if (span.defined() && span->source_name.defined()) {
+ const String& source_str = span->source_name->name;
+ String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0]));
+ String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1]));
+ if (left.size() > 0) {
+ new_source = left + tokens[0] + value + tokens[1] + right;
+ } else {
+ new_source = source_str + tokens[0] + value + tokens[1];
+ }
+ } else {
+ new_source = tokens[0] + value + tokens[1];
+ }
+ if (span.defined()) {
+ return Span(SourceName::Get(new_source), span->line, span->end_line,
span->column,
+ span->end_column);
+ }
+ return Span(SourceName::Get(new_source), 0, 0, 0, 0);
+}
+
+const String SpanUtils::GetAttr(const Span& span, const String& key) {
+ if (span.defined() && span->source_name.defined()) {
+ Array<String> tokens{"<" + key + ">", "</" + key + ">"};
+ return StringUtils::GetClosureOnce(span->source_name->name, tokens[0],
tokens[1]);
+ }
+ return "";
+}
+
+const Map<String, String> SpanUtils::GetAttrs(const Span& span) {
+ Map<String, String> attrs;
+ for (const auto& key : StringUtils::GetClosures(span->source_name->name,
"</", ">")) {
+ attrs.Set(key, GetAttr(span, key));
+ }
+ return attrs;
+}
+
+const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t
inputs_num,
+ bool as_relax) {
+ Array<String> input_types;
+ if (as_relax && (optype == "broadcast_to" || optype == "reshape")) {
+ input_types.push_back("input");
+ input_types.push_back("shape");
+ } else if (optype == "clip" && as_relax) {
+ input_types.push_back("input");
+ input_types.push_back("min");
+ input_types.push_back("max");
+ } else if (optype == "full" && as_relax) {
+ input_types.push_back("shape");
+ input_types.push_back("input");
+ } else if (optype == "trilu") {
+ input_types.push_back("input");
+ input_types.push_back("k");
+ } else if (optype == "image.resize2d" && as_relax) {
+ input_types.push_back("input");
+ input_types.push_back("size");
+ } else if (optype == "nn.conv1d" || optype == "nn.conv2d" || optype ==
"nn.conv3d") {
+ input_types.push_back("input");
+ input_types.push_back("weight");
+ } else if (optype == "nn.batch_norm") {
+ input_types.push_back("input");
+ input_types.push_back("gamma");
+ input_types.push_back("beta");
+ input_types.push_back("mean");
+ input_types.push_back("var");
+ } else if (optype == "nn.layer_norm" || optype == "nn.group_norm") {
+ input_types.push_back("input");
+ input_types.push_back("gamma");
+ input_types.push_back("beta");
+ } else if (optype == "msc.linear") {
+ if (as_relax) {
+ input_types.push_back("weight");
+ input_types.push_back("input");
+ } else {
+ input_types.push_back("input");
+ input_types.push_back("weight");
+ }
+ } else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") {
+ input_types.push_back("input");
+ input_types.push_back("weight");
+ input_types.push_back("bias");
+ if (as_relax) {
+ input_types.push_back("expand_bias");
+ }
+ } else if (optype == "msc.linear_bias") {
+ if (as_relax) {
+ input_types.push_back("weight");
+ input_types.push_back("input");
+ } else {
+ input_types.push_back("input");
+ input_types.push_back("weight");
+ }
+ input_types.push_back("bias");
+ } else if (optype == "msc.embedding" && inputs_num == 2) {
+ input_types.push_back("input");
+ input_types.push_back("weight");
+ } else if (optype == "msc.embedding" && inputs_num == 4) {
+ input_types.push_back("input");
+ input_types.push_back("reduce_in");
+ input_types.push_back("weight");
+ input_types.push_back("expand_out");
+ } else if (optype == "msc.gelu") {
+ input_types.push_back("input");
+ input_types.push_back("factor_1");
+ input_types.push_back("factor_2");
+ input_types.push_back("factor_3");
+ } else {
+ for (size_t i = 0; i < inputs_num; i++) {
+ input_types.push_back("input");
+ }
+ }
+ ICHECK_EQ(input_types.size(), inputs_num)
+ << "Optype " << optype << " get input types " << input_types << " and
inputs_num "
+ << inputs_num << " mismatch";
+ return input_types;
+}
+
+const Array<String> ExprUtils::GetInputTypes(const RelaxCall& call) {
+ const String& optype = StringUtils::Replace(Downcast<Op>(call->op)->name,
"relax.", "");
+ return GetInputTypes(optype, call->args.size(), true);
+}
+
+const Array<String> ExprUtils::GetInputTypes(const RelayCall& call) {
+ const String& optype = StringUtils::Replace(Downcast<Op>(call->op)->name,
"relay.", "");
+ return GetInputTypes(optype, call->args.size(), false);
+}
+
+TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr);
+
+TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs);
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
new file mode 100644
index 0000000000..9da4ce3346
--- /dev/null
+++ b/src/contrib/msc/core/utils.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/utils.h
+ * \brief Common utilities for msc.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relay/expr.h>
+
+#include <tuple>
+#include <vector>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using RelaxCall = tvm::relax::Call;
+using RelayCall = tvm::relay::Call;
+
+class CommonUtils {
+ public:
+ /*!
+ * \brief Check if the index is in range.
+ * \return The valid index.
+ */
+ TVM_DLL static size_t GetIndex(int index, size_t max_size);
+};
+
+/*!
+ * \brief Utils for String.
+ */
+class StringUtils {
+ public:
+ /*!
+ * \brief Split the String into sub Strings.
+ * \return The SubStrings.
+ */
+ TVM_DLL static const Array<String> Split(const String& src_string, const
String& sep);
+
+ /*!
+ * \brief Join the SubStrings into String.
+ * \return The String.
+ */
+ TVM_DLL static const String Join(const Array<String>& sub_strings, const
String& joint);
+
+ /*!
+ * \brief Replace the substring old to new in String.
+ * \return The replaced String.
+ */
+ TVM_DLL static const String Replace(const String& src_string, const String&
old_str,
+ const String& new_str);
+
+ /*!
+ * \brief Split the String into two sub Strings, only split by the frist seq.
+ * \return The SubStrings.
+ */
+ TVM_DLL static const std::tuple<String, String> SplitOnce(const String&
src_string,
+ const String& sep,
+ bool from_left =
true);
+
+ /*!
+ * \brief Get the tokens between left and right.
+ * \return The Tokens.
+ */
+ TVM_DLL static const Array<String> GetClosures(const String& src_string,
const String& left,
+ const String& right);
+
+ /*!
+ * \brief Get the first token between left and right.
+ * \return The Token.
+ */
+ TVM_DLL static const String GetClosureOnce(const String& src_string, const
String& left,
+ const String& right, bool
from_left = true);
+
+ /*!
+ * \brief Change Object to String.
+ * \return The String.
+ */
+ TVM_DLL static const String ToString(const runtime::ObjectRef& obj);
+
+ /*!
+ * \brief Compare String arrays.
+ * \return Whether two array are same.
+ */
+ TVM_DLL static bool CompareArrays(const Array<String>& left, const
Array<String>& right,
+ int size = -1);
+};
+
+/*!
+ * \brief Utils for Array.
+ */
+class ArrayUtils {
+ public:
+ /*!
+ * \brief Replace the element old to new in Array.
+ * \return The replaced Array.
+ */
+ template <typename T>
+ TVM_DLL static const Array<T> Replace(const Array<T>& src_array, const T&
old_ele,
+ const T& new_ele) {
+ Array<T> new_array;
+ for (const auto& a : src_array) {
+ if (a == old_ele) {
+ new_array.push_back(new_ele);
+ } else {
+ new_array.push_back(a);
+ }
+ }
+ return new_array;
+ }
+
+ /*!
+ * \brief Find the index of element.
+ * \return The index, -1 if not found.
+ */
+ template <typename T>
+ TVM_DLL static int IndexOf(const std::vector<T>& array, const T& ele) {
+ for (size_t i = 0; i < array.size(); i++) {
+ if (array[i] == ele) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ /*!
+ * \brief Downcast elements in the array.
+ * \return The downcasted array
+ */
+ template <typename T>
+ TVM_DLL static const Array<T> Cast(const Array<PrimExpr>& src_array) {
+ Array<T> new_array;
+ for (const auto& s : src_array) {
+ new_array.push_back(Downcast<T>(s));
+ }
+ return new_array;
+ }
+};
+
+/*!
+ * \brief Utils for Span.
+ */
+class SpanUtils {
+ public:
+ /*!
+ * \brief Set <key>value</key> to the Span.
+ * \return The new Span.
+ */
+ TVM_DLL static const Span SetAttr(const Span& span, const String& key, const
String& value);
+
+ /*!
+ * \brief Get the value in <key>value</key> from the Span.
+ * \return The value String.
+ */
+ TVM_DLL static const String GetAttr(const Span& span, const String& key);
+
+ /*!
+ * \brief Get all the key:value in format <key>value</key> from the Span.
+ * \return The Attrs Map.
+ */
+ TVM_DLL static const Map<String, String> GetAttrs(const Span& span);
+};
+
+/*!
+ * \brief Utils for Expr.
+ */
+class ExprUtils {
+ public:
+ /*!
+ * \brief Get the input types of call.
+ * \return The input types.
+ */
+ TVM_DLL static const Array<String> GetInputTypes(const String& optype,
size_t inputs_num,
+ bool as_relax);
+
+ /*!
+ * \brief Get the input types of call.
+ * \return The input types.
+ */
+ TVM_DLL static const Array<String> GetInputTypes(const RelaxCall& call);
+
+ /*!
+ * \brief Get the input types of call.
+ * \return The input types.
+ */
+ TVM_DLL static const Array<String> GetInputTypes(const RelayCall& call);
+
+ /*!
+ * \brief Get the scalar value of ndarray.
+ * \return The scalar value.
+ */
+ template <typename T>
+ TVM_DLL static const T GetScalar(const runtime::NDArray& array, size_t i =
0) {
+ if (array->dtype.code == kDLInt) {
+ if (array->dtype.bits == 8) {
+ return T(reinterpret_cast<int8_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 16) {
+ return T(reinterpret_cast<int16_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 32) {
+ return T(reinterpret_cast<int32_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 64) {
+ return T(reinterpret_cast<int64_t*>(array->data)[i]);
+ }
+ } else if (array->dtype.code == kDLUInt) {
+ if (array->dtype.bits == 1) { // bool
+ return T(reinterpret_cast<uint8_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 8) {
+ return T(reinterpret_cast<uint8_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 16) {
+ return T(reinterpret_cast<uint16_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 32) {
+ return T(reinterpret_cast<uint32_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 64) {
+ return T(reinterpret_cast<uint64_t*>(array->data)[i]);
+ }
+ } else if (array->dtype.code == kDLFloat) {
+ if (array->dtype.bits == 32) {
+ return T(reinterpret_cast<float*>(array->data)[i]);
+ } else if (array->dtype.bits == 64) {
+ return T(reinterpret_cast<double*>(array->data)[i]);
+ }
+ }
+ LOG(FATAL) << "Failed to get scalar from array " << array;
+ }
+
+ /*!
+ * \brief Get the scalar value of relax constant.
+ * \return The scalar value.
+ */
+ template <typename T>
+ TVM_DLL static const T GetScalar(const relax::Constant& constant, size_t i =
0) {
+ return GetScalar<T>(constant->data, i);
+ }
+
+ /*!
+ * \brief Get the scalar value of relay constant.
+ * \return The scalar value.
+ */
+ template <typename T>
+ TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i =
0) {
+ return GetScalar<T>(constant->data, i);
+ }
+};
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+#endif // TVM_CONTRIB_MSC_CORE_UTILS_H_
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 23091fffd2..3f028ba656 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -342,6 +342,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR},
{"USE_UMA", TVM_INFO_USE_UMA},
{"USE_VERILATOR", TVM_INFO_USE_VERILATOR},
+ {"USE_MSC", TVM_INFO_USE_MSC},
{"USE_CCACHE", TVM_INFO_USE_CCACHE},
{"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT},
};
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
new file mode 100644
index 0000000000..4717437d76
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm.relay import testing
+from tvm.relay.expr_functor import ExprVisitor
+from tvm.relay.build_module import bind_params_by_name
+
+from tvm.relax.frontend.torch import from_fx
+from tvm.relax import PyExprVisitor
+
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import transform as msc_transform
+
+
+class RelaxChecker(PyExprVisitor):
+ """Check if name as span attribute is setted."""
+
+ def check(self, expr):
+ self._missing_exprs = []
+ if isinstance(expr, tvm.relax.Expr):
+ self.visit_expr(expr)
+ elif isinstance(expr, tvm.relax.BindingBlock):
+ self.visit_binding_block(expr)
+ assert len(self._missing_exprs) == 0, "Missing {}
layouts".format(len(self._missing_exprs))
+
+ def visit_var_binding_(self, binding) -> None:
+ super().visit_var_binding_(binding)
+ layout = _ffi_api.SpanGetAttr(binding.value.span, "layout")
+ if not layout:
+ self._missing_exprs.append(binding.value)
+
+ def visit_constant_(self, op) -> None:
+ super().visit_constant_(op)
+ layout = _ffi_api.SpanGetAttr(op.span, "layout")
+ if not layout:
+ self._missing_exprs.append(op)
+
+
+def test_relax():
+ try:
+ import torch
+ import torchvision
+ from torch import fx
+ except:
+ print("please install pytorch python package")
+ return
+
+ torch_model = torchvision.models.resnet50()
+ graph_model = fx.symbolic_trace(torch_model)
+ input_info = [([1, 3, 224, 224], "float32")]
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+ mod = msc_transform.SetExprLayout()(mod)
+ RelaxChecker().check(mod)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py
b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
new file mode 100644
index 0000000000..0c174ff7bd
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm.relay import testing
+from tvm.relay.expr_functor import ExprVisitor
+from tvm.relay.build_module import bind_params_by_name
+
+from tvm.relax.frontend.torch import from_fx
+from tvm.relax import PyExprVisitor
+
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import transform as msc_transform
+
+
+class RelayChecker(ExprVisitor):
+ """Check if name as span attribute is setted."""
+
+ def check(self, expr):
+ self._missing_exprs = []
+ super().visit(expr)
+ assert len(self._missing_exprs) == 0, "Missing {}
names".format(len(self._missing_exprs))
+
+ def visit_constant(self, expr):
+ super().visit_constant(expr)
+ name = _ffi_api.SpanGetAttr(expr.span, "name")
+ if not name:
+ self._missing_exprs.append(expr)
+
+ def visit_call(self, expr):
+ super().visit_call(expr)
+ name = _ffi_api.SpanGetAttr(expr.span, "name")
+ if not name:
+ self._missing_exprs.append(expr)
+
+
+class RelaxChecker(PyExprVisitor):
+ """Check if name as span attribute is setted."""
+
+ def check(self, expr):
+ self._missing_exprs = []
+ if isinstance(expr, tvm.relax.Expr):
+ self.visit_expr(expr)
+ elif isinstance(expr, tvm.relax.BindingBlock):
+ self.visit_binding_block(expr)
+ assert len(self._missing_exprs) == 0, "Missing {}
names".format(len(self._missing_exprs))
+
+ def visit_var_binding_(self, binding) -> None:
+ super().visit_var_binding_(binding)
+ name = _ffi_api.SpanGetAttr(binding.value.span, "name")
+ if not name:
+ self._missing_exprs.append(binding.value)
+
+ def visit_constant_(self, op) -> None:
+ super().visit_constant_(op)
+ name = _ffi_api.SpanGetAttr(op.span, "name")
+ if not name:
+ self._missing_exprs.append(op)
+
+
+def test_relay():
+ mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1,
dtype="float32")
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ mod = msc_transform.SetExprName(as_relax=False)(mod)
+ RelayChecker().check(mod["main"])
+
+
+def test_relax():
+ try:
+ import torch
+ import torchvision
+ from torch import fx
+ except:
+ print("please install pytorch python package")
+ return
+
+ torch_model = torchvision.models.resnet50()
+ graph_model = fx.symbolic_trace(torch_model)
+ input_info = [([1, 3, 224, 224], "float32")]
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+ mod = msc_transform.SetExprName()(mod)
+ RelaxChecker().check(mod)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/scripts/task_config_build_cpu.sh
b/tests/scripts/task_config_build_cpu.sh
index 9eda0d74d4..0d6c0e2cae 100755
--- a/tests/scripts/task_config_build_cpu.sh
+++ b/tests/scripts/task_config_build_cpu.sh
@@ -57,3 +57,4 @@ echo set\(USE_CCACHE OFF\) >> config.cmake
echo set\(USE_ETHOSU OFF\) >> config.cmake
echo set\(USE_UMA ON\) >> config.cmake
echo set\(SUMMARIZE ON\) >> config.cmake
+echo set\(USE_MSC ON\) >> config.cmake
diff --git a/tests/scripts/task_config_build_gpu.sh
b/tests/scripts/task_config_build_gpu.sh
index 8929ae5041..37ab0a87f1 100755
--- a/tests/scripts/task_config_build_gpu.sh
+++ b/tests/scripts/task_config_build_gpu.sh
@@ -53,3 +53,4 @@ echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake
echo set\(USE_CUTLASS ON\) >> config.cmake
echo set\(USE_CMSISNN ON\) >> config.cmake
+echo set\(USE_MSC ON\) >> config.cmake
diff --git a/tests/scripts/unity/task_python_relax.sh
b/tests/scripts/unity/task_python_relax.sh
index b6b70ab457..121ba1389a 100755
--- a/tests/scripts/unity/task_python_relax.sh
+++ b/tests/scripts/unity/task_python_relax.sh
@@ -36,3 +36,6 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest
tests/python/dlight
# python3 ./apps/relax_examples/mlp.py
# python3 ./apps/relax_examples/nn_module.py
# python3 ./apps/relax_examples/resnet.py
+
+# Test for MSC
+pytest tests/python/contrib/test_msc