This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 107ef2b65e [Unity][Relax] Memory planning for call_tir_dyn (#14750)
107ef2b65e is described below

commit 107ef2b65eb4b005d7d5380115f31fd99c159ae7
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue May 2 09:31:56 2023 -0400

    [Unity][Relax] Memory planning for call_tir_dyn (#14750)
    
    This PR supports memory planning for call_tir_dyn. Previously the pass
    did not take call_tir_dyn into account.
    
    One unit test is added, and the effect of this PR is validated on
    MLC-LLM.
---
 src/relax/transform/static_plan_block_memory.cc    |  6 +-
 .../test_transform_static_plan_block_memory.py     | 65 ++++++++++++++++++++++
 2 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 84e69f3d47..05cf498a4b 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -346,6 +346,7 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
 
   void VisitExpr_(const CallNode* call) final {
     static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+    static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
     if (call->op == alloc_tensor_op) {
       // Create a storage token for builtin alloc_tensor.
       this->CreateToken(call);
@@ -362,7 +363,10 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     // from the arguments.
     // - Otherwise, discard the tokens used by the arguments, as there might be
     // potential external reference.
-    if (IsPrimFuncGlobalVar(call->op) || 
call->op->IsInstance<ExternFuncNode>()) {
+    if (IsPrimFuncGlobalVar(call->op) || 
call->op->IsInstance<ExternFuncNode>() ||
+        call->op == call_tir_dyn_op) {
+      Array<Expr> args =
+          call->op == call_tir_dyn_op ? Downcast<Tuple>(call->args[1])->fields 
: call->args;
       ICHECK(!block_stack_.empty());
       for (const Expr& arg : call->args) {
         Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 06fdd04daa..85036db347 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1009,5 +1009,70 @@ def test_tir_var_decreasing_monotone():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_call_tir_dyn():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def tir_full(var_full: T.handle, n: T.int64):
+            T.evaluate(0)
+
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 20}})
+            cls = Module
+            alloc: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+            full: R.Tensor((n,), dtype="float32") = alloc
+            alloc1: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _1: R.Tuple = cls.tir_exp(full, alloc1)
+            lv2: R.Tensor((n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+            lv3: R.Tensor((n,), dtype="float32") = alloc2
+            return lv3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def tir_full(var_full: T.handle, n: T.int64):
+            T.evaluate(0)
+
+        @R.function
+        def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 20}})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([80]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), 
R.dtype("float32"))
+            _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+            full: R.Tensor((n,), dtype="float32") = alloc
+            storage1: R.Object = R.memory.alloc_storage(R.shape([80]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.tir_exp(full, alloc1)
+            __1: R.Tuple = R.memory.kill_tensor(alloc)
+            lv2: R.Tensor((n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+            _1_1: R.Tuple = R.memory.kill_tensor(alloc1)
+            lv3: R.Tensor((n,), dtype="float32") = alloc2
+            _2_1: R.Tuple = R.memory.kill_storage(storage)
+            _3: R.Tuple = R.memory.kill_storage(storage1)
+            return lv3
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to