This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 18472a779d [Unity][Transform] Memory plan across the IRModule (#14220)
18472a779d is described below

commit 18472a779d67835a4fafd94f157a2d97a8b10704
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Mar 7 08:37:18 2023 -0500

    [Unity][Transform] Memory plan across the IRModule (#14220)
    
    Previously the static memory planning pass only works at single function
    level - each function inside the an IRModule will be independently
    planned. This is not perfect for the VM to reuse allocated memory across
    different functions.
    
    Therefore, this PR turns the static memory planning pass into a module
    pass. Now the plan is done across the IRModule, so that memory alloation
    in different functions can share the same storage token when planning.
    With this PR, it is hopeful that the VM will find more opportunities of
    memory reuse.
---
 src/relax/transform/static_plan_block_memory.cc    | 101 +++++++++++------
 .../test_transform_static_plan_block_memory.py     | 123 ++++++++++++++++++++-
 2 files changed, 191 insertions(+), 33 deletions(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 8b7adae246..ba5177fec0 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -86,11 +86,6 @@ class StorageTokenNode : public Object {
   DataType dtype;
   /*! \brief The storage id, reserved for debug and demo use. */
   int storage_id{-1};
-  /*!
-   * \brief The variable corresponding to the allocated storage, which is 
NullOpt
-   * before definition.
-   */
-  Optional<Var> storage{NullOpt};
 
   static constexpr const char* _type_key = "relax.transform.StorageToken";
   TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
@@ -287,23 +282,36 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
  */
 class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
  public:
-  explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
-
   /*!
    * \brief The entry of the initialization.
+   * \param mod The IRModule to be planned
    * \return The mapping from each Expr to the token it uses.
    */
-  std::unordered_map<const ExprNode*, Tokens> Initialize(const Function& func) 
{
+  static std::unordered_map<const ExprNode*, Tokens> Initialize(const 
IRModule& mod) {
+    StorageAllocatorInit initializer(mod);
+
+    for (auto it : mod->functions) {
+      const auto* func = it.second.as<FunctionNode>();
+      if (func == nullptr) {
+        continue;
+      }
+      initializer(GetRef<Function>(func));
+    }
+    return initializer.token_map_;
+  }
+
+ private:
+  using ExprVisitor::VisitExpr_;
+
+  explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
+
+  void VisitExpr_(const FunctionNode* func) final {
     // Recurse into the function to get its tokens.
     Tokens body_tokens = GetTokens(func->body);
     // Discard the tokens used by the function return value, as they are 
external referenced.
     DiscardTokensIn(body_tokens);
-    return this->token_map_;
   }
 
- private:
-  using ExprVisitor::VisitExpr_;
-
   void VisitExpr_(const CallNode* call) final {
     static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
     if (call->op == alloc_tensor_op) {
@@ -501,6 +509,16 @@ class StorageAllocator : public 
StorageAllocatorBaseVisitor {
     this->token_map_ = std::move(token_map);
   }
 
+  void Allocate(const IRModule& mod) {
+    for (auto it : mod->functions) {
+      const auto* func = it.second.as<FunctionNode>();
+      if (func == nullptr) {
+        continue;
+      }
+      this->VisitExpr_(func);
+    }
+  }
+
   /*!
    * \brief The mapping from each `builtin.alloc_tensor` to its corresponding
    * underlying storage token that it is using.
@@ -629,14 +647,29 @@ class StorageAllocator : public 
StorageAllocatorBaseVisitor {
 class StorageAllocationRewriter : public ExprMutator {
  public:
   explicit StorageAllocationRewriter(
-      std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token,
+      IRModule mod, std::unordered_map<const ExprNode*, StorageToken> 
alloc_tensor2token,
       std::unordered_map<const ExprNode*, std::vector<Var>> 
expr2killed_tensors,
       std::unordered_map<const BindingBlockNode*, std::vector<const 
StorageTokenNode*>>
           block2tokens)
-      : alloc_tensor2token_(std::move(alloc_tensor2token)),
+      : ExprMutator(std::move(mod)),
+        alloc_tensor2token_(std::move(alloc_tensor2token)),
         expr2killed_tensors_(std::move(expr2killed_tensors)),
         block2tokens_(std::move(block2tokens)) {}
 
+  IRModule Rewrite() {
+    const IRModule& mod = builder_->GetContextIRModule();
+    for (const auto& [gv, base_func] : mod->functions) {
+      const auto* func_ = base_func.as<FunctionNode>();
+      if (func_ == nullptr) {
+        continue;
+      }
+      token2storage_var_.clear();
+      Function func = Downcast<Function>(this->VisitExpr_(func_));
+      builder_->UpdateFunction(gv, func);
+    }
+    return builder_->GetContextIRModule();
+  }
+
  private:
   using ExprMutator::VisitExpr_;
 
@@ -648,9 +681,10 @@ class StorageAllocationRewriter : public ExprMutator {
 
     // Insert `memory.kill_storage` for the storage tokens allocated inside 
this block.
     for (const StorageTokenNode* token : block2tokens_[block]) {
-      ICHECK(token->storage.defined());
+      auto it_token = token2storage_var_.find(token);
+      ICHECK(it_token != token2storage_var_.end());
       static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
-      this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}), 
/*name_hint=*/"_");
+      this->builder_->Emit(Call(mem_kill_storage, {it_token->second}), 
/*name_hint=*/"_");
     }
 
     BindingBlock new_block = builder_->EndBlock();
@@ -682,7 +716,9 @@ class StorageAllocationRewriter : public ExprMutator {
       // If the token is visited for the first time, create a storage variable 
using
       // `memory.alloc_storage` for it.
       StorageToken token = it->second;
-      if (!token->storage.defined()) {
+      Var storage_var{nullptr};
+      auto it_token = token2storage_var_.find(token.get());
+      if (it_token == token2storage_var_.end()) {
         static const Op& mem_alloc_storage = 
Op::Get("relax.memory.alloc_storage");
         ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
         PrimValue virtual_device_index = runtime_device_index;
@@ -692,15 +728,17 @@ class StorageAllocationRewriter : public ExprMutator {
             mem_alloc_storage,
             {std::move(size), virtual_device_index, StringImm(storage_scope), 
DataTypeImm(dtype)},
             Attrs());
-        token->storage = builder_->Emit(alloc_storage, "storage");
+        storage_var = builder_->Emit(alloc_storage, "storage");
+        token2storage_var_[token.get()] = storage_var;
+      } else {
+        storage_var = it_token->second;
       }
 
       // And always create a `memory.alloc_tensor` for the old 
`builtin.alloc_tensor`.
       static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
       PrimValue offset = PrimValue::Int64(0);
       DataType dtype = sinfo->dtype;
-      return Call(mem_alloc_tensor,
-                  {token->storage.value(), offset, sinfo->shape.value(), 
DataTypeImm(dtype)},
+      return Call(mem_alloc_tensor, {storage_var, offset, 
sinfo->shape.value(), DataTypeImm(dtype)},
                   Attrs());
     }
 
@@ -716,31 +754,30 @@ class StorageAllocationRewriter : public ExprMutator {
   std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors_;
   /*! \brief The mapping from each binding block to the storage tokens that 
are create inside. */
   std::unordered_map<const BindingBlockNode*, std::vector<const 
StorageTokenNode*>> block2tokens_;
+  /*! \brief The mapping from each token to its corresponding storage var in 
each function. */
+  std::unordered_map<const StorageTokenNode*, Var> token2storage_var_;
 };
 
-Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) {
+IRModule StaticPlanBlockMemory(IRModule mod) {
   // Step 1. Initialize.
-  StorageAllocatorInit initializer(ctx_mod);
-  std::unordered_map<const ExprNode*, Tokens> token_map = 
initializer.Initialize(func);
+  std::unordered_map<const ExprNode*, Tokens> token_map = 
StorageAllocatorInit::Initialize(mod);
   // Step 2. Collect the memory allocation info.
   StorageAllocator allocator(std::move(token_map));
-  allocator(func);
+  allocator.Allocate(mod);
   // Step 3. Rewrite the function.
-  StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token),
+  StorageAllocationRewriter rewriter(std::move(mod),  //
+                                     std::move(allocator.alloc_tensor2token),
                                      std::move(allocator.expr2killed_tensors),
                                      std::move(allocator.block2tokens));
-  func = Downcast<Function>(rewriter(func));
-  return func;
+  return rewriter.Rewrite();
 }
 
 namespace transform {
 
 Pass StaticPlanBlockMemory() {
-  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
-      [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(StaticPlanBlockMemory(std::move(f), m));
-      };
-  return CreateFunctionPass(pass_func, /*opt_level=*/0, 
"StaticPlanBlockMemory", {});
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return 
relax::StaticPlanBlockMemory(std::move(m)); };
+  return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", 
{});
 }
 
 
TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory);
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 1b556139cc..2f04e74062 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -18,7 +18,7 @@
 import tvm
 import tvm.testing
 from tvm import relax
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
 
 
 def test_basic():
@@ -608,5 +608,126 @@ def test_reshape_param():
     tvm.ir.assert_structural_equal(mod, Module)
 
 
+def test_multiple_functions():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add(
+            A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+        ):
+            T.evaluate(0)
+
+        @T.prim_func
+        def add1(
+            A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+        ):
+            T.evaluate(0)
+
+        @R.function
+        def func1(
+            x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="int32")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="float32", runtime_device_index=0
+            )
+            _: R.Tuple() = add(x, x, alloc)
+            gv: R.Tensor((2, 3), dtype="float32") = alloc
+            alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="int32", runtime_device_index=0
+            )
+            _1: R.Tuple() = add1(y, y, alloc1)
+            gv1: R.Tensor((2, 3), dtype="int32") = alloc1
+            return x
+
+        @R.function
+        def func2(
+            x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="float32", runtime_device_index=0
+            )
+            _: R.Tuple() = add(x, x, alloc)
+            gv: R.Tensor((2, 3), dtype="float32") = alloc
+            alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="float32", runtime_device_index=0
+            )
+            _1: R.Tuple() = add(y, y, alloc1)
+            gv1: R.Tensor((2, 3), dtype="float32") = alloc1
+            return x
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(
+            A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+        ):
+            T.evaluate(0)
+
+        @T.prim_func
+        def add1(
+            A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+        ):
+            T.evaluate(0)
+
+        @R.function
+        def func1(
+            x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="int32")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([24]), virtual_device_index=0, storage_scope="global", 
dtype="float32"
+            )
+            alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([2, 3]), dtype="float32"
+            )
+            _: R.Tuple() = add(x, x, alloc)
+            _1: R.Tuple() = R.memory.kill_tensor(alloc)
+            gv1: R.Tensor((2, 3), dtype="float32") = alloc
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([24]), virtual_device_index=0, storage_scope="global", 
dtype="int32"
+            )
+            alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor(
+                storage1, 0, R.shape([2, 3]), dtype="int32"
+            )
+            _2: R.Tuple() = add1(y, y, alloc1)
+            _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+            gv12: R.Tensor((2, 3), dtype="int32") = alloc1
+            _5: R.Tuple() = R.memory.kill_storage(storage)
+            _4: R.Tuple() = R.memory.kill_storage(storage1)
+            return x
+
+        @R.function
+        def func2(
+            x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([24]), virtual_device_index=0, storage_scope="global", 
dtype="float32"
+            )
+            alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([2, 3]), dtype="float32"
+            )
+            _: R.Tuple() = add(x, x, alloc)
+            _1: R.Tuple() = R.memory.kill_tensor(alloc)
+            gv1: R.Tensor((2, 3), dtype="float32") = alloc
+            alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([2, 3]), dtype="float32"
+            )
+            _2: R.Tuple() = add(y, y, alloc1)
+            _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+            gv12: R.Tensor((2, 3), dtype="float32") = alloc1
+            _4: R.Tuple() = R.memory.kill_storage(storage)
+            return x
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to