tqchen commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120970995


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -438,5 +439,79 @@ def test_matmul_transposed_bias_gelu_offload(matmul_x, 
matmul_y, matmul_bias):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
 
 
[email protected](
+    params=[
+        # B, S, N, H
+        (32, (4, 4), 16, (8, 8)),
+        (4, (8, 4), 32, (8, 8)),
+        (4, (8, 4), 32, (8, 16)),
+    ]
+)
+def attention_size(request):
+    return request.param
+
+
[email protected]
+def attention_q(attention_size, target_dtype):
+    b, (s, _), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
[email protected]
+def attention_k(attention_size, target_dtype):
+    b, (_, s), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
[email protected]
+def attention_v(attention_size, target_dtype):
+    b, (_, s), n, (_, h) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+def get_relax_attention_module(q, k, v):
+    dtype = str(q.dtype)
+
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q.shape, dtype))
+            k = R.arg("k", R.Tensor(k.shape, dtype))
+            v = R.arg("v", R.Tensor(v.shape, dtype))
+
+            with R.dataflow() as frame:
+                result = R.emit(R.nn.attention(q, k, v))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+def get_numpy_attention_ref(q, k, v):

Review Comment:
   use tvm.contrib.memoize so we don't need to run it multiple times



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to