This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7dd76cca03 [Unity] Cutlass attention with dynamic sequence length 
(#15028)
7dd76cca03 is described below

commit 7dd76cca032909dec6306b3339ed83b884fbd2ea
Author: Lite Ye <[email protected]>
AuthorDate: Tue Jun 6 14:48:53 2023 -0400

    [Unity] Cutlass attention with dynamic sequence length (#15028)
    
    * Dynamic attention
    
    * Fix lint
---
 python/tvm/contrib/cutlass/attention_operation.py | 17 +++++-
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 37 ++++++------
 python/tvm/relax/backend/contrib/cutlass.py       |  4 ++
 tests/python/relax/test_codegen_cutlass.py        | 74 +++++++++++++++++------
 4 files changed, 97 insertions(+), 35 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 8a96e70fe4..55c1ccd616 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -96,9 +96,20 @@ def instantiate_attention_template(attrs):
   typename Attention::Params p;
   p.logsumexp_ptr = nullptr;
   p.output_ptr = reinterpret_cast<T *>(out0->data);
+
   p.output_accum_ptr = nullptr;
+  uint64_t accumulator_buf_size = ${output_size} * 
sizeof(Attention::output_accum_t);
+  bool accumulator_buf_allocated = false;
   if (Attention::kNeedsOutputAccumulatorBuffer) {
-    p.output_accum_ptr = static_cast<float*>(${workspace}->data);
+    if (accumulator_buf_size <= ${workspace}->shape[0]) {
+        p.output_accum_ptr = static_cast<float*>(${workspace}->data);
+    } else {
+        accumulator_buf_size = true;
+        cudaMalloc(
+          &p.output_accum_ptr,
+          accumulator_buf_size
+        );
+    }
   }
 
   p.num_heads = ${num_heads}; // N
@@ -129,6 +140,10 @@ def instantiate_attention_template(attrs):
 
   CHECK(Attention::check_supported(p));
   kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+
+  if (accumulator_buf_allocated) {
+    cudaFree(p.output_accum_ptr);
+  }
 """
 
     template = substitute_template(
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index be3bb289cf..48e285998a 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -679,10 +679,27 @@ def instantiate_template(func_name, annotations, 
func_args):
     elif "attention" in func_name:
         headers.append("kernel_forward.h")
         data_type = dtype_map[annotations["arg0_dtype"]]
+
+        attrs["qkv_layout"] = annotations["qkv_layout"]
+        if attrs["qkv_layout"] == "default":
+            attrs["query"] = func_args[0]
+            attrs["key"] = func_args[1]
+            attrs["value"] = func_args[2]
+            attrs["num_queries"] = s = get_dim(annotations["num_queries"], 
func_args[0], 1)
+            attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1], 
1)
+            if len(func_args) > 4:  # +1 for workspace, the last arg
+                attrs["bias"] = func_args[3]
+        elif attrs["qkv_layout"] == "qkv_stacked":
+            attrs["qkv"] = func_args[0]
+            attrs["num_queries"] = s = annotations["num_queries"]
+            attrs["num_keys"] = annotations["num_keys"]
+            if len(func_args) > 5:  # +1 for workspace, the last arg
+                attrs["bias"] = func_args[4]
+        else:
+            raise NotImplementedError()
+
         attrs["data_type"] = DataTypeTag[data_type]
         attrs["num_batches"] = b = annotations["num_batches"]
-        attrs["num_queries"] = s = annotations["num_queries"]
-        attrs["num_keys"] = annotations["num_keys"]
         attrs["num_heads"] = n = annotations["num_heads"]
         attrs["head_dim"] = h = annotations["head_dim"]
         attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
@@ -701,7 +718,7 @@ def instantiate_template(func_name, annotations, func_args):
             attrs["kQueriesPerBlock"] = 64
             attrs["kKeysPerBlock"] = 64
             attrs["kSingleValueIteration"] = True
-        attrs["output_size"] = b * s * n * h_v
+        attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
         attrs["scale"] = (
             float(1 / math.sqrt(h.value)) if annotations["scale"] is None else 
annotations["scale"]
         )
@@ -712,24 +729,10 @@ def instantiate_template(func_name, annotations, 
func_args):
         ), "Cutlass may generate nan occasionally when scale == 0.0"
         attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
         attrs["kSupportsDropout"] = False
-        attrs["qkv_layout"] = annotations["qkv_layout"]
 
         for arg in func_args:
             if "workspace" in arg:
                 attrs["workspace"] = arg
-
-        if attrs["qkv_layout"] == "default":
-            attrs["query"] = func_args[0]
-            attrs["key"] = func_args[1]
-            attrs["value"] = func_args[2]
-            if len(func_args) > 4:  # +1 for workspace, the last arg
-                attrs["bias"] = func_args[3]
-        elif attrs["qkv_layout"] == "qkv_stacked":
-            attrs["qkv"] = func_args[0]
-            if len(func_args) > 5:  # +1 for workspace, the last arg
-                attrs["bias"] = func_args[4]
-        else:
-            raise NotImplementedError()
         if "bias" in attrs:
             attrs["kSupportsBias"] = True
             if len(annotations["bias_shape"]) == 4:
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index dffd7c401c..bdd230f189 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -412,6 +412,10 @@ class WorkspaceAnnotator(PyExprMutator):
             out_size_1d = _shape_1d(f.ret_struct_info.shape)
             # This needs to be in sync with the actual value that the kernel 
expects.
             workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 
4}[out_dtype]
+            if not isinstance(workspace_size_bytes, (int, 
tvm.tir.expr.IntImm)):
+                # Tempororay workaround for dynamic shape workload. Will be 
removed when
+                # workspace for dynamic shape workload is implemented.
+                workspace_size_bytes = 8
             return f.with_attr("WorkspaceSize", workspace_size_bytes)
 
         return f
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 8a1675ad35..7bbdd630a5 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -25,9 +25,9 @@ from tvm.contrib.cutlass.build import 
is_shape_valid_for_cutlass_matmul
 from tvm.contrib.pickle_memoize import memoize
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.relax.testing import get_relax_matmul_module
-from tvm.script import tir as T
 from tvm.script import ir as I
 from tvm.script import relax as R
+from tvm.script import tir as T
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder import relax as relax_builder
 
@@ -169,9 +169,16 @@ def get_relax_conv2d_module(
     return tvm.IRModule({"main": func})
 
 
-def _to_concrete_shape(symbolic_shape, var_table):
+def _to_concrete_shape(symbolic_shape, var_table=None):
+    if var_table is None:
+        var_table = {}
+
     result = []
     for dim in symbolic_shape:
+        if isinstance(dim, tuple):
+            result.append(_to_concrete_shape(dim, var_table))
+            continue
+
         if not isinstance(dim, tvm.tir.expr.Var):
             result.append(dim)
             continue
@@ -543,6 +550,7 @@ def attention_dtype(request):
 @pytest.fixture(
     params=[
         # B, S, N, H
+        (32, (_vars["a"], 8), 16, (8, 8)),
         (32, (8, 8), 16, (8, 8)),
         (4, (16, 8), 32, (8, 8)),  # s != s_kv
         (4, (16, 8), 32, (8, 16)),  # h != h_v
@@ -554,9 +562,9 @@ def attention_size(request):
     return request.param
 
 
-def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
-    dtype = str(q.dtype)
-
+def get_relax_attention_module(
+    q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None, 
causal_mask=None
+):
     from tvm.script.ir_builder import IRBuilder
     from tvm.script.ir_builder import relax as relax_builder
     from tvm.script.ir_builder import tir as T
@@ -567,13 +575,15 @@ def get_relax_attention_module(q, k, v, bias=None, 
qk_scale=None, causal=None):
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
-            q = R.arg("q", R.Tensor(q.shape, dtype))
-            k = R.arg("k", R.Tensor(k.shape, dtype))
-            v = R.arg("v", R.Tensor(v.shape, dtype))
-            if bias is not None:
-                bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+            q = R.arg("q", R.Tensor(q_shape, dtype))
+            k = R.arg("k", R.Tensor(k_shape, dtype))
+            v = R.arg("v", R.Tensor(v_shape, dtype))
+            bias = None
+            if bias_shape is not None and bias_shape != "none":
+                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+
             with R.dataflow() as frame:
-                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal))
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal_mask))
                 R.output(result)
 
             R.func_ret_value(frame.output_vars[0])
@@ -620,11 +630,16 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, 
bias_shape, qk_scale, causal,
 
 def test_attention_offload(attention_size, attention_dtype):
     b, (s, s_kv), n, (h, h_v) = attention_size
+    concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv))
     q, k, v, _, ref = get_numpy_attention_ref(
-        b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
+        b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none", 
attention_dtype
     )
 
-    mod = get_relax_attention_module(q, k, v)
+    q_shape = (b, s, n, h)
+    k_shape = (b, s_kv, n, h)
+    v_shape = (b, s_kv, n, h_v)
+
+    mod = get_relax_attention_module(q_shape, k_shape, v_shape, 
dtype=attention_dtype)
     out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -649,11 +664,19 @@ def attention_bias_size(request):
 
 def test_attention_bias_offload(attention_bias_size):
     b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size
+    concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s, 
s_kv, bias_shape))
+
     q, k, v, bias, ref = get_numpy_attention_ref(
-        b, s, s_kv, n, h, h_v, bias_shape, "none", "none", "float32"
+        b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none", 
"none", "float32"
     )
 
-    mod = get_relax_attention_module(q, k, v, bias)
+    q_shape = (b, s, n, h)
+    k_shape = (b, s_kv, n, h)
+    v_shape = (b, s_kv, n, h_v)
+
+    mod = get_relax_attention_module(
+        q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
+    )
     out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -681,7 +704,13 @@ def test_attention_scale_offload(attention_scale_size, 
attention_scale):
         b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32"
     )
 
-    mod = get_relax_attention_module(q, k, v, bias, attention_scale)
+    q_shape = (b, s, n, h)
+    k_shape = (b, s_kv, n, h)
+    v_shape = (b, s_kv, n, h_v)
+
+    mod = get_relax_attention_module(
+        q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, 
qk_scale=attention_scale
+    )
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
     else:
@@ -712,7 +741,18 @@ def test_attention_causal_offload(attention_causal_size, 
attention_causal):
         b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
     )
 
-    mod = get_relax_attention_module(q, k, v, bias, None, attention_causal)
+    q_shape = (b, s, n, h)
+    k_shape = (b, s_kv, n, h)
+    v_shape = (b, s_kv, n, h_v)
+
+    mod = get_relax_attention_module(
+        q_shape,
+        k_shape,
+        v_shape,
+        dtype="float32",
+        bias_shape=bias_shape,
+        causal_mask=attention_causal,
+    )
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
     else:

Reply via email to