This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a0d0859ebe [FIX][RELAX] fix fusion of transpose + matmul when constant 
 weight  (#17761)
a0d0859ebe is described below

commit a0d0859ebe711716ce2eb7e57729136061361ccb
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Wed Mar 19 08:44:11 2025 +0100

    [FIX][RELAX] fix fusion of transpose + matmul when constant  weight  
(#17761)
    
    fix fusion of transpose + matmul when the weight is a constant
---
 .../tvm/relax/transform/fuse_transpose_matmul.py   |  3 +-
 .../relax/test_transform_fuse_transpose_matmul.py  | 54 ++++++++++++++++++++++
 2 files changed, 56 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py 
b/python/tvm/relax/transform/fuse_transpose_matmul.py
index 1d2324a28b..141f926cd3 100644
--- a/python/tvm/relax/transform/fuse_transpose_matmul.py
+++ b/python/tvm/relax/transform/fuse_transpose_matmul.py
@@ -41,7 +41,8 @@ class FuseTransposeMatmul:  # pylint: 
disable=too-few-public-methods
                     "transpose_matmul_fuse",
                     *_pattern(),
                 ),
-            ]
+            ],
+            bind_constants=False,
         )(mod)
         transpose_matmul_codegen = _TransposeMatmulFuser(mod)
         for g_var, func in mod.functions_items():
diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py 
b/tests/python/relax/test_transform_fuse_transpose_matmul.py
index 4b2b1fff8a..446102dcbb 100644
--- a/tests/python/relax/test_transform_fuse_transpose_matmul.py
+++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py
@@ -22,6 +22,7 @@ from tvm import relax
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
+import numpy as np
 
 
 def test_transform_fuse_transpose_matmul():
@@ -78,5 +79,58 @@ def test_transform_fuse_transpose_matmul():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_transform_fuse_transpose_matmul_const():
+    w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32")
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((128, 256), "float32"),
+        ) -> R.Tensor((128, 128), "float32"):
+            with R.dataflow():
+                wT = R.permute_dims(w, [1, 0])
+                o = R.matmul(x, wT)
+                R.output(o)
+            return o
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def NT_matmul(
+            x: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+            w: T.Buffer((T.int64(128), T.int64(256)), "float32"),
+            NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
+                with T.block("NT_matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(x[v_i0, v_k], w[v_i1, v_k])
+                    T.writes(NT_matmul[v_i0, v_i1])
+                    with T.init():
+                        NT_matmul[v_i0, v_i1] = T.float32(0)
+                    NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, 
v_k] * w[v_i1, v_k]
+
+        @R.function
+        def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 
128), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+    after = tvm.ir.transform.Sequential(
+        [
+            relax.transform.FuseTransposeMatmul(),
+            relax.transform.FuseTIR(),  # Only used for remove unused 
primitive function
+        ]
+    )(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to