This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 465994eb05 [Unity] Handle extern func calls in static memory planning
(#14419)
465994eb05 is described below
commit 465994eb058e820a59e6d419a2aa87a4e37f897f
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 29 02:24:23 2023 -0700
[Unity] Handle extern func calls in static memory planning (#14419)
---
src/relax/transform/static_plan_block_memory.cc | 6 ++--
.../test_transform_static_plan_block_memory.py | 42 ++++++++++++++++++++++
2 files changed, 46 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 3f0dfcf149..952513db4c 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -325,10 +325,12 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
}
// - Increase the reference counters of the arguments when the callee is
- // a PrimFunc of the context module.
+ // a PrimFunc of the context module or an external function via
'call_packed'.
+ // It assumes external function calls via 'call_packed' do not retain
memory
+ // from the arguments.
// - Otherwise, discard the tokens used by the arguments, as there might be
// potential external reference.
- if (IsPrimFuncGlobalVar(call->op)) {
+ if (IsPrimFuncGlobalVar(call->op) ||
call->op->IsInstance<ExternFuncNode>()) {
ICHECK(!block_stack_.empty());
for (const Expr& arg : call->args) {
Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 5198d9e075..521fcc1924 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -586,6 +586,48 @@ def test_call_func_other_than_primfunc():
tvm.ir.assert_structural_equal(mod, Module)
+def test_call_packed_external_func():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()])
+ y: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1 = R.call_packed("extern_func", y, alloc1,
sinfo_args=[R.Tuple()])
+ z: R.Tensor((2, 3), dtype="float32") = alloc1
+ return z
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32")
+ )
+ _: R.Tuple = R.call_packed("extern_func", x, alloc,
sinfo_args=(R.Tuple(),))
+ y: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), R.dtype("float32"), R.prim_value(0)
+ )
+ _1: R.Tuple = R.call_packed("extern_func", y, alloc1,
sinfo_args=(R.Tuple(),))
+ _2: R.Tuple = R.memory.kill_tensor(alloc)
+ z: R.Tensor((2, 3), dtype="float32") = alloc1
+ _3: R.Tuple = R.memory.kill_storage(storage)
+ return z
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_symbolic_shape():
@tvm.script.ir_module
class Module: