This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3a02309ed8 [Relax] FuseTransposeMatmul Pass (#17234)
3a02309ed8 is described below
commit 3a02309ed85d308da1b1af127bc97b5b22589a43
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Aug 2 22:14:32 2024 +0800
[Relax] FuseTransposeMatmul Pass (#17234)
Introduce a new pass to fuse transpose and matmul, which specially for
`Linear` ops in PyTorch and NNModule. Note that this pass is migrated
from MLC-LLM.
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
---
python/tvm/relax/transform/__init__.py | 1 +
.../tvm/relax/transform/fuse_transpose_matmul.py | 175 +++++++++++++++++++++
.../relax/test_transform_fuse_transpose_matmul.py | 82 ++++++++++
3 files changed, 258 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 5e76fff6bd..5789e2fcf2 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -90,6 +90,7 @@ from .lower_gpu_ipc_alloc_storage import
LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .fast_math import FastMathTransform
+from .fuse_transpose_matmul import FuseTransposeMatmul
from .attach_external_modules import AttachExternModules
# Import to register the legalization functions.
diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py
b/python/tvm/relax/transform/fuse_transpose_matmul.py
new file mode 100644
index 0000000000..1d2324a28b
--- /dev/null
+++ b/python/tvm/relax/transform/fuse_transpose_matmul.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""A compiler pass that fuses transpose + matmul and generate TIR function.
+Note that
+1. Please put the pass before LegalizeOps pass.
+2. The pass only works for XW^T but not X^TW
+3. The pass would rewrite the relax ops into TIR functions. If you'd like to
dispatch the
+ ops into library (e.g. cuBLAS) calls, please run dispatch pass before this
pass.
+"""
+
+import tvm
+from tvm import IRModule, relax, te, tir
+from tvm.relax.dpl.pattern import is_op, wildcard
+from tvm.relax.expr_functor import PyExprMutator, mutator
+
+
[email protected]_pass(opt_level=0, name="FuseTransposeMatmul")
+class FuseTransposeMatmul: # pylint: disable=too-few-public-methods
+ """A compiler pass that fuses transpose + matmul."""
+
+ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext)
-> IRModule:
+ """IRModule-level transformation"""
+ mod = relax.transform.FuseOpsByPattern(
+ [
+ (
+ "transpose_matmul_fuse",
+ *_pattern(),
+ ),
+ ]
+ )(mod)
+ transpose_matmul_codegen = _TransposeMatmulFuser(mod)
+ for g_var, func in mod.functions_items():
+ if isinstance(func, relax.Function):
+ func = transpose_matmul_codegen.visit_expr(func)
+ transpose_matmul_codegen.builder_.update_func(g_var, func)
+ return transpose_matmul_codegen.builder_.get()
+
+
+def _pattern():
+ """Pattern for transpose + matmul."""
+ # pylint: disable=invalid-name
+ w = wildcard()
+ x = wildcard()
+ wT = is_op("relax.permute_dims")(w)
+ o = is_op("relax.matmul")(x, wT)
+ # pylint: enable=invalid-name
+ annotations = {"o": o, "w": w, "x": x, "wT": wT}
+
+ def _check(context: relax.transform.PatternCheckContext) -> bool:
+ transpose_call = context.annotated_expr["wT"]
+ ndim = transpose_call.args[0].struct_info.ndim
+ if ndim == -1:
+ return False
+ if ndim == 2 and transpose_call.attrs.axes is None:
+ return True
+ axes = list(range(ndim))
+ axes[-1], axes[-2] = axes[-2], axes[-1]
+ return list(transpose_call.attrs.axes) == axes
+
+ return o, annotations, _check
+
+
+# pylint: disable=missing-docstring,invalid-name
+
+
+@mutator
+class _TransposeMatmulFuser(PyExprMutator): # pylint: disable=abstract-method
+ def __init__(self, mod):
+ super().__init__(mod)
+
+ def visit_call_( # pylint: disable=arguments-renamed
+ self,
+ call: relax.Call,
+ ) -> relax.Expr:
+ out_dtype = None
+
+ def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
+ nonlocal out_dtype
+ a_shape = list(a.shape)
+ b_shape = list(b.shape)
+ a_prepended = False
+ b_appended = False
+ if len(a_shape) == 1:
+ a_prepended = True
+ a_shape.insert(0, 1)
+ if len(b_shape) == 1:
+ b_appended = True
+ b_shape.append(1)
+
+ is_a_larger = len(a_shape) > len(b_shape)
+ offset = len(a_shape) - len(b_shape) if is_a_larger else
len(b_shape) - len(a_shape)
+
+ a_relax = relax.Var("a", relax.TensorStructInfo(a.shape))
+ bT_shape = list(b.shape)
+ bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1]
+ bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape))
+ output_shape = self.builder_.normalize(
+ relax.op.matmul(a_relax, bT_relax)
+ ).struct_info.shape
+
+ def matmul_compute(*idx_spatial):
+ k = te.reduce_axis((0, a_shape[-1]), name="k")
+
+ def multiply_compute(idx_reduce):
+ a_indices = []
+ b_indices = []
+
+ for i in range(offset):
+ if is_a_larger:
+ a_indices.append(idx_spatial[i])
+ else:
+ b_indices.append(idx_spatial[i])
+ for i in range(offset, len(output_shape) - (2 -
a_prepended - b_appended)):
+ a_dim = a_shape[i if is_a_larger else i - offset]
+ b_dim = b_shape[i if not is_a_larger else i - offset]
+ dim_equal = a_dim == b_dim
+ if not isinstance(dim_equal, tir.IntImm) or dim_equal
== 0:
+ a_dim_is_one = isinstance(a_dim, tir.IntImm) and
a_dim == 1
+ b_dim_is_one = isinstance(b_dim, tir.IntImm) and
b_dim == 1
+ a_indices.append(0 if a_dim_is_one else
idx_spatial[i])
+ b_indices.append(0 if b_dim_is_one else
idx_spatial[i])
+ else:
+ a_indices.append(idx_spatial[i])
+ b_indices.append(idx_spatial[i])
+
+ if not a_prepended:
+ a_indices.append(idx_spatial[-2 + b_appended])
+ a_indices.append(idx_reduce)
+ if not b_appended:
+ b_indices.append(idx_spatial[-1])
+ b_indices.append(idx_reduce)
+
+ dtype = out_dtype
+ if dtype != "":
+ return a(*a_indices).astype(dtype) *
b(*b_indices).astype(dtype)
+ return a(*a_indices) * b(*b_indices)
+
+ return te.sum(multiply_compute(k), axis=k)
+
+ return te.compute(
+ output_shape,
+ lambda *idx: matmul_compute(*idx), # pylint:
disable=unnecessary-lambda
+ name="NT_matmul",
+ )
+
+ if isinstance(call.op, relax.GlobalVar):
+ function = self.builder_.get()[call.op]
+ if (
+ "Composite" in function.attrs
+ and function.attrs["Composite"] == "transpose_matmul_fuse"
+ ):
+ out_dtype = function.ret_struct_info.dtype
+ return self.builder_.call_te(
+ te_transposed_matmul,
+ call.args[1],
+ call.args[0],
+ primfunc_name_hint="NT_matmul",
+ )
+
+ return super().visit_call_(call)
diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py
b/tests/python/relax/test_transform_fuse_transpose_matmul.py
new file mode 100644
index 0000000000..4b2b1fff8a
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-docstring
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_transform_fuse_transpose_matmul():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((128, 256), "float32"),
+ w: R.Tensor((128, 256), "float32"),
+ ) -> R.Tensor((128, 128), "float32"):
+ with R.dataflow():
+ wT = R.permute_dims(w, [1, 0])
+ o = R.matmul(x, wT)
+ R.output(o)
+ return o
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def NT_matmul(
+ x: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+ w: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+ NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(x[v_i0, v_k], w[v_i1, v_k])
+ T.writes(NT_matmul[v_i0, v_i1])
+ with T.init():
+ NT_matmul[v_i0, v_i1] = T.float32(0)
+ NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0,
v_k] * w[v_i1, v_k]
+
+ @R.function
+ def main(
+ x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256),
dtype="float32")
+ ) -> R.Tensor((128, 128), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ after = tvm.ir.transform.Sequential(
+ [
+ relax.transform.FuseTransposeMatmul(),
+ relax.transform.FuseTIR(), # Only used for remove unused
primitive function
+ ]
+ )(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()