This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2bd506fa92 [Unity][BYOC] Support implicit attention patterns (#14744)
2bd506fa92 is described below
commit 2bd506fa92a37f6a3ca8e9fdd4fcf132f694d93f
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Apr 28 22:24:57 2023 -0700
[Unity][BYOC] Support implicit attention patterns (#14744)
* [Unity][BYOC] Support implicit attention patterns
In this PR, we add the support for the implicit attention pattern matching
in BYOC. By adding a rewriting stage before `FuseOpsByPattern`, we are able to
match and rewrite the patterns like `matmul - softmax - matmul` to
`R.nn.attention` with other operators needed.X
* fix lint
---
python/tvm/relax/backend/contrib/cutlass.py | 19 ++-
python/tvm/relax/backend/patterns.py | 173 +++++++++++++++++++++++++++-
tests/python/relax/test_codegen_cutlass.py | 169 +++++++++++++++++++++++++++
3 files changed, 358 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 86eab773cc..36f43c6c21 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -22,6 +22,7 @@ from typing import Mapping, Sequence
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
from tvm.relax import DataflowVar, Var, transform, Call
from tvm.relax.transform import PatternCheckContext
+from tvm.relax.dpl import rewrite_call
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import (
@@ -31,6 +32,7 @@ from ..patterns import (
make_residual_block_pattern,
make_stacked_attention_pattern,
make_layer_norm_pattern,
+ make_attention_rewrite_pattern,
)
@@ -346,6 +348,18 @@ def layer_norm_pattern():
]
+def attention_rewrite_patterns():
+ """
+ Returns a list of all attention rewriting patterns in cutlass BYOC backend.
+ """
+ patterns = []
+ for qkv_layout in ["BSNH", "BSH"]:
+ for out_layout in ["BSNH", "BSH"]:
+ for with_bias in [True, False]:
+ patterns.append(make_attention_rewrite_pattern(qkv_layout,
out_layout, with_bias))
+ return patterns
+
+
register_patterns(
[
*conv2d_patterns(),
@@ -356,6 +370,8 @@ register_patterns(
]
)
+_REWRITE_PATTERNS = [*attention_rewrite_patterns()]
+
def partition_for_cutlass(mod, annotate_codegen=True):
"""
@@ -377,7 +393,8 @@ def partition_for_cutlass(mod, annotate_codegen=True):
The resulting IRModule, containing partitioned subgraphs to be
compiled by the CUTLASS backend.
"""
-
+ for pattern, rewriter in _REWRITE_PATTERNS:
+ mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
patterns = get_patterns_with_prefix("cutlass")
return transform.FuseOpsByPattern(
patterns, bind_constants=False, annotate_codegen=annotate_codegen
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 7242cb3f0d..53c3879c3a 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -18,8 +18,8 @@
"""Common patterns used in BYOC"""
from typing import Dict, Mapping, Tuple, Union
-
-from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard
+from tvm.script import relax as R, tir as T
+from tvm.relax.dpl.pattern import DFPattern, is_const, is_op,
is_tuple_get_item, wildcard
def _with_bias_activation_pattern(
@@ -261,3 +261,172 @@ def make_layer_norm_pattern():
beta = wildcard()
return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
+
+
+def make_attention_rewrite_pattern(qkv_layout: str, out_layout: str,
with_bias: bool):
+ """
+ Create pattern for implicit fused multi head attention rewriting.
+
+ Parameters
+ ----------
+ qkv_layout: str
+ The layout of the query, key and value tensor, i.e. BSNH or BSH.
+
+ out_layout: str
+ The layout of the output tensor, i.e. BSNH or BSH.
+
+ with_bias: bool
+ Whether or not to include bias addition
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing an implicit fused multi head
attention.
+
+ rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+ The rewriter for the pattern. It will check the matched patterns, and
rewrite.
+ If the matched pattern is not able to be rewritten to
`R.nn.attention`, the rewriter
+ returns the original IR.
+ """
+
+ # pylint: disable=invalid-name
+ def handle_input(tensor, layout, transpose):
+ if layout == "BSNH":
+ permuted = is_op("relax.permute_dims")(tensor)
+ shape = wildcard()
+ reshaped = is_op("relax.reshape")(permuted, shape)
+ if transpose:
+ transposed = is_op("relax.permute_dims")(reshaped)
+
+ def rewriter(matchings, x):
+ if matchings[tensor].struct_info.ndim != 4:
+ return None
+ if list(matchings[permuted].attrs.axes) != [0, 2, 1, 3]:
+ return None
+ before_reshape = matchings[permuted].struct_info.shape.values
+ after_reshape = matchings[shape].struct_info.values
+ if not (
+ len(before_reshape) == 4
+ and len(after_reshape) == 3
+ and before_reshape[-2:] == after_reshape[-2:]
+ ):
+ return None
+ if transpose and list(matchings[transposed].attrs.axes) != [0,
2, 1]:
+ return None
+ return x, x.struct_info.shape
+
+ if transpose:
+ return transposed, rewriter
+ else:
+ return reshaped, rewriter
+ elif layout == "BSH":
+ if transpose:
+ transposed = is_op("relax.permute_dims")(tensor)
+
+ def rewriter(matchings, x):
+ if matchings[tensor].struct_info.ndim != 3:
+ return None
+ if transpose and list(matchings[transposed].attrs.axes) != [0,
2, 1]:
+ return None
+ before_reshape = x.struct_info.shape.values
+ after_reshape = [before_reshape[0], before_reshape[1], 1,
before_reshape[2]]
+ return R.reshape(x, after_reshape), after_reshape
+
+ if transpose:
+ return transposed, rewriter
+ else:
+ return tensor, rewriter
+ else:
+ raise NotImplementedError()
+
+ def handle_output(tensor, layout):
+ if layout == "BSNH":
+ shape = wildcard()
+ reshaped = is_op("relax.reshape")(tensor, shape)
+ permuted = is_op("relax.permute_dims")(reshaped)
+
+ def rewriter(matchings, x):
+ if matchings[tensor].struct_info.ndim != 3:
+ return None
+ before_reshape = matchings[tensor].struct_info.shape.values
+ after_reshape = matchings[shape].struct_info.values
+ if not (
+ len(before_reshape) == 3
+ and len(after_reshape) == 4
+ and before_reshape[-2:] == after_reshape[-2:]
+ ):
+ return None
+ if list(matchings[permuted].attrs.axes) != [0, 2, 1, 3]:
+ return None
+ return x
+
+ return permuted, rewriter
+ elif layout == "BSH":
+
+ def rewriter(matchings, x):
+ if matchings[tensor].struct_info.ndim != 3:
+ return None
+ return R.reshape(x, matchings[tensor].struct_info.shape.values)
+
+ return tensor, rewriter
+ else:
+ raise NotImplementedError()
+
+ q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
+ q, q_rewriter = handle_input(q_raw, qkv_layout, False)
+ k, k_rewriter = handle_input(k_raw, qkv_layout, True)
+ v, v_rewriter = handle_input(v_raw, qkv_layout, False)
+ matmul_1 = is_op("relax.matmul")(q, k)
+ scale = is_const()
+ multiply = is_op("relax.multiply")(matmul_1, scale)
+ if with_bias:
+ bias_raw = wildcard()
+ add = is_op("relax.add")(multiply, bias_raw)
+ softmax = is_op("relax.nn.softmax")(add)
+ else:
+ softmax = is_op("relax.nn.softmax")(multiply)
+ matmul_2 = is_op("relax.matmul")(softmax, v)
+ out, out_rewriter = handle_output(matmul_2, out_layout)
+
+ def rewriter(original, matchings):
+ query, query_shape = q_rewriter(matchings, matchings[q_raw])
+ key, key_shape = k_rewriter(matchings, matchings[k_raw])
+ value, _ = v_rewriter(matchings, matchings[v_raw])
+ if query is None or key is None or value is None:
+ return original
+ if matchings[softmax].attrs.axis != -1:
+ return original
+ b, s, n, _ = query_shape
+ _, s_kv, _, _ = key_shape
+ if with_bias:
+ bias = matchings[bias_raw]
+ bias_shape = list(bias.struct_info.shape)
+ if bias_shape == [b * n, s, s_kv]:
+ bias = R.reshape(bias, [b, n, s, s_kv])
+ elif bias_shape == [b * n, 1, s_kv]:
+ bias = R.reshape(bias, [b, n, 1, s_kv])
+ elif bias_shape == [b, s, s_kv]:
+ bias = R.reshape(bias, [b, 1, s, s_kv])
+ elif bias_shape == [b, 1, s_kv]:
+ bias = R.reshape(bias, [b, 1, 1, s_kv])
+ elif bias_shape in [[1, s, s_kv], [s, s_kv]]:
+ bias = R.reshape(bias, [1, 1, s, s_kv])
+ elif bias_shape in [[1, 1, s_kv], [1, s_kv], [s_kv]]:
+ bias = R.reshape(bias, [1, 1, 1, s_kv])
+ else:
+ return original
+ else:
+ bias = None
+ out = out_rewriter(
+ matchings,
+ R.nn.attention(
+ query,
+ key,
+ value,
+ bias,
+ T.FloatImm(matchings[scale].data.dtype,
float(matchings[scale].data.numpy())),
+ ),
+ )
+ return out
+
+ return out, rewriter
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 85f96b5e96..a2c5b4da96 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -761,6 +761,175 @@ def
test_stacked_attention_strided_slice_offload(stacked_attention_size):
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
[email protected](
+ params=[
+ # B, S, N, H, bias_shape, scale
+ (4, (16, 8), 32, (8, 16), "none", 0.5),
+ (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), 0.5),
+ (4, (16, 8), "none", (8, 16), "none", 0.5),
+ (4, (16, 8), "none", (8, 16), (4, 32, 16, 8), 0.5),
+ ]
+)
+def attention_rewrite_size(request):
+ return request.param
+
+
+def get_relax_attention_rewrite_module(
+ q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None
+):
+ from tvm.script.ir_builder import IRBuilder
+ from tvm.script.ir_builder import relax as relax_builder, tir as T
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ q = R.arg("q", R.Tensor(q_shape, dtype))
+ k = R.arg("k", R.Tensor(k_shape, dtype))
+ v = R.arg("v", R.Tensor(v_shape, dtype))
+ if bias_shape is not None:
+ bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+ with R.dataflow() as frame:
+ if len(q_shape) == 4:
+ q = R.emit(R.permute_dims(q, axes=[0, 2, 1, 3]))
+ q = R.emit(R.reshape(q, [q_shape[0] * q_shape[2],
q_shape[1], q_shape[3]]))
+
+ if len(k_shape) == 4:
+ k = R.emit(R.permute_dims(k, axes=[0, 2, 1, 3]))
+ k = R.emit(R.reshape(k, [k_shape[0] * k_shape[2],
k_shape[1], k_shape[3]]))
+
+ if len(v_shape) == 4:
+ v = R.emit(R.permute_dims(v, axes=[0, 2, 1, 3]))
+ v = R.emit(R.reshape(v, [v_shape[0] * v_shape[2],
v_shape[1], v_shape[3]]))
+
+ k = R.emit(R.permute_dims(k, axes=[0, 2, 1]))
+ qk = R.emit(R.matmul(q, k))
+ qk_scaled = R.emit(R.multiply(qk, R.const(scale, "float32")))
+ if bias_shape is not None:
+ if len(bias_shape) == 4:
+ bias = R.emit(
+ R.reshape(bias, [bias_shape[0] * bias_shape[1],
*bias_shape[2:]])
+ )
+ qk_added = R.emit(R.add(qk_scaled, bias))
+ softmax = R.emit(R.nn.softmax(qk_added, axis=-1))
+ else:
+ softmax = R.emit(R.nn.softmax(qk_scaled, axis=-1))
+ out = R.emit(R.matmul(softmax, v))
+
+ if len(out_shape) == 4:
+ out = R.emit(
+ R.reshape(
+ out,
+ [out_shape[0], out_shape[2], out_shape[1],
out_shape[3]],
+ )
+ )
+ out = R.emit(R.permute_dims(out, axes=[0, 2, 1, 3]))
+ R.output(out)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ original_func = builder.get()
+
+ if scale is not None:
+ scale = T.FloatImm("float32", scale)
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ q = R.arg("q", R.Tensor(q_shape, dtype))
+ k = R.arg("k", R.Tensor(k_shape, dtype))
+ v = R.arg("v", R.Tensor(v_shape, dtype))
+ if bias_shape is not None:
+ bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+ with R.dataflow() as frame:
+ if len(q_shape) == 3:
+ q = R.emit(R.reshape(q, [q_shape[0], q_shape[1], 1,
q_shape[2]]))
+
+ if len(k_shape) == 3:
+ k = R.emit(R.reshape(k, [k_shape[0], k_shape[1], 1,
k_shape[2]]))
+
+ if len(v_shape) == 3:
+ v = R.emit(R.reshape(v, [v_shape[0], v_shape[1], 1,
v_shape[2]]))
+
+ if bias_shape is not None:
+ if len(bias_shape) == 4:
+ bias = R.emit(
+ R.reshape(
+ bias,
+ [
+ bias_shape[0] * bias_shape[1],
+ bias_shape[2],
+ bias_shape[3],
+ ],
+ )
+ )
+ bias = R.emit(
+ R.reshape(
+ bias,
+ [
+ bias_shape[0],
+ bias_shape[1],
+ bias_shape[2],
+ bias_shape[3],
+ ],
+ )
+ )
+ elif len(bias_shape) == 3:
+ bias = R.emit(
+ R.reshape(bias, [bias_shape[0], 1, bias_shape[1],
bias_shape[2]])
+ )
+ else:
+ bias = None
+ out = R.emit(R.nn.attention(q, k, v, bias, scale))
+
+ if len(out_shape) == 3:
+ out = R.emit(R.reshape(out, [out_shape[0], out_shape[1],
out_shape[2]]))
+ R.output(out)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ expected_func = builder.get()
+ return tvm.IRModule({"main": original_func}), tvm.IRModule({"main":
expected_func})
+
+
+def get_numpy_attention_input(q_shape, k_shape, v_shape, bias_shape, dtype):
+ q = np.random.randn(*q_shape).astype(dtype)
+ k = np.random.randn(*k_shape).astype(dtype)
+ v = np.random.randn(*v_shape).astype(dtype)
+ if not bias_shape == "none":
+ bias = np.random.randn(*bias_shape).astype(dtype)
+ else:
+ bias = None
+ return q, k, v, bias
+
+
+def test_attention_rewrite_offload(attention_rewrite_size):
+ b, (s, s_kv), n, (h, h_v), bias_shape, scale = attention_rewrite_size
+ q_shape = [b, s, n, h] if n != "none" else [b, s, h]
+ k_shape = [b, s_kv, n, h] if n != "none" else [b, s_kv, h]
+ v_shape = [b, s_kv, n, h_v] if n != "none" else [b, s_kv, h_v]
+ out_shape = [b, s, n, h_v] if n != "none" else [b, s, h_v]
+ bias_shape = [b, n, s, s_kv] if n != "none" else [b, s, s_kv]
+ q, k, v, bias = get_numpy_attention_input(q_shape, k_shape, v_shape,
bias_shape, "float32")
+ original_mod, expected_mod = get_relax_attention_rewrite_module(
+ q_shape, k_shape, v_shape, out_shape, "float32", bias_shape, scale
+ )
+ original_mod = partition_for_cutlass(original_mod, True)
+ expected_mod = partition_for_cutlass(expected_mod, True)
+ tvm.ir.assert_structural_equal(original_mod, expected_mod, True)
+
+ codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})
+ original_mod = codegen_pass(original_mod)
+ expected_mod = codegen_pass(expected_mod)
+ if bias is None:
+ original_out = build_and_run(original_mod, [q, k, v], "cuda")
+ expected_out = build_and_run(expected_mod, [q, k, v], "cuda")
+ tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5,
atol=1e-5)
+ else:
+ original_out = build_and_run(original_mod, [q, k, v, bias], "cuda")
+ expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda")
+ tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5,
atol=1e-5)
+
+
def test_invalid_residual():
@tvm.script.ir_module
class Module: