This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0cb5b1fa11 [Unity][OP] Add an operator for fused multi head attention 
(#14150)
0cb5b1fa11 is described below

commit 0cb5b1fa11412228e60d0f38fdbf5ca3e9dcada4
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Mar 3 16:57:49 2023 -0800

    [Unity][OP] Add an operator for fused multi head attention (#14150)
    
    * [Unity][OP] Add an operator for fused multi head attention
    
    This PR introduces the new relax operator `R.nn.attention` for fused multi 
head attention, and the support of fused multi head attention to relax cutlass 
BYOC. The input of the operator are query, key and value tensor, with `BSNH` 
layout, namely `[batch size, sequence length, number of heads, dimension of 
heads]`. And the output shares the same layout with all input tensor.
    
    * remove useless codes, remove attrs and add memoize
    
    * add more dispatches
    
    * nit and fix rebase
    
    * fix linter
    
    * add support for bias
    
    * fix lint
    
    * BNSS layout for bias
    
    * update doc
    
    * fix typo
    
    * support bias broadcast
---
 3rdparty/cutlass                                  |   2 +-
 python/tvm/contrib/cutlass/attention_operation.py | 134 ++++++++++++++++++++
 python/tvm/contrib/cutlass/build.py               |  47 +++++++
 python/tvm/contrib/cutlass/gen_tensor_op.py       |  69 ++++++++--
 python/tvm/contrib/cutlass/library.py             |   6 +
 python/tvm/relax/backend/contrib/cutlass.py       |  14 ++-
 python/tvm/relax/backend/patterns.py              |  27 ++++
 python/tvm/relax/op/nn/nn.py                      |  39 ++++++
 src/relax/op/nn/attention.cc                      | 123 ++++++++++++++++++
 src/relax/op/nn/attention.h                       |  41 ++++++
 tests/python/relax/test_codegen_cutlass.py        | 147 ++++++++++++++++++++++
 11 files changed, 635 insertions(+), 14 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index d8359c804b..92ebbf1dc4 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit d8359c804b7e3915a0f0668c19213f63ae88aac6
+Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
new file mode 100644
index 0000000000..9093a03dd6
--- /dev/null
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    bias_template = {
+        "B11S'": """
+  CHECK(${arg3}->ndim == 2); // B, 1, 1, S'
+
+  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.bias_strideM = 0; // 0
+  p.bias_strideH = 0; // 0
+  p.bias_strideB = p.num_keys; // S'
+""",
+        "B1SS'": """
+  CHECK(${arg3}->ndim == 3); // B, 1, S, S'
+
+  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.bias_strideM = p.num_keys; // S'
+  p.bias_strideH = 0; // 0
+  p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S
+""",
+        "BNSS'": """
+  CHECK(${arg3}->ndim == 4); // B, N, S, S'
+
+  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.bias_strideM = p.num_keys; // S'
+  p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S
+  p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N
+""",
+    }
+
+    template = """
+  using T = ${data_type};
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/${kIsAligned},
+                      /*queries_per_block=*/${kQueriesPerBlock},
+                      /*keys_per_block=*/${kKeysPerBlock},
+                      /*single_value_iteration=*/${kSingleValueIteration},
+                      /*supports_dropout=*/${kSupportsDropout},
+                      /*supports_bias=*/${kSupportsBias}
+      >;
+
+  typename Attention::Params p;
+
+  p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
+  p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
+  p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
+  p.logsumexp_ptr = nullptr;
+  p.output_ptr = reinterpret_cast<T *>(out0->data);
+  p.output_accum_ptr = nullptr;
+  if (Attention::kNeedsOutputAccumulatorBuffer) {
+    cudaMalloc(
+      &p.output_accum_ptr,
+      ${output_size} * sizeof(Attention::output_accum_t)
+    );
+  }
+
+  p.num_heads = ${num_heads}; // N
+  p.num_batches = ${num_batches}; // B
+  p.head_dim = ${head_dim}; // H
+  p.head_dim_value = ${head_dim_value}; // H'
+  p.num_queries = ${num_queries}; // S
+  p.num_keys = ${num_keys}; // S'
+  p.scale = 1.0f / sqrt(float(${head_dim}));
+
+  // stride for N
+  p.q_strideH = p.head_dim; // H
+  p.k_strideH = p.head_dim; // H
+  p.v_strideH = p.head_dim_value; // H'
+
+  // stride for S
+  p.q_strideM = p.q_strideH * p.num_heads; // H * N
+  p.k_strideM = p.k_strideH * p.num_heads; // H * N
+  p.v_strideM = p.v_strideH * p.num_heads; // H' * N
+  p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
+
+  // stride for B
+  p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
+  p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
+  p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
+
+  ${bias_template}
+
+  constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
+  int smem_bytes = sizeof(typename Attention::SharedStorage);
+  if (smem_bytes > 0xc000) {
+    static bool once = [&]() {
+      cudaFuncSetAttribute(
+          kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
+      return true;
+    }();
+  }
+
+  CHECK(Attention::check_supported(p));
+  kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+"""
+    if attrs["kSupportsBias"]:
+        template = substitute_template(
+            template, {"bias_template": bias_template[attrs["bias_layout"]]}
+        )
+    else:
+        template = substitute_template(template, {"bias_template": ""})
+    for i, arg in enumerate(func_args):
+        attrs["arg{}".format(i)] = arg
+    return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 7e81113f44..0e8d419bae 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -57,6 +57,7 @@ def _get_cutlass_compile_options(sm, threads, 
use_fast_math=False):
     cutlass_root = _get_cutlass_path()
     cutlass_include = os.path.join(cutlass_root, "include")
     cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
+    cutlass_attention_include = os.path.join(cutlass_root, 
"examples/41_fused_multi_head_attention")
 
     kwargs = {}
     kwargs["cc"] = "nvcc"
@@ -71,6 +72,7 @@ def _get_cutlass_compile_options(sm, threads, 
use_fast_math=False):
         "-std=c++17",
         "-I" + cutlass_include,
         "-I" + cutlass_util_include,
+        "-I" + cutlass_attention_include,
     ]
     if use_fast_math:
         kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
@@ -756,6 +758,49 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             }
         )
 
+    def handle_attention(self, f, op_type):
+        """Tune and annotate a dense op."""
+        signature = _extract_relax_function_signature(f)
+
+        q_shape = signature["arg0_shape"]
+        k_shape = signature["arg1_shape"]
+        v_shape = signature["arg2_shape"]
+        out_shape = signature["ret_shape"]
+        q_dtype = signature["arg0_dtype"]
+        k_dtype = signature["arg1_dtype"]
+        v_dtype = signature["arg2_dtype"]
+        out_dtype = signature["ret_dtype"]
+        num_batches, num_queries, num_heads, head_dim = q_shape
+        _, num_keys, _, _ = k_shape
+        _, _, _, head_dim_value = v_shape
+        bias = {}
+        if "arg3_dtype" in signature:
+            bias["arg3_dtype"] = signature["arg3_dtype"]
+        if "arg3_shape" in signature:
+            bias["arg3_shape"] = signature["arg3_shape"]
+
+        return f.with_attrs(
+            {
+                "op_type": op_type,
+                "arg0_dtype": q_dtype,
+                "arg1_dtype": k_dtype,
+                "arg2_dtype": v_dtype,
+                "ret_dtype": out_dtype,
+                "arg0_shape": q_shape,
+                "arg1_shape": k_shape,
+                "arg2_shape": v_shape,
+                "ret_shape": out_shape,
+                "num_batches": num_batches,
+                "num_queries": num_queries,
+                "num_keys": num_keys,
+                "num_heads": num_heads,
+                "head_dim": head_dim,
+                "head_dim_value": head_dim_value,
+                "arch": self.options["sm"],
+                **bias,
+            }
+        )
+
     def visit_function_(self, f):
         if "Composite" not in f.attrs:
             body = super().visit_expr(f.body)
@@ -767,6 +812,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             return self.handle_conv2d(f, op_type)
         elif "matmul" in op_type:
             return self.handle_matmul(f, op_type)
+        elif "attention" in op_type:
+            return self.handle_attention(f, op_type)
 
         raise ValueError("Unsupported composite {}".format(op_type))
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 2976946dd2..78e2b489c6 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -30,8 +30,10 @@ from tvm.tir import IntImm
 from . import _ffi_api as ffi
 from .conv2d_operation import instantiate_conv2d_template
 from .gemm_operation import instantiate_gemm_template
+from .attention_operation import instantiate_attention_template
 from .library import (
     DataType,
+    DataTypeSize,
     DataTypeTag,
     EpilogueFunctor,
     MathInstruction,
@@ -549,7 +551,7 @@ def instantiate_template(func_name, annotations, func_args):
         attrs["ElementInputB"] = 
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
         attrs["ElementOutput"] = 
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
 
-        attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1]))
+        attrs["K"] = lhs_shape[lhs_batched_offset + 1]
         attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, 
lhs_batched_offset)
 
         if transposed:
@@ -630,27 +632,70 @@ def instantiate_template(func_name, annotations, 
func_args):
         attrs["N"] = get_dim(activation_shape[0], activation_var, 0)
         attrs["H"] = get_dim(activation_shape[1], activation_var, 1)
         attrs["W"] = get_dim(activation_shape[2], activation_var, 2)
-        attrs["C"] = str(int(activation_shape[3]))
+        attrs["C"] = activation_shape[3]
         attrs["P"] = get_dim(output_shape[1], "out0", 1)
         attrs["Q"] = get_dim(output_shape[2], "out0", 2)
-        attrs["K"] = str(int(output_shape[3]))
-        attrs["R"] = str(int(weight_shape[1]))
-        attrs["S"] = str(int(weight_shape[2]))
-        attrs["pad_h"] = str(int(annotations["padding"][0]))
-        attrs["pad_w"] = str(int(annotations["padding"][1]))
-        attrs["stride_h"] = str(int(annotations["strides"][0]))
-        attrs["stride_w"] = str(int(annotations["strides"][1]))
-        attrs["dilation_h"] = str(int(annotations["dilation"][0]))
-        attrs["dilation_w"] = str(int(annotations["dilation"][1]))
+        attrs["K"] = output_shape[3]
+        attrs["R"] = weight_shape[1]
+        attrs["S"] = weight_shape[2]
+        attrs["pad_h"] = annotations["padding"][0]
+        attrs["pad_w"] = annotations["padding"][1]
+        attrs["stride_h"] = annotations["strides"][0]
+        attrs["stride_w"] = annotations["strides"][1]
+        attrs["dilation_h"] = annotations["dilation"][0]
+        attrs["dilation_w"] = annotations["dilation"][1]
 
         if "splitk" in op_name:
             attrs["split_k_mode"] = "kParallel"
             attrs["split_k_slices"] = str(re.search(r"splitk(\d+)", 
op_name).group(1))
         else:
             attrs["split_k_mode"] = "kSerial"
-            attrs["split_k_slices"] = "1"
+            attrs["split_k_slices"] = 1
 
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        data_type = dtype_map[annotations["arg0_dtype"]]
+        attrs["data_type"] = DataTypeTag[data_type]
+        attrs["num_batches"] = b = annotations["num_batches"]
+        attrs["num_queries"] = s = annotations["num_queries"]
+        attrs["num_keys"] = annotations["num_keys"]
+        attrs["num_heads"] = n = annotations["num_heads"]
+        attrs["head_dim"] = h = annotations["head_dim"]
+        attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
+        data_type_size = DataTypeSize[data_type]
+        if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) 
% 16 == 0:
+            attrs["kIsAligned"] = True
+        elif (h % 4 == 0) and (h_v % 4 == 0):
+            attrs["kIsAligned"] = False
+        else:
+            raise NotImplementedError()
+        if h_v > 64:
+            attrs["kQueriesPerBlock"] = 32
+            attrs["kKeysPerBlock"] = 128
+            attrs["kSingleValueIteration"] = h_v <= 128
+        else:
+            attrs["kQueriesPerBlock"] = 64
+            attrs["kKeysPerBlock"] = 64
+            attrs["kSingleValueIteration"] = True
+        attrs["output_size"] = b * s * n * h_v
+        attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
+        attrs["kSupportsDropout"] = False
+        if len(func_args) > 3:
+            attrs["kSupportsBias"] = True
+            if len(annotations["arg3_shape"]) == 4:
+                attrs["bias_layout"] = "BNSS'"
+            elif len(annotations["arg3_shape"]) == 3:
+                attrs["bias_layout"] = "B1SS'"
+            elif len(annotations["arg3_shape"]) == 2:
+                attrs["bias_layout"] = "B11S'"
+            else:
+                raise NotImplementedError()
+        else:
+            attrs["kSupportsBias"] = False
+        code = instantiate_attention_template(attrs, func_args)
+        return CodegenResult(code, headers)
+
     raise ValueError("Do not have a template for {}".format(func_name))
