masahi commented on code in PR #15837:
URL: https://github.com/apache/tvm/pull/15837#discussion_r1344611811


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1992,5 +1992,107 @@ def main(
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_batched_var_len_attention():
+    if not tvm.get_global_func("tvm.contrib.thrust.sum_scan", True):
+        return
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            keys: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            values: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            seq_lens: R.Tensor(("num_seq",), dtype="int32"),
+        ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
+            cls = Module
+            num_tokens = T.int64()
+            num_seq = T.int64()
+
+            with R.dataflow():
+                # TODO(masahi): Workaround for the broken Relax cumsum op on 
GPU.
+                # https://github.com/apache/tvm/issues/15851
+                cumsum = R.call_dps_packed(
+                    "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
+                )
+                max_seqlen_q = R.max(seq_lens)
+                seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
+                q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
+                k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
+                v = R.reshape(values, R.shape([1, num_tokens, 128, 32]))
+                attn_out = R.nn.attention_var_len(
+                    q,
+                    k,
+                    v,
+                    seqstart_q,
+                    max_seqlen_q,
+                    causal_mask="BottomRight",
+                )
+                out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
+                R.output(out)
+            return out
+
+    seq_lens = [5, 3, 8]
+    num_head = 128
+    head_size = 32
+    hidden_size = num_head * head_size
+
+    batched_queries = []
+    batched_keys = []
+    batched_values = []
+    batched_refs = []
+
+    for s in seq_lens:
+        q, k, v, _, ref = get_numpy_attention_ref(
+            1, s, s, num_head, head_size, head_size, "none", "none", 
"BottomRight", "float16"
+        )
+        batched_queries.append(np.reshape(q, [-1, hidden_size]))
+        batched_keys.append(np.reshape(k, [-1, hidden_size]))
+        batched_values.append(np.reshape(v, [-1, hidden_size]))
+        batched_refs.append(np.reshape(ref, [-1, hidden_size]))
+
+    batched_queries = np.vstack(batched_queries)
+    batched_keys = np.vstack(batched_keys)
+    batched_values = np.vstack(batched_values)
+    ref = np.vstack(batched_refs)
+
+    mod = partition_for_cutlass(Module)
+    codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
+    mod = codegen_pass(mod)
+
+    with tvm.target.Target("cuda"):
+        mod = relax.transform.LegalizeOps()(mod)
+        mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+    out = build_and_run(
+        mod,
+        [
+            batched_queries,
+            batched_keys,
+            batched_values,
+            np.array(seq_lens, dtype="int32"),
+        ],
+        "cuda",
+    )
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+    ############# xformer reference for verification #############
+
+    # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)

Review Comment:
   yes, it depends on torch and xformer.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to