This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 18472a779d [Unity][Transform] Memory plan across the IRModule (#14220)
18472a779d is described below
commit 18472a779d67835a4fafd94f157a2d97a8b10704
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Mar 7 08:37:18 2023 -0500
[Unity][Transform] Memory plan across the IRModule (#14220)
Previously the static memory planning pass only works at single function
level - each function inside the an IRModule will be independently
planned. This is not perfect for the VM to reuse allocated memory across
different functions.
Therefore, this PR turns the static memory planning pass into a module
pass. Now the plan is done across the IRModule, so that memory alloation
in different functions can share the same storage token when planning.
With this PR, it is hopeful that the VM will find more opportunities of
memory reuse.
---
src/relax/transform/static_plan_block_memory.cc | 101 +++++++++++------
.../test_transform_static_plan_block_memory.py | 123 ++++++++++++++++++++-
2 files changed, 191 insertions(+), 33 deletions(-)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 8b7adae246..ba5177fec0 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -86,11 +86,6 @@ class StorageTokenNode : public Object {
DataType dtype;
/*! \brief The storage id, reserved for debug and demo use. */
int storage_id{-1};
- /*!
- * \brief The variable corresponding to the allocated storage, which is
NullOpt
- * before definition.
- */
- Optional<Var> storage{NullOpt};
static constexpr const char* _type_key = "relax.transform.StorageToken";
TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
@@ -287,23 +282,36 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
*/
class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
public:
- explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
-
/*!
* \brief The entry of the initialization.
+ * \param mod The IRModule to be planned
* \return The mapping from each Expr to the token it uses.
*/
- std::unordered_map<const ExprNode*, Tokens> Initialize(const Function& func)
{
+ static std::unordered_map<const ExprNode*, Tokens> Initialize(const
IRModule& mod) {
+ StorageAllocatorInit initializer(mod);
+
+ for (auto it : mod->functions) {
+ const auto* func = it.second.as<FunctionNode>();
+ if (func == nullptr) {
+ continue;
+ }
+ initializer(GetRef<Function>(func));
+ }
+ return initializer.token_map_;
+ }
+
+ private:
+ using ExprVisitor::VisitExpr_;
+
+ explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
+
+ void VisitExpr_(const FunctionNode* func) final {
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are
external referenced.
DiscardTokensIn(body_tokens);
- return this->token_map_;
}
- private:
- using ExprVisitor::VisitExpr_;
-
void VisitExpr_(const CallNode* call) final {
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
if (call->op == alloc_tensor_op) {
@@ -501,6 +509,16 @@ class StorageAllocator : public
StorageAllocatorBaseVisitor {
this->token_map_ = std::move(token_map);
}
+ void Allocate(const IRModule& mod) {
+ for (auto it : mod->functions) {
+ const auto* func = it.second.as<FunctionNode>();
+ if (func == nullptr) {
+ continue;
+ }
+ this->VisitExpr_(func);
+ }
+ }
+
/*!
* \brief The mapping from each `builtin.alloc_tensor` to its corresponding
* underlying storage token that it is using.
@@ -629,14 +647,29 @@ class StorageAllocator : public
StorageAllocatorBaseVisitor {
class StorageAllocationRewriter : public ExprMutator {
public:
explicit StorageAllocationRewriter(
- std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token,
+ IRModule mod, std::unordered_map<const ExprNode*, StorageToken>
alloc_tensor2token,
std::unordered_map<const ExprNode*, std::vector<Var>>
expr2killed_tensors,
std::unordered_map<const BindingBlockNode*, std::vector<const
StorageTokenNode*>>
block2tokens)
- : alloc_tensor2token_(std::move(alloc_tensor2token)),
+ : ExprMutator(std::move(mod)),
+ alloc_tensor2token_(std::move(alloc_tensor2token)),
expr2killed_tensors_(std::move(expr2killed_tensors)),
block2tokens_(std::move(block2tokens)) {}
+ IRModule Rewrite() {
+ const IRModule& mod = builder_->GetContextIRModule();
+ for (const auto& [gv, base_func] : mod->functions) {
+ const auto* func_ = base_func.as<FunctionNode>();
+ if (func_ == nullptr) {
+ continue;
+ }
+ token2storage_var_.clear();
+ Function func = Downcast<Function>(this->VisitExpr_(func_));
+ builder_->UpdateFunction(gv, func);
+ }
+ return builder_->GetContextIRModule();
+ }
+
private:
using ExprMutator::VisitExpr_;
@@ -648,9 +681,10 @@ class StorageAllocationRewriter : public ExprMutator {
// Insert `memory.kill_storage` for the storage tokens allocated inside
this block.
for (const StorageTokenNode* token : block2tokens_[block]) {
- ICHECK(token->storage.defined());
+ auto it_token = token2storage_var_.find(token);
+ ICHECK(it_token != token2storage_var_.end());
static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
- this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}),
/*name_hint=*/"_");
+ this->builder_->Emit(Call(mem_kill_storage, {it_token->second}),
/*name_hint=*/"_");
}
BindingBlock new_block = builder_->EndBlock();
@@ -682,7 +716,9 @@ class StorageAllocationRewriter : public ExprMutator {
// If the token is visited for the first time, create a storage variable
using
// `memory.alloc_storage` for it.
StorageToken token = it->second;
- if (!token->storage.defined()) {
+ Var storage_var{nullptr};
+ auto it_token = token2storage_var_.find(token.get());
+ if (it_token == token2storage_var_.end()) {
static const Op& mem_alloc_storage =
Op::Get("relax.memory.alloc_storage");
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
PrimValue virtual_device_index = runtime_device_index;
@@ -692,15 +728,17 @@ class StorageAllocationRewriter : public ExprMutator {
mem_alloc_storage,
{std::move(size), virtual_device_index, StringImm(storage_scope),
DataTypeImm(dtype)},
Attrs());
- token->storage = builder_->Emit(alloc_storage, "storage");
+ storage_var = builder_->Emit(alloc_storage, "storage");
+ token2storage_var_[token.get()] = storage_var;
+ } else {
+ storage_var = it_token->second;
}
// And always create a `memory.alloc_tensor` for the old
`builtin.alloc_tensor`.
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
PrimValue offset = PrimValue::Int64(0);
DataType dtype = sinfo->dtype;
- return Call(mem_alloc_tensor,
- {token->storage.value(), offset, sinfo->shape.value(),
DataTypeImm(dtype)},
+ return Call(mem_alloc_tensor, {storage_var, offset,
sinfo->shape.value(), DataTypeImm(dtype)},
Attrs());
}
@@ -716,31 +754,30 @@ class StorageAllocationRewriter : public ExprMutator {
std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors_;
/*! \brief The mapping from each binding block to the storage tokens that
are create inside. */
std::unordered_map<const BindingBlockNode*, std::vector<const
StorageTokenNode*>> block2tokens_;
+ /*! \brief The mapping from each token to its corresponding storage var in
each function. */
+ std::unordered_map<const StorageTokenNode*, Var> token2storage_var_;
};
-Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) {
+IRModule StaticPlanBlockMemory(IRModule mod) {
// Step 1. Initialize.
- StorageAllocatorInit initializer(ctx_mod);
- std::unordered_map<const ExprNode*, Tokens> token_map =
initializer.Initialize(func);
+ std::unordered_map<const ExprNode*, Tokens> token_map =
StorageAllocatorInit::Initialize(mod);
// Step 2. Collect the memory allocation info.
StorageAllocator allocator(std::move(token_map));
- allocator(func);
+ allocator.Allocate(mod);
// Step 3. Rewrite the function.
- StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token),
+ StorageAllocationRewriter rewriter(std::move(mod), //
+ std::move(allocator.alloc_tensor2token),
std::move(allocator.expr2killed_tensors),
std::move(allocator.block2tokens));
- func = Downcast<Function>(rewriter(func));
- return func;
+ return rewriter.Rewrite();
}
namespace transform {
Pass StaticPlanBlockMemory() {
- runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(StaticPlanBlockMemory(std::move(f), m));
- };
- return CreateFunctionPass(pass_func, /*opt_level=*/0,
"StaticPlanBlockMemory", {});
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return
relax::StaticPlanBlockMemory(std::move(m)); };
+ return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory",
{});
}
TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory);
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 1b556139cc..2f04e74062 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -18,7 +18,7 @@
import tvm
import tvm.testing
from tvm import relax
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
def test_basic():
@@ -608,5 +608,126 @@ def test_reshape_param():
tvm.ir.assert_structural_equal(mod, Module)
+def test_multiple_functions():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def add1(
+ A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def func1(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="int32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ gv: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="int32", runtime_device_index=0
+ )
+ _1: R.Tuple() = add1(y, y, alloc1)
+ gv1: R.Tensor((2, 3), dtype="int32") = alloc1
+ return x
+
+ @R.function
+ def func2(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ gv: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1: R.Tuple() = add(y, y, alloc1)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc1
+ return x
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def add1(
+ A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def func1(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="int32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ _1: R.Tuple() = R.memory.kill_tensor(alloc)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="int32"
+ )
+ alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([2, 3]), dtype="int32"
+ )
+ _2: R.Tuple() = add1(y, y, alloc1)
+ _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+ gv12: R.Tensor((2, 3), dtype="int32") = alloc1
+ _5: R.Tuple() = R.memory.kill_storage(storage)
+ _4: R.Tuple() = R.memory.kill_storage(storage1)
+ return x
+
+ @R.function
+ def func2(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ _1: R.Tuple() = R.memory.kill_tensor(alloc)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _2: R.Tuple() = add(y, y, alloc1)
+ _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+ gv12: R.Tensor((2, 3), dtype="float32") = alloc1
+ _4: R.Tuple() = R.memory.kill_storage(storage)
+ return x
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()