This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7dd76cca03 [Unity] Cutlass attention with dynamic sequence length
(#15028)
7dd76cca03 is described below
commit 7dd76cca032909dec6306b3339ed83b884fbd2ea
Author: Lite Ye <[email protected]>
AuthorDate: Tue Jun 6 14:48:53 2023 -0400
[Unity] Cutlass attention with dynamic sequence length (#15028)
* Dynamic attention
* Fix lint
---
python/tvm/contrib/cutlass/attention_operation.py | 17 +++++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 37 ++++++------
python/tvm/relax/backend/contrib/cutlass.py | 4 ++
tests/python/relax/test_codegen_cutlass.py | 74 +++++++++++++++++------
4 files changed, 97 insertions(+), 35 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 8a96e70fe4..55c1ccd616 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -96,9 +96,20 @@ def instantiate_attention_template(attrs):
typename Attention::Params p;
p.logsumexp_ptr = nullptr;
p.output_ptr = reinterpret_cast<T *>(out0->data);
+
p.output_accum_ptr = nullptr;
+ uint64_t accumulator_buf_size = ${output_size} *
sizeof(Attention::output_accum_t);
+ bool accumulator_buf_allocated = false;
if (Attention::kNeedsOutputAccumulatorBuffer) {
- p.output_accum_ptr = static_cast<float*>(${workspace}->data);
+ if (accumulator_buf_size <= ${workspace}->shape[0]) {
+ p.output_accum_ptr = static_cast<float*>(${workspace}->data);
+ } else {
+ accumulator_buf_size = true;
+ cudaMalloc(
+ &p.output_accum_ptr,
+ accumulator_buf_size
+ );
+ }
}
p.num_heads = ${num_heads}; // N
@@ -129,6 +140,10 @@ def instantiate_attention_template(attrs):
CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+
+ if (accumulator_buf_allocated) {
+ cudaFree(p.output_accum_ptr);
+ }
"""
template = substitute_template(
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index be3bb289cf..48e285998a 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -679,10 +679,27 @@ def instantiate_template(func_name, annotations,
func_args):
elif "attention" in func_name:
headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]
+
+ attrs["qkv_layout"] = annotations["qkv_layout"]
+ if attrs["qkv_layout"] == "default":
+ attrs["query"] = func_args[0]
+ attrs["key"] = func_args[1]
+ attrs["value"] = func_args[2]
+ attrs["num_queries"] = s = get_dim(annotations["num_queries"],
func_args[0], 1)
+ attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1],
1)
+ if len(func_args) > 4: # +1 for workspace, the last arg
+ attrs["bias"] = func_args[3]
+ elif attrs["qkv_layout"] == "qkv_stacked":
+ attrs["qkv"] = func_args[0]
+ attrs["num_queries"] = s = annotations["num_queries"]
+ attrs["num_keys"] = annotations["num_keys"]
+ if len(func_args) > 5: # +1 for workspace, the last arg
+ attrs["bias"] = func_args[4]
+ else:
+ raise NotImplementedError()
+
attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
- attrs["num_queries"] = s = annotations["num_queries"]
- attrs["num_keys"] = annotations["num_keys"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
@@ -701,7 +718,7 @@ def instantiate_template(func_name, annotations, func_args):
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
- attrs["output_size"] = b * s * n * h_v
+ attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else
annotations["scale"]
)
@@ -712,24 +729,10 @@ def instantiate_template(func_name, annotations,
func_args):
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
- attrs["qkv_layout"] = annotations["qkv_layout"]
for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg
-
- if attrs["qkv_layout"] == "default":
- attrs["query"] = func_args[0]
- attrs["key"] = func_args[1]
- attrs["value"] = func_args[2]
- if len(func_args) > 4: # +1 for workspace, the last arg
- attrs["bias"] = func_args[3]
- elif attrs["qkv_layout"] == "qkv_stacked":
- attrs["qkv"] = func_args[0]
- if len(func_args) > 5: # +1 for workspace, the last arg
- attrs["bias"] = func_args[4]
- else:
- raise NotImplementedError()
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index dffd7c401c..bdd230f189 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -412,6 +412,10 @@ class WorkspaceAnnotator(PyExprMutator):
out_size_1d = _shape_1d(f.ret_struct_info.shape)
# This needs to be in sync with the actual value that the kernel
expects.
workspace_size_bytes = out_size_1d * {"float16": 2, "float32":
4}[out_dtype]
+ if not isinstance(workspace_size_bytes, (int,
tvm.tir.expr.IntImm)):
+ # Tempororay workaround for dynamic shape workload. Will be
removed when
+ # workspace for dynamic shape workload is implemented.
+ workspace_size_bytes = 8
return f.with_attr("WorkspaceSize", workspace_size_bytes)
return f
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 8a1675ad35..7bbdd630a5 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -25,9 +25,9 @@ from tvm.contrib.cutlass.build import
is_shape_valid_for_cutlass_matmul
from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.relax.testing import get_relax_matmul_module
-from tvm.script import tir as T
from tvm.script import ir as I
from tvm.script import relax as R
+from tvm.script import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -169,9 +169,16 @@ def get_relax_conv2d_module(
return tvm.IRModule({"main": func})
-def _to_concrete_shape(symbolic_shape, var_table):
+def _to_concrete_shape(symbolic_shape, var_table=None):
+ if var_table is None:
+ var_table = {}
+
result = []
for dim in symbolic_shape:
+ if isinstance(dim, tuple):
+ result.append(_to_concrete_shape(dim, var_table))
+ continue
+
if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue
@@ -543,6 +550,7 @@ def attention_dtype(request):
@pytest.fixture(
params=[
# B, S, N, H
+ (32, (_vars["a"], 8), 16, (8, 8)),
(32, (8, 8), 16, (8, 8)),
(4, (16, 8), 32, (8, 8)), # s != s_kv
(4, (16, 8), 32, (8, 16)), # h != h_v
@@ -554,9 +562,9 @@ def attention_size(request):
return request.param
-def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
- dtype = str(q.dtype)
-
+def get_relax_attention_module(
+ q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None,
causal_mask=None
+):
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
from tvm.script.ir_builder import tir as T
@@ -567,13 +575,15 @@ def get_relax_attention_module(q, k, v, bias=None,
qk_scale=None, causal=None):
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
- q = R.arg("q", R.Tensor(q.shape, dtype))
- k = R.arg("k", R.Tensor(k.shape, dtype))
- v = R.arg("v", R.Tensor(v.shape, dtype))
- if bias is not None:
- bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+ q = R.arg("q", R.Tensor(q_shape, dtype))
+ k = R.arg("k", R.Tensor(k_shape, dtype))
+ v = R.arg("v", R.Tensor(v_shape, dtype))
+ bias = None
+ if bias_shape is not None and bias_shape != "none":
+ bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+
with R.dataflow() as frame:
- result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal))
+ result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal_mask))
R.output(result)
R.func_ret_value(frame.output_vars[0])
@@ -620,11 +630,16 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v,
bias_shape, qk_scale, causal,
def test_attention_offload(attention_size, attention_dtype):
b, (s, s_kv), n, (h, h_v) = attention_size
+ concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv))
q, k, v, _, ref = get_numpy_attention_ref(
- b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
+ b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none",
attention_dtype
)
- mod = get_relax_attention_module(q, k, v)
+ q_shape = (b, s, n, h)
+ k_shape = (b, s_kv, n, h)
+ v_shape = (b, s_kv, n, h_v)
+
+ mod = get_relax_attention_module(q_shape, k_shape, v_shape,
dtype=attention_dtype)
out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -649,11 +664,19 @@ def attention_bias_size(request):
def test_attention_bias_offload(attention_bias_size):
b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size
+ concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s,
s_kv, bias_shape))
+
q, k, v, bias, ref = get_numpy_attention_ref(
- b, s, s_kv, n, h, h_v, bias_shape, "none", "none", "float32"
+ b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none",
"none", "float32"
)
- mod = get_relax_attention_module(q, k, v, bias)
+ q_shape = (b, s, n, h)
+ k_shape = (b, s_kv, n, h)
+ v_shape = (b, s_kv, n, h_v)
+
+ mod = get_relax_attention_module(
+ q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
+ )
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -681,7 +704,13 @@ def test_attention_scale_offload(attention_scale_size,
attention_scale):
b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32"
)
- mod = get_relax_attention_module(q, k, v, bias, attention_scale)
+ q_shape = (b, s, n, h)
+ k_shape = (b, s_kv, n, h)
+ v_shape = (b, s_kv, n, h_v)
+
+ mod = get_relax_attention_module(
+ q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape,
qk_scale=attention_scale
+ )
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
else:
@@ -712,7 +741,18 @@ def test_attention_causal_offload(attention_causal_size,
attention_causal):
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
)
- mod = get_relax_attention_module(q, k, v, bias, None, attention_causal)
+ q_shape = (b, s, n, h)
+ k_shape = (b, s_kv, n, h)
+ v_shape = (b, s_kv, n, h_v)
+
+ mod = get_relax_attention_module(
+ q_shape,
+ k_shape,
+ v_shape,
+ dtype="float32",
+ bias_shape=bias_shape,
+ causal_mask=attention_causal,
+ )
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
else: