This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ec127cb0d6 [Unity][BYOC] Make CUTLASS attention rewriting aware of
fp16 <-> f32 casting (#14957)
ec127cb0d6 is described below
commit ec127cb0d606036b8b9a4ad3c2ed8088b5f8728f
Author: masahi <[email protected]>
AuthorDate: Fri May 26 16:33:26 2023 +0900
[Unity][BYOC] Make CUTLASS attention rewriting aware of fp16 <-> f32
casting (#14957)
* fixed pattern rewriting logic in parititon_for_cutlass
* update attention rewrite pattern
* add test
* add missing doc for scale
* black
* fix
* lint
---
python/tvm/relax/backend/contrib/cutlass.py | 9 ++-
python/tvm/relax/backend/patterns.py | 41 +++++++++--
python/tvm/relax/op/nn/nn.py | 3 +
tests/python/relax/test_codegen_cutlass.py | 102 ++++++++++++++++++++++++++++
4 files changed, 146 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index f27e0d9e50..dd2739fd98 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -360,7 +360,10 @@ def attention_rewrite_patterns():
for qkv_layout in ["BSNH", "BSH"]:
for out_layout in ["BSNH", "BSH"]:
for with_bias in [True, False]:
- patterns.append(make_attention_rewrite_pattern(qkv_layout,
out_layout, with_bias))
+ for with_cast in [True, False]:
+ patterns.append(
+ make_attention_rewrite_pattern(qkv_layout, out_layout,
with_bias, with_cast)
+ )
return patterns
@@ -441,7 +444,9 @@ def partition_for_cutlass(mod, annotate_codegen=True):
for func_name, func in mod.functions.items():
if isinstance(func, Function):
for pattern, rewriter in _REWRITE_PATTERNS:
- mod[func_name] = rewrite_call(pattern, rewriter, func)
+ func = rewrite_call(pattern, rewriter, func)
+ mod[func_name] = func
+
patterns = get_patterns_with_prefix("cutlass")
return tvm.transform.Sequential(
[
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 53c3879c3a..f7e8dd0406 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -263,7 +263,9 @@ def make_layer_norm_pattern():
return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
-def make_attention_rewrite_pattern(qkv_layout: str, out_layout: str,
with_bias: bool):
+def make_attention_rewrite_pattern(
+ qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
+):
"""
Create pattern for implicit fused multi head attention rewriting.
@@ -276,7 +278,11 @@ def make_attention_rewrite_pattern(qkv_layout: str,
out_layout: str, with_bias:
The layout of the output tensor, i.e. BSNH or BSH.
with_bias: bool
- Whether or not to include bias addition
+ Whether or not to include bias addition.
+
+ with_cast: bool
+ Whether or not rewriting is intended to be applied to a module after
the FP16 conversion
+ pass.
Returns
-------
@@ -378,14 +384,31 @@ def make_attention_rewrite_pattern(qkv_layout: str,
out_layout: str, with_bias:
v, v_rewriter = handle_input(v_raw, qkv_layout, False)
matmul_1 = is_op("relax.matmul")(q, k)
scale = is_const()
- multiply = is_op("relax.multiply")(matmul_1, scale)
+
+ if with_cast:
+ multiply = is_op("relax.multiply")(matmul_1,
is_op("relax.astype")(scale))
+ else:
+ multiply = is_op("relax.multiply")(matmul_1, scale)
+
if with_bias:
bias_raw = wildcard()
add = is_op("relax.add")(multiply, bias_raw)
- softmax = is_op("relax.nn.softmax")(add)
+ softmax_input = add
else:
- softmax = is_op("relax.nn.softmax")(multiply)
- matmul_2 = is_op("relax.matmul")(softmax, v)
+ softmax_input = multiply
+
+ if with_cast:
+ softmax_input = is_op("relax.astype")(softmax_input)
+
+ softmax = is_op("relax.nn.softmax")(softmax_input)
+
+ if with_cast:
+ softmax_output = is_op("relax.astype")(softmax)
+ else:
+ softmax_output = softmax
+
+ matmul_2 = is_op("relax.matmul")(softmax_output, v)
+
out, out_rewriter = handle_output(matmul_2, out_layout)
def rewriter(original, matchings):
@@ -394,7 +417,11 @@ def make_attention_rewrite_pattern(qkv_layout: str,
out_layout: str, with_bias:
value, _ = v_rewriter(matchings, matchings[v_raw])
if query is None or key is None or value is None:
return original
- if matchings[softmax].attrs.axis != -1:
+ softmax_axis = matchings[softmax].attrs.axis
+ softmax_input_rank = len(matchings[softmax].struct_info.shape)
+ if softmax_axis == -1:
+ softmax_axis += softmax_input_rank
+ if softmax_axis != softmax_input_rank - 1:
return original
b, s, n, _ = query_shape
_, s_kv, _, _ = key_shape
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 92bb7c2042..be2f968501 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1036,6 +1036,9 @@ def attention(
a 4-D tensor ending with seq_len_kv, and broadcastable to
(batch_size, num_head, seq_len, seq_len_kv).
+ scale: Optional[float]
+ The scale value to be applied to the attention score, by default 1 /
sqrt(head_dim).
+
causal_mask: Optional[str]
The optional causal mask, i.e. 'TopLeft' and 'BottomRight'.
For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`,
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index c0a72fffb9..339b22ca9b 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -25,6 +25,8 @@ from tvm.contrib.cutlass.build import
is_shape_valid_for_cutlass_matmul
from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.relax.testing import get_relax_matmul_module
+from tvm.script import tir as T
+from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -1086,5 +1088,105 @@ def test_layer_norm(data_shape, dtype, axes):
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_attention_rewrite_fp16():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+ k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+ v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+ bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
+ ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ with R.dataflow():
+ lv = R.permute_dims(q, axes=[0, 2, 1, 3])
+ lv1 = R.reshape(lv, R.shape([128, 16, 8]))
+ lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
+ lv3 = R.reshape(lv2, R.shape([128, 8, 8]))
+ lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
+ lv5 = R.reshape(lv4, R.shape([128, 8, 16]))
+ lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
+ lv7 = R.matmul(lv1, lv6, out_dtype="float16")
+ lv3_1 = R.astype(R.const(0.5, "float32"), dtype="float16")
+ lv8 = R.multiply(lv7, lv3_1)
+ lv9 = R.reshape(bias, R.shape([128, 16, 8]))
+ lv10 = R.add(lv8, lv9)
+ lv10_fp16 = R.astype(lv10, dtype="float16")
+ lv11 = R.nn.softmax(lv10_fp16, axis=2)
+ lv5_1 = R.astype(lv11, dtype="float16")
+ lv12 = R.matmul(lv5_1, lv5, out_dtype="float16")
+ lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
+ lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
+ lv14 = R.astype(lv6_1, dtype="float32")
+ R.output(lv14)
+ return lv14
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def fused_relax_nn_attention_bias_cutlass1(
+ q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+ k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+ v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+ lv1: R.Tensor((4, 32, 16, 8), dtype="float16"),
+ workspace: R.Tensor((65536,), dtype="uint8"),
+ ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ R.func_attr(
+ {
+ "Codegen": "cutlass",
+ "WorkspaceSize": T.int64(65536),
+ "global_symbol": "fused_relax_nn_attention_bias_cutlass1",
+ }
+ )
+
+ @R.function
+ def gv_1(
+ q_1: R.Tensor((4, 16, 32, 8), dtype="float16"),
+ k_1: R.Tensor((4, 8, 32, 8), dtype="float16"),
+ v_1: R.Tensor((4, 8, 32, 16), dtype="float16"),
+ lv1_1: R.Tensor((4, 32, 16, 8), dtype="float16"),
+ workspace_1: R.Tensor((65536,), dtype="uint8"),
+ ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ R.func_attr(
+ {
+ "Composite": "cutlass.attention_bias",
+ "Primitive": 1,
+ "WorkspaceSize": T.int64(65536),
+ }
+ )
+ with R.dataflow():
+ gv_2 = R.nn.attention(
+ q_1, k_1, v_1, lv1_1, scale=T.float32(0.5),
causal_mask=None
+ )
+ R.output(gv_2)
+ return gv_2
+
+ gv1: R.Tensor((4, 16, 32, 16), dtype="float16") = gv_1(q, k, v,
lv1, workspace)
+ return gv1
+
+ @R.function
+ def main(
+ q: R.Tensor((4, 16, 32, 8), dtype="float16"),
+ k: R.Tensor((4, 8, 32, 8), dtype="float16"),
+ v: R.Tensor((4, 8, 32, 16), dtype="float16"),
+ bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
+ ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ cls = Expected
+ with R.dataflow():
+ lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0),
R.dtype("uint8"))
+ workspace_main = R.vm.alloc_tensor(
+ lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+ )
+ lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
+ lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
+ lv_2 = cls.fused_relax_nn_attention_bias_cutlass1(q, k, v,
lv1, workspace_main)
+ lv14 = R.astype(lv_2, dtype="float32")
+ R.output(lv14)
+ return lv14
+
+ mod = partition_for_cutlass(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()