This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ec127cb0d6 [Unity][BYOC] Make CUTLASS attention rewriting aware of 
fp16 <-> f32 casting  (#14957)
ec127cb0d6 is described below

commit ec127cb0d606036b8b9a4ad3c2ed8088b5f8728f
Author: masahi <[email protected]>
AuthorDate: Fri May 26 16:33:26 2023 +0900

    [Unity][BYOC] Make CUTLASS attention rewriting aware of fp16 <-> f32 
casting  (#14957)
    
    * fixed pattern rewriting logic in parititon_for_cutlass
    
    * update attention rewrite pattern
    
    * add test
    
    * add missing doc for scale
    
    * black
    
    * fix
    
    * lint
---
 python/tvm/relax/backend/contrib/cutlass.py |   9 ++-
 python/tvm/relax/backend/patterns.py        |  41 +++++++++--
 python/tvm/relax/op/nn/nn.py                |   3 +
 tests/python/relax/test_codegen_cutlass.py  | 102 ++++++++++++++++++++++++++++
 4 files changed, 146 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index f27e0d9e50..dd2739fd98 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -360,7 +360,10 @@ def attention_rewrite_patterns():
     for qkv_layout in ["BSNH", "BSH"]:
         for out_layout in ["BSNH", "BSH"]:
             for with_bias in [True, False]:
-                patterns.append(make_attention_rewrite_pattern(qkv_layout, 
out_layout, with_bias))
+                for with_cast in [True, False]:
+                    patterns.append(
+                        make_attention_rewrite_pattern(qkv_layout, out_layout, 
with_bias, with_cast)
+                    )
     return patterns
 
 
@@ -441,7 +444,9 @@ def partition_for_cutlass(mod, annotate_codegen=True):
     for func_name, func in mod.functions.items():
         if isinstance(func, Function):
             for pattern, rewriter in _REWRITE_PATTERNS:
-                mod[func_name] = rewrite_call(pattern, rewriter, func)
+                func = rewrite_call(pattern, rewriter, func)
+        mod[func_name] = func
+
     patterns = get_patterns_with_prefix("cutlass")
     return tvm.transform.Sequential(
         [
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 53c3879c3a..f7e8dd0406 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -263,7 +263,9 @@ def make_layer_norm_pattern():
     return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
 
 
-def make_attention_rewrite_pattern(qkv_layout: str, out_layout: str, 
with_bias: bool):
+def make_attention_rewrite_pattern(
+    qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
+):
     """
     Create pattern for implicit fused multi head attention rewriting.
 
@@ -276,7 +278,11 @@ def make_attention_rewrite_pattern(qkv_layout: str, 
out_layout: str, with_bias:
         The layout of the output tensor, i.e. BSNH or BSH.
 
     with_bias: bool
-        Whether or not to include bias addition
+        Whether or not to include bias addition.
+
+    with_cast: bool
+        Whether or not rewriting is intended to be applied to a module after 
the FP16 conversion
+        pass.
 
     Returns
     -------
@@ -378,14 +384,31 @@ def make_attention_rewrite_pattern(qkv_layout: str, 
out_layout: str, with_bias:
     v, v_rewriter = handle_input(v_raw, qkv_layout, False)
     matmul_1 = is_op("relax.matmul")(q, k)
     scale = is_const()
-    multiply = is_op("relax.multiply")(matmul_1, scale)
+
+    if with_cast:
+        multiply = is_op("relax.multiply")(matmul_1, 
is_op("relax.astype")(scale))
+    else:
+        multiply = is_op("relax.multiply")(matmul_1, scale)
+
     if with_bias:
         bias_raw = wildcard()
         add = is_op("relax.add")(multiply, bias_raw)
-        softmax = is_op("relax.nn.softmax")(add)
+        softmax_input = add
     else:
-        softmax = is_op("relax.nn.softmax")(multiply)
-    matmul_2 = is_op("relax.matmul")(softmax, v)
+        softmax_input = multiply
+
+    if with_cast:
+        softmax_input = is_op("relax.astype")(softmax_input)
+
+    softmax = is_op("relax.nn.softmax")(softmax_input)
+
+    if with_cast:
+        softmax_output = is_op("relax.astype")(softmax)
+    else:
+        softmax_output = softmax
+
+    matmul_2 = is_op("relax.matmul")(softmax_output, v)
+
     out, out_rewriter = handle_output(matmul_2, out_layout)
 
     def rewriter(original, matchings):
@@ -394,7 +417,11 @@ def make_attention_rewrite_pattern(qkv_layout: str, 
out_layout: str, with_bias:
         value, _ = v_rewriter(matchings, matchings[v_raw])
         if query is None or key is None or value is None:
             return original
-        if matchings[softmax].attrs.axis != -1:
+        softmax_axis = matchings[softmax].attrs.axis
+        softmax_input_rank = len(matchings[softmax].struct_info.shape)
+        if softmax_axis == -1:
+            softmax_axis += softmax_input_rank
+        if softmax_axis != softmax_input_rank - 1:
             return original
         b, s, n, _ = query_shape
         _, s_kv, _, _ = key_shape
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 92bb7c2042..be2f968501 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1036,6 +1036,9 @@ def attention(
         a 4-D tensor ending with seq_len_kv, and broadcastable to
         (batch_size, num_head, seq_len, seq_len_kv).
 
+    scale: Optional[float]
+        The scale value to be applied to the attention score, by default 1 / 
sqrt(head_dim).
+
     causal_mask: Optional[str]
         The optional causal mask, i.e. 'TopLeft' and 'BottomRight'.
         For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`,
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index c0a72fffb9..339b22ca9b 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -25,6 +25,8 @@ from tvm.contrib.cutlass.build import 
is_shape_valid_for_cutlass_matmul
 from tvm.contrib.pickle_memoize import memoize
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.relax.testing import get_relax_matmul_module
+from tvm.script import tir as T
+from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder import relax as relax_builder
@@ -1086,5 +1088,105 @@ def test_layer_norm(data_shape, dtype, axes):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_attention_rewrite_fp16():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+            k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+            v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+            bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
+        ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+            with R.dataflow():
+                lv = R.permute_dims(q, axes=[0, 2, 1, 3])
+                lv1 = R.reshape(lv, R.shape([128, 16, 8]))
+                lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
+                lv3 = R.reshape(lv2, R.shape([128, 8, 8]))
+                lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
+                lv5 = R.reshape(lv4, R.shape([128, 8, 16]))
+                lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
+                lv7 = R.matmul(lv1, lv6, out_dtype="float16")
+                lv3_1 = R.astype(R.const(0.5, "float32"), dtype="float16")
+                lv8 = R.multiply(lv7, lv3_1)
+                lv9 = R.reshape(bias, R.shape([128, 16, 8]))
+                lv10 = R.add(lv8, lv9)
+                lv10_fp16 = R.astype(lv10, dtype="float16")
+                lv11 = R.nn.softmax(lv10_fp16, axis=2)
+                lv5_1 = R.astype(lv11, dtype="float16")
+                lv12 = R.matmul(lv5_1, lv5, out_dtype="float16")
+                lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
+                lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
+                lv14 = R.astype(lv6_1, dtype="float32")
+                R.output(lv14)
+            return lv14
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_relax_nn_attention_bias_cutlass1(
+            q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+            k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+            v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+            lv1: R.Tensor((4, 32, 16, 8), dtype="float16"),
+            workspace: R.Tensor((65536,), dtype="uint8"),
+        ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+            R.func_attr(
+                {
+                    "Codegen": "cutlass",
+                    "WorkspaceSize": T.int64(65536),
+                    "global_symbol": "fused_relax_nn_attention_bias_cutlass1",
+                }
+            )
+
+            @R.function
+            def gv_1(
+                q_1: R.Tensor((4, 16, 32, 8), dtype="float16"),
+                k_1: R.Tensor((4, 8, 32, 8), dtype="float16"),
+                v_1: R.Tensor((4, 8, 32, 16), dtype="float16"),
+                lv1_1: R.Tensor((4, 32, 16, 8), dtype="float16"),
+                workspace_1: R.Tensor((65536,), dtype="uint8"),
+            ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+                R.func_attr(
+                    {
+                        "Composite": "cutlass.attention_bias",
+                        "Primitive": 1,
+                        "WorkspaceSize": T.int64(65536),
+                    }
+                )
+                with R.dataflow():
+                    gv_2 = R.nn.attention(
+                        q_1, k_1, v_1, lv1_1, scale=T.float32(0.5), 
causal_mask=None
+                    )
+                    R.output(gv_2)
+                return gv_2
+
+            gv1: R.Tensor((4, 16, 32, 16), dtype="float16") = gv_1(q, k, v, 
lv1, workspace)
+            return gv1
+
+        @R.function
+        def main(
+            q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+            k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+            v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+            bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
+        ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+            cls = Expected
+            with R.dataflow():
+                lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), 
R.dtype("uint8"))
+                workspace_main = R.vm.alloc_tensor(
+                    lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+                )
+                lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
+                lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
+                lv_2 = cls.fused_relax_nn_attention_bias_cutlass1(q, k, v, 
lv1, workspace_main)
+                lv14 = R.astype(lv_2, dtype="float32")
+                R.output(lv14)
+            return lv14
+
+    mod = partition_for_cutlass(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to