This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2f7e0d578f [Unity] Ensure memory planning cross-function independence 
(#16318)
2f7e0d578f is described below

commit 2f7e0d578f777a630bb6e4a79d9c2ec52b7be461
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Dec 31 13:10:02 2023 -0500

    [Unity] Ensure memory planning cross-function independence (#16318)
    
    Prior to this PR, the memory planning for different Relax functions
    are not independent -- storage tokens are shared across different
    Relax functions.
    
    This will incur memory overuse sometimes. For example, tensor `A`
    in `func1` has 128 bytes, tensor `B` in `func2` has 2048 bytes.
    If the memory planning decides to share the storage token for `A`
    and `B`, the shared token will have size 2048 bytes.
    
    Consider the case when at runtime only `func1` is executed, and
    `func2` is never invoked. In this case, only 128 bytes for tensor
    `A` is needed, while a total 2048-chunk is allocated in total,
    which is a 16x memory overuse.
    
    This PR makes the memory planning across different Relax function
    independent. That means in the example above, when only `func1`
    is executed, only 128 bytes will be allocated.
---
 src/relax/transform/static_plan_block_memory.cc    |  8 +++
 .../test_transform_static_plan_block_memory.py     | 73 +++++++++++++++++++++-
 2 files changed, 80 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 4a2a1555ff..3873f624ef 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -219,6 +219,12 @@ class TokenAllocator1D {
     available_pool_[token->dtype].insert({token->bytes, token});
   }
 
+  /*! \brief Clear the allocator. */
+  void Clear() {
+    available_pool_.clear();
+    full_pool_.clear();
+  }
+
  private:
   /*! \brief A constant scale representing the token search range. */
   const int match_range_{16};
@@ -569,6 +575,8 @@ class StorageAllocator : public StorageAllocatorBaseVisitor 
{
       if (func == nullptr) {
         continue;
       }
+      // Clear the allocator to make the planning of different functions 
independent.
+      allocator_.Clear();
       this->VisitExpr_(func);
     }
   }
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 0c24f90efc..f12b5b9fc1 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -18,7 +18,9 @@
 import tvm
 import tvm.testing
 from tvm import relax
-from tvm.script import ir as I, relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def test_basic():
@@ -1105,5 +1107,74 @@ def test_call_tir_dyn():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_function_independence():
+    # fmt: off
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def exp(A: T.handle, B: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def func1(x: R.Tensor((8,), dtype="float32")) -> R.Tensor((8,), 
dtype="float32"):
+            R.func_attr({"relax.force_pure": 1})
+            cls = Module
+            alloc: R.Tensor((8,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([8,]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((8,), dtype="float32") = alloc
+            alloc1: R.Tensor((8,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([8,]), dtype="float32", runtime_device_index=0)
+            _1: R.Tuple() = cls.exp(lv, alloc1)
+            gv: R.Tensor((8,), dtype="float32") = alloc1
+            return gv
+
+        @R.function
+        def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            R.func_attr({"relax.force_pure": 1})
+            cls = Module
+            alloc: R.Tensor((10,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10,]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((10,), dtype="float32") = alloc
+            alloc1: R.Tensor((10,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10,]), dtype="float32", runtime_device_index=0)
+            _1: R.Tuple() = cls.exp(lv, alloc1)
+            gv: R.Tensor((10,), dtype="float32") = alloc1
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def exp(A: T.handle, B: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def func1(x: R.Tensor((8,), dtype="float32")) -> R.Tensor((8,), 
dtype="float32"):
+            R.func_attr({"relax.force_pure": 1})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((8,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([8]), 
R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((8,), dtype="float32") = alloc
+            alloc1: R.Tensor((8,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([8]), R.dtype("float32"), R.prim_value(0))
+            _1: R.Tuple = cls.exp(lv, alloc1)
+            gv: R.Tensor((8,), dtype="float32") = alloc1
+            return gv
+
+        @R.function
+        def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            R.func_attr({"relax.force_pure": 1})
+            cls = Expected
+            storage1: R.Object = R.memory.alloc_storage(R.shape([40]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((10,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([10]), 
R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((10,), dtype="float32") = alloc
+            alloc1: R.Tensor((10,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+            _1: R.Tuple = cls.exp(lv, alloc1)
+            gv: R.Tensor((10,), dtype="float32") = alloc1
+            return gv
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to