This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 517d0457d3 [Unity][Transform] SplitCallTIRByPattern and CUTLASS 
backend (#14274)
517d0457d3 is described below

commit 517d0457d35a067dbdab57496a678bc076cd76c4
Author: Bohan Hou <[email protected]>
AuthorDate: Fri Mar 24 13:39:20 2023 -0400

    [Unity][Transform] SplitCallTIRByPattern and CUTLASS backend (#14274)
    
    Currently, the BYOC system is based on op-level pattern matching, this PR 
intends to provide primary support for TIR-level pattern matching based on 
backend registration and dispatching.
    
    For now, it simply matches the first set of for loops in PrimFunc.
    
    Co-authored-by: Hongyi Jin (@jinhongyii)
---
 include/tvm/relax/tir_pattern.h                  |  75 +++
 python/tvm/contrib/cutlass/gemm_operation.py     |   3 +-
 python/tvm/relax/backend_tir/__init__.py         |  20 +
 python/tvm/relax/backend_tir/contrib/__init__.py |  20 +
 python/tvm/relax/backend_tir/contrib/cutlass.py  | 720 +++++++++++++++++++++
 python/tvm/relax/backend_tir/pattern.py          | 576 +++++++++++++++++
 python/tvm/relax/transform/transform.py          |  19 +
 src/relax/backend/vm/codegen_vm.cc               |  11 +
 src/relax/ir/tir_pattern.cc                      |  37 ++
 src/relax/transform/split_call_tir_by_pattern.cc | 782 +++++++++++++++++++++++
 tests/python/relax/test_codegen_tir_cutlass.py   | 709 ++++++++++++++++++++
 11 files changed, 2971 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h
new file mode 100644
index 0000000000..02634dcbbf
--- /dev/null
+++ b/include/tvm/relax/tir_pattern.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_pattern.h
+ * \brief Data Structure of TIR Pattern used for matching.
+ */
+
+#ifndef TVM_RELAX_TIR_PATTERN_H_
+#define TVM_RELAX_TIR_PATTERN_H_
+
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace relax {
+
+using TIRPattern = tir::PrimFunc;
+
+/*
+ * \brief The match result of a TIR pattern.
+ */
+class MatchResultNode : public Object {
+ public:
+  /*! The matched tir pattern*/
+  TIRPattern pattern;
+  /*! \brief The evaluated values of symbolic vars. */
+  Array<PrimExpr> symbol_values;
+  /*! \brief The matched buffers of input and output. */
+  Array<tir::Buffer> matched_buffers;
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("symbol_values", &symbol_values);
+    v->Visit("matched_buffers", &matched_buffers);
+  }
+  static constexpr const char* _type_key = "relax.MatchResult";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MatchResultNode.
+ */
+class MatchResult : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor
+   * \param pattern The matched tir pattern.
+   * \param symbol_values The evaluated values of symbolic vars.
+   * \param matched_buffers The matched buffers of input and output.
+   */
+  TVM_DLL explicit MatchResult(TIRPattern pattern, Array<PrimExpr> 
symbol_values,
+                               Array<tir::Buffer> matched_buffers);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode)
+};
+
+using FCodegen = runtime::TypedPackedFunc<Array<ObjectRef>(Array<MatchResult> 
match_results)>;
+}  // namespace relax
+}  // namespace tvm
+#endif  // TVM_RELAX_TIR_PATTERN_H_
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index eb9f92dad3..b820ead016 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -369,7 +369,8 @@ def instantiate_gemm_template(attrs):
             {
                 "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n",
                 "ptr_c": "ptr_bias",
-                "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"],
+                "c_stride": "(${bias_arg}->ndim == 1 || ${bias_arg}->shape[0] 
== 1) ? 0 : "
+                + attrs["ldc"],
             }
         )
     else:
diff --git a/python/tvm/relax/backend_tir/__init__.py 
b/python/tvm/relax/backend_tir/__init__.py
new file mode 100644
index 0000000000..eeb8fe438f
--- /dev/null
+++ b/python/tvm/relax/backend_tir/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax backends, tir based"""
+
+from . import contrib
+from .pattern import get_tir_pattern
diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py 
b/python/tvm/relax/backend_tir/contrib/__init__.py
new file mode 100644
index 0000000000..9274f22374
--- /dev/null
+++ b/python/tvm/relax/backend_tir/contrib/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""External backend codegen modules for Relax, tir based."""
+
+from .cutlass import cutlass_fcodegen
diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py 
b/python/tvm/relax/backend_tir/contrib/cutlass.py
new file mode 100644
index 0000000000..0dbe31c468
--- /dev/null
+++ b/python/tvm/relax/backend_tir/contrib/cutlass.py
@@ -0,0 +1,720 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: 
disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring
+"""codegen for cutlass"""
+import operator
+from functools import reduce
+from typing import List, Dict, Any
+
+from tvm.contrib.cutlass.build import _get_cutlass_path, 
_get_cutlass_compile_options
+from tvm.contrib.nvcc import get_target_compute_version
+from tvm.contrib.cutlass.library import LayoutType, ConvKind
+from tvm.contrib.cutlass.gen_tensor_op import instantiate_template
+from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler
+from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler
+from ..pattern import (
+    MatchResult,
+    matmul_rrr_fp16,
+    bias_row_2d_fp16,
+    bias_row_1d_fp16,
+    batch_bias_row_2d_fp16,
+    batch_bias_row_1d_fp16,
+    relu_fp16,
+    erf_3d_fp32,
+    batch_matmul_rrr_2d_fp16,
+    batch_matmul_rrr_3d_fp16,
+    conv2d_nhwc_fp16,
+    padding_2d_nhwc_fp16,
+    copy_4d_fp16,
+    bias_add_nhwc_2d_fp16,
+    bias_add_nhwc_1d_fp16,
+    elem_add_4d_fp16,
+    elem_mul_3d_fp16,
+    scalar_add_3d_fp16,
+    scalar_mul_3d_fp16,
+    cast_3d_fp16,
+    cast_3d_fp32,
+)
+
+#### helper functions ####
+# list representing the anchor ops
+# in the future more layouts/dtypes can be supported
+MATMUL_LIST = [matmul_rrr_fp16]
+MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16]
+BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16]
+BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16]
+CONV2D_LIST = [conv2d_nhwc_fp16]
+
+# attributes for anchor ops used in code generation
+OP_PATTERN_ATTR_LIST = {
+    matmul_rrr_fp16: {
+        "arg0_dtype": "float16",
+        "arg1_dtype": "float16",
+        "ret_dtype": "float16",
+    },
+    batch_matmul_rrr_2d_fp16: {
+        "arg0_dtype": "float16",
+        "arg1_dtype": "float16",
+        "ret_dtype": "float16",
+    },
+    batch_matmul_rrr_3d_fp16: {
+        "arg0_dtype": "float16",
+        "arg1_dtype": "float16",
+        "ret_dtype": "float16",
+    },
+    conv2d_nhwc_fp16: {
+        "arg0_dtype": "float16",
+        "arg1_dtype": "float16",
+        "ret_dtype": "float16",
+        # in the future we can add layout here
+    },
+}
+
+
+def _get_cutlass_code(attr):
+    pattern = attr["op_type"]
+    if pattern.startswith("cutlass.matmul"):
+        return cutlass_codegen_gemm(attr)
+    elif pattern.startswith("cutlass.conv2d"):
+        return cutlass_codegen_conv2d(attr)
+    else:
+        raise ValueError("op not supported")
+
+
+def _final_code(code, headers, func_args):
+    res = ""
+    res += "#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>\n"
+    res += "#include <tvm/runtime/c_runtime_api.h>\n"
+    res += "#include <tvm/runtime/packed_func.h>\n"
+    res += "#include <dlpack/dlpack.h>\n"
+    res += "#include <cuda_fp16.h>\n"
+    res += "#include <cutlass/cutlass.h>\n"
+    res += "#include <cutlass/coord.h>\n"
+    res += "#include <cutlass/tensor_ref.h>\n"
+    res += "#include <cutlass/util/host_tensor.h>\n"
+
+    for header in headers:
+        res += "#include <" + header + ">\n"
+    res += "namespace {\n"
+    res += "using namespace tvm;\n"
+    res += "using namespace tvm::runtime;\n"
+    res += "void _cutlass_kernel("
+    for arg in func_args:
+        res += "NDArray " + arg + ", "
+    res += "NDArray out0) {"
+    res += code
+    res += "}\n"
+    res += "}  // namespace\n"
+    res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n"
+    return res
+
+
+#### cutlass patterns ####
+def matmul_bias_relu(match_results, attr, get_code=True):
+    if len(match_results) < 3:
+        return None
+    attr = matmul_bias(match_results[:2], attr, get_code=False)
+    if attr is None or match_results[2].pattern != relu_fp16:
+        return None
+    m_bias, n_bias = match_results[1].symbol_values
+    m_relu, n_relu = match_results[2].symbol_values
+    A_bias, B_bias, C_bias = match_results[1].matched_buffers
+    A_relu, B_relu = match_results[2].matched_buffers
+    if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu:
+        attr["op_type"] = "cutlass.matmul_bias_relu"
+        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
+    return None
+
+
+def matmul_bias(match_results, attr, get_code=True):
+    if len(match_results) < 2:
+        return None
+    attr = matmul(match_results[:1], attr, get_code=False)
+    if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST:
+        return None
+    m_matmul, n_matmul, k_matmul = match_results[0].symbol_values
+    m_bias, n_bias = match_results[1].symbol_values
+    A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+    A_bias, B_bias, C_bias = match_results[1].matched_buffers
+    if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias:
+        attr["op_type"] = "cutlass.matmul_bias"
+        attr["bias_arg_idx"] = 2
+        attr["args"].append(B_bias)
+        return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code 
else attr
+    return None
+
+
+def matmul(match_results, attr, get_code=True):
+    if len(match_results) < 1:
+        return None
+    if match_results[0].pattern in MATMUL_LIST:
+        # matmul
+        attr["op_type"] = "cutlass.matmul"
+        return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code 
else attr
+    return None
+
+
+def batch_matmul_bias_gelu(match_results, attr, get_code=True):
+    if len(match_results) < 9:
+        return None
+    attr = batch_matmul_bias(match_results[:2], attr, get_code=False)  # 
batch_matmul, batch_bias
+    if (
+        attr is None
+        or match_results[2].pattern != scalar_mul_3d_fp16
+        or match_results[3].pattern != cast_3d_fp32
+        or match_results[4].pattern != erf_3d_fp32
+        or match_results[5].pattern != cast_3d_fp16
+        or match_results[6].pattern != scalar_mul_3d_fp16
+        or match_results[7].pattern != scalar_add_3d_fp16
+        or match_results[8].pattern != elem_mul_3d_fp16
+    ):
+        return None
+
+    def shape_match_3d(shape1, shape2):
+        if len(shape1) < 3 or len(shape2) < 3:
+            return False
+        return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2] 
== shape2[2]
+
+    for i in range(1, 8):
+        if not shape_match_3d(match_results[i].symbol_values, match_results[i 
+ 1].symbol_values):
+            return None
+
+    if not (
+        match_results[1].matched_buffers[-1] == 
match_results[2].matched_buffers[0]
+        and match_results[2].matched_buffers[-1] == 
match_results[3].matched_buffers[0]
+        and match_results[3].matched_buffers[-1] == 
match_results[4].matched_buffers[0]
+        and match_results[4].matched_buffers[-1] == 
match_results[5].matched_buffers[0]
+        and match_results[5].matched_buffers[-1] == 
match_results[6].matched_buffers[0]
+        and match_results[6].matched_buffers[-1] == 
match_results[7].matched_buffers[0]
+        and match_results[1].matched_buffers[-1] == 
match_results[8].matched_buffers[0]
+        and match_results[7].matched_buffers[-1] == 
match_results[8].matched_buffers[1]
+    ):
+        return None
+
+    if (
+        abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5
+        or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5
+        or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5
+    ):
+        return None
+
+    attr["op_type"] = "cutlass.matmul_bias_gelu"
+    return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else 
attr
+
+
+def batch_matmul_bias_residual_mul(match_results, attr, get_code=True):
+    if len(match_results) < 3:
+        return None
+    attr = batch_matmul_bias(match_results[:2], attr, get_code=False)  # 
batch_matmul, batch_bias
+    if attr is None or match_results[2].pattern != elem_mul_3d_fp16:
+        return None
+    (
+        b_bias,
+        m_bias,
+        n_bias,
+    ) = match_results[1].symbol_values
+    (
+        b_mul,
+        m_mul,
+        n_mul,
+    ) = match_results[2].symbol_values
+    A_bias, B_bias, C_bias = match_results[1].matched_buffers
+    A_mul, B_mul, C_mul = match_results[2].matched_buffers
+    if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias == 
A_mul:
+        attr["op_type"] = "cutlass.matmul_bias_residual_multiply"
+        attr["residual_arg_idx"] = 3
+        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
+    return None
+
+
+def batch_matmul_bias(match_results, attr, get_code=True):
+    if len(match_results) < 2:
+        return None
+    attr = batch_matmul(match_results[:1], attr, get_code=False)
+    if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST:
+        return None
+    (
+        b_matmul,
+        m_matmul,
+        n_matmul,
+        k_matmul,
+    ) = match_results[0].symbol_values
+    (
+        b_bias,
+        m_bias,
+        n_bias,
+    ) = match_results[1].symbol_values
+    A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+    A_bias, B_bias, C_bias = match_results[1].matched_buffers
+    if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and 
C_matmul == A_bias:
+        attr["op_type"] = "cutlass.matmul_bias"
+        attr["bias_arg_idx"] = 2
+        attr["args"].append(B_bias)
+        return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code 
else attr
+    return None
+
+
+def batch_matmul(match_results, attr, get_code=True):
+    if len(match_results) < 1:
+        return None
+    if match_results[0].pattern in BATCH_MATMUL_LIST:
+        attr["op_type"] = "cutlass.matmul"
+        return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code 
else attr
+    return None
+
+
+def conv2d_bias_residual_add(match_results, attr, get_code=True):
+    if len(match_results) < 4:
+        return None
+    attr = conv2d_bias(match_results[:3], attr, get_code=False)
+    if attr is None or match_results[3].pattern != elem_add_4d_fp16:
+        return None
+    N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
+    in1_bias, in2_bias, out_bias = match_results[2].matched_buffers
+    N_add, H_add, W_add, C_add = match_results[3].symbol_values
+    in1_add, in2_add, out_add = match_results[3].matched_buffers
+    if (
+        N_bias == N_add
+        and H_bias == H_add
+        and W_bias == W_add
+        and C_bias == C_add
+        and out_bias in [in1_add, in2_add]
+    ):
+        attr["op_type"] = "cutlass.conv2d_bias_residual_add"
+        attr["residual_arg_idx"] = 3
+        attr["args"].append(in2_add if out_bias == in1_add else in1_add)
+        return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code 
else attr
+    return None
+
+
+def conv2d_bias(match_results, attr, get_code=True):
+    if len(match_results) < 3:
+        return None
+    attr = conv2d(match_results[:2], attr, get_code=False)
+    if attr is None or (
+        match_results[2].pattern not in [bias_add_nhwc_2d_fp16, 
bias_add_nhwc_1d_fp16]
+    ):
+        return None
+    (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) = 
match_results[
+        1
+    ].symbol_values[:7]
+    A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
+    N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
+    A_bias, B_bias, out_bias = match_results[2].matched_buffers
+    if (
+        N_bias == N_conv
+        and H_bias == H_conv
+        and W_bias == W_conv
+        and C_bias == O_conv
+        and out_conv == A_bias
+    ):
+        attr["op_type"] = "cutlass.conv2d_bias"
+        attr["bias_arg_idx"] = 2
+        attr["args"].append(B_bias)
+        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
+    return None
+
+
+def conv2d(match_results, attr, get_code=True):
+    if len(match_results) < 2:
+        return None
+    if (
+        match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16]
+        and match_results[1].pattern == conv2d_nhwc_fp16
+    ):
+        if match_results[0].pattern == padding_2d_nhwc_fp16:
+            (
+                N_pad,
+                H_pad,
+                W_pad,
+                C_pad,
+                pH_pad,
+                pW_pad,
+                lH_pad,
+                lW_pad,
+                rH_pad,
+                rW_pad,
+            ) = match_results[0].symbol_values
+        else:
+            (
+                N_pad,
+                H_pad,
+                W_pad,
+                C_pad,
+            ) = match_results[0].symbol_values
+            pH_pad = rH_pad = H_pad
+            pW_pad = rW_pad = W_pad
+            lH_pad = lW_pad = 0
+        (
+            N_conv,
+            pH_conv,
+            pW_conv,
+            H_conv,
+            W_conv,
+            C_conv,
+            O_conv,
+            KH_conv,
+            KW_conv,
+            stride_h_conv,
+            stride_w_conv,
+            dilation_h_conv,
+            dilation_w_conv,
+        ) = match_results[1].symbol_values
+        A, A_pad = match_results[0].matched_buffers
+        A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
+        if (
+            N_pad == N_conv
+            and pH_pad == pH_conv
+            and pW_pad == pW_conv
+            and C_pad == C_conv
+            and A_pad == A_pad_conv
+        ):
+            if (
+                lH_pad == pH_pad - rH_pad
+                and lW_pad == pW_pad - rW_pad
+                and lH_pad + H_pad == rH_pad
+                and lW_pad + W_pad == rW_pad
+            ):
+                padding = (lH_pad, lW_pad)
+                strides = (stride_h_conv, stride_w_conv)
+                dilation = (dilation_h_conv, dilation_w_conv)
+                attr["padding"] = padding
+                attr["strides"] = strides
+                attr["dilation"] = dilation
+                attr["op_type"] = "cutlass.conv2d"
+                return [_get_cutlass_code(attr=attr), 2, attr["args"]] if 
get_code else attr
+    return None
+
+
+### cutlass codegen functions ###
+def compile_options(target, threads=-1, use_fast_math=False):
+    compute_version = 
int("".join(get_target_compute_version(target).split(".")))
+    kwargs = _get_cutlass_compile_options(compute_version, threads, 
use_fast_math)
+    kwargs["options"].remove("-c")
+    return kwargs
+
+
+def cutlass_fcodegen(sm=80, bin_dir="./bin"):
+    gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir)
+    conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir)
+
+    def cutlass_codegen_with_match_results(match_results: List[MatchResult]):
+        """generate cutlass code with match results"""
+        nonlocal gemm_profiler
+        nonlocal conv2d_profiler
+
+        assert len(match_results) > 0
+
+        # add shape into attr
+        if match_results[0].pattern in MATMUL_LIST:
+            A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+            attr: Dict[Any, Any] = 
OP_PATTERN_ATTR_LIST[match_results[0].pattern]
+            attr["args"] = [A_matmul, B_matmul]
+            attr["arg0_shape"] = A_matmul.shape
+            attr["arg1_shape"] = B_matmul.shape
+            attr["ret_shape"] = C_matmul.shape
+            attr["lhs_arg_idx"] = 0
+            attr["rhs_arg_idx"] = 1
+        elif match_results[0].pattern in BATCH_MATMUL_LIST:
+            A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+            attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern]
+            attr["args"] = [A_matmul, B_matmul]
+            attr["arg0_shape"] = A_matmul.shape
+            attr["arg1_shape"] = B_matmul.shape
+            attr["ret_shape"] = C_matmul.shape
+            attr["lhs_arg_idx"] = 0
+            attr["rhs_arg_idx"] = 1
+        elif len(match_results) >= 1 and match_results[1].pattern in 
CONV2D_LIST:
+            A_input = match_results[0].matched_buffers[0]
+            A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers
+            attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern]
+            attr["args"] = [A_input, B_conv2d]
+            attr["arg0_shape"] = A_input.shape
+            attr["arg1_shape"] = B_conv2d.shape
+            attr["ret_shape"] = C_conv2d.shape
+            attr["lhs_arg_idx"] = 0
+            attr["rhs_arg_idx"] = 1
+        else:
+            return ["", 0]
+
+        # add profiler into attr
+        attr["gemm_profiler"] = gemm_profiler
+        attr["conv2d_profiler"] = conv2d_profiler
+
+        cutlass_patterns = [
+            # 9
+            batch_matmul_bias_gelu,
+            # 4
+            conv2d_bias_residual_add,
+            # 3
+            batch_matmul_bias_residual_mul,
+            matmul_bias_relu,
+            conv2d_bias,
+            # 2
+            matmul_bias,
+            batch_matmul_bias,
+            conv2d,
+            # 1
+            matmul,
+            batch_matmul,
+        ]
+        for pattern in cutlass_patterns:
+            res = pattern(match_results, attr)
+            if res is not None:
+                return res
+
+        return ["", 0]
+
+    return cutlass_codegen_with_match_results
+
+
+def cutlass_codegen_gemm(attrs):
+    """cutlass codegen for gemm"""
+    gemm_profiler = attrs["gemm_profiler"]
+    op_type = attrs["op_type"]
+    lhs_shape = attrs["arg0_shape"]
+    rhs_shape = attrs["arg1_shape"]
+    MM = lhs_shape[-2]
+    KK = lhs_shape[-1]
+    if "transposed" in op_type:
+        NN = rhs_shape[-2]
+        ldb = "K"
+        layout_b = LayoutType.ColumnMajor
+    else:
+        NN = rhs_shape[-1]
+        ldb = "N"
+        layout_b = LayoutType.RowMajor
+
+    lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+    rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+    if lhs_batches == 1 and rhs_batches == 1:
+        # Regular matmul
+        is_batched = False
+        batch_attrs = {}
+    else:
+        is_batched = True
+        batch_attrs = {
+            # If both lhs_batches and rhs_batches are greater than 1,
+            # they must be equal. This is checked by 
is_shape_valid_for_cutlass_matmul.
+            "batch": lhs_batches if rhs_batches == 1 else rhs_batches,
+            "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
+            "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
+            "batch_stride_C": MM * NN,
+        }
+    op_name, op_def, _ = gemm_profiler.profile(
+        op_type,
+        MM,
+        NN,
+        KK,
+        attrs["ret_dtype"],
+        attrs["arg0_dtype"],
+        attrs["arg1_dtype"],
+        False,
+        batched=is_batched,
+        find_first_valid=False,
+        use_multiprocessing=True,
+        layout_b=layout_b,
+    )
+    attrs["cutlass_op_name"] = op_name
+    attrs["cutlass_op_def"] = op_def
+    attrs["lda"] = "K"
+    attrs["ldb"] = ldb
+    attrs["ldc"] = "N"
+    attrs.update(batch_attrs)
+    del attrs["gemm_profiler"]
+    del attrs["conv2d_profiler"]
+
+    nargs = 2
+    if "bias_arg_idx" in attrs:
+        nargs += 1
+    if "residual_arg_idx" in attrs:
+        nargs += 1
+    func_args = ["inp" + str(i) for i in range(nargs)]
+
+    # A temporary solution to handle batch matmul residual cases
+    # TODO(@bohan): remove this after initialize_template supports bmm residual
+    if op_type in [
+        "cutlass.matmul_bias_residual_multiply",
+    ]:
+
+        def _convert_dtype_str(dtype):
+            if isinstance(dtype, list):
+                arr = []
+                for t in dtype:
+                    arr.append(_convert_dtype_str(t))
+                return arr
+            elif isinstance(dtype, str):
+                if dtype == "float16":
+                    return "cutlass::half_t"
+                elif dtype == "float32":
+                    return "float"
+            raise ValueError("dtype not supported")
+
+        typea, typeb, typec = _convert_dtype_str(
+            [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]]
+        )
+
+        text = f"""
+#define CUTLASS_ENABLE_CUBLAS 1
+#define CUTLASS_NAMESPACE cutlass
+#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
+#define NDEBUG
+#include <cutlass/cutlass.h>
+#include <cutlass/tensor_ref.h>
+#include <cutlass/util/host_tensor.h>
+#include <cutlass/gemm/device/gemm.h>
+#include <cutlass/gemm/device/gemm_batched.h>
+#include <cutlass/layout/matrix.h>
+#include <cutlass/numeric_types.h>
+#include "cutlass/epilogue/thread/activation.h"
+#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
+#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <vector>
+#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+namespace {{
+using namespace tvm;
+using namespace tvm::runtime;
+void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{
+    // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M, 
N], C: [Batch, M, N]
+    CHECK_EQ(A->ndim, 3);
+    int bdim = B->ndim;
+    int bias_dim = Bias->ndim;
+    CHECK_EQ(C->ndim, 3);
+    CHECK_EQ(A->shape[2], B->shape[bdim - 2]);
+    CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]);
+    CHECK_EQ(D->ndim, 3);
+    CHECK_EQ(D->shape[0], A->shape[0]);
+    CHECK_EQ(D->shape[1], A->shape[1]);
+    CHECK_EQ(D->shape[2], B->shape[bdim - 1]);
+    CHECK_EQ(C->shape[0], A->shape[0]);
+    CHECK_EQ(C->shape[1], A->shape[1]);
+    CHECK_EQ(C->shape[2], B->shape[bdim - 1]);
+    int64_t M = A->shape[0] * A->shape[1];
+    int64_t N = B->shape[bdim - 1];
+    int64_t K = A->shape[2];
+    int64_t input_a_batch_stride = M * K;
+    int64_t input_a_stride = K;
+    int64_t input_a_offset = 0; // default to 0
+    int64_t input_b_batch_stride = K * N;
+    int64_t input_b_stride = N;
+    int64_t input_b_offset = 0; // default to 0
+    int64_t output_stride = N;
+    int64_t output_offset = 0;
+    int64_t a_size = 1;
+    a_size *= A->shape[0];
+    a_size *= A->shape[1];
+    a_size *= A->shape[2];
+
+    int64_t b_size = 1;
+    b_size *= B->shape[bias_dim - 2];
+    b_size *= B->shape[bias_dim - 1];
+
+    int64_t c_size = 1;
+    c_size *= C->shape[0];
+    c_size *= C->shape[1];
+    c_size *= C->shape[2];
+
+    // Define the GEMM operation
+    {op_def}
+    using kernel = Operation_{op_name};
+    using ElementComputeEpilogue = typename kernel::ElementAccumulator;
+    typename kernel::Arguments arguments({{
+        cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
+        {{M, N, K}}, // GemmCoord problem_size
+        1, // int batch_count
+        {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename 
EpilogueOutputOp::Params epilogue
+        ({typea}*)(A->data) + input_a_offset, // void const * ptr_A
+        ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B
+        ({typec}*)(D->data), // void const * ptr_C1
+        ({typec}*)(C->data) + output_offset, // void * ptr_D
+        ({typea}*)(Bias->data), // void * ptr_Vector
+        nullptr, // void * ptr_Tensor
+        input_a_batch_stride, // int64_t batch_stride_A
+        input_b_batch_stride, // int64_t batch_stride_B
+        0, // int64_t batch_stride_C1
+        0, // int64_t batch_stride_D
+        0, // int64_t batch_stride_Vector
+        0, // int64_t batch_stride_Tensor
+        input_a_stride, // typename LayoutA::Stride::Index lda
+        input_b_stride, // typename LayoutB::Stride::Index ldb
+        N, // typename LayoutC::Stride::Index ldc1
+        output_stride, // typename LayoutC::Stride::Index ldd
+        0, // typename LayoutC::Stride::Index ldr
+        0, // typename LayoutC::Stride::Index ldt
+    }});
+    kernel gemm_op;
+    size_t workspace_size = gemm_op.get_workspace_size(arguments);
+    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
+    cutlass::Status status = gemm_op.can_implement(arguments);
+    CHECK(status == cutlass::Status::kSuccess);
+    status = gemm_op.initialize(arguments, workspace.get());
+    CHECK(status == cutlass::Status::kSuccess);
+    status = gemm_op();
+    CHECK(status == cutlass::Status::kSuccess);
+    return;
+}}
+}}  // namespace
+TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM);
+      """
+        return text
+
+    code = instantiate_template(op_type, attrs, func_args)
+    return _final_code(code.code, code.headers, func_args)
+
+
+def cutlass_codegen_conv2d(attrs):
+    """cutlass codegen for conv2d"""
+    # cutlass backend only supports nhwc for now
+    conv2d_profiler = attrs["conv2d_profiler"]
+    op_type = attrs["op_type"]
+    conv_kind = ConvKind.Fprop
+    op_name, op_def, _ = conv2d_profiler.profile(
+        op_type=attrs["op_type"],
+        d_shape=attrs["arg0_shape"],
+        w_shape=attrs["arg1_shape"],
+        padding=attrs["padding"],
+        stride=attrs["strides"],
+        dilation=attrs["dilation"],
+        out_dtype=attrs["ret_dtype"],
+        data_dtype=attrs["arg0_dtype"],
+        weight_dtype=attrs["arg1_dtype"],
+        use_3xtf32=False,
+        conv_kind=conv_kind,
+        split_k_slices=[1],
+        profile_all_alignments=True,
+        find_first_valid=False,
+        use_multiprocessing=True,
+    )
+    attrs["cutlass_op_def"] = op_def
+    attrs["cutlass_op_name"] = op_name
+    del attrs["gemm_profiler"]
+    del attrs["conv2d_profiler"]
+
+    nargs = 2
+    if "bias_arg_idx" in attrs:
+        nargs += 1
+    if "residual_arg_idx" in attrs:
+        nargs += 1
+    func_args = ["inp" + str(i) for i in range(nargs)]
+    code = instantiate_template(op_type, attrs, func_args)
+    return _final_code(code.code, code.headers, func_args)
diff --git a/python/tvm/relax/backend_tir/pattern.py 
b/python/tvm/relax/backend_tir/pattern.py
new file mode 100644
index 0000000000..10f7a3b162
--- /dev/null
+++ b/python/tvm/relax/backend_tir/pattern.py
@@ -0,0 +1,576 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring,chained-comparison
+"""TIR Patterns"""
+from typing import List
+
+import tvm
+from tvm.runtime import Object
+import tvm._ffi
+
+from tvm.script import tir as T
+
+
+@tvm._ffi.register_object("relax.MatchResult")
+class MatchResult(Object):
+    """The match result of a TIR pattern."""
+
+    def __init__(self, pattern, symbol_values, matched_buffers):
+        self.__init_handle_by_constructor__(
+            tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers
+        )
+
+
[email protected]_func
+def matmul_rrr_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_matmul: T.handle,
+    M: T.int64,
+    N: T.int64,
+    K: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], 
dtype="float16")
+    matmul = T.match_buffer(var_matmul, [M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1, i2 in T.grid(M, N, K):
+        with T.block("matmul"):
+            i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1])
+            T.writes(matmul[i0_1, i1_1])
+            with T.init():
+                matmul[i0_1, i1_1] = T.float16(0)
+            matmul[i0_1, i1_1] = (
+                matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] * 
rxplaceholder_1[k, i1_1]
+            )
+
+
[email protected]_func
+def bias_row_2d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_T_add: T.handle,
+    M: T.int64,
+    N: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], 
dtype="float16")
+    T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1 in T.grid(M, N):
+        with T.block("T_add"):
+            ax0, ax1 = T.axis.remap("SS", [i0, i1])
+            T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1])
+            T.writes(T_add[ax0, ax1])
+            T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + 
rxplaceholder_1[T.int64(0), ax1]
+
+
[email protected]_func
+def bias_row_1d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_T_add: T.handle,
+    M: T.int64,
+    N: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
+    T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1 in T.grid(M, N):
+        with T.block("T_add"):
+            ax0, ax1 = T.axis.remap("SS", [i0, i1])
+            T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
+            T.writes(T_add[ax0, ax1])
+            T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
+
+
[email protected]_func
+def batch_bias_row_2d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_T_add: T.handle,
+    batch: T.int64,
+    M: T.int64,
+    N: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], 
dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], 
dtype="float16")
+    T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1, i2 in T.grid(batch, M, N):
+        with T.block("T_add"):
+            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+            T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0), 
ax2])
+            T.writes(T_add[ax0, ax1, ax2])
+            T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + 
rxplaceholder_1[T.int64(0), ax2]
+
+
[email protected]_func
+def batch_bias_row_1d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_T_add: T.handle,
+    batch: T.int64,
+    M: T.int64,
+    N: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], 
dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
+    T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1, i2 in T.grid(batch, M, N):
+        with T.block("T_add"):
+            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+            T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2])
+            T.writes(T_add[ax0, ax1, ax2])
+            T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + 
rxplaceholder_1[ax2]
+
+
[email protected]_func
+def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64, 
N: T.int64) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+    compute = T.match_buffer(var_compute, [M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1 in T.grid(M, N):
+        with T.block("compute"):
+            i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+            T.reads(rxplaceholder[i0_1, i1_1])
+            T.writes(compute[i0_1, i1_1])
+            compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], 
T.float16(0))
+
+
[email protected]_func
+def batch_matmul_rrr_2d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_matmul: T.handle,
+    batch: T.int64,
+    M: T.int64,
+    N: T.int64,
+    K: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], 
dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], 
dtype="float16")
+    matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1, i2, i3 in T.grid(batch, M, N, K):
+        with T.block("matmul"):
+            i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
+            T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1])
+            T.writes(matmul[i0_1, i1_1, i2_1])
+            with T.init():
+                matmul[i0_1, i1_1, i2_1] = T.float16(0)
+            matmul[i0_1, i1_1, i2_1] = (
+                matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] * 
rxplaceholder_1[k, i2_1]
+            )
+
+
[email protected]_func
+def batch_matmul_rrr_3d_fp16(
+    var_rxplaceholder: T.handle,
+    var_rxplaceholder_1: T.handle,
+    var_matmul: T.handle,
+    batch: T.int64,
+    M: T.int64,
+    N: T.int64,
+    K: T.int64,
+) -> None:
+    # function attr dict
+    T.func_attr({"tir.noalias": True})
+    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], 
dtype="float16")
+    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N], 
dtype="float16")
+    matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
+    # body
+    # with T.block("root")
+    for i0, i1, i2, i3 in T.grid(batch, M, N, K):
+        with T.block("matmul"):
+            i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
+            T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k, 
i2_1])
+            T.writes(matmul[i0_1, i1_1, i2_1])
+            with T.init():
+                matmul[i0_1, i1_1, i2_1] = T.float16(0)
+            matmul[i0_1, i1_1, i2_1] = (
+                matmul[i0_1, i1_1, i2_1]
+                + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1]
+            )
+
+
[email protected]_func
+def copy_4d_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    N: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+) -> None:
+    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+    B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
+    # body
+    # with T.block("root")
+    for n, h, w, c in T.grid(N, H, W, C):
+        with T.block("copy"):
+            vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c])
+            T.reads(A[vn, vh, vw, vc])
+            T.writes(B[vn, vh, vw, vc])
+            B[vn, vh, vw, vc] = A[vn, vh, vw, vc]
+
+
[email protected]_func
+def padding_2d_nhwc_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    N: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+    pH: T.int64,
+    pW: T.int64,
+    lH: T.int64,
+    lW: T.int64,
+    rH: T.int64,
+    rW: T.int64,
+) -> None:
+    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+    B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16")
+    # body
+    # with T.block("root")
+    for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C):
+        with T.block("copy"):
+            v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3])
+            T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7])
+            T.writes(B[v_4, v_5, v_6, v_7])
+            B[v_4, v_5, v_6, v_7] = T.if_then_else(
+                lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW,
+                A[v_4, v_5 - lH, v_6 - lW, v_7],
+                T.float16(0),
+                dtype="float16",
+            )
+
+
[email protected]_func
+def conv2d_nhwc_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    out_handle: T.handle,
+    N: T.int64,
+    pH: T.int64,
+    pW: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+    O: T.int64,
+    KH: T.int64,
+    KW: T.int64,
+    StrideH: T.int64,
+    StrideW: T.int64,
+    DilateH: T.int64,
+    DilateW: T.int64,
+) -> None:
+    A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16")
+    B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16")
+    out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16")
+    # body
+    # with T.block("root")
+    for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C):
+        with T.block("conv"):
+            v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap(
+                "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6]
+            )
+            T.reads(
+                A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * 
StrideW, v_13],
+                B[v_10, v_11, v_12, v_13],
+            )
+            T.writes(out[v_7, v_8, v_9, v_10])
+            with T.init():
+                out[v_7, v_8, v_9, v_10] = T.float16(0)
+            out[v_7, v_8, v_9, v_10] = (
+                out[v_7, v_8, v_9, v_10]
+                + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 
* StrideW, v_13]
+                * B[v_10, v_11, v_12, v_13]
+            )
+
+
[email protected]_func
+def bias_add_nhwc_2d_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    out_handle: T.handle,
+    N: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+):
+    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+    B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
+    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+        with T.block("T_add"):
+            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
+            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), 
T.int64(0), v_ax3])
+            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0), 
T.int64(0), v_ax3]
+            )
+
+
[email protected]_func
+def bias_add_nhwc_1d_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    out_handle: T.handle,
+    N: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+):
+    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+    B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
+    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+        with T.block("T_add"):
+            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
+            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), 
T.int64(0), v_ax3])
+            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0), 
T.int64(0), v_ax3]
+            )
+
+
[email protected]_func
+def elem_add_2d_fp16(
+    in0_handle: T.handle,
+    in1_handle: T.handle,
+    out_handle: T.handle,
+    N: T.int64,
+    M: T.int64,
+):
+    in0 = T.match_buffer(in0_handle, [N, M], dtype="float16")
+    in1 = T.match_buffer(in1_handle, [N, M], dtype="float16")
+    out = T.match_buffer(out_handle, [N, M], dtype="float16")
+    for ax0, ax1 in T.grid(N, M):
+        with T.block("T_add"):
+            v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+            T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1])
+            T.writes(out[v_ax0, v_ax1])
+            out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1]
+
+
[email protected]_func
+def elem_add_3d_fp16(
+    in0_handle: T.handle,
+    in1_handle: T.handle,
+    out_handle: T.handle,
+    B: T.int64,
+    N: T.int64,
+    M: T.int64,
+):
+    in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16")
+    in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16")
+    out = T.match_buffer(out_handle, [B, N, M], dtype="float16")
+    for ax0, ax1, ax2 in T.grid(B, N, M):
+        with T.block("T_add"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0, 
v_ax1, v_ax2]
+
+
[email protected]_func
+def elem_add_4d_fp16(
+    A_handle: T.handle,
+    B_handle: T.handle,
+    out_handle: T.handle,
+    N: T.int64,
+    H: T.int64,
+    W: T.int64,
+    C: T.int64,
+):
+    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+    B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
+    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+        with T.block("T_add"):
+            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
+            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, 
v_ax3])
+            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3]
+            )
+
+
[email protected]_func
+def scalar_mul_3d_fp16(
+    inp0_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+    scalar: T.float16,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_mul"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar
+
+
[email protected]_func
+def erf_3d_fp32(
+    inp0_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_erf"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2])
+
+
[email protected]_func
+def scalar_add_3d_fp16(
+    inp0_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+    scalar: T.float16,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_add"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2]
+
+
[email protected]_func
+def elem_mul_3d_fp16(
+    inp0_handle: T.handle,
+    inp1_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+    inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_mul"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0, 
v_ax1, v_ax2]
+
+
[email protected]_func
+def cast_3d_fp16(
+    inp0_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_cast"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1, 
v_ax2])
+
+
[email protected]_func
+def cast_3d_fp32(
+    inp0_handle: T.handle,
+    out_handle: T.handle,
+    D1: T.int64,
+    D2: T.int64,
+    D3: T.int64,
+):
+    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
+    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+        with T.block("T_cast"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(inp0[v_ax0, v_ax1, v_ax2])
+            T.writes(out[v_ax0, v_ax1, v_ax2])
+            out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1, 
v_ax2])
+
+
+def get_tir_pattern() -> List[tvm.tir.PrimFunc]:
+    """Get the tir patterns for backend dispatch."""
+    return [
+        matmul_rrr_fp16,
+        bias_row_2d_fp16,
+        bias_row_1d_fp16,
+        batch_bias_row_2d_fp16,
+        batch_bias_row_1d_fp16,
+        relu_fp16,
+        erf_3d_fp32,
+        batch_matmul_rrr_2d_fp16,
+        batch_matmul_rrr_3d_fp16,
+        copy_4d_fp16,
+        padding_2d_nhwc_fp16,
+        conv2d_nhwc_fp16,
+        bias_add_nhwc_2d_fp16,
+        bias_add_nhwc_1d_fp16,
+        elem_add_2d_fp16,
+        elem_add_3d_fp16,
+        elem_add_4d_fp16,
+        elem_mul_3d_fp16,
+        scalar_add_3d_fp16,
+        scalar_mul_3d_fp16,
+        cast_3d_fp16,
+        cast_3d_fp32,
+    ]
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index c03df804ee..18321e8dba 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -693,6 +693,25 @@ def ToMixedPrecision(out_dtype="float32") -> 
tvm.ir.transform.Pass:
     return _ffi_api.ToMixedPrecision(out_dtype)  # type: ignore
 
 
+def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass:
+    """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is
+       matched with some pattern, and the second part is the rest of the 
original
+       PrimFunc. It will call fcodegen to generate the code for the matched 
pattern
+       to replace it with a ExternFunc call.
+    Parameters
+    ----------
+    patterns : List[PrimFunc]
+        The list of patterns to match.
+    fcodegen: Callable[[List[MatchResult]], List[Object]]
+        The function to generate the code for the matched patterns.
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for splitting call_tir.
+    """
+    return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen)  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index da0ca3a0b5..b36b5ed4d6 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -315,6 +315,17 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
   }
 
   Instruction::Arg VisitExpr_(const ExternFuncNode* op) final {
+    static const constexpr char* kCSource = "c_source";
+    static const constexpr char* kCSourceFmt = "c_source_fmt";
+    if (Optional<String> opt_code = op->attrs.GetAttr<String>(kCSource)) {
+      String sym = op->global_symbol;
+      String fmt = op->attrs.GetAttr<String>(kCSourceFmt).value_or("c");
+      String code = opt_code.value();
+      Module c_source_module =
+          codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, 
/*func_names=*/{sym},
+                                       /*const_vars=*/{});
+      builder_->exec()->Import(c_source_module);
+    }
     builder_->DeclareFunction(op->global_symbol, 
VMFuncInfo::FuncKind::kPackedFunc);
     return builder_->GetFunction(op->global_symbol);
   }
diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc
new file mode 100644
index 0000000000..cbe4170bb9
--- /dev/null
+++ b/src/relax/ir/tir_pattern.cc
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/tir_pattern.h>
+
+namespace tvm {
+namespace relax {
+
+MatchResult::MatchResult(TIRPattern pattern, Array<PrimExpr> symbol_values,
+                         Array<tir::Buffer> matched_buffers) {
+  auto n = make_object<MatchResultNode>();
+  n->pattern = std::move(pattern);
+  n->symbol_values = std::move(symbol_values);
+  n->matched_buffers = std::move(matched_buffers);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(MatchResultNode);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/split_call_tir_by_pattern.cc 
b/src/relax/transform/split_call_tir_by_pattern.cc
new file mode 100644
index 0000000000..7fcc2cb34a
--- /dev/null
+++ b/src/relax/transform/split_call_tir_by_pattern.cc
@@ -0,0 +1,782 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_non_dataflow.cc
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir/module.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/tir_pattern.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../tir/schedule/ir_comparator.h"
+
+namespace tvm {
+
+static const constexpr char* kLibraryKernel = "library_kernel";
+static const constexpr char* kCSource = "c_source";
+static const constexpr char* kCSourceFmt = "c_source_fmt";
+static const constexpr char* kCSourceFmtCuda = "cu";
+
+namespace tir {
+
+using relax::FCodegen;
+using relax::MatchResult;
+using relax::TIRPattern;
+
+/*! \brief helper to match a for stmt to a pattern*/
+class ForMatcher : public TensorizeComparator {
+ public:
+  using SymbolMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, 
ObjectPtrEqual>;
+  explicit ForMatcher(const tir::PrimFunc& pattern, const Array<Var>& 
pattern_vars)
+      : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), 
pattern_(pattern) {
+    for (const auto& pattern_var : pattern_vars) {
+      this->pattern_vars_.insert(pattern_var);
+    }
+    this->evaluated_symbols.push_back(SymbolMap());
+  }
+
+  bool Match(const For& top) {
+    const ForNode* pattern_top = 
pattern_->body.as<BlockRealizeNode>()->block->body.as<ForNode>();
+    ICHECK(pattern_top) << "Invalid pattern function";
+    if (!VisitStmt(top, GetRef<Stmt>(pattern_top))) {
+      return false;
+    }
+    // Get evaluated symbols, buffers from the pattern.
+    for (const auto& arg : pattern_->params) {
+      auto it = pattern_->buffer_map.find(arg);
+      if (it != pattern_->buffer_map.end()) {
+        auto itt = rhs_buffer_map_.find((*it).second);
+        ICHECK(itt != rhs_buffer_map_.end());
+        evaluated_buffers.push_back(itt->second);
+      }
+    }
+    return true;
+  }
+
+  std::vector<SymbolMap> evaluated_symbols;
+  std::vector<Buffer> evaluated_buffers;
+
+ private:
+  using ExprComparator::VisitExpr_;
+
+  Optional<PrimExpr> QueryEvaluatedSymbols(const Var& var) {
+    for (const SymbolMap& symbol_map : evaluated_symbols) {
+      auto it = symbol_map.find(var);
+      if (it != symbol_map.end()) {
+        return it->second;
+      }
+    }
+    return NullOpt;
+  }
+
+  bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
+    if (const auto* op = rhs.as<VarNode>()) {
+      if (pattern_vars_.count(GetRef<Var>(op))) {
+        // special case for pattern vars
+        const auto* lhs_ptr = lhs.as<VarNode>();
+        if (lhs_ptr == nullptr) {
+          if (lhs->IsInstance<tir::IntImmNode>() || 
lhs->IsInstance<tir::FloatImmNode>()) {
+            Optional<PrimExpr> value = QueryEvaluatedSymbols(GetRef<Var>(op));
+            if (value.defined()) {
+              if (!analyzer_.CanProveEqual(lhs, value.value())) return false;
+            } else {
+              evaluated_symbols.back()[GetRef<Var>(op)] = lhs;
+            }
+            return true;
+          } else {
+            return false;
+          }
+        }
+      }
+    }
+    // pattern_var * expr
+    if (const auto* rhs_ptr = rhs.as<MulNode>()) {
+      const auto* operand_a = rhs_ptr->a.as<VarNode>();
+      const auto* operand_b = rhs_ptr->b.as<VarNode>();
+      if (operand_a != nullptr && pattern_vars_.count(GetRef<Var>(operand_a))) 
{
+        // pattern var is on the left
+        evaluated_symbols.push_back(SymbolMap());
+        bool match = VisitExpr(lhs, rhs_ptr->b);
+        SymbolMap symbol_map = std::move(evaluated_symbols.back());
+        evaluated_symbols.pop_back();
+        if (match) {
+          evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+          evaluated_symbols.back()[GetRef<Var>(operand_a)] = 
MakeConstScalar(rhs_ptr->b.dtype(), 1);
+          return true;
+        }
+      }
+      if (operand_b != nullptr && pattern_vars_.count(GetRef<Var>(operand_b))) 
{
+        // pattern var is on the right
+        evaluated_symbols.push_back(SymbolMap());
+        bool match = VisitExpr(lhs, rhs_ptr->a);
+        SymbolMap symbol_map = std::move(evaluated_symbols.back());
+        evaluated_symbols.pop_back();
+        if (match) {
+          evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+          evaluated_symbols.back()[GetRef<Var>(operand_b)] = 
MakeConstScalar(rhs_ptr->a.dtype(), 1);
+          return true;
+        }
+      }
+    }
+    // pattern_Var + expr
+    if (const auto* rhs_ptr = rhs.as<AddNode>()) {
+      const auto* operand_a = rhs_ptr->a.as<VarNode>();
+      const auto* operand_b = rhs_ptr->b.as<VarNode>();
+      if (operand_a != nullptr && pattern_vars_.count(GetRef<Var>(operand_a))) 
{
+        // pattern var is on the left
+        evaluated_symbols.push_back(SymbolMap());
+        bool match = VisitExpr(lhs, rhs_ptr->b);
+        SymbolMap symbol_map = std::move(evaluated_symbols.back());
+        evaluated_symbols.pop_back();
+        if (match) {
+          evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+          evaluated_symbols.back()[GetRef<Var>(operand_a)] = 
MakeConstScalar(rhs_ptr->b.dtype(), 0);
+          return true;
+        }
+      }
+      if (operand_b != nullptr && pattern_vars_.count(GetRef<Var>(operand_b))) 
{
+        // pattern var is on the right
+        evaluated_symbols.push_back(SymbolMap());
+        bool match = VisitExpr(lhs, rhs_ptr->a);
+        SymbolMap symbol_map = std::move(evaluated_symbols.back());
+        evaluated_symbols.pop_back();
+        if (match) {
+          evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+          evaluated_symbols.back()[GetRef<Var>(operand_b)] = 
MakeConstScalar(rhs_ptr->a.dtype(), 0);
+          return true;
+        }
+      }
+    }
+    return TensorizeComparator::VisitExpr(lhs, rhs);
+  }
+
+  bool VisitExpr_(const tir::AddNode* add, const PrimExpr& other) final {
+    const auto* rhs = other.as<AddNode>();
+    if (rhs == nullptr) return false;
+    {
+      this->evaluated_symbols.push_back(SymbolMap());
+      bool match = VisitExpr(add->a, rhs->a) && VisitExpr(add->b, rhs->b);
+      SymbolMap symbol_map = std::move(evaluated_symbols.back());
+      this->evaluated_symbols.pop_back();
+      if (match) {
+        this->evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+        return true;
+      }
+    }
+    {
+      this->evaluated_symbols.push_back(SymbolMap());
+      bool match = VisitExpr(add->a, rhs->b) && VisitExpr(add->b, rhs->a);
+      SymbolMap symbol_map = std::move(evaluated_symbols.back());
+      this->evaluated_symbols.pop_back();
+      if (match) {
+        this->evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool VisitExpr_(const tir::MulNode* mul, const PrimExpr& other) final {
+    const auto* rhs = other.as<MulNode>();
+    if (rhs == nullptr) return false;
+    {
+      this->evaluated_symbols.push_back(SymbolMap());
+      bool match = VisitExpr(mul->a, rhs->a) && VisitExpr(mul->b, rhs->b);
+      SymbolMap symbol_map = std::move(evaluated_symbols.back());
+      this->evaluated_symbols.pop_back();
+      if (match) {
+        this->evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+        return true;
+      }
+    }
+    {
+      this->evaluated_symbols.push_back(SymbolMap());
+      bool match = VisitExpr(mul->a, rhs->b) && VisitExpr(mul->b, rhs->a);
+      SymbolMap symbol_map = std::move(evaluated_symbols.back());
+      this->evaluated_symbols.pop_back();
+      if (match) {
+        this->evaluated_symbols.back().insert(symbol_map.begin(), 
symbol_map.end());
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool VisitExpr_(const tir::CallNode* call, const PrimExpr& other) final {
+    const auto* rhs = other.as<CallNode>();
+    if (rhs == nullptr) return false;
+    const auto* lhs_op = call->op.as<OpNode>();
+    const auto* rhs_op = rhs->op.as<OpNode>();
+    if (lhs_op == nullptr || rhs_op == nullptr) return false;
+    if (lhs_op->name != rhs_op->name) return false;
+    if (call->args.size() != rhs->args.size()) return false;
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      if (!VisitExpr(call->args[i], rhs->args[i])) return false;
+    }
+    return true;
+  }
+
+  bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final {
+    const auto* rhs = other.as<ForNode>();
+    loop_stack_lhs_.push_back(GetRef<For>(op));
+    loop_stack_rhs_.push_back(GetRef<For>(rhs));
+    // The body of loop must be loop or BlockRealize
+    if (!op->body->IsInstance<BlockRealizeNode>() && 
!op->body->IsInstance<ForNode>()) {
+      return false;
+    }
+    if (!rhs->body->IsInstance<BlockRealizeNode>() && 
!rhs->body->IsInstance<ForNode>()) {
+      return false;
+    }
+    // Build mapping between the loop vars
+    if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+    // Only handle the case where the loop start from 0
+    if (!is_zero(op->min) || !is_zero(rhs->min)) return false;
+    if (op->thread_binding.defined() || rhs->thread_binding.defined()) return 
false;
+    if (op->kind != ForKind::kSerial || op->kind != rhs->kind) return false;
+    if (!op->annotations.empty() || !rhs->annotations.empty()) return false;
+    // Match the extents of loops
+    if (!VisitExpr(op->extent, rhs->extent)) return false;
+    return VisitStmt(op->body, rhs->body);
+  }
+
+  bool VisitStmt_(const tir::BlockNode* op, const Stmt& other) final {
+    const auto* rhs = other.as<BlockNode>();
+    // Check block equality.
+    // All iter vars and buffer regions including the order should match.
+    // When checking iter vars, DefEqual is used to remap variables.
+    if (!CompareArray(op->iter_vars, rhs->iter_vars, 
&ForMatcher::CompareIterVar)) {
+      return false;
+    }
+    // disallow alloc buffers inside the block
+    if (!op->alloc_buffers.empty() || !rhs->alloc_buffers.empty()) return 
false;
+    if (!CompareArray(op->writes, rhs->writes, 
&ForMatcher::CompareBufferRegion)) {
+      return false;
+    }
+    if (!CompareArray(op->reads, rhs->reads, 
&ForMatcher::CompareBufferRegion)) {
+      return false;
+    }
+    // The body of the block has to be BufferStore
+    if (!op->body->IsInstance<BufferStoreNode>() || 
!rhs->body->IsInstance<BufferStoreNode>()) {
+      return false;
+    }
+    // Handle init block
+    if (op->init.defined() && !rhs->init.defined()) return false;
+    if (!op->init.defined() && rhs->init.defined()) return false;
+    if (op->init.defined() && rhs->init.defined()) {
+      if (!VisitStmt(op->init.value(), rhs->init.value())) return false;
+    }
+    return VisitStmt(op->body, rhs->body);
+  }
+
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) final {
+    const auto* rhs = other.as<BlockRealizeNode>();
+    // Only allow trivial bindings
+    for (size_t i = 0; i < op->iter_values.size(); ++i) {
+      if (!op->iter_values[i].same_as(loop_stack_lhs_[i]->loop_var)) return 
false;
+    }
+    for (size_t i = 0; i < rhs->iter_values.size(); ++i) {
+      if (!rhs->iter_values[i].same_as(loop_stack_rhs_[i]->loop_var)) return 
false;
+    }
+    // Disallow predicates now
+    if (!is_one(op->predicate) || !is_one(rhs->predicate)) return false;
+    return VisitStmt(op->block, rhs->block);
+  }
+
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) {
+    const auto* rhs = other.as<BufferStoreNode>();
+    return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+  }
+
+  bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) {
+    const auto* rhs = other.as<BufferLoadNode>();
+    return CompareBufferAccess(op, rhs);
+  }
+
+  bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    auto it = rhs_buffer_map_.find(rhs);
+    bool equal;
+    if (it != rhs_buffer_map_.end()) {
+      equal = (*it).second.same_as(lhs);
+    } else {
+      // Compare shape
+      if (lhs->shape.size() != rhs->shape.size()) return false;
+      for (size_t i = 0; i < lhs->shape.size(); ++i) {
+        if (!VisitExpr(lhs->shape[i], rhs->shape[i])) return false;
+      }
+      // Remap both buffer itself and buffer data
+      equal =
+          DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && 
lhs.scope() == rhs.scope();
+      if (equal) {
+        rhs_buffer_map_[rhs] = lhs;
+      }
+    }
+    return equal;
+  }
+
+  bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) {
+    if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
+      return false;
+    }
+    return CompareArray(lhs->region, rhs->region, &ForMatcher::CompareRange);
+  }
+
+  template <typename T>
+  bool CompareBufferAccess(const T* lhs, const T* rhs) {
+    if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+    return CompareArray(lhs->indices, rhs->indices, &ForMatcher::VisitExpr);
+  }
+
+  template <typename T, typename Self, typename F>
+  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp) {
+    if (lhs.same_as(rhs)) return true;
+    if (lhs.size() != rhs.size()) return false;
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
+    }
+    return true;
+  }
+
+  arith::Analyzer analyzer_;
+  std::vector<For> loop_stack_lhs_, loop_stack_rhs_;
+  tir::PrimFunc pattern_;
+  std::unordered_set<Var, ObjectHash, ObjectEqual> pattern_vars_;
+};
+
+/*! \brief Analyze the function and match it with a list of patterns */
+class TIRPatternMatcher {
+ public:
+  static Array<MatchResult> Match(Array<TIRPattern> patterns, Stmt body) {
+    TIRPatternMatcher matcher(patterns);
+    matcher.OpMatternMatch(body);
+    if (matcher.fail_) return {};
+    return matcher.match_results_;
+  }
+
+ private:
+  explicit TIRPatternMatcher(Array<TIRPattern> patterns) : patterns_(patterns) 
{}
+
+  // Find an op that matches this block
+  bool BlockPatternMatch(const For& top) {
+    for (const TIRPattern& pattern : patterns_) {
+      tir::PrimFunc pattern_func = pattern;
+      Array<Var> pattern_symbolic_vars;
+      int buffer_count = pattern_func->buffer_map.size();
+      for (int i = buffer_count; i < 
static_cast<int>(pattern_func->params.size()); i++) {
+        pattern_symbolic_vars.push_back(pattern_func->params[i]);
+      }
+      ForMatcher block_matcher(pattern_func, pattern_symbolic_vars);
+      if (block_matcher.Match(top)) {
+        // We have found a match
+        Array<PrimExpr> symbol_values;
+        for (int i = buffer_count; i < 
static_cast<int>(pattern_func->params.size()); i++) {
+          
symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]);
+        }
+        match_results_.push_back(
+            MatchResult(pattern, symbol_values, 
block_matcher.evaluated_buffers));
+        return true;
+      }
+    }
+    // The block fails to match any pattern
+    return false;
+  }
+
+  // For each block in the body, try to find its corresponding pattern one by 
one
+  void OpMatternMatch(const Stmt& body) {
+    Array<Stmt> blocks;
+    if (body->IsInstance<ForNode>()) {
+      // {for}
+      blocks = {body};
+    } else if (const SeqStmtNode* seq = body.as<SeqStmtNode>()) {
+      blocks = seq->seq;
+    } else {
+      fail_ = true;
+      return;
+    }
+    for (const Stmt& stmt : blocks) {
+      const ForNode* loop = stmt.as<ForNode>();
+      if (loop == nullptr || !BlockPatternMatch(GetRef<For>(loop))) {
+        break;
+      }
+    }
+    if (match_results_.empty()) {
+      fail_ = true;
+    }
+  }
+  /*! \brief Indicate whether we fail to match.*/
+  bool fail_ = false;
+  /*! \brief The patterns we match the target stmt to.*/
+  Array<TIRPattern> patterns_;
+  /*! \brief The results of the matching process.*/
+  Array<MatchResult> match_results_;
+};
+
+/*! \brief helper class to partition a function into 2 parts. Return function 
information which we
+ * can use to construct the two partitioned parts.*/
+class FunctionPartitioner : public StmtExprVisitor {
+ public:
+  explicit FunctionPartitioner(int num_matched_ops) : 
num_matched_ops_(num_matched_ops) {}
+  /*! \brief alloc_buffers for the first function */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs1;
+  /*! \brief alloc_buffers for the second function */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs2;
+  /*! \brief whether the current block is in the first function */
+  Map<Block, Bool> block_partition;
+  /*! \brief input buffers for the first function */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> input1;
+  /*! \brief input buffers for the second function */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> input2;
+  /*! \brief The output buffer for the first function, which is also the input 
buffer for the second
+  function */
+  Buffer intermediate_buffer;
+  /*! \brief Indicate whether we have failed. If failed, we will not do any 
further analysis and
+  directly return the original one. */
+  bool fail = false;
+
+ private:
+  void VisitStmt_(const BlockNode* op) final {
+    block_counter_++;
+    bool is_matching_ = block_counter_ <= num_matched_ops_;
+    if (block_counter_ == num_matched_ops_) {
+      allocs1.erase(intermediate_buffer);
+    }
+    for (const auto& read : op->reads) {
+      if (is_matching_) {
+        input1.insert(read->buffer);
+      } else {
+        input2.insert(read->buffer);
+      }
+    }
+    for (const auto& write : op->writes) {
+      if (is_matching_) {
+        allocs1.insert(write->buffer);
+      } else if (allocs1.count(write->buffer)) {
+        fail = true;
+        return;
+      } else {
+        allocs2.insert(write->buffer);
+      }
+      if (is_matching_) {
+        intermediate_buffer = write->buffer;
+      } else {
+        input2.insert(write->buffer);
+      }
+    }
+    block_partition.Set(GetRef<Block>(op), Bool(is_matching_));
+  }
+  // The number of matched ops in the function
+  size_t num_matched_ops_;
+  size_t block_counter_ = 0;
+};
+
+/*! \brief remove parts according to block partition, and update the 
alloc_buffers for blocks */
+class BlockRemover : public StmtExprMutator {
+ public:
+  static Stmt RemoveBlockByPartition(
+      Stmt stmt, const Map<Block, Bool>& block_partition,
+      const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& allocs,
+      bool is_library_part) {
+    BlockRemover remover(block_partition, allocs, is_library_part);
+    return remover(stmt);
+  }
+
+ private:
+  BlockRemover(const Map<Block, Bool>& block_partition,
+               const std::unordered_set<Buffer, ObjectPtrHash, 
ObjectPtrEqual>& allocs,
+               bool is_library_part)
+      : block_partition(block_partition), allocs_(allocs), 
is_library_part_(is_library_part) {}
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+    ObjectPtr<BlockNode> n = make_object<BlockNode>(*block.operator->());
+    if (op->name_hint != "root") {
+      ICHECK(block_partition.count(GetRef<Block>(op)));
+      bool block_is_library = block_partition[GetRef<Block>(op)]->value;
+      if (!(is_library_part_ ^ block_is_library)) {
+        n->body = block->body;
+      } else {
+        erased_ = true;
+      }
+    }
+    Array<Buffer> alloc_buffers;
+    for (const Buffer& b : block->alloc_buffers) {
+      if (allocs_.count(b)) {
+        alloc_buffers.push_back(b);
+      }
+    }
+    n->alloc_buffers = alloc_buffers;
+    return Block(n);
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    Array<Stmt> seq;
+    for (const Stmt& s : op->seq) {
+      Stmt new_s = VisitStmt(s);
+      if (erased_) {
+        erased_ = false;
+      } else {
+        seq.push_back(new_s);
+      }
+    }
+    return SeqStmt::Flatten(seq);
+  }
+
+  bool erased_ = false;
+  Map<Block, Bool> block_partition;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs_;
+  bool is_library_part_ = false;
+};
+
+/*!
+ * \brief Split the input function into two functions, one for the library 
kernel and one for the
+ * rest.
+ * \param func The input function.
+ * \param arg_partition The input arg for the functions after split.
+ * \param patterns The patterns to match.
+ * \param f_codegen The function to generate the code for the library kernel.
+ * \return A pair of functions, the first one is the library kernel and the 
second one is the
+ * rest.
+ */
+std::pair<PrimFunc, Optional<PrimFunc>> SplitFunctions(PrimFunc func,
+                                                       
std::vector<std::vector<int>>* arg_partition,
+                                                       Array<TIRPattern> 
patterns,
+                                                       FCodegen f_codegen) {
+  // Step 1. Find the library kernel and the rest.
+  Stmt body = func->body.as<BlockRealizeNode>()->block->body;
+  Array<MatchResult> match_results =
+      TIRPatternMatcher::Match(patterns, 
func->body.as<BlockRealizeNode>()->block->body);
+  if (match_results.empty()) {
+    return {func, NullOpt};
+  }
+  Array<ObjectRef> codegen_result = f_codegen(match_results);
+  ICHECK(codegen_result.size() == 3);
+  String library_code = Downcast<String>(codegen_result[0]);
+  int num_matched_ops = Downcast<Integer>(codegen_result[1])->value;
+  Array<Buffer> func1_args = Downcast<Array<Buffer>>(codegen_result[2]);
+  if (num_matched_ops == 0) {
+    return {func, NullOpt};
+  }
+  FunctionPartitioner partitioner(num_matched_ops);
+  partitioner(body);
+  if (partitioner.fail) {
+    return {func, NullOpt};
+  }
+  bool has_second_func = false;
+  for (const auto& pr : partitioner.block_partition) {
+    if (!pr.second->value) {
+      has_second_func = true;
+      break;
+    }
+  }
+  if (!has_second_func) {
+    // No need to split the function.
+    return {WithAttr(func, kLibraryKernel, library_code), NullOpt};
+  }
+  // Step 2. Split the function into two functions.
+  Stmt body1 = BlockRemover::RemoveBlockByPartition(func->body, 
partitioner.block_partition,
+                                                    partitioner.allocs1, true);
+  Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body, 
partitioner.block_partition,
+                                                    partitioner.allocs2, 
false);
+  // Step 3. Craft the first function.
+  Array<Var> new_params1;
+  std::vector<int> arg_partition1;
+  ICHECK_LE(func1_args.size(), partitioner.input1.size());
+  for (const auto& buffer : func1_args) {
+    ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end());
+    for (size_t i = 0; i < func->params.size(); i++) {
+      if (func->buffer_map[func->params[i]].same_as(buffer)) {
+        new_params1.push_back(func->params[i]);
+        arg_partition1.push_back(i);
+        break;
+      }
+    }
+  }
+  arg_partition->push_back(arg_partition1);
+  new_params1.push_back(Var("output", DataType::Handle()));
+  Map<Var, Buffer> new_buffer_map1;
+  for (const auto& kv : func->buffer_map) {
+    if (partitioner.input1.count(kv.second)) {
+      new_buffer_map1.Set(kv.first, kv.second);
+    }
+  }
+  new_buffer_map1.Set(new_params1.back(), partitioner.intermediate_buffer);
+  PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, 
new_buffer_map1, func->attrs);
+  func1 = WithAttr(func1, kLibraryKernel, library_code);
+  // Step 4. Craft the second function.
+  Array<Var> new_params2;
+  std::vector<int> arg_partition2;
+  new_params2.push_back(Var("input", DataType::Handle()));
+  for (int i = 0; i < static_cast<int>(func->params.size()); i++) {
+    Var param = func->params[i];
+    if (partitioner.input2.count(func->buffer_map[param])) {
+      new_params2.push_back(param);
+      if (i != static_cast<int>(func->params.size()) - 1) {
+        arg_partition2.push_back(i);
+      }
+    }
+  }
+  arg_partition->push_back(arg_partition2);
+  Map<Var, Buffer> new_buffer_map2;
+  new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer);
+  for (const auto& kv : func->buffer_map) {
+    if (partitioner.input2.count(kv.second)) {
+      new_buffer_map2.Set(kv.first, kv.second);
+    }
+  }
+  PrimFunc func2 = PrimFunc(new_params2, body2, func->ret_type, 
new_buffer_map2, func->attrs);
+  return {func1, func2};
+}
+}  // namespace tir
+
+namespace relax {
+void StringReplace(std::string* subject, const std::string& search, const 
std::string& replace) {
+  for (size_t pos = 0; (pos = subject->find(search, pos)) != std::string::npos;
+       pos += replace.length()) {
+    subject->replace(pos, search.length(), replace);
+  }
+}
+
+tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String 
global_symbol) {
+  using namespace tvm::tir;
+  Optional<runtime::String> library_code = 
pf->attrs.GetAttr<runtime::String>(kLibraryKernel);
+  if (!library_code.defined()) {
+    return GetRef<tir::PrimFunc>(pf);
+  }
+  std::string source = library_code.value();
+  StringReplace(&source, "{global_symbol}", global_symbol);
+  ExternFunc ret(global_symbol);
+  ret = WithAttrs(std::move(ret), Map<String, ObjectRef>{
+                                      {String(kCSource), String(source)},
+                                      {String(kCSourceFmt), 
String(kCSourceFmtCuda)},
+                                  });
+  return ret;
+}
+
+/*! \brief Emit 2 calls to the library kernel and the rest of the function. */
+class SplitMutator : public ExprMutator {
+ public:
+  SplitMutator(const tvm::IRModule& mod, Array<TIRPattern> patterns, FCodegen 
fcodegen)
+      : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen) 
{}
+  static IRModule Transform(const IRModule& mod, Array<TIRPattern> patterns, 
FCodegen fcodegen) {
+    SplitMutator mutator(mod, patterns, fcodegen);
+    for (auto& kv : mod->functions) {
+      if (auto* func = kv.second.as<FunctionNode>()) {
+        Function new_func = 
Downcast<Function>(mutator(GetRef<Function>(func)));
+        mutator.builder_->UpdateFunction(kv.first, new_func);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  using ExprMutator::VisitExpr_;
+
+  inline Array<Expr> GetCallTIRArgs(Expr args) {
+    if (args.as<TupleNode>()) {
+      return args.as<TupleNode>()->fields;
+    } else {
+      return {args};
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* op) final {
+    Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    static const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed");
+    if (!call->op.same_as(call_tir_op_)) return call;
+    // the first argument is the function to be called
+    const auto* gv_ptr = call->args[0].as<GlobalVarNode>();
+    if (gv_ptr == nullptr) return call;
+    GlobalVar gv = GetRef<GlobalVar>(gv_ptr);
+    // retrieve the function from the module and split it
+    tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    std::vector<std::vector<int>> arg_partition;
+    // split the function into two functions, one for the library kernel and 
one for the rest.
+    std::pair<tir::PrimFunc, Optional<tir::PrimFunc>> split_funcs =
+        tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_);
+    if (!split_funcs.second.defined()) {
+      // no need to split, the function itself a library kernel
+      tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), 
gv->name_hint);
+      if (lib_func->IsInstance<tir::PrimFuncNode>()) return GetRef<Call>(op);
+      // Update the function in the module with the library kernel
+      ICHECK(lib_func->IsInstance<ExternFuncNode>());
+      builder_->UpdateFunction(gv, lib_func);
+      // emit the call to the library kernel
+      ObjectPtr<CallNode> new_call = make_object<CallNode>(*call.operator->());
+      new_call->op = this->call_dps_packed_;
+      new_call->args = {lib_func, call->args[1]};
+      return Call(new_call);
+    }
+    tir::PrimFunc func1 = tir::RenewDefs(split_funcs.first);
+    tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value());
+    ICHECK(arg_partition.size() == 2);
+    // emit the first call to the library kernel
+    Array<Expr> args1;
+    for (int p : arg_partition[0]) {
+      args1.push_back(GetCallTIRArgs(call->args[1])[p]);
+    }
+    // replace the function in the module with the library kernel
+    tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint);
+    if (lib_func->IsInstance<tir::PrimFuncNode>()) return GetRef<Call>(op);
+    ICHECK(lib_func->IsInstance<ExternFuncNode>());
+    builder_->UpdateFunction(gv, lib_func);
+    tir::Buffer intermediate_buffer = 
func1->buffer_map.at(func1->params.back());
+    DataType dtype = intermediate_buffer->dtype;
+    Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs,
+               {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), 
dtype)});
+    Var call_var1 = builder_->Emit(call1);
+    // emit the second call to the rest of the function
+    Array<Expr> args2;
+    args2.push_back(call_var1);
+    for (int p : arg_partition[1]) {
+      args2.push_back(GetCallTIRArgs(call->args[1])[p]);
+    }
+    GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue");
+    Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, 
call->sinfo_args);
+    builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol"));
+    return call2;
+  }
+
+  const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed");
+  tvm::IRModule mod_;
+  Array<TIRPattern> patterns_;
+  FCodegen fcodegen_;
+};
+
+namespace transform {
+Pass SplitCallTIRByPattern(Array<TIRPattern> patterns, FCodegen fcodegen) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m, 
patterns, fcodegen); };
+  return CreateModulePass(/*pass_function=*/pass_func,            //
+                          /*opt_level=*/0,                        //
+                          /*pass_name=*/"SplitCallTIRByPattern",  //
+                          /*required=*/{});
+}
+TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_codegen_tir_cutlass.py 
b/tests/python/relax/test_codegen_tir_cutlass.py
new file mode 100644
index 0000000000..9c960ed355
--- /dev/null
+++ b/tests/python/relax/test_codegen_tir_cutlass.py
@@ -0,0 +1,709 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+import tempfile
+
+from tvm import relax, runtime
+import tvm
+import tvm.testing
+from tvm import relax
+import scipy
+from scipy.special import erf
+import numpy as np
+from tvm.target import Target
+from tvm.relax.vm_build import build as relax_build
+from tvm.script.ir_builder import relax as R
+from tvm.script.ir_builder import ir as I
+from tvm.script.ir_builder import tir as T
+from tvm.script.ir_builder import IRBuilder
+
+from tvm.relax.backend_tir import get_tir_pattern
+from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen, 
compile_options
+
+A_TYPE = "float16"
+B_TYPE = "float16"
+C_TYPE = "float16"
+
+target = Target("cuda")
+
+
+def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input):
+    vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device)
+    return vm["main"](*input)
+
+
+def build(mod):
+    mod = relax.transform.LegalizeOps()(mod)
+    mod = relax.transform.AnnotateTIROpPattern()(mod)
+    mod = relax.transform.FuseOps()(mod)
+    mod = relax.transform.FuseTIR()(mod)
+    mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), 
cutlass_fcodegen())(mod)
+    mod = relax.transform.DeadCodeElimination()(mod)
+    print(mod.script())
+    f = tempfile.NamedTemporaryFile(suffix=".so", delete=True)
+    executable = relax_build(mod, target)
+
+    executable.mod.export_library(f.name, **compile_options(target))
+    rt_mod = runtime.load_module(f.name)
+    f.close()
+    return rt_mod
+
+
+def build_and_run_reference(mod, inputs_np):
+    mod = relax.transform.LegalizeOps()(mod)
+    dev = tvm.device("llvm", 0)
+    ex = relax.build(mod, "llvm")
+    vm = relax.VirtualMachine(ex, dev)
+    f = vm["main"]
+    inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+    return f(*inputs).numpy()
+
+
+def constructGEMM(M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    R.output(C)
+                (C,) = df.output_vars
+                R.func_ret_value(C)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense():
+    m, n, k = 128, 64, 256
+    executable = build(constructGEMM(m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructGEMM_bias(M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    relax_mod = ib.get()
+    return relax_mod
+
+
+def constructGEMM_bias2(M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((N,), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias():
+    m, n, k = 128, 64, 256
+    executable = build(constructGEMM_bias(m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(1, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias2():
+    m, n, k = 128, 64, 256
+    executable = build(constructGEMM_bias2(m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
+
+
+def constructGEMM_bias_relu(M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    E = R.emit(R.nn.relu(D))
+                    R.output(E)
+                (E,) = df.output_vars
+                R.func_ret_value(E)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias_relu():
+    m, n, k = 128, 64, 256
+    executable = build(constructGEMM_bias_relu(m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(1, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), 
rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    R.output(C)
+                (C,) = df.output_vars
+                R.func_ret_value(C)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM2(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    R.output(C)
+                (C,) = df.output_vars
+                R.func_ret_value(C)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense2():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM2(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(b, k, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM_bias(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM_bias(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(1, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
+
+
+def constructBatchGEMM_bias2(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((N,), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM_bias2(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
+
+
+def constructBatchGEMM_bias2_gelu(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((N,), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    E = R.emit(R.nn.gelu(D))
+                    R.output(E)
+                (E,) = df.output_vars
+                R.func_ret_value(E)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2_gelu():
+    b, m, n, k = 2, 128, 64, 256
+    executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    C = A @ B + bias
+    O = 0.5 * C * (1 + erf(C / np.sqrt(2)))
+    np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM_bias2_mul(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((N,), A_TYPE)
+                )  # pylint: disable=invalid-name
+                residual = R.arg("residual", relax.TensorStructInfo((batch, M, 
N), A_TYPE))
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    E = R.emit(R.multiply(D, residual))
+                    R.output(E)
+                (E,) = df.output_vars
+                R.func_ret_value(E)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2_mul():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM_bias2_mul(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(k, n).astype("float16")
+    bias = np.random.randn(n).astype("float16")
+    residual = np.random.randn(b, m, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    residual_tvm = tvm.nd.array(residual, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm)
+    np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual, 
rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM2_bias(batch, M, N, K):
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                A = R.arg(
+                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+                )  # pylint: disable=invalid-name
+                B = R.arg(
+                    "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    relax_mod = ib.get()
+    return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense2_bias():
+    b, m, n, k = 2, 128, 256, 64
+    executable = build(constructBatchGEMM2_bias(b, m, n, k))
+    dev = tvm.cuda()
+    A = np.random.randn(b, m, k).astype("float16")
+    B = np.random.randn(b, k, n).astype("float16")
+    bias = np.random.randn(1, n).astype("float16")
+    A_tvm = tvm.nd.array(A, dev)
+    B_tvm = tvm.nd.array(B, dev)
+    bias_tvm = tvm.nd.array(bias, dev)
+    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
+
+
+def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation):
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import ir as I
+    from tvm.script.ir_builder import relax as R
+    from tvm.script.ir_builder import tir as T
+
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                x = R.arg(
+                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+                )  # pylint: disable=invalid-name
+                w = R.arg(
+                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(
+                        R.nn.conv2d(
+                            x,
+                            w,
+                            strides=strides,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=1,
+                            data_layout="NHWC",
+                            kernel_layout="OHWI",
+                            out_layout="NHWC",
+                            out_dtype=C_TYPE,
+                        )
+                    )
+                    R.output(C)
+                (C,) = df.output_vars
+                R.func_ret_value(C)
+    mod = ib.get()
+    return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d():
+    n, c, h, w = 1, 3, 224, 224
+    kh, kw, o = 3, 3, 64
+    for strides in [(1, 1), (2, 2)]:
+        for padding in [(0, 0), (3, 3)]:
+            for dilation in [(1, 1), (4, 4)]:
+                mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding, 
dilation)
+                executable = build(mod)
+                dev = tvm.cuda()
+                np.random.seed(0)
+                A = np.random.randn(n, h, w, c).astype("float16")
+                B = np.random.randn(o, kh, kw, c).astype("float16")
+                A_tvm = tvm.nd.array(A, dev)
+                B_tvm = tvm.nd.array(B, dev)
+                result = f_run(executable, dev, A_tvm, B_tvm)
+                result_ref = build_and_run_reference(mod, [A, B])
+                np.testing.assert_allclose(
+                    result.numpy(),
+                    result_ref,
+                    rtol=5e-2,
+                    atol=5e-2,
+                )
+
+
+def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation):
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import ir as I
+    from tvm.script.ir_builder import relax as R
+    from tvm.script.ir_builder import tir as T
+
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                x = R.arg(
+                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+                )  # pylint: disable=invalid-name
+                w = R.arg(
+                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(
+                        R.nn.conv2d(
+                            x,
+                            w,
+                            strides=strides,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=1,
+                            data_layout="NHWC",
+                            kernel_layout="OHWI",
+                            out_layout="NHWC",
+                            out_dtype=C_TYPE,
+                        )
+                    )
+                    D = R.emit(R.add(C, bias))
+                    R.output(D)
+                (D,) = df.output_vars
+                R.func_ret_value(D)
+    mod = ib.get()
+    return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d_bias():
+    c, h, w = 3, 224, 224
+    kh, kw, o = 3, 3, 64
+    for n in [1, 2]:
+        for strides in [(1, 1), (2, 2)]:
+            for padding in [(0, 0), (3, 3)]:
+                for dilation in [(1, 1), (4, 4)]:
+                    mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides, 
padding, dilation)
+                    executable = build(mod)
+                    dev = tvm.cuda()
+                    np.random.seed(0)
+                    A = np.random.randn(n, h, w, c).astype("float16")
+                    B = np.random.randn(o, kh, kw, c).astype("float16")
+                    bias = np.random.randn(1, 1, 1, o).astype("float16")
+                    A_tvm = tvm.nd.array(A, dev)
+                    B_tvm = tvm.nd.array(B, dev)
+                    bias_tvm = tvm.nd.array(bias, dev)
+                    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+                    result_ref = build_and_run_reference(mod, [A, B, bias])
+                    np.testing.assert_allclose(
+                        result.numpy(),
+                        result_ref,
+                        rtol=5e-2,
+                        atol=5e-2,
+                    )
+
+
+def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, 
dilation):
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import ir as I
+    from tvm.script.ir_builder import relax as R
+    from tvm.script.ir_builder import tir as T
+
+    with IRBuilder() as ib:  # pylint: disable=invalid-name
+        with I.ir_module() as frame:
+            with R.function():
+                R.func_name("main")
+                x = R.arg(
+                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+                )  # pylint: disable=invalid-name
+                w = R.arg(
+                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+                )  # pylint: disable=invalid-name
+                bias = R.arg(
+                    "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
+                )  # pylint: disable=invalid-name
+                res = R.arg(
+                    "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE)
+                )  # pylint: disable=invalid-name
+                with R.dataflow() as df:
+                    C = R.emit(
+                        R.nn.conv2d(
+                            x,
+                            w,
+                            strides=strides,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=1,
+                            data_layout="NHWC",
+                            kernel_layout="OHWI",
+                            out_layout="NHWC",
+                            out_dtype=C_TYPE,
+                        )
+                    )
+                    D = R.emit(R.add(C, bias))
+                    E = R.emit(R.add(D, res))
+                    R.output(E)
+                (E,) = df.output_vars
+                R.func_ret_value(E)
+    mod = ib.get()
+    return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d_bias_add():
+    n, c, h, w = 2, 3, 224, 224
+    kh, kw, o = 3, 3, 64
+    for strides in [(1, 1), (2, 2)]:
+        for padding in [(0, 0), (3, 3)]:
+            for dilation in [(1, 1), (4, 4)]:
+                oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // 
strides[0] + 1
+                ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // 
strides[1] + 1
+                mod = constructConv2D_bias_add(
+                    n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation
+                )
+                executable = build(mod)
+                dev = tvm.cuda()
+                np.random.seed(0)
+                A = np.random.randn(n, h, w, c).astype("float16")
+                B = np.random.randn(o, kh, kw, c).astype("float16")
+                bias = np.random.randn(1, 1, 1, o).astype("float16")
+                res = np.random.randn(n, oh, ow, o).astype("float16")
+                A_tvm = tvm.nd.array(A, dev)
+                B_tvm = tvm.nd.array(B, dev)
+                bias_tvm = tvm.nd.array(bias, dev)
+                res_tvm = tvm.nd.array(res, dev)
+                result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, 
res_tvm)
+                result_ref = build_and_run_reference(mod, [A, B, bias, res])
+                np.testing.assert_allclose(
+                    result.numpy(),
+                    result_ref,
+                    rtol=5e-2,
+                    atol=5e-2,
+                )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to