This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0cb5b1fa11 [Unity][OP] Add an operator for fused multi head attention
(#14150)
0cb5b1fa11 is described below
commit 0cb5b1fa11412228e60d0f38fdbf5ca3e9dcada4
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Mar 3 16:57:49 2023 -0800
[Unity][OP] Add an operator for fused multi head attention (#14150)
* [Unity][OP] Add an operator for fused multi head attention
This PR introduces the new relax operator `R.nn.attention` for fused multi
head attention, and the support of fused multi head attention to relax cutlass
BYOC. The input of the operator are query, key and value tensor, with `BSNH`
layout, namely `[batch size, sequence length, number of heads, dimension of
heads]`. And the output shares the same layout with all input tensor.
* remove useless codes, remove attrs and add memoize
* add more dispatches
* nit and fix rebase
* fix linter
* add support for bias
* fix lint
* BNSS layout for bias
* update doc
* fix typo
* support bias broadcast
---
3rdparty/cutlass | 2 +-
python/tvm/contrib/cutlass/attention_operation.py | 134 ++++++++++++++++++++
python/tvm/contrib/cutlass/build.py | 47 +++++++
python/tvm/contrib/cutlass/gen_tensor_op.py | 69 ++++++++--
python/tvm/contrib/cutlass/library.py | 6 +
python/tvm/relax/backend/contrib/cutlass.py | 14 ++-
python/tvm/relax/backend/patterns.py | 27 ++++
python/tvm/relax/op/nn/nn.py | 39 ++++++
src/relax/op/nn/attention.cc | 123 ++++++++++++++++++
src/relax/op/nn/attention.h | 41 ++++++
tests/python/relax/test_codegen_cutlass.py | 147 ++++++++++++++++++++++
11 files changed, 635 insertions(+), 14 deletions(-)
diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index d8359c804b..92ebbf1dc4 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit d8359c804b7e3915a0f0668c19213f63ae88aac6
+Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
new file mode 100644
index 0000000000..9093a03dd6
--- /dev/null
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+ """Return CUTLASS host code for fused multi head attention
+ based on a template and the provided attribute map."""
+
+ bias_template = {
+ "B11S'": """
+ CHECK(${arg3}->ndim == 2); // B, 1, 1, S'
+
+ p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+ p.bias_strideM = 0; // 0
+ p.bias_strideH = 0; // 0
+ p.bias_strideB = p.num_keys; // S'
+""",
+ "B1SS'": """
+ CHECK(${arg3}->ndim == 3); // B, 1, S, S'
+
+ p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+ p.bias_strideM = p.num_keys; // S'
+ p.bias_strideH = 0; // 0
+ p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S
+""",
+ "BNSS'": """
+ CHECK(${arg3}->ndim == 4); // B, N, S, S'
+
+ p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+ p.bias_strideM = p.num_keys; // S'
+ p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S
+ p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N
+""",
+ }
+
+ template = """
+ using T = ${data_type};
+
+ CHECK(${arg0}->ndim == 4); // B, S, N, H
+ CHECK(${arg1}->ndim == 4); // B, S', N, H
+ CHECK(${arg2}->ndim == 4); // B, S', N, H'
+ CHECK(out0->ndim == 4); // B, S, N, H'
+
+ using Attention =
+ AttentionKernel<T,
+ /*ArchTag=*/${arch},
+ /*is_aligned=*/${kIsAligned},
+ /*queries_per_block=*/${kQueriesPerBlock},
+ /*keys_per_block=*/${kKeysPerBlock},
+ /*single_value_iteration=*/${kSingleValueIteration},
+ /*supports_dropout=*/${kSupportsDropout},
+ /*supports_bias=*/${kSupportsBias}
+ >;
+
+ typename Attention::Params p;
+
+ p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
+ p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
+ p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
+ p.logsumexp_ptr = nullptr;
+ p.output_ptr = reinterpret_cast<T *>(out0->data);
+ p.output_accum_ptr = nullptr;
+ if (Attention::kNeedsOutputAccumulatorBuffer) {
+ cudaMalloc(
+ &p.output_accum_ptr,
+ ${output_size} * sizeof(Attention::output_accum_t)
+ );
+ }
+
+ p.num_heads = ${num_heads}; // N
+ p.num_batches = ${num_batches}; // B
+ p.head_dim = ${head_dim}; // H
+ p.head_dim_value = ${head_dim_value}; // H'
+ p.num_queries = ${num_queries}; // S
+ p.num_keys = ${num_keys}; // S'
+ p.scale = 1.0f / sqrt(float(${head_dim}));
+
+ // stride for N
+ p.q_strideH = p.head_dim; // H
+ p.k_strideH = p.head_dim; // H
+ p.v_strideH = p.head_dim_value; // H'
+
+ // stride for S
+ p.q_strideM = p.q_strideH * p.num_heads; // H * N
+ p.k_strideM = p.k_strideH * p.num_heads; // H * N
+ p.v_strideM = p.v_strideH * p.num_heads; // H' * N
+ p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
+
+ // stride for B
+ p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
+ p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
+ p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
+
+ ${bias_template}
+
+ constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
+ int smem_bytes = sizeof(typename Attention::SharedStorage);
+ if (smem_bytes > 0xc000) {
+ static bool once = [&]() {
+ cudaFuncSetAttribute(
+ kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
+ return true;
+ }();
+ }
+
+ CHECK(Attention::check_supported(p));
+ kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+"""
+ if attrs["kSupportsBias"]:
+ template = substitute_template(
+ template, {"bias_template": bias_template[attrs["bias_layout"]]}
+ )
+ else:
+ template = substitute_template(template, {"bias_template": ""})
+ for i, arg in enumerate(func_args):
+ attrs["arg{}".format(i)] = arg
+ return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 7e81113f44..0e8d419bae 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -57,6 +57,7 @@ def _get_cutlass_compile_options(sm, threads,
use_fast_math=False):
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
+ cutlass_attention_include = os.path.join(cutlass_root,
"examples/41_fused_multi_head_attention")
kwargs = {}
kwargs["cc"] = "nvcc"
@@ -71,6 +72,7 @@ def _get_cutlass_compile_options(sm, threads,
use_fast_math=False):
"-std=c++17",
"-I" + cutlass_include,
"-I" + cutlass_util_include,
+ "-I" + cutlass_attention_include,
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
@@ -756,6 +758,49 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
}
)
+ def handle_attention(self, f, op_type):
+ """Tune and annotate a dense op."""
+ signature = _extract_relax_function_signature(f)
+
+ q_shape = signature["arg0_shape"]
+ k_shape = signature["arg1_shape"]
+ v_shape = signature["arg2_shape"]
+ out_shape = signature["ret_shape"]
+ q_dtype = signature["arg0_dtype"]
+ k_dtype = signature["arg1_dtype"]
+ v_dtype = signature["arg2_dtype"]
+ out_dtype = signature["ret_dtype"]
+ num_batches, num_queries, num_heads, head_dim = q_shape
+ _, num_keys, _, _ = k_shape
+ _, _, _, head_dim_value = v_shape
+ bias = {}
+ if "arg3_dtype" in signature:
+ bias["arg3_dtype"] = signature["arg3_dtype"]
+ if "arg3_shape" in signature:
+ bias["arg3_shape"] = signature["arg3_shape"]
+
+ return f.with_attrs(
+ {
+ "op_type": op_type,
+ "arg0_dtype": q_dtype,
+ "arg1_dtype": k_dtype,
+ "arg2_dtype": v_dtype,
+ "ret_dtype": out_dtype,
+ "arg0_shape": q_shape,
+ "arg1_shape": k_shape,
+ "arg2_shape": v_shape,
+ "ret_shape": out_shape,
+ "num_batches": num_batches,
+ "num_queries": num_queries,
+ "num_keys": num_keys,
+ "num_heads": num_heads,
+ "head_dim": head_dim,
+ "head_dim_value": head_dim_value,
+ "arch": self.options["sm"],
+ **bias,
+ }
+ )
+
def visit_function_(self, f):
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
@@ -767,6 +812,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
return self.handle_conv2d(f, op_type)
elif "matmul" in op_type:
return self.handle_matmul(f, op_type)
+ elif "attention" in op_type:
+ return self.handle_attention(f, op_type)
raise ValueError("Unsupported composite {}".format(op_type))
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 2976946dd2..78e2b489c6 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -30,8 +30,10 @@ from tvm.tir import IntImm
from . import _ffi_api as ffi
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template
+from .attention_operation import instantiate_attention_template
from .library import (
DataType,
+ DataTypeSize,
DataTypeTag,
EpilogueFunctor,
MathInstruction,
@@ -549,7 +551,7 @@ def instantiate_template(func_name, annotations, func_args):
attrs["ElementInputB"] =
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
attrs["ElementOutput"] =
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
- attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1]))
+ attrs["K"] = lhs_shape[lhs_batched_offset + 1]
attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0,
lhs_batched_offset)
if transposed:
@@ -630,27 +632,70 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["N"] = get_dim(activation_shape[0], activation_var, 0)
attrs["H"] = get_dim(activation_shape[1], activation_var, 1)
attrs["W"] = get_dim(activation_shape[2], activation_var, 2)
- attrs["C"] = str(int(activation_shape[3]))
+ attrs["C"] = activation_shape[3]
attrs["P"] = get_dim(output_shape[1], "out0", 1)
attrs["Q"] = get_dim(output_shape[2], "out0", 2)
- attrs["K"] = str(int(output_shape[3]))
- attrs["R"] = str(int(weight_shape[1]))
- attrs["S"] = str(int(weight_shape[2]))
- attrs["pad_h"] = str(int(annotations["padding"][0]))
- attrs["pad_w"] = str(int(annotations["padding"][1]))
- attrs["stride_h"] = str(int(annotations["strides"][0]))
- attrs["stride_w"] = str(int(annotations["strides"][1]))
- attrs["dilation_h"] = str(int(annotations["dilation"][0]))
- attrs["dilation_w"] = str(int(annotations["dilation"][1]))
+ attrs["K"] = output_shape[3]
+ attrs["R"] = weight_shape[1]
+ attrs["S"] = weight_shape[2]
+ attrs["pad_h"] = annotations["padding"][0]
+ attrs["pad_w"] = annotations["padding"][1]
+ attrs["stride_h"] = annotations["strides"][0]
+ attrs["stride_w"] = annotations["strides"][1]
+ attrs["dilation_h"] = annotations["dilation"][0]
+ attrs["dilation_w"] = annotations["dilation"][1]
if "splitk" in op_name:
attrs["split_k_mode"] = "kParallel"
attrs["split_k_slices"] = str(re.search(r"splitk(\d+)",
op_name).group(1))
else:
attrs["split_k_mode"] = "kSerial"
- attrs["split_k_slices"] = "1"
+ attrs["split_k_slices"] = 1
code = instantiate_conv2d_template(attrs, func_args)
return CodegenResult(code, headers)
+ elif "attention" in func_name:
+ headers.append("kernel_forward.h")
+ data_type = dtype_map[annotations["arg0_dtype"]]
+ attrs["data_type"] = DataTypeTag[data_type]
+ attrs["num_batches"] = b = annotations["num_batches"]
+ attrs["num_queries"] = s = annotations["num_queries"]
+ attrs["num_keys"] = annotations["num_keys"]
+ attrs["num_heads"] = n = annotations["num_heads"]
+ attrs["head_dim"] = h = annotations["head_dim"]
+ attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
+ data_type_size = DataTypeSize[data_type]
+ if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8)
% 16 == 0:
+ attrs["kIsAligned"] = True
+ elif (h % 4 == 0) and (h_v % 4 == 0):
+ attrs["kIsAligned"] = False
+ else:
+ raise NotImplementedError()
+ if h_v > 64:
+ attrs["kQueriesPerBlock"] = 32
+ attrs["kKeysPerBlock"] = 128
+ attrs["kSingleValueIteration"] = h_v <= 128
+ else:
+ attrs["kQueriesPerBlock"] = 64
+ attrs["kKeysPerBlock"] = 64
+ attrs["kSingleValueIteration"] = True
+ attrs["output_size"] = b * s * n * h_v
+ attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
+ attrs["kSupportsDropout"] = False
+ if len(func_args) > 3:
+ attrs["kSupportsBias"] = True
+ if len(annotations["arg3_shape"]) == 4:
+ attrs["bias_layout"] = "BNSS'"
+ elif len(annotations["arg3_shape"]) == 3:
+ attrs["bias_layout"] = "B1SS'"
+ elif len(annotations["arg3_shape"]) == 2:
+ attrs["bias_layout"] = "B11S'"
+ else:
+ raise NotImplementedError()
+ else:
+ attrs["kSupportsBias"] = False
+ code = instantiate_attention_template(attrs, func_args)
+ return CodegenResult(code, headers)
+
raise ValueError("Do not have a template for {}".format(func_name))
diff --git a/python/tvm/contrib/cutlass/library.py
b/python/tvm/contrib/cutlass/library.py
index 8632ab1564..b72553ef60 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -20,6 +20,8 @@ import re
import enum
from enum import auto as enum_auto
+from tvm.tir.expr import IntImm
+
class GeneratorTarget(enum.Enum):
Library = enum_auto()
@@ -143,6 +145,10 @@ def substitute_template(template, values):
while changed:
changed = False
for key, value in values.items():
+ if isinstance(value, (int, IntImm)):
+ value = str(int(value))
+ elif isinstance(value, bool):
+ value = str(value).lower()
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 2d8908184b..19165fa832 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -25,7 +25,11 @@ from tvm.relax import Call, Expr, ShapeExpr, transform
from tvm.relax.dpl import DFPattern
from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+from ..patterns import (
+ make_fused_bias_activation_pattern,
+ make_matmul_pattern,
+ make_attention_pattern,
+)
def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
@@ -157,6 +161,14 @@ register_patterns(
),
_check_matmul,
),
+ (
+ "cutlass.attention",
+ *make_attention_pattern(),
+ ),
+ (
+ "cutlass.attention_bias",
+ *make_attention_pattern(with_bias=True),
+ ),
]
)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 2f744af660..a2ea803d9d 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -113,3 +113,30 @@ def make_matmul_pattern(
out = is_op("relax.matmul")(lhs, rhs)
return _with_bias_activation_pattern(out, args, with_bias, activation)
+
+
+def make_attention_pattern(with_bias: bool = False):
+ """
+ Create pattern for fused multi head attention.
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing a fused multi head attention.
+
+ args: Mapping[str, DFPattern]
+ The mapping from arg name to its pattern. It can be used to extract
+ arg expression from match result.
+ """
+ query = wildcard()
+ key = wildcard()
+ value = wildcard()
+ args = {"query": query, "key": key, "value": value}
+ if with_bias:
+ bias = wildcard()
+ args["bias"] = bias
+ out = is_op("relax.nn.attention_bias")(query, key, value, bias)
+ else:
+ out = is_op("relax.nn.attention")(query, key, value)
+
+ return out, args
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 0ff143fd04..2fef372497 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -577,3 +577,42 @@ def cross_entropy_with_logits(predictions: Expr, labels:
Expr) -> Expr:
The computed result.
"""
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type:
ignore
+
+
+def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] =
None) -> Expr:
+ r"""Computes fused multi head attention.
+
+ All input tensors are of 4-D tensors with BSNH layout.
+
+ .. math::
+ FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+ .. note::
+ The input tensor is required to have float16 dtype
+
+ Parameters
+ ----------
+ query: relax.Expr
+ The input query to the operator. The layout of the input query should
be
+ (batch_size, seq_len, num_head, head_dim).
+
+ key: relax.Expr
+ The input key to the operator. The layout of the input key should be
+ (batch_size, seq_len_kv, num_head, head_dim).
+
+ value: relax.Expr
+ The input value to the operator. The layout of the input value should
be
+ (batch_size, seq_len_kv, num_head, head_dim_v).
+
+ bias: Optional[Expr]
+ The optional attention bias to the operator. The layout of the
attention bias should be
+ (batch_size, num_head, seq_len, seq_len_kv),
+ (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result. The layout of the output should be
+ (batch_size, seq_len, num_head, head_dim_v).
+ """
+ return _ffi_api.attention(query, key, value, bias) # type: ignore
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
new file mode 100644
index 0000000000..e139aa09d6
--- /dev/null
+++ b/src/relax/op/nn/attention.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "attention.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.nn.attention */
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
+ if (bias.defined()) {
+ return Call(Op::Get("relax.nn.attention_bias"),
+ {std::move(query), std::move(key), std::move(value),
std::move(bias.value())}, {},
+ {});
+ }
+ return Call(Op::Get("relax.nn.attention"), {std::move(query),
std::move(key), std::move(value)},
+ {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
+
+StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx)
{
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo q_sinfo = input_sinfo[0];
+ TensorStructInfo k_sinfo = input_sinfo[1];
+ TensorStructInfo v_sinfo = input_sinfo[2];
+ auto diag_dim = [&](TensorStructInfo sinfo, String name) {
+ if (sinfo->ndim != 4) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The " << name << " should have 4 dimension, namely "
+ << "[batch size, sequence length, number of heads,
dimension of heads].");
+ }
+ };
+ diag_dim(q_sinfo, "query");
+ diag_dim(k_sinfo, "key");
+ diag_dim(v_sinfo, "value");
+ const ShapeExprNode* q_shape = q_sinfo->shape.as<ShapeExprNode>();
+ const ShapeExprNode* k_shape = k_sinfo->shape.as<ShapeExprNode>();
+ const ShapeExprNode* v_shape = v_sinfo->shape.as<ShapeExprNode>();
+ PrimExpr num_batches = q_shape->values[0];
+ PrimExpr num_queries = q_shape->values[1];
+ PrimExpr num_heads = q_shape->values[2];
+ PrimExpr head_dim = q_shape->values[3];
+ PrimExpr num_keys = k_shape->values[1];
+ PrimExpr head_dim_value = v_shape->values[3];
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String
dim) {
+ if (analyzer->CanProve(v1 != v2)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The " << m1 << " " << dim << " and the " << m2 << "
" << dim
+ << " should be the same. However, the " << dim << " of
" << m1 << " is "
+ << v1 << " while the " << dim << " of " << m2 << " is "
<< v2);
+ }
+ };
+ diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
+ diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
+ diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
+ diag_equal(num_heads, v_shape->values[2], "query", "value", "number of
heads");
+ diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
+ diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of
heads");
+
+ if (input_sinfo.size() == 4) {
+ TensorStructInfo bias_sinfo = input_sinfo[3];
+ const ShapeExprNode* bias_shape = bias_sinfo->shape.as<ShapeExprNode>();
+ if (bias_sinfo->ndim == 4) {
+ diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch
size");
+ diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of
heads");
+ diag_equal(num_queries, bias_shape->values[2], "query", "bias",
"sequence length");
+ diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence
length");
+ } else if (bias_sinfo->ndim == 3) {
+ diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch
size");
+ diag_equal(num_queries, bias_shape->values[1], "query", "bias",
"sequence length");
+ diag_equal(num_keys, bias_shape->values[2], "key", "bias", "sequence
length");
+ } else if (bias_sinfo->ndim == 2) {
+ diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch
size");
+ diag_equal(num_keys, bias_shape->values[1], "key", "bias", "sequence
length");
+ } else {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The bias should have 2, 3 or 4 dimensions."
+ << "However, the bias input has " << bias_sinfo->ndim
<< " dimensions.");
+ }
+ }
+
+ Array<PrimExpr> output_shape = {num_batches, num_queries, num_heads,
head_dim_value};
+ return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.attention")
+ .set_num_inputs(3)
+ .add_argument("query", "Tensor", "The input queries tensor.")
+ .add_argument("key", "Tensor", "The input keys tensor.")
+ .add_argument("value", "Tensor", "The input values tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
+
+TVM_REGISTER_OP("relax.nn.attention_bias")
+ .set_num_inputs(4)
+ .add_argument("query", "Tensor", "The input queries tensor.")
+ .add_argument("key", "Tensor", "The input keys tensor.")
+ .add_argument("value", "Tensor", "The input values tensor.")
+ .add_argument("bias", "Tensor", "The input bias tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
new file mode 100644
index 0000000000..662e0b7e7b
--- /dev/null
+++ b/src/relax/op/nn/attention.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_NN_ATTENTION_H_
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 83104d6fe1..36a1c4cd16 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -20,6 +20,8 @@ import scipy
import tvm
import tvm.testing
+import tvm.topi.testing
+from tvm.contrib.pickle_memoize import memoize
from tvm import relax, relay
from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
from tvm.relax.backend import get_patterns_with_prefix
@@ -300,5 +302,150 @@ def test_cutlass_partition_matmul_blocked(x_shape,
y_shape, transpose_y, dtype):
tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod))
[email protected](params=["float16", "float32"])
+def attention_dtype(request):
+ return request.param
+
+
[email protected](
+ params=[
+ # B, S, N, H
+ (32, (8, 8), 16, (8, 8)),
+ (4, (16, 8), 32, (8, 8)), # s != s_kv
+ (4, (16, 8), 32, (8, 16)), # h != h_v
+ (32, (8, 8), 16, (4, 4)), # h is not aligned
+ (2, (8, 8), 8, (256, 256)), # needs output accumulator buffer
+ ]
+)
+def attention_size(request):
+ return request.param
+
+
+def get_relax_attention_module(q, k, v, bias=None):
+ dtype = str(q.dtype)
+
+ from tvm.script.ir_builder import IRBuilder
+ from tvm.script.ir_builder import relax as relax_builder
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ q = R.arg("q", R.Tensor(q.shape, dtype))
+ k = R.arg("k", R.Tensor(k.shape, dtype))
+ v = R.arg("v", R.Tensor(v.shape, dtype))
+ if bias is not None:
+ bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+ with R.dataflow() as frame:
+ result = R.emit(R.nn.attention(q, k, v, bias))
+ R.output(result)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
+def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype):
+ q = np.random.randn(b, s, n, h).astype(dtype)
+ k = np.random.randn(b, s_kv, n, h).astype(dtype)
+ v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+ qt = q.transpose(0, 2, 1, 3) # b, n, s, h
+ kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
+ score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ attn = tvm.topi.testing.softmax_python(score, -1)
+ vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
+ ref = attn @ vt # b, n, s, h_v
+ return q, k, v, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+
+
+def test_attention_offload(attention_size, attention_dtype):
+ b, (s, s_kv), n, (h, h_v) = attention_size
+ q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v,
attention_dtype)
+
+ mod = get_relax_attention_module(q, k, v)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload")
+def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype):
+ q = np.random.randn(b, s, n, h).astype(dtype)
+ k = np.random.randn(b, s_kv, n, h).astype(dtype)
+ v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+ bias = np.random.randn(b, n, s, s_kv).astype(dtype)
+ qt = q.transpose(0, 2, 1, 3) # b, n, s, h
+ kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
+ score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ score_bias = score + bias # b, n, s, s_kv
+ attn = tvm.topi.testing.softmax_python(score_bias, -1)
+ vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
+ ref = attn @ vt # b, n, s, h_v
+ return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+
+
+def test_attention_bias_4d_offload(attention_size, attention_dtype):
+ b, (s, s_kv), n, (h, h_v) = attention_size
+ q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
+
+ mod = get_relax_attention_module(q, k, v, bias)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload")
+def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype):
+ q = np.random.randn(b, s, n, h).astype(dtype)
+ k = np.random.randn(b, s_kv, n, h).astype(dtype)
+ v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+ bias = np.random.randn(b, s, s_kv).astype(dtype)
+ qt = q.transpose(0, 2, 1, 3) # b, n, s, h
+ kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
+ score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ score_bias = score + bias.reshape(b, 1, s, s_kv) # b, n, s, s_kv
+ attn = tvm.topi.testing.softmax_python(score_bias, -1)
+ vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
+ ref = attn @ vt # b, n, s, h_v
+ return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+
+
+def test_attention_bias_3d_offload(attention_size, attention_dtype):
+ b, (s, s_kv), n, (h, h_v) = attention_size
+ q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
+
+ mod = get_relax_attention_module(q, k, v, bias)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload")
+def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype):
+ q = np.random.randn(b, s, n, h).astype(dtype)
+ k = np.random.randn(b, s_kv, n, h).astype(dtype)
+ v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+ bias = np.random.randn(b, s_kv).astype(dtype)
+ qt = q.transpose(0, 2, 1, 3) # b, n, s, h
+ kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
+ score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ score_bias = score + bias.reshape(b, 1, 1, s_kv) # b, n, s, s_kv
+ attn = tvm.topi.testing.softmax_python(score_bias, -1)
+ vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
+ ref = attn @ vt # b, n, s, h_v
+ return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+
+
+def test_attention_bias_2d_offload(attention_size, attention_dtype):
+ b, (s, s_kv), n, (h, h_v) = attention_size
+ q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
+
+ mod = get_relax_attention_module(q, k, v, bias)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()