diff --git a/python/tvm/contrib/cutlass/library.py 
b/python/tvm/contrib/cutlass/library.py
index 8632ab1564..b72553ef60 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -20,6 +20,8 @@ import re
 import enum
 from enum import auto as enum_auto
 
+from tvm.tir.expr import IntImm
+
 
 class GeneratorTarget(enum.Enum):
     Library = enum_auto()
@@ -143,6 +145,10 @@ def substitute_template(template, values):
     while changed:
         changed = False
         for key, value in values.items():
+            if isinstance(value, (int, IntImm)):
+                value = str(int(value))
+            elif isinstance(value, bool):
+                value = str(value).lower()
             regex = "\\$\\{%s\\}" % key
             newtext = re.sub(regex, value, text)
             if newtext != text:
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 2d8908184b..19165fa832 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -25,7 +25,11 @@ from tvm.relax import Call, Expr, ShapeExpr, transform
 from tvm.relax.dpl import DFPattern
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+from ..patterns import (
+    make_fused_bias_activation_pattern,
+    make_matmul_pattern,
+    make_attention_pattern,
+)
 
 
 def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
@@ -157,6 +161,14 @@ register_patterns(
             ),
             _check_matmul,
         ),
+        (
+            "cutlass.attention",
+            *make_attention_pattern(),
+        ),
+        (
+            "cutlass.attention_bias",
+            *make_attention_pattern(with_bias=True),
+        ),
     ]
 )
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 2f744af660..a2ea803d9d 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -113,3 +113,30 @@ def make_matmul_pattern(
     out = is_op("relax.matmul")(lhs, rhs)
 
     return _with_bias_activation_pattern(out, args, with_bias, activation)
+
+
+def make_attention_pattern(with_bias: bool = False):
+    """
+    Create pattern for fused multi head attention.
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing a fused multi head attention.
+
+    args: Mapping[str, DFPattern]
+        The mapping from arg name to its pattern. It can be used to extract
+        arg expression from match result.
+    """
+    query = wildcard()
+    key = wildcard()
+    value = wildcard()
+    args = {"query": query, "key": key, "value": value}
+    if with_bias:
+        bias = wildcard()
+        args["bias"] = bias
+        out = is_op("relax.nn.attention_bias")(query, key, value, bias)
+    else:
+        out = is_op("relax.nn.attention")(query, key, value)
+
+    return out, args
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 0ff143fd04..2fef372497 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -577,3 +577,42 @@ def cross_entropy_with_logits(predictions: Expr, labels: 
Expr) -> Expr:
       The computed result.
     """
     return _ffi_api.cross_entropy_with_logits(predictions, labels)  # type: 
ignore
+
+
+def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = 
None) -> Expr:
+    r"""Computes fused multi head attention.
+
+    All input tensors are of 4-D tensors with BSNH layout.
+
+    .. math::
+        FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+    .. note::
+        The input tensor is required to have float16 dtype
+
+    Parameters
+    ----------
+    query: relax.Expr
+        The input query to the operator. The layout of the input query should 
be
+        (batch_size, seq_len, num_head, head_dim).
+
+    key: relax.Expr
+        The input key to the operator. The layout of the input key should be
+        (batch_size, seq_len_kv, num_head, head_dim).
+
+    value: relax.Expr
+        The input value to the operator. The layout of the input value should 
be
+        (batch_size, seq_len_kv, num_head, head_dim_v).
+
+    bias: Optional[Expr]
+        The optional attention bias to the operator. The layout of the 
attention bias should be
+        (batch_size, num_head, seq_len, seq_len_kv),
+        (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result. The layout of the output should be
+        (batch_size, seq_len, num_head, head_dim_v).
+    """
+    return _ffi_api.attention(query, key, value, bias)  # type: ignore
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
new file mode 100644
index 0000000000..e139aa09d6
--- /dev/null
+++ b/src/relax/op/nn/attention.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "attention.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.nn.attention */
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
+  if (bias.defined()) {
+    return Call(Op::Get("relax.nn.attention_bias"),
+                {std::move(query), std::move(key), std::move(value), 
std::move(bias.value())}, {},
+                {});
+  }
+  return Call(Op::Get("relax.nn.attention"), {std::move(query), 
std::move(key), std::move(value)},
+              {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
+
+StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) 
{
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo q_sinfo = input_sinfo[0];
+  TensorStructInfo k_sinfo = input_sinfo[1];
+  TensorStructInfo v_sinfo = input_sinfo[2];
+  auto diag_dim = [&](TensorStructInfo sinfo, String name) {
+    if (sinfo->ndim != 4) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "The " << name << " should have 4 dimension, namely "
+                       << "[batch size, sequence length, number of heads, 
dimension of heads].");
+    }
+  };
+  diag_dim(q_sinfo, "query");
+  diag_dim(k_sinfo, "key");
+  diag_dim(v_sinfo, "value");
+  const ShapeExprNode* q_shape = q_sinfo->shape.as<ShapeExprNode>();
+  const ShapeExprNode* k_shape = k_sinfo->shape.as<ShapeExprNode>();
+  const ShapeExprNode* v_shape = v_sinfo->shape.as<ShapeExprNode>();
+  PrimExpr num_batches = q_shape->values[0];
+  PrimExpr num_queries = q_shape->values[1];
+  PrimExpr num_heads = q_shape->values[2];
+  PrimExpr head_dim = q_shape->values[3];
+  PrimExpr num_keys = k_shape->values[1];
+  PrimExpr head_dim_value = v_shape->values[3];
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String 
dim) {
+    if (analyzer->CanProve(v1 != v2)) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "The " << m1 << " " << dim << " and the " << m2 << " 
" << dim
+                       << " should be the same. However, the " << dim << " of 
" << m1 << " is "
+                       << v1 << " while the " << dim << " of " << m2 << " is " 
<< v2);
+    }
+  };
+  diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
+  diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
+  diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
+  diag_equal(num_heads, v_shape->values[2], "query", "value", "number of 
heads");
+  diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
+  diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of 
heads");
+
+  if (input_sinfo.size() == 4) {
+    TensorStructInfo bias_sinfo = input_sinfo[3];
+    const ShapeExprNode* bias_shape = bias_sinfo->shape.as<ShapeExprNode>();
+    if (bias_sinfo->ndim == 4) {
+      diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch 
size");
+      diag_equal(num_heads, bias_shape->values[1], "query", "bias", "number of 
heads");
+      diag_equal(num_queries, bias_shape->values[2], "query", "bias", 
"sequence length");
+      diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence 
length");
+    } else if (bias_sinfo->ndim == 3) {
+      diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch 
size");
+      diag_equal(num_queries, bias_shape->values[1], "query", "bias", 
"sequence length");
+      diag_equal(num_keys, bias_shape->values[2], "key", "bias", "sequence 
length");
+    } else if (bias_sinfo->ndim == 2) {
+      diag_equal(num_batches, bias_shape->values[0], "query", "bias", "batch 
size");
+      diag_equal(num_keys, bias_shape->values[1], "key", "bias", "sequence 
length");
+    } else {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "The bias should have 2, 3 or 4 dimensions."
+                       << "However, the bias input has " << bias_sinfo->ndim 
<< " dimensions.");
+    }
+  }
+
+  Array<PrimExpr> output_shape = {num_batches, num_queries, num_heads, 
head_dim_value};
+  return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.attention")
+    .set_num_inputs(3)
+    .add_argument("query", "Tensor", "The input queries tensor.")
+    .add_argument("key", "Tensor", "The input keys tensor.")
+    .add_argument("value", "Tensor", "The input values tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
+
+TVM_REGISTER_OP("relax.nn.attention_bias")
+    .set_num_inputs(4)
+    .add_argument("query", "Tensor", "The input queries tensor.")
+    .add_argument("key", "Tensor", "The input keys tensor.")
+    .add_argument("value", "Tensor", "The input values tensor.")
+    .add_argument("bias", "Tensor", "The input bias tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
new file mode 100644
index 0000000000..662e0b7e7b
--- /dev/null
+++ b/src/relax/op/nn/attention.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_NN_ATTENTION_H_
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 83104d6fe1..36a1c4cd16 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -20,6 +20,8 @@ import scipy
 
 import tvm
 import tvm.testing
+import tvm.topi.testing
+from tvm.contrib.pickle_memoize import memoize
 from tvm import relax, relay
 from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
 from tvm.relax.backend import get_patterns_with_prefix
@@ -300,5 +302,150 @@ def test_cutlass_partition_matmul_blocked(x_shape, 
y_shape, transpose_y, dtype):
     tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod))
 
 
[email protected](params=["float16", "float32"])
+def attention_dtype(request):
+    return request.param
+
+
[email protected](
+    params=[
+        # B, S, N, H
+        (32, (8, 8), 16, (8, 8)),
+        (4, (16, 8), 32, (8, 8)),  # s != s_kv
+        (4, (16, 8), 32, (8, 16)),  # h != h_v
+        (32, (8, 8), 16, (4, 4)),  # h is not aligned
+        (2, (8, 8), 8, (256, 256)),  # needs output accumulator buffer
+    ]
+)
+def attention_size(request):
+    return request.param
+
+
+def get_relax_attention_module(q, k, v, bias=None):
+    dtype = str(q.dtype)
+
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q.shape, dtype))
+            k = R.arg("k", R.Tensor(k.shape, dtype))
+            v = R.arg("v", R.Tensor(v.shape, dtype))
+            if bias is not None:
+                bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+            with R.dataflow() as frame:
+                result = R.emit(R.nn.attention(q, k, v, bias))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
+def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype):
+    q = np.random.randn(b, s, n, h).astype(dtype)
+    k = np.random.randn(b, s_kv, n, h).astype(dtype)
+    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
+    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
+    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    attn = tvm.topi.testing.softmax_python(score, -1)
+    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
+    ref = attn @ vt  # b, n, s, h_v
+    return q, k, v, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+
+
+def test_attention_offload(attention_size, attention_dtype):
+    b, (s, s_kv), n, (h, h_v) = attention_size
+    q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v, 
attention_dtype)
+
+    mod = get_relax_attention_module(q, k, v)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload")
+def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype):
+    q = np.random.randn(b, s, n, h).astype(dtype)
+    k = np.random.randn(b, s_kv, n, h).astype(dtype)
+    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+    bias = np.random.randn(b, n, s, s_kv).astype(dtype)
+    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
+    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
+    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    score_bias = score + bias  # b, n, s, s_kv
+    attn = tvm.topi.testing.softmax_python(score_bias, -1)
+    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
+    ref = attn @ vt  # b, n, s, h_v
+    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+
+
+def test_attention_bias_4d_offload(attention_size, attention_dtype):
+    b, (s, s_kv), n, (h, h_v) = attention_size
+    q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
+
+    mod = get_relax_attention_module(q, k, v, bias)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload")
+def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype):
+    q = np.random.randn(b, s, n, h).astype(dtype)
+    k = np.random.randn(b, s_kv, n, h).astype(dtype)
+    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+    bias = np.random.randn(b, s, s_kv).astype(dtype)
+    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
+    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
+    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    score_bias = score + bias.reshape(b, 1, s, s_kv)  # b, n, s, s_kv
+    attn = tvm.topi.testing.softmax_python(score_bias, -1)
+    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
+    ref = attn @ vt  # b, n, s, h_v
+    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+
+
+def test_attention_bias_3d_offload(attention_size, attention_dtype):
+    b, (s, s_kv), n, (h, h_v) = attention_size
+    q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
+
+    mod = get_relax_attention_module(q, k, v, bias)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload")
+def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype):
+    q = np.random.randn(b, s, n, h).astype(dtype)
+    k = np.random.randn(b, s_kv, n, h).astype(dtype)
+    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+    bias = np.random.randn(b, s_kv).astype(dtype)
+    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
+    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
+    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    score_bias = score + bias.reshape(b, 1, 1, s_kv)  # b, n, s, s_kv
+    attn = tvm.topi.testing.softmax_python(score_bias, -1)
+    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
+    ref = attn @ vt  # b, n, s, h_v
+    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+
+
+def test_attention_bias_2d_offload(attention_size, attention_dtype):
+    b, (s, s_kv), n, (h, h_v) = attention_size
+    q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
+
+    mod = get_relax_attention_module(q, k, v, bias)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to