This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2bd506fa92 [Unity][BYOC] Support implicit attention patterns (#14744)
2bd506fa92 is described below

commit 2bd506fa92a37f6a3ca8e9fdd4fcf132f694d93f
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Apr 28 22:24:57 2023 -0700

    [Unity][BYOC] Support implicit attention patterns (#14744)
    
    * [Unity][BYOC] Support implicit attention patterns
    
    In this PR, we add the support for the implicit attention pattern matching 
in BYOC. By adding a rewriting stage before `FuseOpsByPattern`, we are able to 
match and rewrite the patterns like `matmul - softmax - matmul` to 
`R.nn.attention` with other operators needed.X
    
    * fix lint
---
 python/tvm/relax/backend/contrib/cutlass.py |  19 ++-
 python/tvm/relax/backend/patterns.py        | 173 +++++++++++++++++++++++++++-
 tests/python/relax/test_codegen_cutlass.py  | 169 +++++++++++++++++++++++++++
 3 files changed, 358 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 86eab773cc..36f43c6c21 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -22,6 +22,7 @@ from typing import Mapping, Sequence
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
 from tvm.relax import DataflowVar, Var, transform, Call
 from tvm.relax.transform import PatternCheckContext
+from tvm.relax.dpl import rewrite_call
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import (
@@ -31,6 +32,7 @@ from ..patterns import (
     make_residual_block_pattern,
     make_stacked_attention_pattern,
     make_layer_norm_pattern,
+    make_attention_rewrite_pattern,
 )
 
 
@@ -346,6 +348,18 @@ def layer_norm_pattern():
     ]
 
 
+def attention_rewrite_patterns():
+    """
+    Returns a list of all attention rewriting patterns in cutlass BYOC backend.
+    """
+    patterns = []
+    for qkv_layout in ["BSNH", "BSH"]:
+        for out_layout in ["BSNH", "BSH"]:
+            for with_bias in [True, False]:
+                patterns.append(make_attention_rewrite_pattern(qkv_layout, 
out_layout, with_bias))
+    return patterns
+
+
 register_patterns(
     [
         *conv2d_patterns(),
@@ -356,6 +370,8 @@ register_patterns(
     ]
 )
 
+_REWRITE_PATTERNS = [*attention_rewrite_patterns()]
+
 
 def partition_for_cutlass(mod, annotate_codegen=True):
     """
@@ -377,7 +393,8 @@ def partition_for_cutlass(mod, annotate_codegen=True):
         The resulting IRModule, containing partitioned subgraphs to be
         compiled by the CUTLASS backend.
     """
-
+    for pattern, rewriter in _REWRITE_PATTERNS:
+        mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
     patterns = get_patterns_with_prefix("cutlass")
     return transform.FuseOpsByPattern(
         patterns, bind_constants=False, annotate_codegen=annotate_codegen
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 7242cb3f0d..53c3879c3a 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -18,8 +18,8 @@
 """Common patterns used in BYOC"""
 
 from typing import Dict, Mapping, Tuple, Union
-
-from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard
+from tvm.script import relax as R, tir as T
+from tvm.relax.dpl.pattern import DFPattern, is_const, is_op, 
is_tuple_get_item, wildcard
 
 
 def _with_bias_activation_pattern(
@@ -261,3 +261,172 @@ def make_layer_norm_pattern():
     beta = wildcard()
 
     return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
+
+
+def make_attention_rewrite_pattern(qkv_layout: str, out_layout: str, 
with_bias: bool):
+    """
+    Create pattern for implicit fused multi head attention rewriting.
+
+    Parameters
+    ----------
+    qkv_layout: str
+        The layout of the query, key and value tensor, i.e. BSNH or BSH.
+
+    out_layout: str
+        The layout of the output tensor, i.e. BSNH or BSH.
+
+    with_bias: bool
+        Whether or not to include bias addition
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing an implicit fused multi head 
attention.
+
+    rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+        The rewriter for the pattern. It will check the matched patterns, and 
rewrite.
+        If the matched pattern is not able to be rewritten to 
`R.nn.attention`, the rewriter
+        returns the original IR.
+    """
+
+    # pylint: disable=invalid-name
+    def handle_input(tensor, layout, transpose):
+        if layout == "BSNH":
+            permuted = is_op("relax.permute_dims")(tensor)
+            shape = wildcard()
+            reshaped = is_op("relax.reshape")(permuted, shape)
+            if transpose:
+                transposed = is_op("relax.permute_dims")(reshaped)
+
+            def rewriter(matchings, x):
+                if matchings[tensor].struct_info.ndim != 4:
+                    return None
+                if list(matchings[permuted].attrs.axes) != [0, 2, 1, 3]:
+                    return None
+                before_reshape = matchings[permuted].struct_info.shape.values
+                after_reshape = matchings[shape].struct_info.values
+                if not (
+                    len(before_reshape) == 4
+                    and len(after_reshape) == 3
+                    and before_reshape[-2:] == after_reshape[-2:]
+                ):
+                    return None
+                if transpose and list(matchings[transposed].attrs.axes) != [0, 
2, 1]:
+                    return None
+                return x, x.struct_info.shape
+
+            if transpose:
+                return transposed, rewriter
+            else:
+                return reshaped, rewriter
+        elif layout == "BSH":
+            if transpose:
+                transposed = is_op("relax.permute_dims")(tensor)
+
+            def rewriter(matchings, x):
+                if matchings[tensor].struct_info.ndim != 3:
+                    return None
+                if transpose and list(matchings[transposed].attrs.axes) != [0, 
2, 1]:
+                    return None
+                before_reshape = x.struct_info.shape.values
+                after_reshape = [before_reshape[0], before_reshape[1], 1, 
before_reshape[2]]
+                return R.reshape(x, after_reshape), after_reshape
+
+            if transpose:
+                return transposed, rewriter
+            else:
+                return tensor, rewriter
+        else:
+            raise NotImplementedError()
+
+    def handle_output(tensor, layout):
+        if layout == "BSNH":
+            shape = wildcard()
+            reshaped = is_op("relax.reshape")(tensor, shape)
+            permuted = is_op("relax.permute_dims")(reshaped)
+
+            def rewriter(matchings, x):
+                if matchings[tensor].struct_info.ndim != 3:
+                    return None
+                before_reshape = matchings[tensor].struct_info.shape.values
+                after_reshape = matchings[shape].struct_info.values
+                if not (
+                    len(before_reshape) == 3
+                    and len(after_reshape) == 4
+                    and before_reshape[-2:] == after_reshape[-2:]
+                ):
+                    return None
+                if list(matchings[permuted].attrs.axes) != [0, 2, 1, 3]:
+                    return None
+                return x
+
+            return permuted, rewriter
+        elif layout == "BSH":
+
+            def rewriter(matchings, x):
+                if matchings[tensor].struct_info.ndim != 3:
+                    return None
+                return R.reshape(x, matchings[tensor].struct_info.shape.values)
+
+            return tensor, rewriter
+        else:
+            raise NotImplementedError()
+
+    q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
+    q, q_rewriter = handle_input(q_raw, qkv_layout, False)
+    k, k_rewriter = handle_input(k_raw, qkv_layout, True)
+    v, v_rewriter = handle_input(v_raw, qkv_layout, False)
+    matmul_1 = is_op("relax.matmul")(q, k)
+    scale = is_const()
+    multiply = is_op("relax.multiply")(matmul_1, scale)
+    if with_bias:
+        bias_raw = wildcard()
+        add = is_op("relax.add")(multiply, bias_raw)
+        softmax = is_op("relax.nn.softmax")(add)
+    else:
+        softmax = is_op("relax.nn.softmax")(multiply)
+    matmul_2 = is_op("relax.matmul")(softmax, v)
+    out, out_rewriter = handle_output(matmul_2, out_layout)
+
+    def rewriter(original, matchings):
+        query, query_shape = q_rewriter(matchings, matchings[q_raw])
+        key, key_shape = k_rewriter(matchings, matchings[k_raw])
+        value, _ = v_rewriter(matchings, matchings[v_raw])
+        if query is None or key is None or value is None:
+            return original
+        if matchings[softmax].attrs.axis != -1:
+            return original
+        b, s, n, _ = query_shape
+        _, s_kv, _, _ = key_shape
+        if with_bias:
+            bias = matchings[bias_raw]
+            bias_shape = list(bias.struct_info.shape)
+            if bias_shape == [b * n, s, s_kv]:
+                bias = R.reshape(bias, [b, n, s, s_kv])
+            elif bias_shape == [b * n, 1, s_kv]:
+                bias = R.reshape(bias, [b, n, 1, s_kv])
+            elif bias_shape == [b, s, s_kv]:
+                bias = R.reshape(bias, [b, 1, s, s_kv])
+            elif bias_shape == [b, 1, s_kv]:
+                bias = R.reshape(bias, [b, 1, 1, s_kv])
+            elif bias_shape in [[1, s, s_kv], [s, s_kv]]:
+                bias = R.reshape(bias, [1, 1, s, s_kv])
+            elif bias_shape in [[1, 1, s_kv], [1, s_kv], [s_kv]]:
+                bias = R.reshape(bias, [1, 1, 1, s_kv])
+            else:
+                return original
+        else:
+            bias = None
+        out = out_rewriter(
+            matchings,
+            R.nn.attention(
+                query,
+                key,
+                value,
+                bias,
+                T.FloatImm(matchings[scale].data.dtype, 
float(matchings[scale].data.numpy())),
+            ),
+        )
+        return out
+
+    return out, rewriter
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 85f96b5e96..a2c5b4da96 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -761,6 +761,175 @@ def 
test_stacked_attention_strided_slice_offload(stacked_attention_size):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
[email protected](
+    params=[
+        # B, S, N, H, bias_shape, scale
+        (4, (16, 8), 32, (8, 16), "none", 0.5),
+        (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), 0.5),
+        (4, (16, 8), "none", (8, 16), "none", 0.5),
+        (4, (16, 8), "none", (8, 16), (4, 32, 16, 8), 0.5),
+    ]
+)
+def attention_rewrite_size(request):
+    return request.param
+
+
+def get_relax_attention_rewrite_module(
+    q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None
+):
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder, tir as T
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q_shape, dtype))
+            k = R.arg("k", R.Tensor(k_shape, dtype))
+            v = R.arg("v", R.Tensor(v_shape, dtype))
+            if bias_shape is not None:
+                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+            with R.dataflow() as frame:
+                if len(q_shape) == 4:
+                    q = R.emit(R.permute_dims(q, axes=[0, 2, 1, 3]))
+                    q = R.emit(R.reshape(q, [q_shape[0] * q_shape[2], 
q_shape[1], q_shape[3]]))
+
+                if len(k_shape) == 4:
+                    k = R.emit(R.permute_dims(k, axes=[0, 2, 1, 3]))
+                    k = R.emit(R.reshape(k, [k_shape[0] * k_shape[2], 
k_shape[1], k_shape[3]]))
+
+                if len(v_shape) == 4:
+                    v = R.emit(R.permute_dims(v, axes=[0, 2, 1, 3]))
+                    v = R.emit(R.reshape(v, [v_shape[0] * v_shape[2], 
v_shape[1], v_shape[3]]))
+
+                k = R.emit(R.permute_dims(k, axes=[0, 2, 1]))
+                qk = R.emit(R.matmul(q, k))
+                qk_scaled = R.emit(R.multiply(qk, R.const(scale, "float32")))
+                if bias_shape is not None:
+                    if len(bias_shape) == 4:
+                        bias = R.emit(
+                            R.reshape(bias, [bias_shape[0] * bias_shape[1], 
*bias_shape[2:]])
+                        )
+                    qk_added = R.emit(R.add(qk_scaled, bias))
+                    softmax = R.emit(R.nn.softmax(qk_added, axis=-1))
+                else:
+                    softmax = R.emit(R.nn.softmax(qk_scaled, axis=-1))
+                out = R.emit(R.matmul(softmax, v))
+
+                if len(out_shape) == 4:
+                    out = R.emit(
+                        R.reshape(
+                            out,
+                            [out_shape[0], out_shape[2], out_shape[1], 
out_shape[3]],
+                        )
+                    )
+                    out = R.emit(R.permute_dims(out, axes=[0, 2, 1, 3]))
+                R.output(out)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    original_func = builder.get()
+
+    if scale is not None:
+        scale = T.FloatImm("float32", scale)
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q_shape, dtype))
+            k = R.arg("k", R.Tensor(k_shape, dtype))
+            v = R.arg("v", R.Tensor(v_shape, dtype))
+            if bias_shape is not None:
+                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+            with R.dataflow() as frame:
+                if len(q_shape) == 3:
+                    q = R.emit(R.reshape(q, [q_shape[0], q_shape[1], 1, 
q_shape[2]]))
+
+                if len(k_shape) == 3:
+                    k = R.emit(R.reshape(k, [k_shape[0], k_shape[1], 1, 
k_shape[2]]))
+
+                if len(v_shape) == 3:
+                    v = R.emit(R.reshape(v, [v_shape[0], v_shape[1], 1, 
v_shape[2]]))
+
+                if bias_shape is not None:
+                    if len(bias_shape) == 4:
+                        bias = R.emit(
+                            R.reshape(
+                                bias,
+                                [
+                                    bias_shape[0] * bias_shape[1],
+                                    bias_shape[2],
+                                    bias_shape[3],
+                                ],
+                            )
+                        )
+                        bias = R.emit(
+                            R.reshape(
+                                bias,
+                                [
+                                    bias_shape[0],
+                                    bias_shape[1],
+                                    bias_shape[2],
+                                    bias_shape[3],
+                                ],
+                            )
+                        )
+                    elif len(bias_shape) == 3:
+                        bias = R.emit(
+                            R.reshape(bias, [bias_shape[0], 1, bias_shape[1], 
bias_shape[2]])
+                        )
+                else:
+                    bias = None
+                out = R.emit(R.nn.attention(q, k, v, bias, scale))
+
+                if len(out_shape) == 3:
+                    out = R.emit(R.reshape(out, [out_shape[0], out_shape[1], 
out_shape[2]]))
+                R.output(out)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    expected_func = builder.get()
+    return tvm.IRModule({"main": original_func}), tvm.IRModule({"main": 
expected_func})
+
+
+def get_numpy_attention_input(q_shape, k_shape, v_shape, bias_shape, dtype):
+    q = np.random.randn(*q_shape).astype(dtype)
+    k = np.random.randn(*k_shape).astype(dtype)
+    v = np.random.randn(*v_shape).astype(dtype)
+    if not bias_shape == "none":
+        bias = np.random.randn(*bias_shape).astype(dtype)
+    else:
+        bias = None
+    return q, k, v, bias
+
+
+def test_attention_rewrite_offload(attention_rewrite_size):
+    b, (s, s_kv), n, (h, h_v), bias_shape, scale = attention_rewrite_size
+    q_shape = [b, s, n, h] if n != "none" else [b, s, h]
+    k_shape = [b, s_kv, n, h] if n != "none" else [b, s_kv, h]
+    v_shape = [b, s_kv, n, h_v] if n != "none" else [b, s_kv, h_v]
+    out_shape = [b, s, n, h_v] if n != "none" else [b, s, h_v]
+    bias_shape = [b, n, s, s_kv] if n != "none" else [b, s, s_kv]
+    q, k, v, bias = get_numpy_attention_input(q_shape, k_shape, v_shape, 
bias_shape, "float32")
+    original_mod, expected_mod = get_relax_attention_rewrite_module(
+        q_shape, k_shape, v_shape, out_shape, "float32", bias_shape, scale
+    )
+    original_mod = partition_for_cutlass(original_mod, True)
+    expected_mod = partition_for_cutlass(expected_mod, True)
+    tvm.ir.assert_structural_equal(original_mod, expected_mod, True)
+
+    codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})
+    original_mod = codegen_pass(original_mod)
+    expected_mod = codegen_pass(expected_mod)
+    if bias is None:
+        original_out = build_and_run(original_mod, [q, k, v], "cuda")
+        expected_out = build_and_run(expected_mod, [q, k, v], "cuda")
+        tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, 
atol=1e-5)
+    else:
+        original_out = build_and_run(original_mod, [q, k, v, bias], "cuda")
+        expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda")
+        tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, 
atol=1e-5)
+
+
 def test_invalid_residual():
     @tvm.script.ir_module
     class Module:

Reply via email to