This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 517d0457d3 [Unity][Transform] SplitCallTIRByPattern and CUTLASS
backend (#14274)
517d0457d3 is described below
commit 517d0457d35a067dbdab57496a678bc076cd76c4
Author: Bohan Hou <[email protected]>
AuthorDate: Fri Mar 24 13:39:20 2023 -0400
[Unity][Transform] SplitCallTIRByPattern and CUTLASS backend (#14274)
Currently, the BYOC system is based on op-level pattern matching, this PR
intends to provide primary support for TIR-level pattern matching based on
backend registration and dispatching.
For now, it simply matches the first set of for loops in PrimFunc.
Co-authored-by: Hongyi Jin (@jinhongyii)
---
include/tvm/relax/tir_pattern.h | 75 +++
python/tvm/contrib/cutlass/gemm_operation.py | 3 +-
python/tvm/relax/backend_tir/__init__.py | 20 +
python/tvm/relax/backend_tir/contrib/__init__.py | 20 +
python/tvm/relax/backend_tir/contrib/cutlass.py | 720 +++++++++++++++++++++
python/tvm/relax/backend_tir/pattern.py | 576 +++++++++++++++++
python/tvm/relax/transform/transform.py | 19 +
src/relax/backend/vm/codegen_vm.cc | 11 +
src/relax/ir/tir_pattern.cc | 37 ++
src/relax/transform/split_call_tir_by_pattern.cc | 782 +++++++++++++++++++++++
tests/python/relax/test_codegen_tir_cutlass.py | 709 ++++++++++++++++++++
11 files changed, 2971 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h
new file mode 100644
index 0000000000..02634dcbbf
--- /dev/null
+++ b/include/tvm/relax/tir_pattern.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_pattern.h
+ * \brief Data Structure of TIR Pattern used for matching.
+ */
+
+#ifndef TVM_RELAX_TIR_PATTERN_H_
+#define TVM_RELAX_TIR_PATTERN_H_
+
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace relax {
+
+using TIRPattern = tir::PrimFunc;
+
+/*
+ * \brief The match result of a TIR pattern.
+ */
+class MatchResultNode : public Object {
+ public:
+ /*! The matched tir pattern*/
+ TIRPattern pattern;
+ /*! \brief The evaluated values of symbolic vars. */
+ Array<PrimExpr> symbol_values;
+ /*! \brief The matched buffers of input and output. */
+ Array<tir::Buffer> matched_buffers;
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("pattern", &pattern);
+ v->Visit("symbol_values", &symbol_values);
+ v->Visit("matched_buffers", &matched_buffers);
+ }
+ static constexpr const char* _type_key = "relax.MatchResult";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MatchResultNode.
+ */
+class MatchResult : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor
+ * \param pattern The matched tir pattern.
+ * \param symbol_values The evaluated values of symbolic vars.
+ * \param matched_buffers The matched buffers of input and output.
+ */
+ TVM_DLL explicit MatchResult(TIRPattern pattern, Array<PrimExpr>
symbol_values,
+ Array<tir::Buffer> matched_buffers);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode)
+};
+
+using FCodegen = runtime::TypedPackedFunc<Array<ObjectRef>(Array<MatchResult>
match_results)>;
+} // namespace relax
+} // namespace tvm
+#endif // TVM_RELAX_TIR_PATTERN_H_
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index eb9f92dad3..b820ead016 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -369,7 +369,8 @@ def instantiate_gemm_template(attrs):
{
"bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n",
"ptr_c": "ptr_bias",
- "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"],
+ "c_stride": "(${bias_arg}->ndim == 1 || ${bias_arg}->shape[0]
== 1) ? 0 : "
+ + attrs["ldc"],
}
)
else:
diff --git a/python/tvm/relax/backend_tir/__init__.py
b/python/tvm/relax/backend_tir/__init__.py
new file mode 100644
index 0000000000..eeb8fe438f
--- /dev/null
+++ b/python/tvm/relax/backend_tir/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax backends, tir based"""
+
+from . import contrib
+from .pattern import get_tir_pattern
diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py
b/python/tvm/relax/backend_tir/contrib/__init__.py
new file mode 100644
index 0000000000..9274f22374
--- /dev/null
+++ b/python/tvm/relax/backend_tir/contrib/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""External backend codegen modules for Relax, tir based."""
+
+from .cutlass import cutlass_fcodegen
diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py
b/python/tvm/relax/backend_tir/contrib/cutlass.py
new file mode 100644
index 0000000000..0dbe31c468
--- /dev/null
+++ b/python/tvm/relax/backend_tir/contrib/cutlass.py
@@ -0,0 +1,720 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:
disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring
+"""codegen for cutlass"""
+import operator
+from functools import reduce
+from typing import List, Dict, Any
+
+from tvm.contrib.cutlass.build import _get_cutlass_path,
_get_cutlass_compile_options
+from tvm.contrib.nvcc import get_target_compute_version
+from tvm.contrib.cutlass.library import LayoutType, ConvKind
+from tvm.contrib.cutlass.gen_tensor_op import instantiate_template
+from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler
+from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler
+from ..pattern import (
+ MatchResult,
+ matmul_rrr_fp16,
+ bias_row_2d_fp16,
+ bias_row_1d_fp16,
+ batch_bias_row_2d_fp16,
+ batch_bias_row_1d_fp16,
+ relu_fp16,
+ erf_3d_fp32,
+ batch_matmul_rrr_2d_fp16,
+ batch_matmul_rrr_3d_fp16,
+ conv2d_nhwc_fp16,
+ padding_2d_nhwc_fp16,
+ copy_4d_fp16,
+ bias_add_nhwc_2d_fp16,
+ bias_add_nhwc_1d_fp16,
+ elem_add_4d_fp16,
+ elem_mul_3d_fp16,
+ scalar_add_3d_fp16,
+ scalar_mul_3d_fp16,
+ cast_3d_fp16,
+ cast_3d_fp32,
+)
+
+#### helper functions ####
+# list representing the anchor ops
+# in the future more layouts/dtypes can be supported
+MATMUL_LIST = [matmul_rrr_fp16]
+MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16]
+BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16]
+BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16]
+CONV2D_LIST = [conv2d_nhwc_fp16]
+
+# attributes for anchor ops used in code generation
+OP_PATTERN_ATTR_LIST = {
+ matmul_rrr_fp16: {
+ "arg0_dtype": "float16",
+ "arg1_dtype": "float16",
+ "ret_dtype": "float16",
+ },
+ batch_matmul_rrr_2d_fp16: {
+ "arg0_dtype": "float16",
+ "arg1_dtype": "float16",
+ "ret_dtype": "float16",
+ },
+ batch_matmul_rrr_3d_fp16: {
+ "arg0_dtype": "float16",
+ "arg1_dtype": "float16",
+ "ret_dtype": "float16",
+ },
+ conv2d_nhwc_fp16: {
+ "arg0_dtype": "float16",
+ "arg1_dtype": "float16",
+ "ret_dtype": "float16",
+ # in the future we can add layout here
+ },
+}
+
+
+def _get_cutlass_code(attr):
+ pattern = attr["op_type"]
+ if pattern.startswith("cutlass.matmul"):
+ return cutlass_codegen_gemm(attr)
+ elif pattern.startswith("cutlass.conv2d"):
+ return cutlass_codegen_conv2d(attr)
+ else:
+ raise ValueError("op not supported")
+
+
+def _final_code(code, headers, func_args):
+ res = ""
+ res += "#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>\n"
+ res += "#include <tvm/runtime/c_runtime_api.h>\n"
+ res += "#include <tvm/runtime/packed_func.h>\n"
+ res += "#include <dlpack/dlpack.h>\n"
+ res += "#include <cuda_fp16.h>\n"
+ res += "#include <cutlass/cutlass.h>\n"
+ res += "#include <cutlass/coord.h>\n"
+ res += "#include <cutlass/tensor_ref.h>\n"
+ res += "#include <cutlass/util/host_tensor.h>\n"
+
+ for header in headers:
+ res += "#include <" + header + ">\n"
+ res += "namespace {\n"
+ res += "using namespace tvm;\n"
+ res += "using namespace tvm::runtime;\n"
+ res += "void _cutlass_kernel("
+ for arg in func_args:
+ res += "NDArray " + arg + ", "
+ res += "NDArray out0) {"
+ res += code
+ res += "}\n"
+ res += "} // namespace\n"
+ res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n"
+ return res
+
+
+#### cutlass patterns ####
+def matmul_bias_relu(match_results, attr, get_code=True):
+ if len(match_results) < 3:
+ return None
+ attr = matmul_bias(match_results[:2], attr, get_code=False)
+ if attr is None or match_results[2].pattern != relu_fp16:
+ return None
+ m_bias, n_bias = match_results[1].symbol_values
+ m_relu, n_relu = match_results[2].symbol_values
+ A_bias, B_bias, C_bias = match_results[1].matched_buffers
+ A_relu, B_relu = match_results[2].matched_buffers
+ if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu:
+ attr["op_type"] = "cutlass.matmul_bias_relu"
+ return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
+ return None
+
+
+def matmul_bias(match_results, attr, get_code=True):
+ if len(match_results) < 2:
+ return None
+ attr = matmul(match_results[:1], attr, get_code=False)
+ if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST:
+ return None
+ m_matmul, n_matmul, k_matmul = match_results[0].symbol_values
+ m_bias, n_bias = match_results[1].symbol_values
+ A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+ A_bias, B_bias, C_bias = match_results[1].matched_buffers
+ if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias:
+ attr["op_type"] = "cutlass.matmul_bias"
+ attr["bias_arg_idx"] = 2
+ attr["args"].append(B_bias)
+ return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code
else attr
+ return None
+
+
+def matmul(match_results, attr, get_code=True):
+ if len(match_results) < 1:
+ return None
+ if match_results[0].pattern in MATMUL_LIST:
+ # matmul
+ attr["op_type"] = "cutlass.matmul"
+ return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code
else attr
+ return None
+
+
+def batch_matmul_bias_gelu(match_results, attr, get_code=True):
+ if len(match_results) < 9:
+ return None
+ attr = batch_matmul_bias(match_results[:2], attr, get_code=False) #
batch_matmul, batch_bias
+ if (
+ attr is None
+ or match_results[2].pattern != scalar_mul_3d_fp16
+ or match_results[3].pattern != cast_3d_fp32
+ or match_results[4].pattern != erf_3d_fp32
+ or match_results[5].pattern != cast_3d_fp16
+ or match_results[6].pattern != scalar_mul_3d_fp16
+ or match_results[7].pattern != scalar_add_3d_fp16
+ or match_results[8].pattern != elem_mul_3d_fp16
+ ):
+ return None
+
+ def shape_match_3d(shape1, shape2):
+ if len(shape1) < 3 or len(shape2) < 3:
+ return False
+ return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2]
== shape2[2]
+
+ for i in range(1, 8):
+ if not shape_match_3d(match_results[i].symbol_values, match_results[i
+ 1].symbol_values):
+ return None
+
+ if not (
+ match_results[1].matched_buffers[-1] ==
match_results[2].matched_buffers[0]
+ and match_results[2].matched_buffers[-1] ==
match_results[3].matched_buffers[0]
+ and match_results[3].matched_buffers[-1] ==
match_results[4].matched_buffers[0]
+ and match_results[4].matched_buffers[-1] ==
match_results[5].matched_buffers[0]
+ and match_results[5].matched_buffers[-1] ==
match_results[6].matched_buffers[0]
+ and match_results[6].matched_buffers[-1] ==
match_results[7].matched_buffers[0]
+ and match_results[1].matched_buffers[-1] ==
match_results[8].matched_buffers[0]
+ and match_results[7].matched_buffers[-1] ==
match_results[8].matched_buffers[1]
+ ):
+ return None
+
+ if (
+ abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5
+ or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5
+ or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5
+ ):
+ return None
+
+ attr["op_type"] = "cutlass.matmul_bias_gelu"
+ return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else
attr
+
+
+def batch_matmul_bias_residual_mul(match_results, attr, get_code=True):
+ if len(match_results) < 3:
+ return None
+ attr = batch_matmul_bias(match_results[:2], attr, get_code=False) #
batch_matmul, batch_bias
+ if attr is None or match_results[2].pattern != elem_mul_3d_fp16:
+ return None
+ (
+ b_bias,
+ m_bias,
+ n_bias,
+ ) = match_results[1].symbol_values
+ (
+ b_mul,
+ m_mul,
+ n_mul,
+ ) = match_results[2].symbol_values
+ A_bias, B_bias, C_bias = match_results[1].matched_buffers
+ A_mul, B_mul, C_mul = match_results[2].matched_buffers
+ if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias ==
A_mul:
+ attr["op_type"] = "cutlass.matmul_bias_residual_multiply"
+ attr["residual_arg_idx"] = 3
+ return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
+ return None
+
+
+def batch_matmul_bias(match_results, attr, get_code=True):
+ if len(match_results) < 2:
+ return None
+ attr = batch_matmul(match_results[:1], attr, get_code=False)
+ if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST:
+ return None
+ (
+ b_matmul,
+ m_matmul,
+ n_matmul,
+ k_matmul,
+ ) = match_results[0].symbol_values
+ (
+ b_bias,
+ m_bias,
+ n_bias,
+ ) = match_results[1].symbol_values
+ A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+ A_bias, B_bias, C_bias = match_results[1].matched_buffers
+ if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and
C_matmul == A_bias:
+ attr["op_type"] = "cutlass.matmul_bias"
+ attr["bias_arg_idx"] = 2
+ attr["args"].append(B_bias)
+ return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code
else attr
+ return None
+
+
+def batch_matmul(match_results, attr, get_code=True):
+ if len(match_results) < 1:
+ return None
+ if match_results[0].pattern in BATCH_MATMUL_LIST:
+ attr["op_type"] = "cutlass.matmul"
+ return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code
else attr
+ return None
+
+
+def conv2d_bias_residual_add(match_results, attr, get_code=True):
+ if len(match_results) < 4:
+ return None
+ attr = conv2d_bias(match_results[:3], attr, get_code=False)
+ if attr is None or match_results[3].pattern != elem_add_4d_fp16:
+ return None
+ N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
+ in1_bias, in2_bias, out_bias = match_results[2].matched_buffers
+ N_add, H_add, W_add, C_add = match_results[3].symbol_values
+ in1_add, in2_add, out_add = match_results[3].matched_buffers
+ if (
+ N_bias == N_add
+ and H_bias == H_add
+ and W_bias == W_add
+ and C_bias == C_add
+ and out_bias in [in1_add, in2_add]
+ ):
+ attr["op_type"] = "cutlass.conv2d_bias_residual_add"
+ attr["residual_arg_idx"] = 3
+ attr["args"].append(in2_add if out_bias == in1_add else in1_add)
+ return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code
else attr
+ return None
+
+
+def conv2d_bias(match_results, attr, get_code=True):
+ if len(match_results) < 3:
+ return None
+ attr = conv2d(match_results[:2], attr, get_code=False)
+ if attr is None or (
+ match_results[2].pattern not in [bias_add_nhwc_2d_fp16,
bias_add_nhwc_1d_fp16]
+ ):
+ return None
+ (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) =
match_results[
+ 1
+ ].symbol_values[:7]
+ A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
+ N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
+ A_bias, B_bias, out_bias = match_results[2].matched_buffers
+ if (
+ N_bias == N_conv
+ and H_bias == H_conv
+ and W_bias == W_conv
+ and C_bias == O_conv
+ and out_conv == A_bias
+ ):
+ attr["op_type"] = "cutlass.conv2d_bias"
+ attr["bias_arg_idx"] = 2
+ attr["args"].append(B_bias)
+ return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
+ return None
+
+
+def conv2d(match_results, attr, get_code=True):
+ if len(match_results) < 2:
+ return None
+ if (
+ match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16]
+ and match_results[1].pattern == conv2d_nhwc_fp16
+ ):
+ if match_results[0].pattern == padding_2d_nhwc_fp16:
+ (
+ N_pad,
+ H_pad,
+ W_pad,
+ C_pad,
+ pH_pad,
+ pW_pad,
+ lH_pad,
+ lW_pad,
+ rH_pad,
+ rW_pad,
+ ) = match_results[0].symbol_values
+ else:
+ (
+ N_pad,
+ H_pad,
+ W_pad,
+ C_pad,
+ ) = match_results[0].symbol_values
+ pH_pad = rH_pad = H_pad
+ pW_pad = rW_pad = W_pad
+ lH_pad = lW_pad = 0
+ (
+ N_conv,
+ pH_conv,
+ pW_conv,
+ H_conv,
+ W_conv,
+ C_conv,
+ O_conv,
+ KH_conv,
+ KW_conv,
+ stride_h_conv,
+ stride_w_conv,
+ dilation_h_conv,
+ dilation_w_conv,
+ ) = match_results[1].symbol_values
+ A, A_pad = match_results[0].matched_buffers
+ A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
+ if (
+ N_pad == N_conv
+ and pH_pad == pH_conv
+ and pW_pad == pW_conv
+ and C_pad == C_conv
+ and A_pad == A_pad_conv
+ ):
+ if (
+ lH_pad == pH_pad - rH_pad
+ and lW_pad == pW_pad - rW_pad
+ and lH_pad + H_pad == rH_pad
+ and lW_pad + W_pad == rW_pad
+ ):
+ padding = (lH_pad, lW_pad)
+ strides = (stride_h_conv, stride_w_conv)
+ dilation = (dilation_h_conv, dilation_w_conv)
+ attr["padding"] = padding
+ attr["strides"] = strides
+ attr["dilation"] = dilation
+ attr["op_type"] = "cutlass.conv2d"
+ return [_get_cutlass_code(attr=attr), 2, attr["args"]] if
get_code else attr
+ return None
+
+
+### cutlass codegen functions ###
+def compile_options(target, threads=-1, use_fast_math=False):
+ compute_version =
int("".join(get_target_compute_version(target).split(".")))
+ kwargs = _get_cutlass_compile_options(compute_version, threads,
use_fast_math)
+ kwargs["options"].remove("-c")
+ return kwargs
+
+
+def cutlass_fcodegen(sm=80, bin_dir="./bin"):
+ gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir)
+ conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir)
+
+ def cutlass_codegen_with_match_results(match_results: List[MatchResult]):
+ """generate cutlass code with match results"""
+ nonlocal gemm_profiler
+ nonlocal conv2d_profiler
+
+ assert len(match_results) > 0
+
+ # add shape into attr
+ if match_results[0].pattern in MATMUL_LIST:
+ A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+ attr: Dict[Any, Any] =
OP_PATTERN_ATTR_LIST[match_results[0].pattern]
+ attr["args"] = [A_matmul, B_matmul]
+ attr["arg0_shape"] = A_matmul.shape
+ attr["arg1_shape"] = B_matmul.shape
+ attr["ret_shape"] = C_matmul.shape
+ attr["lhs_arg_idx"] = 0
+ attr["rhs_arg_idx"] = 1
+ elif match_results[0].pattern in BATCH_MATMUL_LIST:
+ A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
+ attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern]
+ attr["args"] = [A_matmul, B_matmul]
+ attr["arg0_shape"] = A_matmul.shape
+ attr["arg1_shape"] = B_matmul.shape
+ attr["ret_shape"] = C_matmul.shape
+ attr["lhs_arg_idx"] = 0
+ attr["rhs_arg_idx"] = 1
+ elif len(match_results) >= 1 and match_results[1].pattern in
CONV2D_LIST:
+ A_input = match_results[0].matched_buffers[0]
+ A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers
+ attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern]
+ attr["args"] = [A_input, B_conv2d]
+ attr["arg0_shape"] = A_input.shape
+ attr["arg1_shape"] = B_conv2d.shape
+ attr["ret_shape"] = C_conv2d.shape
+ attr["lhs_arg_idx"] = 0
+ attr["rhs_arg_idx"] = 1
+ else:
+ return ["", 0]
+
+ # add profiler into attr
+ attr["gemm_profiler"] = gemm_profiler
+ attr["conv2d_profiler"] = conv2d_profiler
+
+ cutlass_patterns = [
+ # 9
+ batch_matmul_bias_gelu,
+ # 4
+ conv2d_bias_residual_add,
+ # 3
+ batch_matmul_bias_residual_mul,
+ matmul_bias_relu,
+ conv2d_bias,
+ # 2
+ matmul_bias,
+ batch_matmul_bias,
+ conv2d,
+ # 1
+ matmul,
+ batch_matmul,
+ ]
+ for pattern in cutlass_patterns:
+ res = pattern(match_results, attr)
+ if res is not None:
+ return res
+
+ return ["", 0]
+
+ return cutlass_codegen_with_match_results
+
+
+def cutlass_codegen_gemm(attrs):
+ """cutlass codegen for gemm"""
+ gemm_profiler = attrs["gemm_profiler"]
+ op_type = attrs["op_type"]
+ lhs_shape = attrs["arg0_shape"]
+ rhs_shape = attrs["arg1_shape"]
+ MM = lhs_shape[-2]
+ KK = lhs_shape[-1]
+ if "transposed" in op_type:
+ NN = rhs_shape[-2]
+ ldb = "K"
+ layout_b = LayoutType.ColumnMajor
+ else:
+ NN = rhs_shape[-1]
+ ldb = "N"
+ layout_b = LayoutType.RowMajor
+
+ lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+ rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+ if lhs_batches == 1 and rhs_batches == 1:
+ # Regular matmul
+ is_batched = False
+ batch_attrs = {}
+ else:
+ is_batched = True
+ batch_attrs = {
+ # If both lhs_batches and rhs_batches are greater than 1,
+ # they must be equal. This is checked by
is_shape_valid_for_cutlass_matmul.
+ "batch": lhs_batches if rhs_batches == 1 else rhs_batches,
+ "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
+ "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
+ "batch_stride_C": MM * NN,
+ }
+ op_name, op_def, _ = gemm_profiler.profile(
+ op_type,
+ MM,
+ NN,
+ KK,
+ attrs["ret_dtype"],
+ attrs["arg0_dtype"],
+ attrs["arg1_dtype"],
+ False,
+ batched=is_batched,
+ find_first_valid=False,
+ use_multiprocessing=True,
+ layout_b=layout_b,
+ )
+ attrs["cutlass_op_name"] = op_name
+ attrs["cutlass_op_def"] = op_def
+ attrs["lda"] = "K"
+ attrs["ldb"] = ldb
+ attrs["ldc"] = "N"
+ attrs.update(batch_attrs)
+ del attrs["gemm_profiler"]
+ del attrs["conv2d_profiler"]
+
+ nargs = 2
+ if "bias_arg_idx" in attrs:
+ nargs += 1
+ if "residual_arg_idx" in attrs:
+ nargs += 1
+ func_args = ["inp" + str(i) for i in range(nargs)]
+
+ # A temporary solution to handle batch matmul residual cases
+ # TODO(@bohan): remove this after initialize_template supports bmm residual
+ if op_type in [
+ "cutlass.matmul_bias_residual_multiply",
+ ]:
+
+ def _convert_dtype_str(dtype):
+ if isinstance(dtype, list):
+ arr = []
+ for t in dtype:
+ arr.append(_convert_dtype_str(t))
+ return arr
+ elif isinstance(dtype, str):
+ if dtype == "float16":
+ return "cutlass::half_t"
+ elif dtype == "float32":
+ return "float"
+ raise ValueError("dtype not supported")
+
+ typea, typeb, typec = _convert_dtype_str(
+ [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]]
+ )
+
+ text = f"""
+#define CUTLASS_ENABLE_CUBLAS 1
+#define CUTLASS_NAMESPACE cutlass
+#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
+#define NDEBUG
+#include <cutlass/cutlass.h>
+#include <cutlass/tensor_ref.h>
+#include <cutlass/util/host_tensor.h>
+#include <cutlass/gemm/device/gemm.h>
+#include <cutlass/gemm/device/gemm_batched.h>
+#include <cutlass/layout/matrix.h>
+#include <cutlass/numeric_types.h>
+#include "cutlass/epilogue/thread/activation.h"
+#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
+#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <vector>
+#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+namespace {{
+using namespace tvm;
+using namespace tvm::runtime;
+void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{
+ // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M,
N], C: [Batch, M, N]
+ CHECK_EQ(A->ndim, 3);
+ int bdim = B->ndim;
+ int bias_dim = Bias->ndim;
+ CHECK_EQ(C->ndim, 3);
+ CHECK_EQ(A->shape[2], B->shape[bdim - 2]);
+ CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]);
+ CHECK_EQ(D->ndim, 3);
+ CHECK_EQ(D->shape[0], A->shape[0]);
+ CHECK_EQ(D->shape[1], A->shape[1]);
+ CHECK_EQ(D->shape[2], B->shape[bdim - 1]);
+ CHECK_EQ(C->shape[0], A->shape[0]);
+ CHECK_EQ(C->shape[1], A->shape[1]);
+ CHECK_EQ(C->shape[2], B->shape[bdim - 1]);
+ int64_t M = A->shape[0] * A->shape[1];
+ int64_t N = B->shape[bdim - 1];
+ int64_t K = A->shape[2];
+ int64_t input_a_batch_stride = M * K;
+ int64_t input_a_stride = K;
+ int64_t input_a_offset = 0; // default to 0
+ int64_t input_b_batch_stride = K * N;
+ int64_t input_b_stride = N;
+ int64_t input_b_offset = 0; // default to 0
+ int64_t output_stride = N;
+ int64_t output_offset = 0;
+ int64_t a_size = 1;
+ a_size *= A->shape[0];
+ a_size *= A->shape[1];
+ a_size *= A->shape[2];
+
+ int64_t b_size = 1;
+ b_size *= B->shape[bias_dim - 2];
+ b_size *= B->shape[bias_dim - 1];
+
+ int64_t c_size = 1;
+ c_size *= C->shape[0];
+ c_size *= C->shape[1];
+ c_size *= C->shape[2];
+
+ // Define the GEMM operation
+ {op_def}
+ using kernel = Operation_{op_name};
+ using ElementComputeEpilogue = typename kernel::ElementAccumulator;
+ typename kernel::Arguments arguments({{
+ cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
+ {{M, N, K}}, // GemmCoord problem_size
+ 1, // int batch_count
+ {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename
EpilogueOutputOp::Params epilogue
+ ({typea}*)(A->data) + input_a_offset, // void const * ptr_A
+ ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B
+ ({typec}*)(D->data), // void const * ptr_C1
+ ({typec}*)(C->data) + output_offset, // void * ptr_D
+ ({typea}*)(Bias->data), // void * ptr_Vector
+ nullptr, // void * ptr_Tensor
+ input_a_batch_stride, // int64_t batch_stride_A
+ input_b_batch_stride, // int64_t batch_stride_B
+ 0, // int64_t batch_stride_C1
+ 0, // int64_t batch_stride_D
+ 0, // int64_t batch_stride_Vector
+ 0, // int64_t batch_stride_Tensor
+ input_a_stride, // typename LayoutA::Stride::Index lda
+ input_b_stride, // typename LayoutB::Stride::Index ldb
+ N, // typename LayoutC::Stride::Index ldc1
+ output_stride, // typename LayoutC::Stride::Index ldd
+ 0, // typename LayoutC::Stride::Index ldr
+ 0, // typename LayoutC::Stride::Index ldt
+ }});
+ kernel gemm_op;
+ size_t workspace_size = gemm_op.get_workspace_size(arguments);
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
+ cutlass::Status status = gemm_op.can_implement(arguments);
+ CHECK(status == cutlass::Status::kSuccess);
+ status = gemm_op.initialize(arguments, workspace.get());
+ CHECK(status == cutlass::Status::kSuccess);
+ status = gemm_op();
+ CHECK(status == cutlass::Status::kSuccess);
+ return;
+}}
+}} // namespace
+TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM);
+ """
+ return text
+
+ code = instantiate_template(op_type, attrs, func_args)
+ return _final_code(code.code, code.headers, func_args)
+
+
+def cutlass_codegen_conv2d(attrs):
+ """cutlass codegen for conv2d"""
+ # cutlass backend only supports nhwc for now
+ conv2d_profiler = attrs["conv2d_profiler"]
+ op_type = attrs["op_type"]
+ conv_kind = ConvKind.Fprop
+ op_name, op_def, _ = conv2d_profiler.profile(
+ op_type=attrs["op_type"],
+ d_shape=attrs["arg0_shape"],
+ w_shape=attrs["arg1_shape"],
+ padding=attrs["padding"],
+ stride=attrs["strides"],
+ dilation=attrs["dilation"],
+ out_dtype=attrs["ret_dtype"],
+ data_dtype=attrs["arg0_dtype"],
+ weight_dtype=attrs["arg1_dtype"],
+ use_3xtf32=False,
+ conv_kind=conv_kind,
+ split_k_slices=[1],
+ profile_all_alignments=True,
+ find_first_valid=False,
+ use_multiprocessing=True,
+ )
+ attrs["cutlass_op_def"] = op_def
+ attrs["cutlass_op_name"] = op_name
+ del attrs["gemm_profiler"]
+ del attrs["conv2d_profiler"]
+
+ nargs = 2
+ if "bias_arg_idx" in attrs:
+ nargs += 1
+ if "residual_arg_idx" in attrs:
+ nargs += 1
+ func_args = ["inp" + str(i) for i in range(nargs)]
+ code = instantiate_template(op_type, attrs, func_args)
+ return _final_code(code.code, code.headers, func_args)
diff --git a/python/tvm/relax/backend_tir/pattern.py
b/python/tvm/relax/backend_tir/pattern.py
new file mode 100644
index 0000000000..10f7a3b162
--- /dev/null
+++ b/python/tvm/relax/backend_tir/pattern.py
@@ -0,0 +1,576 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring,chained-comparison
+"""TIR Patterns"""
+from typing import List
+
+import tvm
+from tvm.runtime import Object
+import tvm._ffi
+
+from tvm.script import tir as T
+
+
+@tvm._ffi.register_object("relax.MatchResult")
+class MatchResult(Object):
+ """The match result of a TIR pattern."""
+
+ def __init__(self, pattern, symbol_values, matched_buffers):
+ self.__init_handle_by_constructor__(
+ tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers
+ )
+
+
[email protected]_func
+def matmul_rrr_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_matmul: T.handle,
+ M: T.int64,
+ N: T.int64,
+ K: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N],
dtype="float16")
+ matmul = T.match_buffer(var_matmul, [M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1, i2 in T.grid(M, N, K):
+ with T.block("matmul"):
+ i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1])
+ T.writes(matmul[i0_1, i1_1])
+ with T.init():
+ matmul[i0_1, i1_1] = T.float16(0)
+ matmul[i0_1, i1_1] = (
+ matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] *
rxplaceholder_1[k, i1_1]
+ )
+
+
[email protected]_func
+def bias_row_2d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_T_add: T.handle,
+ M: T.int64,
+ N: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N],
dtype="float16")
+ T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1 in T.grid(M, N):
+ with T.block("T_add"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1])
+ T.writes(T_add[ax0, ax1])
+ T_add[ax0, ax1] = rxplaceholder[ax0, ax1] +
rxplaceholder_1[T.int64(0), ax1]
+
+
[email protected]_func
+def bias_row_1d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_T_add: T.handle,
+ M: T.int64,
+ N: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
+ T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1 in T.grid(M, N):
+ with T.block("T_add"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
+ T.writes(T_add[ax0, ax1])
+ T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
+
+
[email protected]_func
+def batch_bias_row_2d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_T_add: T.handle,
+ batch: T.int64,
+ M: T.int64,
+ N: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N],
dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N],
dtype="float16")
+ T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1, i2 in T.grid(batch, M, N):
+ with T.block("T_add"):
+ ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0),
ax2])
+ T.writes(T_add[ax0, ax1, ax2])
+ T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] +
rxplaceholder_1[T.int64(0), ax2]
+
+
[email protected]_func
+def batch_bias_row_1d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_T_add: T.handle,
+ batch: T.int64,
+ M: T.int64,
+ N: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N],
dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
+ T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1, i2 in T.grid(batch, M, N):
+ with T.block("T_add"):
+ ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2])
+ T.writes(T_add[ax0, ax1, ax2])
+ T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] +
rxplaceholder_1[ax2]
+
+
[email protected]_func
+def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64,
N: T.int64) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
+ compute = T.match_buffer(var_compute, [M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1 in T.grid(M, N):
+ with T.block("compute"):
+ i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[i0_1, i1_1])
+ T.writes(compute[i0_1, i1_1])
+ compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1],
T.float16(0))
+
+
[email protected]_func
+def batch_matmul_rrr_2d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_matmul: T.handle,
+ batch: T.int64,
+ M: T.int64,
+ N: T.int64,
+ K: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K],
dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N],
dtype="float16")
+ matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1, i2, i3 in T.grid(batch, M, N, K):
+ with T.block("matmul"):
+ i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1])
+ T.writes(matmul[i0_1, i1_1, i2_1])
+ with T.init():
+ matmul[i0_1, i1_1, i2_1] = T.float16(0)
+ matmul[i0_1, i1_1, i2_1] = (
+ matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] *
rxplaceholder_1[k, i2_1]
+ )
+
+
[email protected]_func
+def batch_matmul_rrr_3d_fp16(
+ var_rxplaceholder: T.handle,
+ var_rxplaceholder_1: T.handle,
+ var_matmul: T.handle,
+ batch: T.int64,
+ M: T.int64,
+ N: T.int64,
+ K: T.int64,
+) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K],
dtype="float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N],
dtype="float16")
+ matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
+ # body
+ # with T.block("root")
+ for i0, i1, i2, i3 in T.grid(batch, M, N, K):
+ with T.block("matmul"):
+ i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k,
i2_1])
+ T.writes(matmul[i0_1, i1_1, i2_1])
+ with T.init():
+ matmul[i0_1, i1_1, i2_1] = T.float16(0)
+ matmul[i0_1, i1_1, i2_1] = (
+ matmul[i0_1, i1_1, i2_1]
+ + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1]
+ )
+
+
[email protected]_func
+def copy_4d_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ N: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+) -> None:
+ A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+ B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
+ # body
+ # with T.block("root")
+ for n, h, w, c in T.grid(N, H, W, C):
+ with T.block("copy"):
+ vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c])
+ T.reads(A[vn, vh, vw, vc])
+ T.writes(B[vn, vh, vw, vc])
+ B[vn, vh, vw, vc] = A[vn, vh, vw, vc]
+
+
[email protected]_func
+def padding_2d_nhwc_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ N: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+ pH: T.int64,
+ pW: T.int64,
+ lH: T.int64,
+ lW: T.int64,
+ rH: T.int64,
+ rW: T.int64,
+) -> None:
+ A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+ B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16")
+ # body
+ # with T.block("root")
+ for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C):
+ with T.block("copy"):
+ v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3])
+ T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7])
+ T.writes(B[v_4, v_5, v_6, v_7])
+ B[v_4, v_5, v_6, v_7] = T.if_then_else(
+ lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW,
+ A[v_4, v_5 - lH, v_6 - lW, v_7],
+ T.float16(0),
+ dtype="float16",
+ )
+
+
[email protected]_func
+def conv2d_nhwc_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ out_handle: T.handle,
+ N: T.int64,
+ pH: T.int64,
+ pW: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+ O: T.int64,
+ KH: T.int64,
+ KW: T.int64,
+ StrideH: T.int64,
+ StrideW: T.int64,
+ DilateH: T.int64,
+ DilateW: T.int64,
+) -> None:
+ A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16")
+ B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16")
+ out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16")
+ # body
+ # with T.block("root")
+ for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C):
+ with T.block("conv"):
+ v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap(
+ "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6]
+ )
+ T.reads(
+ A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 *
StrideW, v_13],
+ B[v_10, v_11, v_12, v_13],
+ )
+ T.writes(out[v_7, v_8, v_9, v_10])
+ with T.init():
+ out[v_7, v_8, v_9, v_10] = T.float16(0)
+ out[v_7, v_8, v_9, v_10] = (
+ out[v_7, v_8, v_9, v_10]
+ + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9
* StrideW, v_13]
+ * B[v_10, v_11, v_12, v_13]
+ )
+
+
[email protected]_func
+def bias_add_nhwc_2d_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ out_handle: T.handle,
+ N: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+):
+ A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+ B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
+ out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+ for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0),
T.int64(0), v_ax3])
+ T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+ out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0),
T.int64(0), v_ax3]
+ )
+
+
[email protected]_func
+def bias_add_nhwc_1d_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ out_handle: T.handle,
+ N: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+):
+ A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+ B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
+ out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+ for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0),
T.int64(0), v_ax3])
+ T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+ out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0),
T.int64(0), v_ax3]
+ )
+
+
[email protected]_func
+def elem_add_2d_fp16(
+ in0_handle: T.handle,
+ in1_handle: T.handle,
+ out_handle: T.handle,
+ N: T.int64,
+ M: T.int64,
+):
+ in0 = T.match_buffer(in0_handle, [N, M], dtype="float16")
+ in1 = T.match_buffer(in1_handle, [N, M], dtype="float16")
+ out = T.match_buffer(out_handle, [N, M], dtype="float16")
+ for ax0, ax1 in T.grid(N, M):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1])
+ T.writes(out[v_ax0, v_ax1])
+ out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1]
+
+
[email protected]_func
+def elem_add_3d_fp16(
+ in0_handle: T.handle,
+ in1_handle: T.handle,
+ out_handle: T.handle,
+ B: T.int64,
+ N: T.int64,
+ M: T.int64,
+):
+ in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16")
+ in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16")
+ out = T.match_buffer(out_handle, [B, N, M], dtype="float16")
+ for ax0, ax1, ax2 in T.grid(B, N, M):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0,
v_ax1, v_ax2]
+
+
[email protected]_func
+def elem_add_4d_fp16(
+ A_handle: T.handle,
+ B_handle: T.handle,
+ out_handle: T.handle,
+ N: T.int64,
+ H: T.int64,
+ W: T.int64,
+ C: T.int64,
+):
+ A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
+ B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
+ out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
+ for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
+ out[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3]
+ )
+
+
[email protected]_func
+def scalar_mul_3d_fp16(
+ inp0_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+ scalar: T.float16,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_mul"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar
+
+
[email protected]_func
+def erf_3d_fp32(
+ inp0_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_erf"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2])
+
+
[email protected]_func
+def scalar_add_3d_fp16(
+ inp0_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+ scalar: T.float16,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2]
+
+
[email protected]_func
+def elem_mul_3d_fp16(
+ inp0_handle: T.handle,
+ inp1_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+ inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_mul"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0,
v_ax1, v_ax2]
+
+
[email protected]_func
+def cast_3d_fp16(
+ inp0_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_cast"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1,
v_ax2])
+
+
[email protected]_func
+def cast_3d_fp32(
+ inp0_handle: T.handle,
+ out_handle: T.handle,
+ D1: T.int64,
+ D2: T.int64,
+ D3: T.int64,
+):
+ inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
+ out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
+ for ax0, ax1, ax2 in T.grid(D1, D2, D3):
+ with T.block("T_cast"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(inp0[v_ax0, v_ax1, v_ax2])
+ T.writes(out[v_ax0, v_ax1, v_ax2])
+ out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1,
v_ax2])
+
+
+def get_tir_pattern() -> List[tvm.tir.PrimFunc]:
+ """Get the tir patterns for backend dispatch."""
+ return [
+ matmul_rrr_fp16,
+ bias_row_2d_fp16,
+ bias_row_1d_fp16,
+ batch_bias_row_2d_fp16,
+ batch_bias_row_1d_fp16,
+ relu_fp16,
+ erf_3d_fp32,
+ batch_matmul_rrr_2d_fp16,
+ batch_matmul_rrr_3d_fp16,
+ copy_4d_fp16,
+ padding_2d_nhwc_fp16,
+ conv2d_nhwc_fp16,
+ bias_add_nhwc_2d_fp16,
+ bias_add_nhwc_1d_fp16,
+ elem_add_2d_fp16,
+ elem_add_3d_fp16,
+ elem_add_4d_fp16,
+ elem_mul_3d_fp16,
+ scalar_add_3d_fp16,
+ scalar_mul_3d_fp16,
+ cast_3d_fp16,
+ cast_3d_fp32,
+ ]
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index c03df804ee..18321e8dba 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -693,6 +693,25 @@ def ToMixedPrecision(out_dtype="float32") ->
tvm.ir.transform.Pass:
return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore
+def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass:
+ """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is
+ matched with some pattern, and the second part is the rest of the
original
+ PrimFunc. It will call fcodegen to generate the code for the matched
pattern
+ to replace it with a ExternFunc call.
+ Parameters
+ ----------
+ patterns : List[PrimFunc]
+ The list of patterns to match.
+ fcodegen: Callable[[List[MatchResult]], List[Object]]
+ The function to generate the code for the matched patterns.
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for splitting call_tir.
+ """
+ return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index da0ca3a0b5..b36b5ed4d6 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -315,6 +315,17 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
}
Instruction::Arg VisitExpr_(const ExternFuncNode* op) final {
+ static const constexpr char* kCSource = "c_source";
+ static const constexpr char* kCSourceFmt = "c_source_fmt";
+ if (Optional<String> opt_code = op->attrs.GetAttr<String>(kCSource)) {
+ String sym = op->global_symbol;
+ String fmt = op->attrs.GetAttr<String>(kCSourceFmt).value_or("c");
+ String code = opt_code.value();
+ Module c_source_module =
+ codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt,
/*func_names=*/{sym},
+ /*const_vars=*/{});
+ builder_->exec()->Import(c_source_module);
+ }
builder_->DeclareFunction(op->global_symbol,
VMFuncInfo::FuncKind::kPackedFunc);
return builder_->GetFunction(op->global_symbol);
}
diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc
new file mode 100644
index 0000000000..cbe4170bb9
--- /dev/null
+++ b/src/relax/ir/tir_pattern.cc
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/tir_pattern.h>
+
+namespace tvm {
+namespace relax {
+
+MatchResult::MatchResult(TIRPattern pattern, Array<PrimExpr> symbol_values,
+ Array<tir::Buffer> matched_buffers) {
+ auto n = make_object<MatchResultNode>();
+ n->pattern = std::move(pattern);
+ n->symbol_values = std::move(symbol_values);
+ n->matched_buffers = std::move(matched_buffers);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(MatchResultNode);
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/split_call_tir_by_pattern.cc
b/src/relax/transform/split_call_tir_by_pattern.cc
new file mode 100644
index 0000000000..7fcc2cb34a
--- /dev/null
+++ b/src/relax/transform/split_call_tir_by_pattern.cc
@@ -0,0 +1,782 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_non_dataflow.cc
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir/module.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/tir_pattern.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../tir/schedule/ir_comparator.h"
+
+namespace tvm {
+
+static const constexpr char* kLibraryKernel = "library_kernel";
+static const constexpr char* kCSource = "c_source";
+static const constexpr char* kCSourceFmt = "c_source_fmt";
+static const constexpr char* kCSourceFmtCuda = "cu";
+
+namespace tir {
+
+using relax::FCodegen;
+using relax::MatchResult;
+using relax::TIRPattern;
+
+/*! \brief helper to match a for stmt to a pattern*/
+class ForMatcher : public TensorizeComparator {
+ public:
+ using SymbolMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash,
ObjectPtrEqual>;
+ explicit ForMatcher(const tir::PrimFunc& pattern, const Array<Var>&
pattern_vars)
+ : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false),
pattern_(pattern) {
+ for (const auto& pattern_var : pattern_vars) {
+ this->pattern_vars_.insert(pattern_var);
+ }
+ this->evaluated_symbols.push_back(SymbolMap());
+ }
+
+ bool Match(const For& top) {
+ const ForNode* pattern_top =
pattern_->body.as<BlockRealizeNode>()->block->body.as<ForNode>();
+ ICHECK(pattern_top) << "Invalid pattern function";
+ if (!VisitStmt(top, GetRef<Stmt>(pattern_top))) {
+ return false;
+ }
+ // Get evaluated symbols, buffers from the pattern.
+ for (const auto& arg : pattern_->params) {
+ auto it = pattern_->buffer_map.find(arg);
+ if (it != pattern_->buffer_map.end()) {
+ auto itt = rhs_buffer_map_.find((*it).second);
+ ICHECK(itt != rhs_buffer_map_.end());
+ evaluated_buffers.push_back(itt->second);
+ }
+ }
+ return true;
+ }
+
+ std::vector<SymbolMap> evaluated_symbols;
+ std::vector<Buffer> evaluated_buffers;
+
+ private:
+ using ExprComparator::VisitExpr_;
+
+ Optional<PrimExpr> QueryEvaluatedSymbols(const Var& var) {
+ for (const SymbolMap& symbol_map : evaluated_symbols) {
+ auto it = symbol_map.find(var);
+ if (it != symbol_map.end()) {
+ return it->second;
+ }
+ }
+ return NullOpt;
+ }
+
+ bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
+ if (const auto* op = rhs.as<VarNode>()) {
+ if (pattern_vars_.count(GetRef<Var>(op))) {
+ // special case for pattern vars
+ const auto* lhs_ptr = lhs.as<VarNode>();
+ if (lhs_ptr == nullptr) {
+ if (lhs->IsInstance<tir::IntImmNode>() ||
lhs->IsInstance<tir::FloatImmNode>()) {
+ Optional<PrimExpr> value = QueryEvaluatedSymbols(GetRef<Var>(op));
+ if (value.defined()) {
+ if (!analyzer_.CanProveEqual(lhs, value.value())) return false;
+ } else {
+ evaluated_symbols.back()[GetRef<Var>(op)] = lhs;
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+ }
+ }
+ // pattern_var * expr
+ if (const auto* rhs_ptr = rhs.as<MulNode>()) {
+ const auto* operand_a = rhs_ptr->a.as<VarNode>();
+ const auto* operand_b = rhs_ptr->b.as<VarNode>();
+ if (operand_a != nullptr && pattern_vars_.count(GetRef<Var>(operand_a)))
{
+ // pattern var is on the left
+ evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(lhs, rhs_ptr->b);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ evaluated_symbols.pop_back();
+ if (match) {
+ evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ evaluated_symbols.back()[GetRef<Var>(operand_a)] =
MakeConstScalar(rhs_ptr->b.dtype(), 1);
+ return true;
+ }
+ }
+ if (operand_b != nullptr && pattern_vars_.count(GetRef<Var>(operand_b)))
{
+ // pattern var is on the right
+ evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(lhs, rhs_ptr->a);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ evaluated_symbols.pop_back();
+ if (match) {
+ evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ evaluated_symbols.back()[GetRef<Var>(operand_b)] =
MakeConstScalar(rhs_ptr->a.dtype(), 1);
+ return true;
+ }
+ }
+ }
+ // pattern_Var + expr
+ if (const auto* rhs_ptr = rhs.as<AddNode>()) {
+ const auto* operand_a = rhs_ptr->a.as<VarNode>();
+ const auto* operand_b = rhs_ptr->b.as<VarNode>();
+ if (operand_a != nullptr && pattern_vars_.count(GetRef<Var>(operand_a)))
{
+ // pattern var is on the left
+ evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(lhs, rhs_ptr->b);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ evaluated_symbols.pop_back();
+ if (match) {
+ evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ evaluated_symbols.back()[GetRef<Var>(operand_a)] =
MakeConstScalar(rhs_ptr->b.dtype(), 0);
+ return true;
+ }
+ }
+ if (operand_b != nullptr && pattern_vars_.count(GetRef<Var>(operand_b)))
{
+ // pattern var is on the right
+ evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(lhs, rhs_ptr->a);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ evaluated_symbols.pop_back();
+ if (match) {
+ evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ evaluated_symbols.back()[GetRef<Var>(operand_b)] =
MakeConstScalar(rhs_ptr->a.dtype(), 0);
+ return true;
+ }
+ }
+ }
+ return TensorizeComparator::VisitExpr(lhs, rhs);
+ }
+
+ bool VisitExpr_(const tir::AddNode* add, const PrimExpr& other) final {
+ const auto* rhs = other.as<AddNode>();
+ if (rhs == nullptr) return false;
+ {
+ this->evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(add->a, rhs->a) && VisitExpr(add->b, rhs->b);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ this->evaluated_symbols.pop_back();
+ if (match) {
+ this->evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ return true;
+ }
+ }
+ {
+ this->evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(add->a, rhs->b) && VisitExpr(add->b, rhs->a);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ this->evaluated_symbols.pop_back();
+ if (match) {
+ this->evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool VisitExpr_(const tir::MulNode* mul, const PrimExpr& other) final {
+ const auto* rhs = other.as<MulNode>();
+ if (rhs == nullptr) return false;
+ {
+ this->evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(mul->a, rhs->a) && VisitExpr(mul->b, rhs->b);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ this->evaluated_symbols.pop_back();
+ if (match) {
+ this->evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ return true;
+ }
+ }
+ {
+ this->evaluated_symbols.push_back(SymbolMap());
+ bool match = VisitExpr(mul->a, rhs->b) && VisitExpr(mul->b, rhs->a);
+ SymbolMap symbol_map = std::move(evaluated_symbols.back());
+ this->evaluated_symbols.pop_back();
+ if (match) {
+ this->evaluated_symbols.back().insert(symbol_map.begin(),
symbol_map.end());
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool VisitExpr_(const tir::CallNode* call, const PrimExpr& other) final {
+ const auto* rhs = other.as<CallNode>();
+ if (rhs == nullptr) return false;
+ const auto* lhs_op = call->op.as<OpNode>();
+ const auto* rhs_op = rhs->op.as<OpNode>();
+ if (lhs_op == nullptr || rhs_op == nullptr) return false;
+ if (lhs_op->name != rhs_op->name) return false;
+ if (call->args.size() != rhs->args.size()) return false;
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ if (!VisitExpr(call->args[i], rhs->args[i])) return false;
+ }
+ return true;
+ }
+
+ bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final {
+ const auto* rhs = other.as<ForNode>();
+ loop_stack_lhs_.push_back(GetRef<For>(op));
+ loop_stack_rhs_.push_back(GetRef<For>(rhs));
+ // The body of loop must be loop or BlockRealize
+ if (!op->body->IsInstance<BlockRealizeNode>() &&
!op->body->IsInstance<ForNode>()) {
+ return false;
+ }
+ if (!rhs->body->IsInstance<BlockRealizeNode>() &&
!rhs->body->IsInstance<ForNode>()) {
+ return false;
+ }
+ // Build mapping between the loop vars
+ if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+ // Only handle the case where the loop start from 0
+ if (!is_zero(op->min) || !is_zero(rhs->min)) return false;
+ if (op->thread_binding.defined() || rhs->thread_binding.defined()) return
false;
+ if (op->kind != ForKind::kSerial || op->kind != rhs->kind) return false;
+ if (!op->annotations.empty() || !rhs->annotations.empty()) return false;
+ // Match the extents of loops
+ if (!VisitExpr(op->extent, rhs->extent)) return false;
+ return VisitStmt(op->body, rhs->body);
+ }
+
+ bool VisitStmt_(const tir::BlockNode* op, const Stmt& other) final {
+ const auto* rhs = other.as<BlockNode>();
+ // Check block equality.
+ // All iter vars and buffer regions including the order should match.
+ // When checking iter vars, DefEqual is used to remap variables.
+ if (!CompareArray(op->iter_vars, rhs->iter_vars,
&ForMatcher::CompareIterVar)) {
+ return false;
+ }
+ // disallow alloc buffers inside the block
+ if (!op->alloc_buffers.empty() || !rhs->alloc_buffers.empty()) return
false;
+ if (!CompareArray(op->writes, rhs->writes,
&ForMatcher::CompareBufferRegion)) {
+ return false;
+ }
+ if (!CompareArray(op->reads, rhs->reads,
&ForMatcher::CompareBufferRegion)) {
+ return false;
+ }
+ // The body of the block has to be BufferStore
+ if (!op->body->IsInstance<BufferStoreNode>() ||
!rhs->body->IsInstance<BufferStoreNode>()) {
+ return false;
+ }
+ // Handle init block
+ if (op->init.defined() && !rhs->init.defined()) return false;
+ if (!op->init.defined() && rhs->init.defined()) return false;
+ if (op->init.defined() && rhs->init.defined()) {
+ if (!VisitStmt(op->init.value(), rhs->init.value())) return false;
+ }
+ return VisitStmt(op->body, rhs->body);
+ }
+
+ bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) final {
+ const auto* rhs = other.as<BlockRealizeNode>();
+ // Only allow trivial bindings
+ for (size_t i = 0; i < op->iter_values.size(); ++i) {
+ if (!op->iter_values[i].same_as(loop_stack_lhs_[i]->loop_var)) return
false;
+ }
+ for (size_t i = 0; i < rhs->iter_values.size(); ++i) {
+ if (!rhs->iter_values[i].same_as(loop_stack_rhs_[i]->loop_var)) return
false;
+ }
+ // Disallow predicates now
+ if (!is_one(op->predicate) || !is_one(rhs->predicate)) return false;
+ return VisitStmt(op->block, rhs->block);
+ }
+
+ bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) {
+ const auto* rhs = other.as<BufferStoreNode>();
+ return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+ }
+
+ bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<BufferLoadNode>();
+ return CompareBufferAccess(op, rhs);
+ }
+
+ bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
+ if (lhs.same_as(rhs)) return true;
+ auto it = rhs_buffer_map_.find(rhs);
+ bool equal;
+ if (it != rhs_buffer_map_.end()) {
+ equal = (*it).second.same_as(lhs);
+ } else {
+ // Compare shape
+ if (lhs->shape.size() != rhs->shape.size()) return false;
+ for (size_t i = 0; i < lhs->shape.size(); ++i) {
+ if (!VisitExpr(lhs->shape[i], rhs->shape[i])) return false;
+ }
+ // Remap both buffer itself and buffer data
+ equal =
+ DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype &&
lhs.scope() == rhs.scope();
+ if (equal) {
+ rhs_buffer_map_[rhs] = lhs;
+ }
+ }
+ return equal;
+ }
+
+ bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) {
+ if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
+ return false;
+ }
+ return CompareArray(lhs->region, rhs->region, &ForMatcher::CompareRange);
+ }
+
+ template <typename T>
+ bool CompareBufferAccess(const T* lhs, const T* rhs) {
+ if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+ return CompareArray(lhs->indices, rhs->indices, &ForMatcher::VisitExpr);
+ }
+
+ template <typename T, typename Self, typename F>
+ bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp) {
+ if (lhs.same_as(rhs)) return true;
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
+ }
+ return true;
+ }
+
+ arith::Analyzer analyzer_;
+ std::vector<For> loop_stack_lhs_, loop_stack_rhs_;
+ tir::PrimFunc pattern_;
+ std::unordered_set<Var, ObjectHash, ObjectEqual> pattern_vars_;
+};
+
+/*! \brief Analyze the function and match it with a list of patterns */
+class TIRPatternMatcher {
+ public:
+ static Array<MatchResult> Match(Array<TIRPattern> patterns, Stmt body) {
+ TIRPatternMatcher matcher(patterns);
+ matcher.OpMatternMatch(body);
+ if (matcher.fail_) return {};
+ return matcher.match_results_;
+ }
+
+ private:
+ explicit TIRPatternMatcher(Array<TIRPattern> patterns) : patterns_(patterns)
{}
+
+ // Find an op that matches this block
+ bool BlockPatternMatch(const For& top) {
+ for (const TIRPattern& pattern : patterns_) {
+ tir::PrimFunc pattern_func = pattern;
+ Array<Var> pattern_symbolic_vars;
+ int buffer_count = pattern_func->buffer_map.size();
+ for (int i = buffer_count; i <
static_cast<int>(pattern_func->params.size()); i++) {
+ pattern_symbolic_vars.push_back(pattern_func->params[i]);
+ }
+ ForMatcher block_matcher(pattern_func, pattern_symbolic_vars);
+ if (block_matcher.Match(top)) {
+ // We have found a match
+ Array<PrimExpr> symbol_values;
+ for (int i = buffer_count; i <
static_cast<int>(pattern_func->params.size()); i++) {
+
symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]);
+ }
+ match_results_.push_back(
+ MatchResult(pattern, symbol_values,
block_matcher.evaluated_buffers));
+ return true;
+ }
+ }
+ // The block fails to match any pattern
+ return false;
+ }
+
+ // For each block in the body, try to find its corresponding pattern one by
one
+ void OpMatternMatch(const Stmt& body) {
+ Array<Stmt> blocks;
+ if (body->IsInstance<ForNode>()) {
+ // {for}
+ blocks = {body};
+ } else if (const SeqStmtNode* seq = body.as<SeqStmtNode>()) {
+ blocks = seq->seq;
+ } else {
+ fail_ = true;
+ return;
+ }
+ for (const Stmt& stmt : blocks) {
+ const ForNode* loop = stmt.as<ForNode>();
+ if (loop == nullptr || !BlockPatternMatch(GetRef<For>(loop))) {
+ break;
+ }
+ }
+ if (match_results_.empty()) {
+ fail_ = true;
+ }
+ }
+ /*! \brief Indicate whether we fail to match.*/
+ bool fail_ = false;
+ /*! \brief The patterns we match the target stmt to.*/
+ Array<TIRPattern> patterns_;
+ /*! \brief The results of the matching process.*/
+ Array<MatchResult> match_results_;
+};
+
+/*! \brief helper class to partition a function into 2 parts. Return function
information which we
+ * can use to construct the two partitioned parts.*/
+class FunctionPartitioner : public StmtExprVisitor {
+ public:
+ explicit FunctionPartitioner(int num_matched_ops) :
num_matched_ops_(num_matched_ops) {}
+ /*! \brief alloc_buffers for the first function */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs1;
+ /*! \brief alloc_buffers for the second function */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs2;
+ /*! \brief whether the current block is in the first function */
+ Map<Block, Bool> block_partition;
+ /*! \brief input buffers for the first function */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> input1;
+ /*! \brief input buffers for the second function */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> input2;
+ /*! \brief The output buffer for the first function, which is also the input
buffer for the second
+ function */
+ Buffer intermediate_buffer;
+ /*! \brief Indicate whether we have failed. If failed, we will not do any
further analysis and
+ directly return the original one. */
+ bool fail = false;
+
+ private:
+ void VisitStmt_(const BlockNode* op) final {
+ block_counter_++;
+ bool is_matching_ = block_counter_ <= num_matched_ops_;
+ if (block_counter_ == num_matched_ops_) {
+ allocs1.erase(intermediate_buffer);
+ }
+ for (const auto& read : op->reads) {
+ if (is_matching_) {
+ input1.insert(read->buffer);
+ } else {
+ input2.insert(read->buffer);
+ }
+ }
+ for (const auto& write : op->writes) {
+ if (is_matching_) {
+ allocs1.insert(write->buffer);
+ } else if (allocs1.count(write->buffer)) {
+ fail = true;
+ return;
+ } else {
+ allocs2.insert(write->buffer);
+ }
+ if (is_matching_) {
+ intermediate_buffer = write->buffer;
+ } else {
+ input2.insert(write->buffer);
+ }
+ }
+ block_partition.Set(GetRef<Block>(op), Bool(is_matching_));
+ }
+ // The number of matched ops in the function
+ size_t num_matched_ops_;
+ size_t block_counter_ = 0;
+};
+
+/*! \brief remove parts according to block partition, and update the
alloc_buffers for blocks */
+class BlockRemover : public StmtExprMutator {
+ public:
+ static Stmt RemoveBlockByPartition(
+ Stmt stmt, const Map<Block, Bool>& block_partition,
+ const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& allocs,
+ bool is_library_part) {
+ BlockRemover remover(block_partition, allocs, is_library_part);
+ return remover(stmt);
+ }
+
+ private:
+ BlockRemover(const Map<Block, Bool>& block_partition,
+ const std::unordered_set<Buffer, ObjectPtrHash,
ObjectPtrEqual>& allocs,
+ bool is_library_part)
+ : block_partition(block_partition), allocs_(allocs),
is_library_part_(is_library_part) {}
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*block.operator->());
+ if (op->name_hint != "root") {
+ ICHECK(block_partition.count(GetRef<Block>(op)));
+ bool block_is_library = block_partition[GetRef<Block>(op)]->value;
+ if (!(is_library_part_ ^ block_is_library)) {
+ n->body = block->body;
+ } else {
+ erased_ = true;
+ }
+ }
+ Array<Buffer> alloc_buffers;
+ for (const Buffer& b : block->alloc_buffers) {
+ if (allocs_.count(b)) {
+ alloc_buffers.push_back(b);
+ }
+ }
+ n->alloc_buffers = alloc_buffers;
+ return Block(n);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* op) final {
+ Array<Stmt> seq;
+ for (const Stmt& s : op->seq) {
+ Stmt new_s = VisitStmt(s);
+ if (erased_) {
+ erased_ = false;
+ } else {
+ seq.push_back(new_s);
+ }
+ }
+ return SeqStmt::Flatten(seq);
+ }
+
+ bool erased_ = false;
+ Map<Block, Bool> block_partition;
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocs_;
+ bool is_library_part_ = false;
+};
+
+/*!
+ * \brief Split the input function into two functions, one for the library
kernel and one for the
+ * rest.
+ * \param func The input function.
+ * \param arg_partition The input arg for the functions after split.
+ * \param patterns The patterns to match.
+ * \param f_codegen The function to generate the code for the library kernel.
+ * \return A pair of functions, the first one is the library kernel and the
second one is the
+ * rest.
+ */
+std::pair<PrimFunc, Optional<PrimFunc>> SplitFunctions(PrimFunc func,
+
std::vector<std::vector<int>>* arg_partition,
+ Array<TIRPattern>
patterns,
+ FCodegen f_codegen) {
+ // Step 1. Find the library kernel and the rest.
+ Stmt body = func->body.as<BlockRealizeNode>()->block->body;
+ Array<MatchResult> match_results =
+ TIRPatternMatcher::Match(patterns,
func->body.as<BlockRealizeNode>()->block->body);
+ if (match_results.empty()) {
+ return {func, NullOpt};
+ }
+ Array<ObjectRef> codegen_result = f_codegen(match_results);
+ ICHECK(codegen_result.size() == 3);
+ String library_code = Downcast<String>(codegen_result[0]);
+ int num_matched_ops = Downcast<Integer>(codegen_result[1])->value;
+ Array<Buffer> func1_args = Downcast<Array<Buffer>>(codegen_result[2]);
+ if (num_matched_ops == 0) {
+ return {func, NullOpt};
+ }
+ FunctionPartitioner partitioner(num_matched_ops);
+ partitioner(body);
+ if (partitioner.fail) {
+ return {func, NullOpt};
+ }
+ bool has_second_func = false;
+ for (const auto& pr : partitioner.block_partition) {
+ if (!pr.second->value) {
+ has_second_func = true;
+ break;
+ }
+ }
+ if (!has_second_func) {
+ // No need to split the function.
+ return {WithAttr(func, kLibraryKernel, library_code), NullOpt};
+ }
+ // Step 2. Split the function into two functions.
+ Stmt body1 = BlockRemover::RemoveBlockByPartition(func->body,
partitioner.block_partition,
+ partitioner.allocs1, true);
+ Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body,
partitioner.block_partition,
+ partitioner.allocs2,
false);
+ // Step 3. Craft the first function.
+ Array<Var> new_params1;
+ std::vector<int> arg_partition1;
+ ICHECK_LE(func1_args.size(), partitioner.input1.size());
+ for (const auto& buffer : func1_args) {
+ ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end());
+ for (size_t i = 0; i < func->params.size(); i++) {
+ if (func->buffer_map[func->params[i]].same_as(buffer)) {
+ new_params1.push_back(func->params[i]);
+ arg_partition1.push_back(i);
+ break;
+ }
+ }
+ }
+ arg_partition->push_back(arg_partition1);
+ new_params1.push_back(Var("output", DataType::Handle()));
+ Map<Var, Buffer> new_buffer_map1;
+ for (const auto& kv : func->buffer_map) {
+ if (partitioner.input1.count(kv.second)) {
+ new_buffer_map1.Set(kv.first, kv.second);
+ }
+ }
+ new_buffer_map1.Set(new_params1.back(), partitioner.intermediate_buffer);
+ PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type,
new_buffer_map1, func->attrs);
+ func1 = WithAttr(func1, kLibraryKernel, library_code);
+ // Step 4. Craft the second function.
+ Array<Var> new_params2;
+ std::vector<int> arg_partition2;
+ new_params2.push_back(Var("input", DataType::Handle()));
+ for (int i = 0; i < static_cast<int>(func->params.size()); i++) {
+ Var param = func->params[i];
+ if (partitioner.input2.count(func->buffer_map[param])) {
+ new_params2.push_back(param);
+ if (i != static_cast<int>(func->params.size()) - 1) {
+ arg_partition2.push_back(i);
+ }
+ }
+ }
+ arg_partition->push_back(arg_partition2);
+ Map<Var, Buffer> new_buffer_map2;
+ new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer);
+ for (const auto& kv : func->buffer_map) {
+ if (partitioner.input2.count(kv.second)) {
+ new_buffer_map2.Set(kv.first, kv.second);
+ }
+ }
+ PrimFunc func2 = PrimFunc(new_params2, body2, func->ret_type,
new_buffer_map2, func->attrs);
+ return {func1, func2};
+}
+} // namespace tir
+
+namespace relax {
+void StringReplace(std::string* subject, const std::string& search, const
std::string& replace) {
+ for (size_t pos = 0; (pos = subject->find(search, pos)) != std::string::npos;
+ pos += replace.length()) {
+ subject->replace(pos, search.length(), replace);
+ }
+}
+
+tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String
global_symbol) {
+ using namespace tvm::tir;
+ Optional<runtime::String> library_code =
pf->attrs.GetAttr<runtime::String>(kLibraryKernel);
+ if (!library_code.defined()) {
+ return GetRef<tir::PrimFunc>(pf);
+ }
+ std::string source = library_code.value();
+ StringReplace(&source, "{global_symbol}", global_symbol);
+ ExternFunc ret(global_symbol);
+ ret = WithAttrs(std::move(ret), Map<String, ObjectRef>{
+ {String(kCSource), String(source)},
+ {String(kCSourceFmt),
String(kCSourceFmtCuda)},
+ });
+ return ret;
+}
+
+/*! \brief Emit 2 calls to the library kernel and the rest of the function. */
+class SplitMutator : public ExprMutator {
+ public:
+ SplitMutator(const tvm::IRModule& mod, Array<TIRPattern> patterns, FCodegen
fcodegen)
+ : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen)
{}
+ static IRModule Transform(const IRModule& mod, Array<TIRPattern> patterns,
FCodegen fcodegen) {
+ SplitMutator mutator(mod, patterns, fcodegen);
+ for (auto& kv : mod->functions) {
+ if (auto* func = kv.second.as<FunctionNode>()) {
+ Function new_func =
Downcast<Function>(mutator(GetRef<Function>(func)));
+ mutator.builder_->UpdateFunction(kv.first, new_func);
+ }
+ }
+ return mutator.builder_->GetContextIRModule();
+ }
+
+ private:
+ using ExprMutator::VisitExpr_;
+
+ inline Array<Expr> GetCallTIRArgs(Expr args) {
+ if (args.as<TupleNode>()) {
+ return args.as<TupleNode>()->fields;
+ } else {
+ return {args};
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* op) final {
+ Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ static const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed");
+ if (!call->op.same_as(call_tir_op_)) return call;
+ // the first argument is the function to be called
+ const auto* gv_ptr = call->args[0].as<GlobalVarNode>();
+ if (gv_ptr == nullptr) return call;
+ GlobalVar gv = GetRef<GlobalVar>(gv_ptr);
+ // retrieve the function from the module and split it
+ tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ std::vector<std::vector<int>> arg_partition;
+ // split the function into two functions, one for the library kernel and
one for the rest.
+ std::pair<tir::PrimFunc, Optional<tir::PrimFunc>> split_funcs =
+ tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_);
+ if (!split_funcs.second.defined()) {
+ // no need to split, the function itself a library kernel
+ tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(),
gv->name_hint);
+ if (lib_func->IsInstance<tir::PrimFuncNode>()) return GetRef<Call>(op);
+ // Update the function in the module with the library kernel
+ ICHECK(lib_func->IsInstance<ExternFuncNode>());
+ builder_->UpdateFunction(gv, lib_func);
+ // emit the call to the library kernel
+ ObjectPtr<CallNode> new_call = make_object<CallNode>(*call.operator->());
+ new_call->op = this->call_dps_packed_;
+ new_call->args = {lib_func, call->args[1]};
+ return Call(new_call);
+ }
+ tir::PrimFunc func1 = tir::RenewDefs(split_funcs.first);
+ tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value());
+ ICHECK(arg_partition.size() == 2);
+ // emit the first call to the library kernel
+ Array<Expr> args1;
+ for (int p : arg_partition[0]) {
+ args1.push_back(GetCallTIRArgs(call->args[1])[p]);
+ }
+ // replace the function in the module with the library kernel
+ tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint);
+ if (lib_func->IsInstance<tir::PrimFuncNode>()) return GetRef<Call>(op);
+ ICHECK(lib_func->IsInstance<ExternFuncNode>());
+ builder_->UpdateFunction(gv, lib_func);
+ tir::Buffer intermediate_buffer =
func1->buffer_map.at(func1->params.back());
+ DataType dtype = intermediate_buffer->dtype;
+ Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs,
+ {TensorStructInfo(ShapeExpr(intermediate_buffer->shape),
dtype)});
+ Var call_var1 = builder_->Emit(call1);
+ // emit the second call to the rest of the function
+ Array<Expr> args2;
+ args2.push_back(call_var1);
+ for (int p : arg_partition[1]) {
+ args2.push_back(GetCallTIRArgs(call->args[1])[p]);
+ }
+ GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue");
+ Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs,
call->sinfo_args);
+ builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol"));
+ return call2;
+ }
+
+ const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed");
+ tvm::IRModule mod_;
+ Array<TIRPattern> patterns_;
+ FCodegen fcodegen_;
+};
+
+namespace transform {
+Pass SplitCallTIRByPattern(Array<TIRPattern> patterns, FCodegen fcodegen) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m,
patterns, fcodegen); };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"SplitCallTIRByPattern", //
+ /*required=*/{});
+}
+TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_codegen_tir_cutlass.py
b/tests/python/relax/test_codegen_tir_cutlass.py
new file mode 100644
index 0000000000..9c960ed355
--- /dev/null
+++ b/tests/python/relax/test_codegen_tir_cutlass.py
@@ -0,0 +1,709 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+import tempfile
+
+from tvm import relax, runtime
+import tvm
+import tvm.testing
+from tvm import relax
+import scipy
+from scipy.special import erf
+import numpy as np
+from tvm.target import Target
+from tvm.relax.vm_build import build as relax_build
+from tvm.script.ir_builder import relax as R
+from tvm.script.ir_builder import ir as I
+from tvm.script.ir_builder import tir as T
+from tvm.script.ir_builder import IRBuilder
+
+from tvm.relax.backend_tir import get_tir_pattern
+from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen,
compile_options
+
+A_TYPE = "float16"
+B_TYPE = "float16"
+C_TYPE = "float16"
+
+target = Target("cuda")
+
+
+def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input):
+ vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device)
+ return vm["main"](*input)
+
+
+def build(mod):
+ mod = relax.transform.LegalizeOps()(mod)
+ mod = relax.transform.AnnotateTIROpPattern()(mod)
+ mod = relax.transform.FuseOps()(mod)
+ mod = relax.transform.FuseTIR()(mod)
+ mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(),
cutlass_fcodegen())(mod)
+ mod = relax.transform.DeadCodeElimination()(mod)
+ print(mod.script())
+ f = tempfile.NamedTemporaryFile(suffix=".so", delete=True)
+ executable = relax_build(mod, target)
+
+ executable.mod.export_library(f.name, **compile_options(target))
+ rt_mod = runtime.load_module(f.name)
+ f.close()
+ return rt_mod
+
+
+def build_and_run_reference(mod, inputs_np):
+ mod = relax.transform.LegalizeOps()(mod)
+ dev = tvm.device("llvm", 0)
+ ex = relax.build(mod, "llvm")
+ vm = relax.VirtualMachine(ex, dev)
+ f = vm["main"]
+ inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+ return f(*inputs).numpy()
+
+
+def constructGEMM(M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ R.output(C)
+ (C,) = df.output_vars
+ R.func_ret_value(C)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense():
+ m, n, k = 128, 64, 256
+ executable = build(constructGEMM(m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructGEMM_bias(M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, N), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ relax_mod = ib.get()
+ return relax_mod
+
+
+def constructGEMM_bias2(M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((N,), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias():
+ m, n, k = 128, 64, 256
+ executable = build(constructGEMM_bias(m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(1, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias2():
+ m, n, k = 128, 64, 256
+ executable = build(constructGEMM_bias2(m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
+
+
+def constructGEMM_bias_relu(M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, N), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ E = R.emit(R.nn.relu(D))
+ R.output(E)
+ (E,) = df.output_vars
+ R.func_ret_value(E)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_dense_bias_relu():
+ m, n, k = 128, 64, 256
+ executable = build(constructGEMM_bias_relu(m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(1, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0),
rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ R.output(C)
+ (C,) = df.output_vars
+ R.func_ret_value(C)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM2(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ R.output(C)
+ (C,) = df.output_vars
+ R.func_ret_value(C)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense2():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM2(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(b, k, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM_bias(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, N), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM_bias(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(1, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
+
+
+def constructBatchGEMM_bias2(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((N,), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM_bias2(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
+
+
+def constructBatchGEMM_bias2_gelu(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((N,), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ E = R.emit(R.nn.gelu(D))
+ R.output(E)
+ (E,) = df.output_vars
+ R.func_ret_value(E)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2_gelu():
+ b, m, n, k = 2, 128, 64, 256
+ executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ C = A @ B + bias
+ O = 0.5 * C * (1 + erf(C / np.sqrt(2)))
+ np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM_bias2_mul(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((N,), A_TYPE)
+ ) # pylint: disable=invalid-name
+ residual = R.arg("residual", relax.TensorStructInfo((batch, M,
N), A_TYPE))
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ E = R.emit(R.multiply(D, residual))
+ R.output(E)
+ (E,) = df.output_vars
+ R.func_ret_value(E)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense_bias2_mul():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM_bias2_mul(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(k, n).astype("float16")
+ bias = np.random.randn(n).astype("float16")
+ residual = np.random.randn(b, m, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ residual_tvm = tvm.nd.array(residual, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm)
+ np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual,
rtol=5e-2, atol=5e-2)
+
+
+def constructBatchGEMM2_bias(batch, M, N, K):
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ A = R.arg(
+ "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
+ ) # pylint: disable=invalid-name
+ B = R.arg(
+ "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, N), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ relax_mod = ib.get()
+ return relax_mod
+
+
[email protected]_cutlass
+def test_cutlass_batch_dense2_bias():
+ b, m, n, k = 2, 128, 256, 64
+ executable = build(constructBatchGEMM2_bias(b, m, n, k))
+ dev = tvm.cuda()
+ A = np.random.randn(b, m, k).astype("float16")
+ B = np.random.randn(b, k, n).astype("float16")
+ bias = np.random.randn(1, n).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
+
+
+def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation):
+ from tvm.script.ir_builder import IRBuilder
+ from tvm.script.ir_builder import ir as I
+ from tvm.script.ir_builder import relax as R
+ from tvm.script.ir_builder import tir as T
+
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ x = R.arg(
+ "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+ ) # pylint: disable=invalid-name
+ w = R.arg(
+ "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(
+ R.nn.conv2d(
+ x,
+ w,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype=C_TYPE,
+ )
+ )
+ R.output(C)
+ (C,) = df.output_vars
+ R.func_ret_value(C)
+ mod = ib.get()
+ return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d():
+ n, c, h, w = 1, 3, 224, 224
+ kh, kw, o = 3, 3, 64
+ for strides in [(1, 1), (2, 2)]:
+ for padding in [(0, 0), (3, 3)]:
+ for dilation in [(1, 1), (4, 4)]:
+ mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding,
dilation)
+ executable = build(mod)
+ dev = tvm.cuda()
+ np.random.seed(0)
+ A = np.random.randn(n, h, w, c).astype("float16")
+ B = np.random.randn(o, kh, kw, c).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm)
+ result_ref = build_and_run_reference(mod, [A, B])
+ np.testing.assert_allclose(
+ result.numpy(),
+ result_ref,
+ rtol=5e-2,
+ atol=5e-2,
+ )
+
+
+def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation):
+ from tvm.script.ir_builder import IRBuilder
+ from tvm.script.ir_builder import ir as I
+ from tvm.script.ir_builder import relax as R
+ from tvm.script.ir_builder import tir as T
+
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ x = R.arg(
+ "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+ ) # pylint: disable=invalid-name
+ w = R.arg(
+ "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(
+ R.nn.conv2d(
+ x,
+ w,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype=C_TYPE,
+ )
+ )
+ D = R.emit(R.add(C, bias))
+ R.output(D)
+ (D,) = df.output_vars
+ R.func_ret_value(D)
+ mod = ib.get()
+ return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d_bias():
+ c, h, w = 3, 224, 224
+ kh, kw, o = 3, 3, 64
+ for n in [1, 2]:
+ for strides in [(1, 1), (2, 2)]:
+ for padding in [(0, 0), (3, 3)]:
+ for dilation in [(1, 1), (4, 4)]:
+ mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides,
padding, dilation)
+ executable = build(mod)
+ dev = tvm.cuda()
+ np.random.seed(0)
+ A = np.random.randn(n, h, w, c).astype("float16")
+ B = np.random.randn(o, kh, kw, c).astype("float16")
+ bias = np.random.randn(1, 1, 1, o).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
+ result_ref = build_and_run_reference(mod, [A, B, bias])
+ np.testing.assert_allclose(
+ result.numpy(),
+ result_ref,
+ rtol=5e-2,
+ atol=5e-2,
+ )
+
+
+def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding,
dilation):
+ from tvm.script.ir_builder import IRBuilder
+ from tvm.script.ir_builder import ir as I
+ from tvm.script.ir_builder import relax as R
+ from tvm.script.ir_builder import tir as T
+
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with I.ir_module() as frame:
+ with R.function():
+ R.func_name("main")
+ x = R.arg(
+ "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
+ ) # pylint: disable=invalid-name
+ w = R.arg(
+ "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
+ ) # pylint: disable=invalid-name
+ bias = R.arg(
+ "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
+ ) # pylint: disable=invalid-name
+ res = R.arg(
+ "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE)
+ ) # pylint: disable=invalid-name
+ with R.dataflow() as df:
+ C = R.emit(
+ R.nn.conv2d(
+ x,
+ w,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype=C_TYPE,
+ )
+ )
+ D = R.emit(R.add(C, bias))
+ E = R.emit(R.add(D, res))
+ R.output(E)
+ (E,) = df.output_vars
+ R.func_ret_value(E)
+ mod = ib.get()
+ return mod
+
+
[email protected]_cutlass
+def test_cutlass_conv2d_bias_add():
+ n, c, h, w = 2, 3, 224, 224
+ kh, kw, o = 3, 3, 64
+ for strides in [(1, 1), (2, 2)]:
+ for padding in [(0, 0), (3, 3)]:
+ for dilation in [(1, 1), (4, 4)]:
+ oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) //
strides[0] + 1
+ ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) //
strides[1] + 1
+ mod = constructConv2D_bias_add(
+ n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation
+ )
+ executable = build(mod)
+ dev = tvm.cuda()
+ np.random.seed(0)
+ A = np.random.randn(n, h, w, c).astype("float16")
+ B = np.random.randn(o, kh, kw, c).astype("float16")
+ bias = np.random.randn(1, 1, 1, o).astype("float16")
+ res = np.random.randn(n, oh, ow, o).astype("float16")
+ A_tvm = tvm.nd.array(A, dev)
+ B_tvm = tvm.nd.array(B, dev)
+ bias_tvm = tvm.nd.array(bias, dev)
+ res_tvm = tvm.nd.array(res, dev)
+ result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm,
res_tvm)
+ result_ref = build_and_run_reference(mod, [A, B, bias, res])
+ np.testing.assert_allclose(
+ result.numpy(),
+ result_ref,
+ rtol=5e-2,
+ atol=5e-2,
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()