This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 107ef2b65e [Unity][Relax] Memory planning for call_tir_dyn (#14750)
107ef2b65e is described below
commit 107ef2b65eb4b005d7d5380115f31fd99c159ae7
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue May 2 09:31:56 2023 -0400
[Unity][Relax] Memory planning for call_tir_dyn (#14750)
This PR supports memory planning for call_tir_dyn. Previously the pass
did not take call_tir_dyn into account.
One unit test is added, and the effect of this PR is validated on
MLC-LLM.
---
src/relax/transform/static_plan_block_memory.cc | 6 +-
.../test_transform_static_plan_block_memory.py | 65 ++++++++++++++++++++++
2 files changed, 70 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 84e69f3d47..05cf498a4b 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -346,6 +346,7 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
void VisitExpr_(const CallNode* call) final {
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
if (call->op == alloc_tensor_op) {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
@@ -362,7 +363,10 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
// from the arguments.
// - Otherwise, discard the tokens used by the arguments, as there might be
// potential external reference.
- if (IsPrimFuncGlobalVar(call->op) ||
call->op->IsInstance<ExternFuncNode>()) {
+ if (IsPrimFuncGlobalVar(call->op) ||
call->op->IsInstance<ExternFuncNode>() ||
+ call->op == call_tir_dyn_op) {
+ Array<Expr> args =
+ call->op == call_tir_dyn_op ? Downcast<Tuple>(call->args[1])->fields
: call->args;
ICHECK(!block_stack_.empty());
for (const Expr& arg : call->args) {
Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 06fdd04daa..85036db347 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1009,5 +1009,70 @@ def test_tir_var_decreasing_monotone():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_call_tir_dyn():
+ # fmt: off
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def tir_full(var_full: T.handle, n: T.int64):
+ T.evaluate(0)
+
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_upper_bound": {"n": 20}})
+ cls = Module
+ alloc: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+ full: R.Tensor((n,), dtype="float32") = alloc
+ alloc1: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _1: R.Tuple = cls.tir_exp(full, alloc1)
+ lv2: R.Tensor((n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+ lv3: R.Tensor((n,), dtype="float32") = alloc2
+ return lv3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def tir_full(var_full: T.handle, n: T.int64):
+ T.evaluate(0)
+
+ @R.function
+ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_upper_bound": {"n": 20}})
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(R.shape([80]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc: R.Tensor((n,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]),
R.dtype("float32"))
+ _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+ full: R.Tensor((n,), dtype="float32") = alloc
+ storage1: R.Object = R.memory.alloc_storage(R.shape([80]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc1: R.Tensor((n,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]),
R.dtype("float32"))
+ _1: R.Tuple = cls.tir_exp(full, alloc1)
+ __1: R.Tuple = R.memory.kill_tensor(alloc)
+ lv2: R.Tensor((n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+ _1_1: R.Tuple = R.memory.kill_tensor(alloc1)
+ lv3: R.Tensor((n,), dtype="float32") = alloc2
+ _2_1: R.Tuple = R.memory.kill_storage(storage)
+ _3: R.Tuple = R.memory.kill_storage(storage1)
+ return lv3
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()