This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dd2f452e4d [Unity][Transform] Fix bug for tir expression in shape in 
fuse_tir (#14931)
dd2f452e4d is described below

commit dd2f452e4d7da1e673267b67f1b0f27cdc35931d
Author: Yixin Dong <[email protected]>
AuthorDate: Fri May 26 20:23:31 2023 +0800

    [Unity][Transform] Fix bug for tir expression in shape in fuse_tir (#14931)
---
 src/relax/transform/fuse_tir.cc               |  3 +-
 tests/python/relax/test_transform_fuse_tir.py | 76 +++++++++++++++++++++++++++
 2 files changed, 77 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index d5dcd64cc7..601a4cff23 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -450,8 +450,7 @@ class FusedTIRConstructor : public ExprVisitor {
         ICHECK_GE(num_params, vars.size());
         for (size_t i = 0; i < vars.size(); ++i) {
           const tir::Var& param = prim_func->params[num_params - vars.size() + 
i];
-          ICHECK(!func_info_.symbolic_var_remap.count(param));
-          func_info_.symbolic_var_remap.Set(param, vars[i]);
+          func_info_.symbolic_var_matcher.Match(param, vars[i]);
         }
       } else {
         LOG(FATAL) << "TIR vars should be a shape expr, but got: " << 
tir_vars->GetTypeKey();
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index aabbd544bd..af770e0fc6 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1003,5 +1003,81 @@ def test_same_buffer_multiple_read():
     _check(Module, Expected)
 
 
+def test_tir_expression_in_shape():
+    @I.ir_module
+    class Module:
+        @R.function
+        def fused_transpose_matmul(
+            x: R.Tensor((3, 4), dtype="float32"),
+            y: R.Tensor(("n - 1", 4), dtype="float32"),
+            tir_vars: R.Shape(["n"]),
+        ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv = R.emit_te(topi.transpose, x)
+                gv = R.emit_te(topi.matmul, y, lv)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((3, 4), dtype="float32"),
+            y: R.Tensor(("n - 1", 4), dtype="float32"),
+            tir_vars: R.Shape(["n"]),
+        ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                lv = cls.fused_transpose_matmul(x, y, tir_vars)
+                R.output(lv)
+            return lv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def fused_transpose_matmul(
+            x: T.Buffer((T.int64(3), T.int64(4)), "float32"),
+            p_y: T.handle,
+            p_output0: T.handle,
+            n: T.int64,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            y = T.match_buffer(p_y, (n - T.int64(1), T.int64(4)))
+            var_T_matmul_intermediate = T.match_buffer(p_output0, (n - 
T.int64(1), T.int64(3)))
+            var_T_transpose_intermediate = T.alloc_buffer((T.int64(4), 
T.int64(3)))
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(3)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    var_T_transpose_intermediate[v_ax0, v_ax1] = x[v_ax1, 
v_ax0]
+            for ax0, ax1, k in T.grid(n - T.int64(1), T.int64(3), T.int64(4)):
+                with T.block("T_matmul"):
+                    v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
+                    with T.init():
+                        var_T_matmul_intermediate[v_ax0, v_ax1] = T.float32(0)
+                    var_T_matmul_intermediate[v_ax0, v_ax1] = (
+                        var_T_matmul_intermediate[v_ax0, v_ax1]
+                        + y[v_ax0, v_k] * var_T_transpose_intermediate[v_k, 
v_ax1]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((3, 4), dtype="float32"),
+            y: R.Tensor(("n - 1", 4), dtype="float32"),
+            tir_vars: R.Shape(["n"]),
+        ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+            n = T.int64()
+            cls = Expected
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.fused_transpose_matmul,
+                    (x, y),
+                    out_sinfo=R.Tensor((n - 1, 3), dtype="float32"),
+                    tir_vars=R.shape([n]),
+                )
+                R.output(lv)
+            return lv
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to