This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d6c8f7933b [Unity] Preserve symbolic var args when applying call_tir
(#14555)
d6c8f7933b is described below
commit d6c8f7933bf5cd2d92d8e048d10cdaf18db6e3b8
Author: Hongyi Jin <[email protected]>
AuthorDate: Mon Apr 10 04:36:44 2023 -0400
[Unity] Preserve symbolic var args when applying call_tir (#14555)
Currently fuse_tir will remove the symbolic var args of call_tir. This PR
fixes this behavior.
---
src/relax/transform/fuse_tir.cc | 4 ++-
tests/python/relax/test_transform_fuse_tir.py | 40 +++++++++++++++++++++++++++
2 files changed, 43 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 5ddda93705..f5a05a0539 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -794,7 +794,9 @@ class TIRFuseMutator : public ExprMutator {
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func =
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
- return Call(call->op, {new_gv, call->args[1]}, call->attrs,
call->sinfo_args, call->span);
+ Array<Expr> new_args = call->args;
+ new_args.Set(0, new_gv);
+ return Call(call->op, new_args, call->attrs, call->sinfo_args,
call->span);
}
}
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index bdbd9be966..47480346e7 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -793,5 +793,45 @@ def test_symbolic_shape_aware_fuse_with_allocation():
_check(Before, Expected)
+def test_symbolic_var_in_call_tir_args():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def foo(
+ rxplaceholder: T.Buffer((1, 1, 32, 128), "float32"),
+ rxplaceholder_1: T.Buffer((2048, 128), "float32"),
+ rotary: T.Buffer((1, 1, 32, 128), "float32"),
+ m: T.int64,
+ ):
+ # with T.block("root"):
+ for i0, i1, i2, i3 in T.grid(1, 1, 32, 128):
+ with T.block("rotary"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ rotary[v_i0, v_i1, v_i2, v_i3] = (
+ rxplaceholder_1[m + v_i1 - 1, v_i3] *
rxplaceholder[v_i0, v_i1, v_i2, v_i3]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 1, 32, 128), dtype="float32"),
+ y: R.Tensor((2048, 128), dtype="float32"),
+ len: R.Shape(["m"]),
+ ):
+ m = T.int64()
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.foo,
+ [x, y],
+ out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float32"),
+ tir_vars=R.shape([m]),
+ )
+ R.output(gv)
+ return gv
+
+ Expected = Before
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()