This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a0d0859ebe [FIX][RELAX] fix fusion of transpose + matmul when constant
weight (#17761)
a0d0859ebe is described below
commit a0d0859ebe711716ce2eb7e57729136061361ccb
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Wed Mar 19 08:44:11 2025 +0100
[FIX][RELAX] fix fusion of transpose + matmul when constant weight
(#17761)
fix fusion of transpose + matmul when the weight is a constant
---
.../tvm/relax/transform/fuse_transpose_matmul.py | 3 +-
.../relax/test_transform_fuse_transpose_matmul.py | 54 ++++++++++++++++++++++
2 files changed, 56 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py
b/python/tvm/relax/transform/fuse_transpose_matmul.py
index 1d2324a28b..141f926cd3 100644
--- a/python/tvm/relax/transform/fuse_transpose_matmul.py
+++ b/python/tvm/relax/transform/fuse_transpose_matmul.py
@@ -41,7 +41,8 @@ class FuseTransposeMatmul: # pylint:
disable=too-few-public-methods
"transpose_matmul_fuse",
*_pattern(),
),
- ]
+ ],
+ bind_constants=False,
)(mod)
transpose_matmul_codegen = _TransposeMatmulFuser(mod)
for g_var, func in mod.functions_items():
diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py
b/tests/python/relax/test_transform_fuse_transpose_matmul.py
index 4b2b1fff8a..446102dcbb 100644
--- a/tests/python/relax/test_transform_fuse_transpose_matmul.py
+++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py
@@ -22,6 +22,7 @@ from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
+import numpy as np
def test_transform_fuse_transpose_matmul():
@@ -78,5 +79,58 @@ def test_transform_fuse_transpose_matmul():
tvm.ir.assert_structural_equal(after, Expected)
+def test_transform_fuse_transpose_matmul_const():
+ w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32")
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((128, 256), "float32"),
+ ) -> R.Tensor((128, 128), "float32"):
+ with R.dataflow():
+ wT = R.permute_dims(w, [1, 0])
+ o = R.matmul(x, wT)
+ R.output(o)
+ return o
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def NT_matmul(
+ x: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+ w: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+ NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(x[v_i0, v_k], w[v_i1, v_k])
+ T.writes(NT_matmul[v_i0, v_i1])
+ with T.init():
+ NT_matmul[v_i0, v_i1] = T.float32(0)
+ NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0,
v_k] * w[v_i1, v_k]
+
+ @R.function
+ def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128,
128), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ after = tvm.ir.transform.Sequential(
+ [
+ relax.transform.FuseTransposeMatmul(),
+ relax.transform.FuseTIR(), # Only used for remove unused
primitive function
+ ]
+ )(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()