This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dd2f452e4d [Unity][Transform] Fix bug for tir expression in shape in
fuse_tir (#14931)
dd2f452e4d is described below
commit dd2f452e4d7da1e673267b67f1b0f27cdc35931d
Author: Yixin Dong <[email protected]>
AuthorDate: Fri May 26 20:23:31 2023 +0800
[Unity][Transform] Fix bug for tir expression in shape in fuse_tir (#14931)
---
src/relax/transform/fuse_tir.cc | 3 +-
tests/python/relax/test_transform_fuse_tir.py | 76 +++++++++++++++++++++++++++
2 files changed, 77 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index d5dcd64cc7..601a4cff23 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -450,8 +450,7 @@ class FusedTIRConstructor : public ExprVisitor {
ICHECK_GE(num_params, vars.size());
for (size_t i = 0; i < vars.size(); ++i) {
const tir::Var& param = prim_func->params[num_params - vars.size() +
i];
- ICHECK(!func_info_.symbolic_var_remap.count(param));
- func_info_.symbolic_var_remap.Set(param, vars[i]);
+ func_info_.symbolic_var_matcher.Match(param, vars[i]);
}
} else {
LOG(FATAL) << "TIR vars should be a shape expr, but got: " <<
tir_vars->GetTypeKey();
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index aabbd544bd..af770e0fc6 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1003,5 +1003,81 @@ def test_same_buffer_multiple_read():
_check(Module, Expected)
+def test_tir_expression_in_shape():
+ @I.ir_module
+ class Module:
+ @R.function
+ def fused_transpose_matmul(
+ x: R.Tensor((3, 4), dtype="float32"),
+ y: R.Tensor(("n - 1", 4), dtype="float32"),
+ tir_vars: R.Shape(["n"]),
+ ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv = R.emit_te(topi.transpose, x)
+ gv = R.emit_te(topi.matmul, y, lv)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((3, 4), dtype="float32"),
+ y: R.Tensor(("n - 1", 4), dtype="float32"),
+ tir_vars: R.Shape(["n"]),
+ ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ lv = cls.fused_transpose_matmul(x, y, tir_vars)
+ R.output(lv)
+ return lv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def fused_transpose_matmul(
+ x: T.Buffer((T.int64(3), T.int64(4)), "float32"),
+ p_y: T.handle,
+ p_output0: T.handle,
+ n: T.int64,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ y = T.match_buffer(p_y, (n - T.int64(1), T.int64(4)))
+ var_T_matmul_intermediate = T.match_buffer(p_output0, (n -
T.int64(1), T.int64(3)))
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(4),
T.int64(3)))
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(3)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ var_T_transpose_intermediate[v_ax0, v_ax1] = x[v_ax1,
v_ax0]
+ for ax0, ax1, k in T.grid(n - T.int64(1), T.int64(3), T.int64(4)):
+ with T.block("T_matmul"):
+ v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
+ with T.init():
+ var_T_matmul_intermediate[v_ax0, v_ax1] = T.float32(0)
+ var_T_matmul_intermediate[v_ax0, v_ax1] = (
+ var_T_matmul_intermediate[v_ax0, v_ax1]
+ + y[v_ax0, v_k] * var_T_transpose_intermediate[v_k,
v_ax1]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((3, 4), dtype="float32"),
+ y: R.Tensor(("n - 1", 4), dtype="float32"),
+ tir_vars: R.Shape(["n"]),
+ ) -> R.Tensor(("n - 1", 3), dtype="float32"):
+ n = T.int64()
+ cls = Expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.fused_transpose_matmul,
+ (x, y),
+ out_sinfo=R.Tensor((n - 1, 3), dtype="float32"),
+ tir_vars=R.shape([n]),
+ )
+ R.output(lv)
+ return lv
+
+ _check(Module, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()