This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5d5edd2fd8 [Relax] Integrate cuDNN attention (#17157)
5d5edd2fd8 is described below
commit 5d5edd2fd8b891bb74681f83095d606739cadfcb
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Jul 22 12:36:06 2024 -0700
[Relax] Integrate cuDNN attention (#17157)
* [Relax] Integrate cuDNN attention
* update cmake
* lint
* lint
* cudnn frontend
* lint
* lint
* fix test
* skip test
---
cmake/config.cmake | 7 +
cmake/modules/CUDA.cmake | 16 ++
python/tvm/contrib/cutlass/build.py | 32 +--
python/tvm/contrib/cutlass/gen_tensor_op.py | 4 +-
python/tvm/relax/backend/contrib/cudnn.py | 99 +++++++-
python/tvm/relax/backend/contrib/cutlass.py | 18 +-
python/tvm/relax/backend/patterns.py | 32 ++-
python/tvm/relax/frontend/nn/op.py | 9 +-
python/tvm/relax/testing/__init__.py | 1 +
python/tvm/relax/testing/attention.py | 148 ++++++++++++
python/tvm/topi/testing/__init__.py | 1 +
python/tvm/topi/testing/attention_python.py | 122 ++++++++++
src/relax/backend/contrib/cudnn/codegen.cc | 47 ++++
src/relax/transform/allocate_workspace.cc | 9 +-
src/relax/transform/fuse_ops.cc | 19 +-
.../contrib/cudnn/cudnn_frontend/attention.cc | 124 ++++++++++
.../contrib/cudnn/cudnn_frontend/attention.h | 83 +++++++
src/runtime/contrib/cudnn/cudnn_json_runtime.cc | 267 ++++++++++++---------
tests/python/relax/test_codegen_cudnn.py | 65 ++++-
tests/python/relax/test_codegen_cutlass.py | 213 +++++-----------
.../relax/test_transform_allocate_workspace.py | 3 +-
.../test_transform_merge_composite_functions.py | 5 +-
22 files changed, 1010 insertions(+), 314 deletions(-)
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 416eec0dcb..26d50630f7 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -245,6 +245,13 @@ set(USE_EDGETPU OFF)
# - /path/to/cudnn: use specific path to cuDNN path
set(USE_CUDNN OFF)
+# Whether use cuDNN frontend
+# Possible values:
+# - ON: enable cuDNN frontend
+# - /path/to/cudnn_frontend: use specific path to cuDNN frontend
+# - OFF: disable cuDNN frontend
+set(USE_CUDNN_FRONTEND OFF)
+
# Whether use cuBLAS
set(USE_CUBLAS OFF)
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index b7b405f822..ad83ebe26b 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -77,6 +77,22 @@ if(USE_CUDA)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY})
endif(USE_CUDNN)
+ if (USE_CUDNN_FRONTEND)
+ message(STATUS "Build with cuDNN Frontend support")
+ if (IS_DIRECTORY ${USE_CUDNN_FRONTEND})
+ find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h HINTS
${USE_CUDNN_FRONTEND}/include)
+ include_directories(SYSTEM ${USE_CUDNN_FRONTEND}/include)
+ else()
+ find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h)
+ endif()
+ if (NOT CUDNN_FRONTEND_HEADER)
+ message(FATAL_ERROR "Cannot find cudnn_frontend.h, please set
USE_CUDNN_FRONTEND to the path of the cuDNN frontend header")
+ endif()
+ tvm_file_glob(GLOB CONTRIB_CUDNN_FRONTEND_SRCS
src/runtime/contrib/cudnn/cudnn_frontend/*.cc)
+ set_property(SOURCE ${CONTRIB_CUDNN_SRCS} APPEND PROPERTY
COMPILE_DEFINITIONS TVM_USE_CUDNN_FRONTEND=1)
+ list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_FRONTEND_SRCS})
+ endif(USE_CUDNN_FRONTEND)
+
if(USE_CUBLAS)
message(STATUS "Build with cuBLAS support")
tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC
src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 1c0a30c62d..5c09c79bd9 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -868,34 +868,26 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
signature = _extract_relax_function_signature(f)
if _get_call_node(f.body, "relax.nn.attention") is not None:
- op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
+ attention_node = _get_call_node(f.body, "relax.nn.attention")
+ op_attrs = attention_node.attrs
elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
- op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
+ attention_node = _get_call_node(f.body, "relax.nn.attention_bias")
+ op_attrs = attention_node.attrs
elif _get_call_node(f.body, "relax.nn.attention_var_len") is not None:
- op_attrs = _get_call_node(f.body,
"relax.nn.attention_var_len").attrs
+ attention_node = _get_call_node(f.body,
"relax.nn.attention_var_len")
+ op_attrs = attention_node.attrs
else:
raise ValueError("Cannot find call node for attention")
arg = {}
if "stacked_attention" in op_type:
- arg["arg0_shape"] = signature["arg0_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
- arg["arg1_shape"] = q_shape = signature["arg1_shape"]
-
- if "arg3_shape" not in signature:
- # arg0: qkv, arg1: shape, arg2: workspace
- arg["arg2_shape"] = k_shape = signature["arg1_shape"]
- arg["arg3_shape"] = v_shape = signature["arg1_shape"]
- else:
- # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4:
workspace
- arg["arg2_shape"] = k_shape = signature["arg2_shape"]
- arg["arg3_shape"] = v_shape = signature["arg3_shape"]
-
- if "arg5_dtype" in signature:
- # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4:
bias, arg5: workspace
- arg["bias_dtype"] = signature["arg4_dtype"]
- if "arg5_shape" in signature:
- arg["bias_shape"] = signature["arg4_shape"]
+ q_shape = get_const_tuple(attention_node.args[0].struct_info.shape)
+ k_shape = get_const_tuple(attention_node.args[1].struct_info.shape)
+ v_shape = get_const_tuple(attention_node.args[2].struct_info.shape)
+ if len(attention_node.args) == 4:
+ arg["bias_shape"] =
get_const_tuple(attention_node.args[3].struct_info.shape)
+ arg["bias_dtype"] = attention_node.args[3].struct_info.dtype
qkv_layout = "qkv_stacked"
else:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 2f21a1d313..5d04cf13e6 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -745,8 +745,8 @@ def instantiate_template(func_name, annotations, func_args):
attrs["qkv"] = func_args[0]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
- if len(func_args) > 5 and not is_var_len: # +1 for workspace, the
last arg
- attrs["bias"] = func_args[4]
+ if len(func_args) > 2 and not is_var_len: # +1 for workspace, the
last arg
+ attrs["bias"] = func_args[1]
else:
raise NotImplementedError()
diff --git a/python/tvm/relax/backend/contrib/cudnn.py
b/python/tvm/relax/backend/contrib/cudnn.py
index f730d4e5be..2f15e3a4fd 100644
--- a/python/tvm/relax/backend/contrib/cudnn.py
+++ b/python/tvm/relax/backend/contrib/cudnn.py
@@ -16,11 +16,16 @@
# under the License.
"""Pattern table for cuDNN backend"""
-from tvm.relax import transform
+import operator
+from functools import partial, reduce
+
+import tvm
+from tvm import relax
+from tvm.relax import PyExprMutator, expr_functor, transform
from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_conv2d_pattern
+from ..patterns import make_conv2d_pattern, make_stacked_attention_pattern
from ..utils import has_leaking_intermediate_variables
@@ -60,6 +65,29 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
return True
+def _check_stacked_attention(context: PatternCheckContext, layout: str) ->
bool:
+ """Check if the given stacked attention workload can be offloaded to
cuDNN."""
+ if has_leaking_intermediate_variables(context):
+ return False
+ if layout == "BS3NH":
+ if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
+ return False
+ if "split" in context.annotated_expr:
+ split_op = context.annotated_expr["split"]
+ if not split_op.attrs.axis == 2:
+ return False
+ elif layout == "SBN3H":
+ if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4:
+ return False
+ if "split" in context.annotated_expr:
+ split_op = context.annotated_expr["split"]
+ if not split_op.attrs.axis == 3:
+ return False
+ else:
+ raise NotImplementedError(f"Unsupported layout: {layout}")
+ return True
+
+
register_patterns(
[
(
@@ -84,6 +112,16 @@ register_patterns(
),
_check_conv2d,
),
+ (
+ "cudnn.attention.BS3NH",
+ *make_stacked_attention_pattern(start_op="split", layout="BS3NH"),
+ partial(_check_stacked_attention, layout="BS3NH"),
+ ),
+ (
+ "cudnn.attention.SBN3H",
+ *make_stacked_attention_pattern(start_op="split", layout="SBN3H"),
+ partial(_check_stacked_attention, layout="SBN3H"),
+ ),
]
)
@@ -105,4 +143,59 @@ def partition_for_cudnn(mod):
"""
patterns = get_patterns_with_prefix("cudnn")
- return transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True)(mod)
+ return tvm.transform.Sequential(
+ [
+ transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True),
+ annotate_workspace,
+ transform.AllocateWorkspace(),
+ ]
+ )(mod)
+
+
+def _shape_1d(shape):
+ return reduce(operator.mul, shape, 1)
+
+
+@expr_functor.mutator
+class WorkspaceAnnotator(PyExprMutator):
+ """Annotate a workspace requirement for each cuDNN-offloaded function."""
+
+ def __init__(self, mod):
+ super().__init__(mod)
+
+ def visit_function_(self, f):
+ if "Composite" not in f.attrs:
+ body = super().visit_expr(f.body)
+ new_f = relax.Function(f.params, body, f.ret_struct_info,
f.is_pure, f.attrs, f.span)
+
+ if "global_symbol" in f.attrs and "cudnn" in
f.attrs["global_symbol"]:
+ composite_func = body.blocks[0].bindings[0].value
+ if "WorkspaceSize" in composite_func.attrs:
+ return new_f.with_attr("WorkspaceSize",
composite_func.attrs["WorkspaceSize"])
+
+ return new_f
+
+ if "attention" in f.attrs["Composite"] and "cudnn" in
f.attrs["Composite"]:
+ # Workspace is needed only for larger head sizes, but for
simplicity we always allocate.
+ out_dtype = f.ret_struct_info.dtype
+ out_size_1d = _shape_1d(f.ret_struct_info.shape)
+ # This needs to be in sync with the actual value that the kernel
expects.
+ workspace_size_bytes = out_size_1d * {"float16": 2, "float32":
4}[out_dtype]
+ if not isinstance(workspace_size_bytes, (int,
tvm.tir.expr.IntImm)):
+ # Tempororay workaround for dynamic shape workload. Will be
removed when
+ # workspace for dynamic shape workload is implemented.
+ workspace_size_bytes = 8
+ return f.with_attr("WorkspaceSize", workspace_size_bytes)
+
+ return f
+
+
[email protected]_pass(opt_level=0)
+def annotate_workspace(mod, _):
+ """Pass to annotate a workspace requirement for each cuDNN-offloaded
function."""
+ annotator = WorkspaceAnnotator(mod)
+ for name, f in mod.functions_items():
+ if isinstance(f, relax.Function):
+ new_f = annotator.visit_expr(f)
+ mod.update_func(name, new_f)
+ return mod
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 0d9f4ff8e9..80979bbe7e 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -383,19 +383,25 @@ def _check_stacked_attention(context:
PatternCheckContext) -> bool:
if not split_op.attrs.axis == 2:
return False
else:
+ get_const_int_list = lambda tup: [int(e.value) for e in tup]
last_end = 0
for name in ["query", "key", "value"]:
assert f"strided_slice_{name}" in context.annotated_expr
strided_slice_op = context.annotated_expr[f"strided_slice_{name}"]
- if list(strided_slice_op.attrs.axes) != [2]:
+ axes = get_const_int_list(strided_slice_op.args[1])
+ begins = get_const_int_list(strided_slice_op.args[2])
+ ends = get_const_int_list(strided_slice_op.args[3])
+ strides = get_const_int_list(strided_slice_op.args[4])
+
+ if axes != [2]:
return False
- if list(strided_slice_op.attrs.begin) != [last_end]:
+ if begins != [last_end]:
return False
- if not len(strided_slice_op.attrs.end) == 1:
+ if not len(ends) == 1:
return False
- last_end = strided_slice_op.attrs.end[0]
- if list(strided_slice_op.attrs.strides) != [1]:
+ if strides != [1]:
return False
+ last_end = ends[0]
return True
@@ -537,7 +543,7 @@ class WorkspaceAnnotator(PyExprMutator):
return new_f
- if "attention" in f.attrs["Composite"]:
+ if "attention" in f.attrs["Composite"] and "cutlass" in
f.attrs["Composite"]:
# Workspace is needed only for larger head sizes, but for
simplicity we always allocate.
out_dtype = f.ret_struct_info.dtype
out_size_1d = _shape_1d(f.ret_struct_info.shape)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 8ec43f1f27..1faef9cceb 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -260,7 +260,7 @@ def make_attention_pattern(with_bias: bool = False,
var_len: bool = False):
return out, annotations
-def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
+def make_stacked_attention_pattern(start_op: str, with_bias: bool = False,
layout="BS3NH"):
"""
Create pattern for fused multi head attention with stacked input.
@@ -272,6 +272,9 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
with_bias: bool
Whether or not to include bias addition
+ layout: str
+ The layout of the stacked input tensor.
+
Returns
-------
pattern: DFPattern
@@ -290,17 +293,28 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
key_raw = is_tuple_get_item(qkv_tuple, 1)
value_raw = is_tuple_get_item(qkv_tuple, 2)
elif start_op == "strided_slice":
- ops["strided_slice_query"] = query_raw =
is_op("relax.strided_slice")(stacked_qkv)
- ops["strided_slice_key"] = key_raw =
is_op("relax.strided_slice")(stacked_qkv)
- ops["strided_slice_value"] = value_raw =
is_op("relax.strided_slice")(stacked_qkv)
+ ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(
+ stacked_qkv, varg_default_wildcard=True
+ )
+ ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(
+ stacked_qkv, varg_default_wildcard=True
+ )
+ ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(
+ stacked_qkv, varg_default_wildcard=True
+ )
else:
raise NotImplementedError()
query_reshape_list = wildcard()
key_reshape_list = wildcard()
value_reshape_list = wildcard()
- query = is_op("relax.reshape")(query_raw, query_reshape_list)
- key = is_op("relax.reshape")(key_raw, key_reshape_list)
- value = is_op("relax.reshape")(value_raw, value_reshape_list)
+ if layout == "BS3NH":
+ query = is_op("relax.reshape")(query_raw, query_reshape_list)
+ key = is_op("relax.reshape")(key_raw, key_reshape_list)
+ value = is_op("relax.reshape")(value_raw, value_reshape_list)
+ elif layout == "SBN3H":
+ ops["q_transpose"] = query = is_op("relax.permute_dims")(query_raw)
+ ops["k_transpose"] = key = is_op("relax.permute_dims")(key_raw)
+ ops["v_transpose"] = value = is_op("relax.permute_dims")(value_raw)
annotations = {
"stacked_qkv": stacked_qkv,
"query_reshape_list": query_reshape_list,
@@ -314,6 +328,10 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
out = is_op("relax.nn.attention_bias")(query, key, value, bias)
else:
out = is_op("relax.nn.attention")(query, key, value)
+
+ if layout == "SBN3H":
+ out = is_op("relax.permute_dims")(out)
+
return out, annotations
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 725a930fd6..ec072f663c 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1568,11 +1568,14 @@ def scaled_dot_product_attention(
Parameters
----------
query : Tensor
- Tensor representing current attention lookup.
+ Tensor representing current attention lookup of shape
+ [batch, seq_len, num_heads, head_size].
key : Tensor
- Tensor representing cross attention mapping.
+ Tensor representing cross attention mapping of shape
+ [batch, seq_len_kv, num_heads_kv, head_size].
value : Tensor
- Tensor representing embedded attention values.
+ Tensor representing embedded attention values of shape
+ [batch, seq_len_kv, num_heads_kv, head_size_value].
attn_mask : Optional[Tensor]
Optional mask for attention, not yet supported.
is_causal : Optional[bool]
diff --git a/python/tvm/relax/testing/__init__.py
b/python/tvm/relax/testing/__init__.py
index 4256ebc3be..dc43d6c1f8 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -21,3 +21,4 @@ from .nn import *
from .relay_translator import *
from .ast_printer import dump_ast
from .matmul import *
+from .attention import *
diff --git a/python/tvm/relax/testing/attention.py
b/python/tvm/relax/testing/attention.py
new file mode 100644
index 0000000000..a00674394b
--- /dev/null
+++ b/python/tvm/relax/testing/attention.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Relax script for attention module."""
+import tvm
+from tvm.script import relax as R, tir as T
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_relax_attention_module(
+ q_shape,
+ k_shape,
+ v_shape,
+ *,
+ dtype,
+ bias_shape=None,
+ qk_scale=None,
+ causal_mask=None,
+ window_size=None,
+): # pylint: disable=too-many-arguments, too-many-locals, invalid-name
+ """Get a relax module for attention."""
+
+ if qk_scale is not None:
+ qk_scale = T.FloatImm("float32", qk_scale)
+
+ if window_size is not None:
+ window_size = T.IntImm("int32", window_size)
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ q = R.arg("q", R.Tensor(q_shape, dtype))
+ k = R.arg("k", R.Tensor(k_shape, dtype))
+ v = R.arg("v", R.Tensor(v_shape, dtype))
+ bias = None
+ if bias_shape is not None and bias_shape != "none":
+ bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+
+ with R.dataflow() as frame:
+ result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal_mask, window_size))
+ R.output(result)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
+
+
+def get_relax_stacked_attention_module(
+ qkv,
+ b,
+ s,
+ n,
+ h,
+ h_v,
+ op,
+ bias=None,
+ qk_scale=None,
+ single_shape=False,
+ layout="BS3NH",
+): # pylint: disable=too-many-arguments, too-many-locals, too-many-branches,
invalid-name
+ # pylint: disable=too-many-statements
+ """Get a relax module for stacked attention."""
+ dtype = str(qkv.dtype)
+ assert layout in ["BS3NH", "SBN3H"]
+
+ if qk_scale is not None:
+ qk_scale = T.FloatImm("float32", qk_scale)
+
+ if single_shape:
+ if layout == "BS3NH":
+ qk_shape = R.shape([b, s, n, h])
+ elif layout == "SBN3H":
+ qk_shape = R.shape([b, s, n, h])
+ v_shape = qk_shape
+ else:
+ if layout == "BS3NH":
+ qk_shape = [b, s, n, h]
+ v_shape = [b, s, n, h_v]
+ elif layout == "SBN3H":
+ qk_shape = [s, b, n, h]
+ v_shape = [s, b, n, h_v]
+
+ if layout == "BS3NH":
+ split_axis = 2
+ split_sections = [n * h, n * h * 2]
+ elif layout == "SBN3H":
+ split_axis = 3
+ split_sections = [h, h * 2]
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype))
+ if bias is not None:
+ bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+ with R.dataflow() as frame:
+ if op == "split":
+ qkv_tuple = R.split(qkv, split_sections, axis=split_axis)
+ q = qkv_tuple[0]
+ k = qkv_tuple[1]
+ v = qkv_tuple[2]
+ elif op == "strided_slice":
+ q = R.strided_slice(qkv, [split_axis], [0],
[split_sections[0]], [1])
+ k = R.strided_slice(
+ qkv, [split_axis], [split_sections[0]],
[split_sections[1]], [1]
+ )
+ v = R.strided_slice(
+ qkv,
+ [split_axis],
+ [split_sections[1]],
+ [int(qkv.struct_info.shape[split_axis])],
+ [1],
+ )
+ else:
+ raise NotImplementedError()
+ if layout == "BS3NH":
+ q = R.reshape(q, qk_shape)
+ k = R.reshape(k, qk_shape)
+ v = R.reshape(v, v_shape)
+ elif layout == "SBN3H":
+ q = R.permute_dims(q, [1, 0, 2, 3])
+ k = R.permute_dims(k, [1, 0, 2, 3])
+ v = R.permute_dims(v, [1, 0, 2, 3])
+ result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
+ if layout == "SBN3H":
+ result = R.emit(R.permute_dims(result, [1, 0, 2, 3]))
+ R.output(result)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
diff --git a/python/tvm/topi/testing/__init__.py
b/python/tvm/topi/testing/__init__.py
index 72a7cedc49..1486e9986e 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -84,3 +84,4 @@ from .dense import dense
from .searchsorted import searchsorted_ref
from .conv2d_backcward_weight_python import conv2d_backward_weight_python
from .lstm_python import lstm_python
+from .attention_python import attention_python
diff --git a/python/tvm/topi/testing/attention_python.py
b/python/tvm/topi/testing/attention_python.py
new file mode 100644
index 0000000000..856667aedd
--- /dev/null
+++ b/python/tvm/topi/testing/attention_python.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Attention operator in python"""
+from typing import Optional
+import numpy as np
+from .softmax_python import softmax_python
+
+
+def attention_python(
+ q: np.ndarray,
+ k: np.ndarray,
+ v: np.ndarray,
+ bias: Optional[np.ndarray],
+ qk_scale: float,
+ causal: str,
+ window_size: Optional[int] = None,
+ layout: str = "BSNH",
+): # pylint: disable=too-many-arguments, too-many-locals, invalid-name
+ """Attention operator in python
+
+ Parameters
+ ----------
+ q : np.ndarray
+ Query tensor with shape [batch, seq_length, num_heads, head_dim] in
the layout specified by
+ `layout`.
+ k : np.ndarray
+ Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim]
in the layout specified
+ by `layout`.
+ v : np.ndarray
+ Value tensor with shape [batch, seq_length_kv, num_kv_heads,
head_dim_v] in the layout
+ specified by `layout`.
+ bias : np.ndarray
+ Bias tensor with shape [batch, num_heads, seq_length, seq_length]
+ qk_scale : float
+ Scale factor for the query-key product.
+ causal : str
+ The type of causal mask to apply. Can be "none", "TopLeft", or
"BottomRight".
+ window_size : Optional[int]
+ The window size for the causal mask.
+ layout : str
+ The layout of the input tensors, e.g. "BSNH" or "BNSH".
+
+ Returns
+ -------
+ np.ndarray
+ The output tensor with shape [batch, seq_length, num_heads,
head_dim_v] in the layout
+ specified by `layout`.
+ """
+ assert layout in ["BSNH", "BNSH", "SBNH"]
+
+ dim_b = layout.find("B")
+ dim_s = layout.find("S")
+ dim_n = layout.find("N")
+ dim_h = layout.find("H")
+
+ q = q.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s, h
+ k = k.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s_kv, h
+ kt = k.transpose(0, 1, 3, 2) # b, n, h, s_kv
+ v = v.transpose(dim_b, dim_n, dim_s, dim_h)
+
+ num_heads = q.shape[1]
+ num_kv_heads = k.shape[1]
+ s = q.shape[2]
+ s_kv = k.shape[2]
+
+ if num_heads != num_kv_heads:
+ assert num_heads % num_kv_heads == 0
+ factor = num_heads // num_kv_heads
+ kt = np.repeat(kt, factor, axis=1)
+ v = np.repeat(v, factor, axis=1)
+
+ if not qk_scale == "none":
+ score = q @ kt * qk_scale # b, n, s, s_kv
+ else:
+ score = q @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ if bias is not None:
+ score = score + bias # b, n, s, s_kv
+ if causal == "none":
+ attn = softmax_python(score, -1)
+ else:
+ if causal == "TopLeft":
+ offset = 0
+ elif causal == "BottomRight":
+ offset = abs(s - s_kv)
+ else:
+ raise ValueError(f"Unsupported causal type: {causal}")
+ score_masked = np.tril(score, k=offset)
+
+ if window_size:
+ score_masked = np.triu(
+ score_masked, -window_size + 1 # pylint:
disable=invalid-unary-operand-type
+ )
+
+ score_masked_exp = np.tril(
+ np.exp(score_masked - np.max(score_masked, axis=-1,
keepdims=True)), k=offset
+ )
+
+ if window_size:
+ score_masked_exp = np.triu(
+ score_masked_exp, -window_size + 1 # pylint:
disable=invalid-unary-operand-type
+ )
+
+ score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
+ attn = np.divide(score_masked_exp, score_masked_sum)
+
+ out = attn @ v # b, n, s, h_v
+ return out.transpose(*np.argsort([dim_b, dim_n, dim_s, dim_h]).tolist())
diff --git a/src/relax/backend/contrib/cudnn/codegen.cc
b/src/relax/backend/contrib/cudnn/codegen.cc
index 812016b8ea..d8ca5f4e97 100644
--- a/src/relax/backend/contrib/cudnn/codegen.cc
+++ b/src/relax/backend/contrib/cudnn/codegen.cc
@@ -55,6 +55,17 @@ class cuDNNJSONSerializer : public JSONSerializer {
std::string composite_name = composite_opt.value();
+ if (composite_name.find("cudnn.conv2d") != std::string::npos) {
+ return HandleConv2D(call_node, fn, composite_name);
+ } else if (composite_name.find("cudnn.attention") != std::string::npos) {
+ return HandleAttention(call_node, fn, composite_name);
+ } else {
+ LOG(FATAL) << "Unsupported composite function: " << composite_name;
+ }
+ }
+
+ NodeEntries HandleConv2D(const CallNode* call_node, const Function& fn,
+ const std::string& composite_name) {
NodeEntries inputs_tmp;
for (const auto& arg : call_node->args) {
auto res = VisitExpr(arg);
@@ -80,6 +91,42 @@ class cuDNNJSONSerializer : public JSONSerializer {
return AddNode(node, GetRef<Expr>(call_node));
}
+ NodeEntries HandleAttention(const CallNode* call_node, const Function& fn,
+ const std::string& composite_name) {
+ std::string layout =
composite_name.substr(composite_name.find_last_of(".") + 1);
+ NodeEntries inputs;
+ for (const auto& arg : call_node->args) {
+ auto res = VisitExpr(arg);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+ ICHECK_EQ(inputs.size(), 2);
+ auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
+ "kernel", /* op_type_ */
+ inputs, 1 /* num_outputs_ */);
+ const CallNode* root_call = backend::GetOpInFunction(fn,
"relax.nn.attention");
+ auto q_shape = Downcast<ShapeExpr>(
+
Downcast<TensorStructInfo>(root_call->args[0]->struct_info_.value())->shape.value());
+ auto k_shape = Downcast<ShapeExpr>(
+
Downcast<TensorStructInfo>(root_call->args[1]->struct_info_.value())->shape.value());
+ auto v_shape = Downcast<ShapeExpr>(
+
Downcast<TensorStructInfo>(root_call->args[2]->struct_info_.value())->shape.value());
+ int num_heads = q_shape->values[2].as<IntImmNode>()->value;
+ int num_kv_heads = k_shape->values[2].as<IntImmNode>()->value;
+ int head_size = q_shape->values[3].as<IntImmNode>()->value;
+ int head_size_v = v_shape->values[3].as<IntImmNode>()->value;
+ SetCallNodeAttribute(node, root_call);
+
+ auto to_str_array = [](int val) {
+ return
std::vector<dmlc::any>{std::vector<std::string>{std::to_string(val)}};
+ };
+ node->SetAttr("num_heads", to_str_array(num_heads));
+ node->SetAttr("num_kv_heads", to_str_array(num_kv_heads));
+ node->SetAttr("head_size", to_str_array(head_size));
+ node->SetAttr("head_size_v", to_str_array(head_size_v));
+ node->SetAttr("layout",
std::vector<dmlc::any>{std::vector<std::string>{layout}});
+ return AddNode(node, GetRef<Expr>(call_node));
+ }
+
private:
/*! \brief The bindings to look up composite functions. */
Map<Var, Expr> bindings_;
diff --git a/src/relax/transform/allocate_workspace.cc
b/src/relax/transform/allocate_workspace.cc
index 1d4a017712..05aa8ce552 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -66,8 +66,10 @@ class ExternFunctionRewriter : ExprMutator {
}
new_params.push_back(workspace_param);
+ auto new_attrs = func_node->attrs;
+ new_attrs.CopyOnWrite()->dict.erase(attr::kWorkspaceSize);
return Function(new_params, VisitExpr(func_node->body),
func_node->ret_struct_info,
- func_node->is_pure, func_node->attrs);
+ func_node->is_pure, new_attrs);
}
return ExprMutator::VisitExpr_(func_node);
}
@@ -122,6 +124,7 @@ class WorkspaceProvider : ExprMutator {
builder_->UpdateFunction(new_gvar,
WithAttr(f, tvm::attr::kGlobalSymbol,
new_gvar->name_hint));
gvar_map_[gvar] = new_gvar;
+ new_gvars_.insert(new_gvar);
builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
}
@@ -164,8 +167,7 @@ class WorkspaceProvider : ExprMutator {
auto new_op = VisitExpr(call_node->op);
if (auto gv = new_op.as<GlobalVar>()) {
- auto callee = builder_->GetContextIRModule()->Lookup(gv.value());
- if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) {
+ if (new_gvars_.count(gv.value())) {
auto new_args = call_node->args;
ICHECK(workspace_var_main_.defined());
new_args.push_back(workspace_var_main_);
@@ -185,6 +187,7 @@ class WorkspaceProvider : ExprMutator {
* the new ones that are transformed to take an additional workspace
parameter. This is only
* needed since the struct info of the global variables changes between
transformation. */
std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+ std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> new_gvars_;
};
} // namespace relax
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 2be7ad41f3..6030a28d93 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -33,6 +33,7 @@
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
@@ -595,8 +596,7 @@ class FunctionCreator : public ExprMutator {
}
StructInfo param_sinfo = GetStructInfo(expr);
- // Exclude PrimValues from arg/params to make composite functions
contain PrimValues.
- if (!expr->IsInstance<PrimValueNode>()) {
+ if (!IsInlinableConstants(expr)) {
Var param(std::move(name), GetStructInfo(expr));
arguments_.push_back(expr);
params_.push_back(param);
@@ -621,6 +621,21 @@ class FunctionCreator : public ExprMutator {
return ExprMutator::VisitExpr(expr);
}
+ // Check if the expression is constant PrimValue or ShapeExpr or tuple of
them that can be
+ // inlined in the composite functions and excluded from args/params.
+ bool IsInlinableConstants(const Expr& expr) {
+ if (const auto* tuple = expr.as<TupleNode>()) {
+ return std::all_of(tuple->fields.begin(), tuple->fields.end(),
+ [this](const Expr& e) { return
IsInlinableConstants(e); });
+ } else if (const auto* prim_value = expr.as<PrimValueNode>()) {
+ return tvm::tir::UndefinedVars(prim_value->value).empty();
+ } else if (const auto* shape_expr = expr.as<ShapeExprNode>()) {
+ return std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
+ [this](const PrimExpr& e) { return
tvm::tir::UndefinedVars(e).empty(); });
+ }
+ return false;
+ }
+
private:
/*! \brief The variables defined in this function */
std::unordered_set<const VarNode*> defined_vars_;
diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
new file mode 100644
index 0000000000..f8b170fe20
--- /dev/null
+++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
+ * \brief cuDNN scale dot product attention implementation
+ */
+
+#include "./attention.h"
+
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../../cuda/cuda_common.h"
+#include "../cudnn_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t
num_heads,
+ int64_t num_kv_heads, int64_t head_size,
int64_t head_size_v,
+ double scale, const DLDataType& data_type,
+ const std::string& layout) {
+ graph_ = std::make_unique<cudnn_frontend::graph::Graph>();
+
+ CHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16)
+ << "Only float16 is supported";
+
+ graph_->set_io_data_type(cudnn_frontend::DataType_t::HALF)
+ .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT)
+ .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
+
+ auto q_desc =
cudnn_frontend::graph::Tensor_attributes().set_name("Q").set_uid(kTensorIDQ);
+ auto k_desc =
cudnn_frontend::graph::Tensor_attributes().set_name("K").set_uid(kTensorIDK);
+ auto v_desc =
cudnn_frontend::graph::Tensor_attributes().set_name("V").set_uid(kTensorIDV);
+ auto o_desc =
cudnn_frontend::graph::Tensor_attributes().set_name("Out").set_uid(kTensorIDOut);
+
+ std::vector<int64_t> q_stride, k_stride, v_stride,
+ o_stride; // stride in the order of (batch, num_heads, seq_len,
head_size)
+
+ if (layout == "BS3NH") {
+ int64_t stride_H = 1;
+ int64_t q_stride_N = head_size;
+ int64_t k_stride_N = head_size;
+ int64_t v_stride_N = head_size_v;
+ int64_t stride_S =
+ num_heads * q_stride_N + num_kv_heads * k_stride_N + num_kv_heads *
v_stride_N;
+ int64_t stride_B = stride_S * seq_len;
+ q_stride = {stride_B, q_stride_N, stride_S, stride_H};
+ k_stride = {stride_B, k_stride_N, stride_S, stride_H};
+ v_stride = {stride_B, v_stride_N, stride_S, stride_H};
+ o_stride = {seq_len * num_heads * head_size_v, head_size_v, num_heads *
head_size_v, 1};
+ offset_k_ = num_heads * head_size;
+ offset_v_ = offset_k_ + num_kv_heads * head_size;
+ } else if (layout == "SBN3H") {
+ CHECK_EQ(num_kv_heads, num_heads);
+ int64_t stride_H = 1;
+ int64_t stride_N = head_size + head_size + head_size_v;
+ int64_t stride_B = num_heads * stride_N;
+ int64_t stride_S = stride_B * batch;
+ q_stride = k_stride = v_stride = {stride_B, stride_N, stride_S, stride_H};
+ o_stride = {num_heads * head_size_v, head_size_v, num_heads * head_size_v
* batch, 1};
+ offset_k_ = head_size;
+ offset_v_ = offset_k_ * 2;
+ } else {
+ LOG(FATAL) << "Unsupported layout: " << layout;
+ }
+
+ q_desc = q_desc.set_dim({batch, num_heads, seq_len,
head_size}).set_stride(q_stride);
+ k_desc = k_desc.set_dim({batch, num_kv_heads, seq_len,
head_size}).set_stride(k_stride);
+ v_desc = v_desc.set_dim({batch, num_kv_heads, seq_len,
head_size_v}).set_stride(v_stride);
+ auto sdpa_options = cudnn_frontend::graph::SDPA_attributes()
+ .set_name("flash_attention")
+ .set_is_inference(true)
+ .set_alibi_mask(false)
+ .set_causal_mask(false)
+ .set_attn_scale(scale);
+
+ auto q = graph_->tensor(q_desc);
+ auto k = graph_->tensor(k_desc);
+ auto v = graph_->tensor(v_desc);
+ auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options);
+ CHECK(stats == nullptr);
+ o->set_output(true).set_dim({batch, num_heads, seq_len,
head_size_v}).set_stride(o_stride);
+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+ CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle,
{cudnn_frontend::HeurMode_t::A}));
+}
+
+void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace,
DLTensor* out) {
+ CUDNN_CALL(
+ cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle,
tvm::runtime::GetCUDAStream()));
+ auto* qkv_base = reinterpret_cast<uint8_t*>(qkv->data) + qkv->byte_offset;
+ auto* q_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_q_;
+ auto* k_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_k_;
+ auto* v_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_v_;
+ auto* out_ptr = reinterpret_cast<uint8_t*>(out->data) + out->byte_offset;
+
+ size_t workspace_size = graph_->get_workspace_size();
+ CHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small";
+ std::unordered_map<cudnn_frontend::graph::Tensor_attributes::uid_t, void*>
inputs = {
+ {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr},
{kTensorIDOut, out_ptr}};
+
+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+ CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs,
workspace->data));
+}
+
+} // namespace contrib
+} // namespace tvm
diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h
b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h
new file mode 100644
index 0000000000..4d0309fb3b
--- /dev/null
+++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.h
+ * \brief cuDNN scale dot product attention implementation
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
+#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
+
+#include <cudnn_frontend.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <string>
+
+#define CUDNN_FRONTEND_CALL(func) \
+ do { \
+ auto status = (func); \
+ CHECK(status.is_good()) << status.get_message(); \
+ } while (0)
+
+namespace tvm {
+namespace contrib {
+
+class CuDNNSDPARunnerNode : public tvm::runtime::Object {
+ public:
+ CuDNNSDPARunnerNode() {}
+
+ ~CuDNNSDPARunnerNode() {}
+
+ static constexpr const char* _type_key = "contrib.cudnn.SDPARunner";
+
+ void Init(int64_t batch, int64_t seq_len, int64_t num_heads, int64_t
num_kv_heads,
+ int64_t head_size, int64_t head_size_v, double scale, const
DLDataType& data_type,
+ const std::string& layout);
+
+ void Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out);
+
+ static constexpr int kTensorIDQ = 0;
+ static constexpr int kTensorIDK = 1;
+ static constexpr int kTensorIDV = 2;
+ static constexpr int kTensorIDOut = 4;
+
+ private:
+ std::unique_ptr<cudnn_frontend::graph::Graph> graph_{nullptr};
+ int64_t offset_q_{0};
+ int64_t offset_k_{0};
+ int64_t offset_v_{0};
+};
+
+class CuDNNSDPARunner : public tvm::runtime::ObjectRef {
+ public:
+ static CuDNNSDPARunner Create() {
+ auto n = make_object<CuDNNSDPARunnerNode>();
+ return CuDNNSDPARunner(n);
+ }
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CuDNNSDPARunner,
tvm::runtime::ObjectRef,
+ CuDNNSDPARunnerNode);
+};
+
+} // namespace contrib
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
index 7d701396d0..3f4b659275 100644
--- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
+++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
@@ -31,6 +31,10 @@
#include "../json/json_node.h"
#include "../json/json_runtime.h"
+
+#ifdef TVM_USE_CUDNN_FRONTEND
+#include "./cudnn_frontend/attention.h"
+#endif
#include "cudnn_utils.h"
namespace tvm {
@@ -47,78 +51,19 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
: JSONRuntimeBase(symbol_name, graph_json, const_names) {}
void Init(const Array<NDArray>& consts) override {
- auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
- auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
- ICHECK(func != nullptr);
- stream = static_cast<cudaStream_t>((*func)().operator void*());
-
- auto attr_in_name = [](const std::string& op_name, const std::string&
attr_name) {
- return op_name.find(attr_name) != std::string::npos;
- };
-
- auto vstr2vint = [](const JSONGraphNode& node, const std::string& attrStr)
{
- auto string_to_int = [](const std::string& str) { return std::stoi(str);
};
- auto string_vec = node.GetAttr<std::vector<std::string>>(attrStr);
- std::vector<int> int_vec(string_vec.size());
- std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(),
string_to_int);
- return int_vec;
- };
+ op_execs_.resize(nodes_.size());
// get some config from the graph
for (size_t i = 0; i < nodes_.size(); ++i) {
const auto& node = nodes_[i];
if (node.GetOpType() == "kernel") {
- op_name = node.GetOpName();
- std::vector<int> input_dims, kernel_dims, output_dims;
- auto input_node = nodes_[0];
- auto input_shapes = input_node.GetOpShape()[0];
- auto kernel_node = nodes_[1];
- auto kernel_shapes = kernel_node.GetOpShape()[0];
- auto output_shapes = node.GetOpShape()[0];
- for (const auto& _i : input_shapes) {
- input_dims.emplace_back(static_cast<int>(_i));
- }
- for (const auto& _i : kernel_shapes) {
- kernel_dims.emplace_back(static_cast<int>(_i));
+ std::string op_name = node.GetOpName();
+ if (op_name.find("conv2d") != std::string::npos) {
+ op_execs_[i] = GetConv2DExec(node);
+ } else if (op_name.find("attention") != std::string::npos) {
+ op_execs_[i] = GetAttentionExec(node);
+ } else {
+ LOG(FATAL) << "Unsupported op: " << op_name;
}
- for (const auto& _i : output_shapes) {
- output_dims.emplace_back(static_cast<int>(_i));
- }
- has_bias = attr_in_name(op_name, "bias");
- groups =
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
- padding = vstr2vint(node, "padding");
- strides = vstr2vint(node, "strides");
- dilation = vstr2vint(node, "dilation");
- conv_dtype = node.GetAttr<std::vector<std::string>>("out_dtype")[0];
- std::string layout =
node.GetAttr<std::vector<std::string>>("out_layout")[0];
- dims = layout.size() - 2; // remove O and I dims
-
- if (layout == "NCHW")
- format = CUDNN_TENSOR_NCHW;
- else if (layout == "NHWC")
- format = CUDNN_TENSOR_NHWC;
- else
- LOG(FATAL) << "Unsupported layout: " << layout;
-
- if (attr_in_name(op_name, "relu")) {
- act = CUDNN_ACTIVATION_RELU;
- } else if (attr_in_name(op_name, "relu6")) {
- act = CUDNN_ACTIVATION_CLIPPED_RELU;
- coef = 6.0;
- } else if (attr_in_name(op_name, "leaky_relu")) {
- act = CUDNN_ACTIVATION_RELU;
- coef = 0.1;
- }
- this->handle = entry_ptr->handle;
- this->kernel_node = node;
-
- // find best algo
- TVMRetValue best_algo;
-
- tvm::contrib::FindAlgo(format, dims, groups, padding.data(),
strides.data(),
- dilation.data(), input_dims.data(),
kernel_dims.data(),
- output_dims.data(), conv_dtype, conv_dtype,
false, &best_algo);
-
- this->algo = best_algo.operator int();
}
}
}
@@ -126,27 +71,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
const char* type_key() const override { return "cudnn_json"; } // May be
overridden
void Run() override {
- auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
- const DLTensor* bias = nullptr;
- if (has_bias) {
- bias = GetInput(node, 2);
+ for (const auto& f : op_execs_) {
+ if (f != nullptr) {
+ f();
}
- return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
- };
-
- auto [a_ptr, b_ptr, bias_ptr] = get_inputs(kernel_node, has_bias);
- uint32_t output_eid = EntryID(outputs_[0]);
- auto out_ptr = data_entry_[output_eid];
-
- if (this->has_bias) {
- tvm::contrib::ConvolutionBiasActivationForward(
- this->mode, this->format, this->algo, this->dims, this->groups,
this->act, this->coef,
- this->padding.data(), this->strides.data(), this->dilation.data(),
a_ptr, b_ptr, out_ptr,
- bias_ptr, this->conv_dtype);
- } else {
- tvm::contrib::ConvolutionForward(
- this->mode, this->format, this->algo, this->dims, this->groups,
this->padding.data(),
- this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr,
this->conv_dtype);
}
}
@@ -157,27 +85,150 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
ICHECK(eid < data_entry_.size());
return data_entry_[eid];
}
- /*conv op name*/
- std::string op_name;
- /*conv mode: CUDNN_CROSS_CORRELATION by default*/
- int mode = CUDNN_CROSS_CORRELATION;
- /*algo: by default we select the implicit gemm algo, will be tuned in the
initial pass.*/
- int algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
- /*if has bias*/
- bool has_bias = false;
- /*args for function call*/
- int act = CUDNN_ACTIVATION_IDENTITY;
- double coef = 1.0;
- int format = CUDNN_TENSOR_NHWC;
- int dims = 2;
- int groups = 1;
- std::vector<int> padding;
- std::vector<int> strides;
- std::vector<int> dilation;
- std::string conv_dtype;
- cudaStream_t stream;
- cudnnHandle_t handle;
- tvm::runtime::json::JSONGraphNode kernel_node;
+
+ bool attr_in_name(const std::string& op_name, const std::string& attr_name) {
+ return op_name.find(attr_name) != std::string::npos;
+ }
+
+ std::vector<int> vstr2vint(const JSONGraphNode& node, const std::string&
attrStr) {
+ auto string_to_int = [](const std::string& str) { return std::stoi(str); };
+ auto string_vec = node.GetAttr<std::vector<std::string>>(attrStr);
+ std::vector<int> int_vec(string_vec.size());
+ std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(),
string_to_int);
+ return int_vec;
+ }
+
+ std::function<void()> GetConv2DExec(const JSONGraphNode& node) {
+ auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
+ auto op_name = node.GetOpName();
+
+ std::vector<int> input_dims, kernel_dims, output_dims;
+ auto input_node = nodes_[0];
+ auto input_shapes = input_node.GetOpShape()[0];
+ auto kernel_shapes = nodes_[1].GetOpShape()[0];
+ auto output_shapes = node.GetOpShape()[0];
+ for (const auto& _i : input_shapes) {
+ input_dims.emplace_back(static_cast<int>(_i));
+ }
+ for (const auto& _i : kernel_shapes) {
+ kernel_dims.emplace_back(static_cast<int>(_i));
+ }
+ for (const auto& _i : output_shapes) {
+ output_dims.emplace_back(static_cast<int>(_i));
+ }
+ bool has_bias = attr_in_name(op_name, "bias");
+ int groups =
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
+ std::vector<int> padding = vstr2vint(node, "padding");
+ std::vector<int> strides = vstr2vint(node, "strides");
+ std::vector<int> dilation = vstr2vint(node, "dilation");
+ auto conv_dtype = node.GetAttr<std::vector<std::string>>("out_dtype")[0];
+ std::string layout =
node.GetAttr<std::vector<std::string>>("out_layout")[0];
+ int dims = layout.size() - 2; // remove O and I dims
+
+ int format = CUDNN_TENSOR_NHWC;
+ if (layout == "NCHW") {
+ format = CUDNN_TENSOR_NCHW;
+ } else if (layout == "NHWC") {
+ format = CUDNN_TENSOR_NHWC;
+ } else {
+ LOG(FATAL) << "Unsupported layout: " << layout;
+ }
+
+ int act = CUDNN_ACTIVATION_IDENTITY;
+ double coef = 1.0;
+ if (attr_in_name(op_name, "relu")) {
+ act = CUDNN_ACTIVATION_RELU;
+ } else if (attr_in_name(op_name, "relu6")) {
+ act = CUDNN_ACTIVATION_CLIPPED_RELU;
+ coef = 6.0;
+ } else if (attr_in_name(op_name, "leaky_relu")) {
+ act = CUDNN_ACTIVATION_RELU;
+ coef = 0.1;
+ }
+
+ /*conv mode: CUDNN_CROSS_CORRELATION by default*/
+ int mode = CUDNN_CROSS_CORRELATION;
+
+ // find best algo
+ TVMRetValue best_algo;
+
+ tvm::contrib::FindAlgo(format, dims, groups, padding.data(),
strides.data(), dilation.data(),
+ input_dims.data(), kernel_dims.data(),
output_dims.data(), conv_dtype,
+ conv_dtype, false, &best_algo);
+
+ int algo = best_algo.operator int();
+ std::function<void()> op_exec = [=]() {
+ auto stream = static_cast<cudaStream_t>(GetCUDAStream());
+ CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream));
+
+ auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
+ const DLTensor* bias = nullptr;
+ if (has_bias) {
+ bias = GetInput(node, 2);
+ }
+ return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
+ };
+
+ auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, has_bias);
+ uint32_t output_eid = EntryID(outputs_[0]);
+ auto out_ptr = data_entry_[output_eid];
+ if (has_bias) {
+ tvm::contrib::ConvolutionBiasActivationForward(
+ mode, format, algo, dims, groups, act, coef, padding.data(),
strides.data(),
+ dilation.data(), a_ptr, b_ptr, out_ptr, bias_ptr, conv_dtype);
+ } else {
+ tvm::contrib::ConvolutionForward(mode, format, algo, dims, groups,
padding.data(),
+ strides.data(), dilation.data(),
a_ptr, b_ptr, out_ptr,
+ conv_dtype);
+ }
+ };
+ return op_exec;
+ }
+
+ std::function<void()> GetAttentionExec(const JSONGraphNode& node) {
+#ifdef TVM_USE_CUDNN_FRONTEND
+ auto dtype = node.GetOpDataType()[0];
+ int num_heads = vstr2vint(node, "num_heads")[0];
+ int num_kv_heads = vstr2vint(node, "num_kv_heads")[0];
+ int head_size = vstr2vint(node, "head_size")[0];
+ int head_size_v = vstr2vint(node, "head_size_v")[0];
+ std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
+ const auto& input_qkv_node = nodes_[EntryID(node.GetInputs()[0])];
+ auto qkv_shapes = input_qkv_node.GetOpShape()[0];
+
+ int64_t batch, seq_len;
+ if (layout == "BS3NH") {
+ ICHECK_EQ(qkv_shapes.size(), 3);
+ batch = qkv_shapes[0];
+ seq_len = qkv_shapes[1];
+ } else if (layout == "SBN3H") {
+ ICHECK_EQ(qkv_shapes.size(), 4);
+ batch = qkv_shapes[1];
+ seq_len = qkv_shapes[0];
+ } else {
+ LOG(FATAL) << "Unsupported layout: " << layout;
+ }
+ double scale = 1 / std::sqrt(head_size);
+ std::string scale_attr =
node.GetAttr<std::vector<std::string>>("scale")[0];
+ if (scale_attr.size()) {
+ scale = std::stod(scale_attr);
+ }
+
+ auto runner = tvm::contrib::CuDNNSDPARunner::Create();
+ runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size,
head_size_v, scale, dtype,
+ layout);
+ return [=]() {
+ auto qkv = GetInput(node, 0);
+ auto workspace = const_cast<DLTensor*>(GetInput(node, 1));
+ auto out = const_cast<DLTensor*>(data_entry_[EntryID(outputs_[0])]);
+ runner->Run(qkv, workspace, out);
+ };
+#else
+ LOG(FATAL) << "Please build with CUDNN frontend to use attention op";
+#endif
+ }
+
+ std::vector<std::function<void()>> op_execs_;
};
runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json,
diff --git a/tests/python/relax/test_codegen_cudnn.py
b/tests/python/relax/test_codegen_cudnn.py
index 0f911905f8..59f49bfde8 100644
--- a/tests/python/relax/test_codegen_cudnn.py
+++ b/tests/python/relax/test_codegen_cudnn.py
@@ -22,7 +22,8 @@ import tvm.testing
import tvm.topi.testing
from tvm import relax
from tvm.relax.backend.contrib.cudnn import partition_for_cudnn
-from tvm.relax.testing import get_relax_matmul_module
+from tvm.relax.testing import get_relax_matmul_module,
get_relax_stacked_attention_module
+from tvm.contrib.pickle_memoize import memoize
from tvm.script import relax as R
from tvm.script.ir_builder import IRBuilder
@@ -99,7 +100,7 @@ def get_relax_conv2d_module(
def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False):
mod = partition_for_cudnn(mod)
mod = relax.transform.RunCodegen()(mod)
- return build_and_run(mod, np_inputs, "cuda", cuda_graph)
+ return build_and_run(mod, np_inputs, "cuda", cuda_graph=cuda_graph)
def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
@@ -244,5 +245,65 @@ def test_conv2d_nchw_oihw_offload(data_shape,
weight_shape, dtype, with_bias, ac
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+@memoize("topi.tests.test_codegen_cudnn.test_stacked_attention_offload")
+def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale,
dtype, layout):
+ if layout == "BS3NH":
+ qkv = np.random.randn(b, s, n * h * 2 + n * h_v).astype(dtype)
+ split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2)
+ q = split_qkv[0].reshape(b, s, n, h)
+ k = split_qkv[1].reshape(b, s, n, h)
+ v = split_qkv[2].reshape(b, s, n, h_v)
+ layout = "BSNH"
+ elif layout == "SBN3H":
+ qkv = np.random.randn(s, b, n, h * 2 + h_v).astype(dtype)
+ q, k, v = np.split(qkv, [h, h * 2], axis=3)
+ layout = "SBNH"
+ else:
+ raise ValueError("Unsupported layout: {}".format(layout))
+ if not bias_shape == "none":
+ bias = np.random.randn(*bias_shape).astype(dtype)
+ score = score + bias # b, n, s, s
+ else:
+ bias = None
+ ref = tvm.topi.testing.attention_python(q, k, v, bias, qk_scale, "none",
None, layout)
+ return qkv, bias, ref
+
+
[email protected](
+ params=[
+ # B, S, N, H, bias_shape scale, single_shape, layout
+ (4, 8, 32, (64, 32), "none", 1.0, False, "BS3NH"),
+ (4, 8, 32, (64, 64), "none", "none", True, "BS3NH"),
+ (4, 8, 32, (64, 32), "none", 1.0, False, "SBN3H"),
+ (4, 8, 32, (64, 64), "none", "none", True, "SBN3H"),
+ ]
+)
+def stacked_attention_size(request):
+ return request.param
+
+
[email protected](reason="require cudnn frontend")
+def test_stacked_attention_split_offload(stacked_attention_size):
+ b, s, n, (h, h_v), bias_shape, scale, single_shape, layout =
stacked_attention_size
+ qkv, bias, ref = get_numpy_stacked_attention_ref(
+ b, s, n, h, h_v, bias_shape, scale, "float16", layout
+ )
+ if scale == "none":
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape,
layout=layout
+ )
+ scale = 1.0 / np.sqrt(h)
+ else:
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "split", bias, scale,
single_shape=single_shape, layout=layout
+ )
+
+ if bias is None:
+ out = get_result_with_relax_cudnn_offload(mod, [qkv])
+ else:
+ out = get_result_with_relax_cudnn_offload(mod, [qkv, bias])
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=2e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 969651f72f..3fa3f2d914 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -24,7 +24,11 @@ from tvm import relax
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
-from tvm.relax.testing import get_relax_matmul_module
+from tvm.relax.testing import (
+ get_relax_matmul_module,
+ get_relax_attention_module,
+ get_relax_stacked_attention_module,
+)
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@@ -594,47 +598,6 @@ def attention_size(request):
return request.param
-def get_relax_attention_module(
- q_shape,
- k_shape,
- v_shape,
- *,
- dtype,
- bias_shape=None,
- qk_scale=None,
- causal_mask=None,
- window_size=None,
-):
- from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import relax as relax_builder
- from tvm.script.ir_builder import tir as T
-
- if qk_scale is not None:
- qk_scale = T.FloatImm("float32", qk_scale)
-
- if window_size is not None:
- window_size = T.IntImm("int32", window_size)
-
- with IRBuilder() as builder:
- with relax_builder.function():
- R.func_name("main")
- q = R.arg("q", R.Tensor(q_shape, dtype))
- k = R.arg("k", R.Tensor(k_shape, dtype))
- v = R.arg("v", R.Tensor(v_shape, dtype))
- bias = None
- if bias_shape is not None and bias_shape != "none":
- bias = R.arg("bias", R.Tensor(bias_shape, dtype))
-
- with R.dataflow() as frame:
- result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal_mask, window_size))
- R.output(result)
-
- R.func_ret_value(frame.output_vars[0])
-
- func = builder.get()
- return tvm.IRModule({"main": func})
-
-
def get_numpy_attention_ref(
b,
s,
@@ -649,59 +612,20 @@ def get_numpy_attention_ref(
window_size=None,
num_kv_head=None,
):
- if num_kv_head is None:
- num_kv_head = n
-
+ num_kv_head = num_kv_head or n
q = np.random.randn(b, s, n, h).astype(dtype)
- k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
- v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
-
- if num_kv_head is None:
- k = k_orig
- v = v_orig
- else:
- factor = n // num_kv_head
- k = np.repeat(k_orig, factor, axis=2)
- v = np.repeat(v_orig, factor, axis=2)
-
- qt = q.transpose(0, 2, 1, 3) # b, n, s, h
- kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
- if not qk_scale == "none":
- score = qt @ kt * qk_scale # b, n, s, s_kv
- else:
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
- if not bias_shape == "none":
- bias = np.random.randn(*bias_shape).astype(dtype)
- score = score + bias # b, n, s, s_kv
- else:
+ k = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
+ v = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
+ if bias_shape == "none":
bias = None
- if causal == "none":
- attn = tvm.topi.testing.softmax_python(score, -1)
else:
- if causal == "TopLeft":
- offset = 0
- elif causal == "BottomRight":
- offset = abs(s - s_kv)
- else:
- raise NotImplementedError()
- score_masked = np.tril(score, k=offset)
-
- if window_size:
- score_masked = np.triu(score_masked, -window_size + 1)
-
- score_masked_exp = np.tril(
- np.exp(score_masked - np.max(score_masked, axis=-1,
keepdims=True)), k=offset
- )
-
- if window_size:
- score_masked_exp = np.triu(score_masked_exp, -window_size + 1)
+ bias = np.random.randn(*bias_shape).astype(dtype)
- score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
- attn = np.divide(score_masked_exp, score_masked_sum)
+ ref = tvm.topi.testing.attention_python(
+ q, k, v, bias, qk_scale, causal=causal, window_size=window_size,
layout="BSNH"
+ )
- vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
- ref = attn @ vt # b, n, s, h_v
- return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+ return q, k, v, bias, ref
def test_attention_offload(attention_size, attention_dtype):
@@ -844,69 +768,14 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, qk_scale, dtype
q = np.reshape(split_qkv[0], (b, s, n, h))
k = np.reshape(split_qkv[1], (b, s, n, h))
v = np.reshape(split_qkv[2], (b, s, n, h_v))
- qt = q.transpose(0, 2, 1, 3) # b, n, s, h
- kt = k.transpose(0, 2, 3, 1) # b, n, h, s
- if not qk_scale == "none":
- score = qt @ kt * qk_scale # b, n, s, s
- else:
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s
if not bias_shape == "none":
bias = np.random.randn(*bias_shape).astype(dtype)
- score = score + bias # b, n, s, s
else:
bias = None
- attn = tvm.topi.testing.softmax_python(score, -1)
- vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v
- ref = attn @ vt # b, n, s, h_v
- return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
-
-
-def get_relax_stacked_attention_module(
- qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False
-):
- dtype = str(qkv.dtype)
-
- from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import relax as relax_builder
- from tvm.script.ir_builder import tir as T
-
- if qk_scale is not None:
- qk_scale = T.FloatImm("float32", qk_scale)
-
- if single_shape:
- qk_shape = R.shape([b, s, n, h])
- v_shape = qk_shape
- else:
- qk_shape = [b, s, n, h]
- v_shape = [b, s, n, h_v]
-
- with IRBuilder() as builder:
- with relax_builder.function():
- R.func_name("main")
- qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype))
- if bias is not None:
- bias = R.arg("bias", R.Tensor(bias.shape, dtype))
- with R.dataflow() as frame:
- if op == "split":
- qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
- q = R.reshape(qkv_tuple[0], qk_shape)
- k = R.reshape(qkv_tuple[1], qk_shape)
- v = R.reshape(qkv_tuple[2], v_shape)
- elif op == "strided_slice":
- q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h],
[1]), qk_shape)
- k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h *
2], [1]), qk_shape)
- v = R.reshape(
- R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n
* h_v], [1]), v_shape
- )
- else:
- raise NotImplementedError()
- result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
- R.output(result)
-
- R.func_ret_value(frame.output_vars[0])
-
- func = builder.get()
- return tvm.IRModule({"main": func})
+ ref = tvm.topi.testing.attention_python(
+ q, k, v, bias, qk_scale, causal="none", window_size=None, layout="BSNH"
+ )
+ return qkv, bias, ref
@pytest.fixture(
@@ -926,11 +795,30 @@ def
test_stacked_attention_split_offload(stacked_attention_size):
qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, scale, "float16")
if scale == "none":
mod = get_relax_stacked_attention_module(
- qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape
+ qkv,
+ b,
+ s,
+ n,
+ h,
+ h_v,
+ "split",
+ bias,
+ single_shape=single_shape,
+ layout="BS3NH",
)
else:
mod = get_relax_stacked_attention_module(
- qkv, b, s, n, h, h_v, "split", bias, scale,
single_shape=single_shape
+ qkv,
+ b,
+ s,
+ n,
+ h,
+ h_v,
+ "split",
+ bias,
+ scale,
+ single_shape=single_shape,
+ layout="BS3NH",
)
if bias is None:
@@ -945,11 +833,30 @@ def
test_stacked_attention_strided_slice_offload(stacked_attention_size):
qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, scale, "float32")
if scale == "none":
mod = get_relax_stacked_attention_module(
- qkv, b, s, n, h, h_v, "strided_slice", bias,
single_shape=single_shape
+ qkv,
+ b,
+ s,
+ n,
+ h,
+ h_v,
+ "strided_slice",
+ bias,
+ single_shape=single_shape,
+ layout="BS3NH",
)
else:
mod = get_relax_stacked_attention_module(
- qkv, b, s, n, h, h_v, "strided_slice", bias, scale,
single_shape=single_shape
+ qkv,
+ b,
+ s,
+ n,
+ h,
+ h_v,
+ "strided_slice",
+ bias,
+ scale,
+ single_shape=single_shape,
+ layout="BS3NH",
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=2)
diff --git a/tests/python/relax/test_transform_allocate_workspace.py
b/tests/python/relax/test_transform_allocate_workspace.py
index 1198642d3f..248d195d65 100644
--- a/tests/python/relax/test_transform_allocate_workspace.py
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -95,7 +95,6 @@ class Expected:
R.func_attr(
{
"Codegen": "cutlass",
- "WorkspaceSize": 65536,
"global_symbol": "fused_relax_nn_attention_cutlass1",
}
)
@@ -107,7 +106,7 @@ class Expected:
v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
workspace_1: R.Tensor((65536,), dtype="uint8"),
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
- R.func_attr({"Composite": "cutlass.attention", "Primitive": 1,
"WorkspaceSize": 65536})
+ R.func_attr({"Composite": "cutlass.attention", "Primitive": 1})
with R.dataflow():
gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") =
R.nn.attention(
q_1, k_1, v_1, scale=None
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
index 6a36314a74..cff832a21f 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -1053,7 +1053,6 @@ def test_reshape():
@R.function
def fused_relax_reshape_relax_matmul_tensorrt(
inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
- param_0: R.Shape([1, 784]),
lv1: R.Tensor((784, 512), dtype="float32"),
) -> R.Tensor((1, 512), dtype="float32"):
R.func_attr({"Codegen": "tensorrt"})
@@ -1069,7 +1068,7 @@ def test_reshape():
R.output(gv)
return gv
- lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0)
+ lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0,
R.shape([1, 784]))
@R.function
def lv1_1_1(
@@ -1100,7 +1099,7 @@ def test_reshape():
)
gv: R.Tensor(
(1, 512), dtype="float32"
- ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0,
R.shape([1, 784]), lv1)
+ ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1)
R.output(gv)
return gv