This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a02309ed8 [Relax] FuseTransposeMatmul Pass (#17234)
3a02309ed8 is described below

commit 3a02309ed85d308da1b1af127bc97b5b22589a43
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Aug 2 22:14:32 2024 +0800

    [Relax] FuseTransposeMatmul Pass (#17234)
    
    Introduce a new pass to fuse transpose and matmul, which specially for
    `Linear` ops in PyTorch and NNModule. Note that this pass is migrated
    from MLC-LLM.
    
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Junru Shao <[email protected]>
---
 python/tvm/relax/transform/__init__.py             |   1 +
 .../tvm/relax/transform/fuse_transpose_matmul.py   | 175 +++++++++++++++++++++
 .../relax/test_transform_fuse_transpose_matmul.py  |  82 ++++++++++
 3 files changed, 258 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 5e76fff6bd..5789e2fcf2 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -90,6 +90,7 @@ from .lower_gpu_ipc_alloc_storage import 
LowerGPUIPCAllocStorage
 from .optimize_layout_transform import OptimizeLayoutTransform
 from .remove_redundant_reshape import RemoveRedundantReshape
 from .fast_math import FastMathTransform
+from .fuse_transpose_matmul import FuseTransposeMatmul
 from .attach_external_modules import AttachExternModules
 
 # Import to register the legalization functions.
diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py 
b/python/tvm/relax/transform/fuse_transpose_matmul.py
new file mode 100644
index 0000000000..1d2324a28b
--- /dev/null
+++ b/python/tvm/relax/transform/fuse_transpose_matmul.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""A compiler pass that fuses transpose + matmul and generate TIR function.
+Note that
+1. Please put the pass before LegalizeOps pass.
+2. The pass only works for XW^T but not X^TW
+3. The pass would rewrite the relax ops into TIR functions. If you'd like to 
dispatch the
+   ops into library (e.g. cuBLAS) calls, please run dispatch pass before this 
pass.
+"""
+
+import tvm
+from tvm import IRModule, relax, te, tir
+from tvm.relax.dpl.pattern import is_op, wildcard
+from tvm.relax.expr_functor import PyExprMutator, mutator
+
+
[email protected]_pass(opt_level=0, name="FuseTransposeMatmul")
+class FuseTransposeMatmul:  # pylint: disable=too-few-public-methods
+    """A compiler pass that fuses transpose + matmul."""
+
+    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) 
-> IRModule:
+        """IRModule-level transformation"""
+        mod = relax.transform.FuseOpsByPattern(
+            [
+                (
+                    "transpose_matmul_fuse",
+                    *_pattern(),
+                ),
+            ]
+        )(mod)
+        transpose_matmul_codegen = _TransposeMatmulFuser(mod)
+        for g_var, func in mod.functions_items():
+            if isinstance(func, relax.Function):
+                func = transpose_matmul_codegen.visit_expr(func)
+                transpose_matmul_codegen.builder_.update_func(g_var, func)
+        return transpose_matmul_codegen.builder_.get()
+
+
+def _pattern():
+    """Pattern for transpose + matmul."""
+    # pylint: disable=invalid-name
+    w = wildcard()
+    x = wildcard()
+    wT = is_op("relax.permute_dims")(w)
+    o = is_op("relax.matmul")(x, wT)
+    # pylint: enable=invalid-name
+    annotations = {"o": o, "w": w, "x": x, "wT": wT}
+
+    def _check(context: relax.transform.PatternCheckContext) -> bool:
+        transpose_call = context.annotated_expr["wT"]
+        ndim = transpose_call.args[0].struct_info.ndim
+        if ndim == -1:
+            return False
+        if ndim == 2 and transpose_call.attrs.axes is None:
+            return True
+        axes = list(range(ndim))
+        axes[-1], axes[-2] = axes[-2], axes[-1]
+        return list(transpose_call.attrs.axes) == axes
+
+    return o, annotations, _check
+
+
+# pylint: disable=missing-docstring,invalid-name
+
+
+@mutator
+class _TransposeMatmulFuser(PyExprMutator):  # pylint: disable=abstract-method
+    def __init__(self, mod):
+        super().__init__(mod)
+
+    def visit_call_(  # pylint: disable=arguments-renamed
+        self,
+        call: relax.Call,
+    ) -> relax.Expr:
+        out_dtype = None
+
+        def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
+            nonlocal out_dtype
+            a_shape = list(a.shape)
+            b_shape = list(b.shape)
+            a_prepended = False
+            b_appended = False
+            if len(a_shape) == 1:
+                a_prepended = True
+                a_shape.insert(0, 1)
+            if len(b_shape) == 1:
+                b_appended = True
+                b_shape.append(1)
+
+            is_a_larger = len(a_shape) > len(b_shape)
+            offset = len(a_shape) - len(b_shape) if is_a_larger else 
len(b_shape) - len(a_shape)
+
+            a_relax = relax.Var("a", relax.TensorStructInfo(a.shape))
+            bT_shape = list(b.shape)
+            bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1]
+            bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape))
+            output_shape = self.builder_.normalize(
+                relax.op.matmul(a_relax, bT_relax)
+            ).struct_info.shape
+
+            def matmul_compute(*idx_spatial):
+                k = te.reduce_axis((0, a_shape[-1]), name="k")
+
+                def multiply_compute(idx_reduce):
+                    a_indices = []
+                    b_indices = []
+
+                    for i in range(offset):
+                        if is_a_larger:
+                            a_indices.append(idx_spatial[i])
+                        else:
+                            b_indices.append(idx_spatial[i])
+                    for i in range(offset, len(output_shape) - (2 - 
a_prepended - b_appended)):
+                        a_dim = a_shape[i if is_a_larger else i - offset]
+                        b_dim = b_shape[i if not is_a_larger else i - offset]
+                        dim_equal = a_dim == b_dim
+                        if not isinstance(dim_equal, tir.IntImm) or dim_equal 
== 0:
+                            a_dim_is_one = isinstance(a_dim, tir.IntImm) and 
a_dim == 1
+                            b_dim_is_one = isinstance(b_dim, tir.IntImm) and 
b_dim == 1
+                            a_indices.append(0 if a_dim_is_one else 
idx_spatial[i])
+                            b_indices.append(0 if b_dim_is_one else 
idx_spatial[i])
+                        else:
+                            a_indices.append(idx_spatial[i])
+                            b_indices.append(idx_spatial[i])
+
+                    if not a_prepended:
+                        a_indices.append(idx_spatial[-2 + b_appended])
+                    a_indices.append(idx_reduce)
+                    if not b_appended:
+                        b_indices.append(idx_spatial[-1])
+                    b_indices.append(idx_reduce)
+
+                    dtype = out_dtype
+                    if dtype != "":
+                        return a(*a_indices).astype(dtype) * 
b(*b_indices).astype(dtype)
+                    return a(*a_indices) * b(*b_indices)
+
+                return te.sum(multiply_compute(k), axis=k)
+
+            return te.compute(
+                output_shape,
+                lambda *idx: matmul_compute(*idx),  # pylint: 
disable=unnecessary-lambda
+                name="NT_matmul",
+            )
+
+        if isinstance(call.op, relax.GlobalVar):
+            function = self.builder_.get()[call.op]
+            if (
+                "Composite" in function.attrs
+                and function.attrs["Composite"] == "transpose_matmul_fuse"
+            ):
+                out_dtype = function.ret_struct_info.dtype
+                return self.builder_.call_te(
+                    te_transposed_matmul,
+                    call.args[1],
+                    call.args[0],
+                    primfunc_name_hint="NT_matmul",
+                )
+
+        return super().visit_call_(call)
diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py 
b/tests/python/relax/test_transform_fuse_transpose_matmul.py
new file mode 100644
index 0000000000..4b2b1fff8a
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-docstring
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_transform_fuse_transpose_matmul():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((128, 256), "float32"),
+            w: R.Tensor((128, 256), "float32"),
+        ) -> R.Tensor((128, 128), "float32"):
+            with R.dataflow():
+                wT = R.permute_dims(w, [1, 0])
+                o = R.matmul(x, wT)
+                R.output(o)
+            return o
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def NT_matmul(
+            x: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+            w: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+            NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
+                with T.block("NT_matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(x[v_i0, v_k], w[v_i1, v_k])
+                    T.writes(NT_matmul[v_i0, v_i1])
+                    with T.init():
+                        NT_matmul[v_i0, v_i1] = T.float32(0)
+                    NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, 
v_k] * w[v_i1, v_k]
+
+        @R.function
+        def main(
+            x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), 
dtype="float32")
+        ) -> R.Tensor((128, 128), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+    after = tvm.ir.transform.Sequential(
+        [
+            relax.transform.FuseTransposeMatmul(),
+            relax.transform.FuseTIR(),  # Only used for remove unused 
primitive function
+        ]
+    )(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to