This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 465994eb05 [Unity] Handle extern func calls in static memory planning 
(#14419)
465994eb05 is described below

commit 465994eb058e820a59e6d419a2aa87a4e37f897f
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 29 02:24:23 2023 -0700

    [Unity] Handle extern func calls in static memory planning (#14419)
---
 src/relax/transform/static_plan_block_memory.cc    |  6 ++--
 .../test_transform_static_plan_block_memory.py     | 42 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 3f0dfcf149..952513db4c 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -325,10 +325,12 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     }
 
     // - Increase the reference counters of the arguments when the callee is
-    // a PrimFunc of the context module.
+    // a PrimFunc of the context module or an external function via 
'call_packed'.
+    // It assumes external function calls via 'call_packed' do not retain 
memory
+    // from the arguments.
     // - Otherwise, discard the tokens used by the arguments, as there might be
     // potential external reference.
-    if (IsPrimFuncGlobalVar(call->op)) {
+    if (IsPrimFuncGlobalVar(call->op) || 
call->op->IsInstance<ExternFuncNode>()) {
       ICHECK(!block_stack_.empty());
       for (const Expr& arg : call->args) {
         Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 5198d9e075..521fcc1924 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -586,6 +586,48 @@ def test_call_func_other_than_primfunc():
     tvm.ir.assert_structural_equal(mod, Module)
 
 
+def test_call_packed_external_func():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")):
+            alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="float32", runtime_device_index=0
+            )
+            _ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()])
+            y: R.Tensor((2, 3), dtype="float32") = alloc
+            alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="float32", runtime_device_index=0
+            )
+            _1 = R.call_packed("extern_func", y, alloc1, 
sinfo_args=[R.Tuple()])
+            z: R.Tensor((2, 3), dtype="float32") = alloc1
+            return z
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([24]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32")
+            )
+            _: R.Tuple = R.call_packed("extern_func", x, alloc, 
sinfo_args=(R.Tuple(),))
+            y: R.Tensor((2, 3), dtype="float32") = alloc
+            alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), R.dtype("float32"), R.prim_value(0)
+            )
+            _1: R.Tuple = R.call_packed("extern_func", y, alloc1, 
sinfo_args=(R.Tuple(),))
+            _2: R.Tuple = R.memory.kill_tensor(alloc)
+            z: R.Tensor((2, 3), dtype="float32") = alloc1
+            _3: R.Tuple = R.memory.kill_storage(storage)
+            return z
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_symbolic_shape():
     @tvm.script.ir_module
     class Module:

Reply via email to