This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d6c8f7933b [Unity] Preserve symbolic var args when applying call_tir 
(#14555)
d6c8f7933b is described below

commit d6c8f7933bf5cd2d92d8e048d10cdaf18db6e3b8
Author: Hongyi Jin <[email protected]>
AuthorDate: Mon Apr 10 04:36:44 2023 -0400

    [Unity] Preserve symbolic var args when applying call_tir (#14555)
    
    Currently fuse_tir will remove the symbolic var args of call_tir. This PR 
fixes this behavior.
---
 src/relax/transform/fuse_tir.cc               |  4 ++-
 tests/python/relax/test_transform_fuse_tir.py | 40 +++++++++++++++++++++++++++
 2 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 5ddda93705..f5a05a0539 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -794,7 +794,9 @@ class TIRFuseMutator : public ExprMutator {
       if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
         tir::PrimFunc func = 
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
         GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
-        return Call(call->op, {new_gv, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
+        Array<Expr> new_args = call->args;
+        new_args.Set(0, new_gv);
+        return Call(call->op, new_args, call->attrs, call->sinfo_args, 
call->span);
       }
     }
 
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index bdbd9be966..47480346e7 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -793,5 +793,45 @@ def test_symbolic_shape_aware_fuse_with_allocation():
     _check(Before, Expected)
 
 
+def test_symbolic_var_in_call_tir_args():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def foo(
+            rxplaceholder: T.Buffer((1, 1, 32, 128), "float32"),
+            rxplaceholder_1: T.Buffer((2048, 128), "float32"),
+            rotary: T.Buffer((1, 1, 32, 128), "float32"),
+            m: T.int64,
+        ):
+            # with T.block("root"):
+            for i0, i1, i2, i3 in T.grid(1, 1, 32, 128):
+                with T.block("rotary"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    rotary[v_i0, v_i1, v_i2, v_i3] = (
+                        rxplaceholder_1[m + v_i1 - 1, v_i3] * 
rxplaceholder[v_i0, v_i1, v_i2, v_i3]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((1, 1, 32, 128), dtype="float32"),
+            y: R.Tensor((2048, 128), dtype="float32"),
+            len: R.Shape(["m"]),
+        ):
+            m = T.int64()
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.foo,
+                    [x, y],
+                    out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float32"),
+                    tir_vars=R.shape([m]),
+                )
+                R.output(gv)
+            return gv
+
+    Expected = Before
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to