This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5e8ad8e90a [Unity][Pass] Block-level static memory planning (#14038)
5e8ad8e90a is described below
commit 5e8ad8e90ab146532ae98e9a2fe4a67455f558cc
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Feb 18 11:38:32 2023 -0500
[Unity][Pass] Block-level static memory planning (#14038)
This PR introduces the static memory planning pass on binding block level,
as well as an analysis function that estimate the memory usage after the memory
planning pass. It supports the following features: nested-tuples, reuse memory
of the input of reshape ops, an estimator that returns total memory size needed
to be allocated before and after memory planning, as well as the number of
tensors / memory blocks to be allocated before and after memory planning.
The estimation is static -- it does not consider control flows (such as
“if” and cross-function calls). It simply accumulates the size of every
alloc_tensor and alloc_storage.
We will produce “`relax.memory.alloc_tensor/storage`” as the results
produced by memory planning.
---
include/tvm/relax/transform.h | 9 +
python/tvm/relax/analysis/__init__.py | 1 +
python/tvm/relax/analysis/estimate_memory_usage.py | 164 +++++
python/tvm/relax/transform/transform.py | 11 +
python/tvm/relax/vm.py | 1 +
src/relax/transform/static_plan_block_memory.cc | 750 +++++++++++++++++++++
.../relax/test_analysis_estimate_memory_usage.py | 125 ++++
.../test_transform_static_plan_block_memory.py | 612 +++++++++++++++++
8 files changed, 1673 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 8b7c7880b9..1934a9f9f2 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -95,6 +95,15 @@ TVM_DLL Pass CallTIRRewrite();
*/
TVM_DLL Pass RewriteDataflowReshape();
+/*!
+ * \brief The static memory planning pass on BindingBlock level.
+ * The pass will reuse allocated memory to its best effort, in order to
+ * reduce the total amount of allocated memory size.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass StaticPlanBlockMemory();
+
/*!
* \brief Bind params of function of the module to constant tensors.
*
diff --git a/python/tvm/relax/analysis/__init__.py
b/python/tvm/relax/analysis/__init__.py
index cc0089ff31..7ba56ff408 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -18,3 +18,4 @@
"""Relax IR analysis. """
from .analysis import *
+from .estimate_memory_usage import estimate_memory_usage
diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py
b/python/tvm/relax/analysis/estimate_memory_usage.py
new file mode 100644
index 0000000000..55f82740ec
--- /dev/null
+++ b/python/tvm/relax/analysis/estimate_memory_usage.py
@@ -0,0 +1,164 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=abstract-method,unused-argument
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from typing import Union
+import tvm
+from tvm.ir import Op
+from tvm.ir.module import IRModule
+
+from ..expr import Call, Expr, Function, ShapeExpr
+from ..expr_functor import visitor, PyExprVisitor
+
+
+def estimate_memory_usage(mod: Union[IRModule, Function]) -> str:
+ """Analysis function that estimates the memory usage of Relax functions
+ in an IRModule. The estimation includes the total memory size needed to
+ be allocated before and after memory planning.
+
+ The result might be over-estimated, as the estimation is static, which
+ does not consider control flows (such as "if" and cross-function calls).
+ It simply accumulates the size of every alloc_tensor and alloc_storage.
+
+ This analysis function is used to demonstrate the effect of memory
+ planning.
+
+ Parameters
+ ----------
+ mod : Union[IRModule, Function]
+ The input IRModule whose functions inside are to be analyzed.
+ If the input is a Function, we will wrap it with a IRModule, with
+ the function named "main".
+
+ Returns
+ -------
+ est : str
+ The estimation information, in the form of a string.
+
+ Notes
+ -----
+ We regards "relax.memory.alloc_tensor/storage" as the results produced by
memory planning.
+ """
+
+ @visitor
+ class MemoryEstimator(PyExprVisitor):
+ """The IR visitor which estimates the memory usage of each Relax
function.
+
+ Attributes
+ ----------
+ total_alloc_tensor_mem : int
+ The total memory size of alloc_tensor, in bytes.
+
+ total_const_size_tensor_num : int
+ The number of constant-size tensors.
+
+ total_dyn_size_tensor_num : int
+ The number of dynamic-size tensors.
+
+ planned_alloc_mem : int
+ The total memory size of memory.alloc_storage after memory
planning, in bytes.
+
+ planned_mem_num : int
+ The number of memory.alloc_storages.
+ """
+
+ total_alloc_tensor_mem: int
+ total_const_size_tensor_num: int
+ total_dyn_size_tensor_num: int
+ planned_alloc_mem: int
+ planned_mem_num: int
+ builtin_alloc_tensor_op = Op.get("relax.builtin.alloc_tensor")
+ memory_alloc_tensor_op = Op.get("relax.memory.alloc_tensor")
+ memory_alloc_storage_op = Op.get("relax.memory.alloc_storage")
+
+ def estimate(self, mod: IRModule) -> str:
+ estimation: str = ""
+ for global_var, func in mod.functions.items():
+ if not isinstance(func, Function):
+ continue
+
+ self.cleanup()
+ self.visit_expr(func)
+ estimation += self.generate_est_string(global_var.name_hint)
+
+ if estimation != "":
+ estimation = "Memory usage estimation:\n" + estimation
+ return estimation
+
+ def cleanup(self) -> None:
+ self.total_alloc_tensor_mem = 0
+ self.total_const_size_tensor_num = 0
+ self.total_dyn_size_tensor_num = 0
+ self.planned_alloc_mem = 0
+ self.planned_mem_num = 0
+
+ def visit_call_(self, call: Call) -> None: # pylint:
disable=arguments-differ
+ if call.op == self.builtin_alloc_tensor_op:
+ self.accumulate_tensor_alloc(shape=call.args[0],
dtype_str=call.args[1].value)
+ elif call.op == self.memory_alloc_tensor_op:
+ self.accumulate_tensor_alloc(shape=call.args[2],
dtype_str=call.args[3].value)
+ elif call.op == self.memory_alloc_storage_op:
+ self.accumulate_storage_alloc(size=call.args[0])
+
+ def accumulate_tensor_alloc(self, shape: Expr, dtype_str: str) -> None:
+ if not isinstance(shape, ShapeExpr):
+ raise TypeError(
+ "The shape of relax.builtin.alloc_tensor and "
+ "relax.memory.alloc_tensor is expected to be ShapeExpr"
+ )
+ size: int = 1
+ for dim_len in shape.values:
+ if not isinstance(dim_len, tvm.tir.IntImm):
+ self.total_dyn_size_tensor_num += 1
+ return
+ size *= dim_len.value
+
+ dtype = tvm.DataType(dtype_str)
+ self.total_const_size_tensor_num += 1
+ self.total_alloc_tensor_mem += (size * dtype.bits * dtype.lanes +
7) // 8
+
+ def accumulate_storage_alloc(self, size: Expr) -> None:
+ if not isinstance(size, ShapeExpr):
+ raise TypeError(
+ "The size of relax.memory.alloc_storage is expected to be
ShapeExpr"
+ )
+
+ self.planned_mem_num += 1
+ self.planned_alloc_mem += size.values[0].value
+
+ def generate_est_string(self, func_name: str) -> str:
+ est = (
+ f" * Without memory planning, there are
{self.total_const_size_tensor_num} "
+ "constant-size memory allocation(s) with total size "
+ "{0:.4} GB".format(self.total_alloc_tensor_mem / 2**30)
+ )
+ if self.total_dyn_size_tensor_num > 0:
+ est += f", and {self.total_dyn_size_tensor_num} dynamic-size
allocation(s)"
+ est += (
+ f".\n * With memory planning, there are {self.planned_mem_num}
constant-size "
+ "memory allocation(s) with total size "
+ "{0:.4} GB.\n".format(self.planned_alloc_mem / 2**30)
+ )
+ est += " * Memory planning reduces constant memory size to "
"{0:.1%}.".format(
+ self.planned_alloc_mem / self.total_alloc_tensor_mem
+ )
+ return "- Function " + func_name + ":\n" + est
+
+ if isinstance(mod, Function):
+ mod = tvm.IRModule({tvm.ir.GlobalVar("foo"): mod})
+
+ return MemoryEstimator().estimate(mod)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 4ba967935b..1f14823b5a 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -93,6 +93,17 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
return _ffi_api.RewriteDataflowReshape() # type: ignore
+def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
+ """The static memory planning pass on BindingBlock level.
+ The pass will reuse allocated memory to its best effort, in order to
+ reduce the total amount of allocated memory size.
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ """
+ return _ffi_api.StaticPlanBlockMemory() # type: ignore
+
+
def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py
index ff6bf816b6..2cf1250690 100644
--- a/python/tvm/relax/vm.py
+++ b/python/tvm/relax/vm.py
@@ -585,6 +585,7 @@ def build(
passes.append(relax.transform.RewriteDataflowReshape())
passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.CallTIRRewrite())
+ passes.append(relax.transform.StaticPlanBlockMemory())
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
passes.append(relax.transform.AttachGlobalSymbol())
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
new file mode 100644
index 0000000000..8b7adae246
--- /dev/null
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -0,0 +1,750 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/static_plan_block_memory.cc
+ * \brief The static memory planning pass on BindingBlock level.
+ * \details
+ * The core data structure of the planning pass is StorageToken, which denotes
+ * reusable memory in this planning pass.
+ *
+ * The memory planning pass contains three stages:
+ *
+ * The first stage is initialization. A storage token object will be created
+ * for each builtin alloc_tensor as long as the allocated storage satisfies
+ * the requirements (which are described in the code). The reference counter
+ * (i.e., the times of reference) for each token is recorded.
+ *
+ * The second stage is allocation planning. We maintain a pool of available
+ * allocated storage, in the form of storage tokens. For the storage token of
+ * each builtin alloc_tensor, we check if there is appropriate available token
+ * in the pool under certain criterion. If there is, we reuse that storage
+ * for this alloc_tensor. Otherwise, we decide to allocate a storage for the
+ * alloc_tensor.
+ *
+ * The third stage is IR rewrite. Based on the decision made in the second
+ * stage, we insert memory alloc_storage, alloc_tensor, kill_tensor, and
+ * kill_storage accordingly. Specifically, we
+ * - insert alloc_storage before the site that each storage token is firstly
+ * used,
+ * - insert memory alloc_tensor for each builtin alloc_tensor,
+ * - insert kill_tensor after the site that a tensor created by alloc_tensor
+ * is last referenced, and
+ * - insert kill_storage at the end of each binding block, for all the storage
+ * tokens that are allocated inside the binding block, as the memory planning
+ * only works on block level.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/transform.h>
+
+#include <map>
+#include <set>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief A representation of a block of reusable memory required at runtime.
+ * \details Only the tensors whose memory can be "possibly reused" will have
+ * their storage token. In other words, we do not have storage token for tensor
+ * - that is a function parameter,
+ * - that is a function return value,
+ * - one of whose use site is a BindingBlock different from its allocation
site,
+ * - that is used as a condition or branch return of a IfNode,
+ * - that is used as the body of a SeqExprNode,
+ * - that is used as arguments in a Call whose op is not a PrimFunc.
+ *
+ * In practice, we do create a storage token for such tensor at first. But at
+ * any time we find a tensor satisfying any of the conditions above, we erase
+ * its storage token.
+ */
+class StorageTokenNode : public Object {
+ public:
+ /*! \brief Reference counter. */
+ int ref_counter{0};
+ /*! \brief Number of bytes that this token requires. */
+ int64_t bytes;
+ /*! \brief The dtype of this token. */
+ DataType dtype;
+ /*! \brief The storage id, reserved for debug and demo use. */
+ int storage_id{-1};
+ /*!
+ * \brief The variable corresponding to the allocated storage, which is
NullOpt
+ * before definition.
+ */
+ Optional<Var> storage{NullOpt};
+
+ static constexpr const char* _type_key = "relax.transform.StorageToken";
+ TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StorageTokenNode.
+ * \sa StorageTokenNode
+ */
+class StorageToken : public ObjectRef {
+ public:
+ explicit StorageToken(Array<PrimExpr> shape, DataType dtype) {
+ // Compute the tensor size from the shape.
+ int64_t size = 1;
+ for (const PrimExpr& dim_len : shape) {
+ const auto* int_len = dim_len.as<IntImmNode>();
+ ICHECK_NOTNULL(int_len);
+ size *= int_len->value;
+ }
+
+ ObjectPtr<StorageTokenNode> n = make_object<StorageTokenNode>();
+ n->bytes = (size * dtype.bits() * dtype.lanes() + 7) / 8;
+ n->dtype = dtype;
+ data_ = std::move(n);
+ }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef,
StorageTokenNode);
+};
+
+// We use NestedMsg to store the tokens used by each Expr.
+using Tokens = NestedMsg<StorageToken>;
+
+/*!
+ * \brief Memory manager for flattened 1d memory (buffers)
+ * \note We can generalize this implementation to multi-dimensional memory
+ * following the same flow in the future.
+ */
+class TokenAllocator1D {
+ public:
+ /*!
+ * \brief Request a storage token from the available token pool for a
+ * given prototype, or report no appropriate available token in the pool.
+ * \param prototype The requesting prototype storage token.
+ * \return The request result token. Return NullOpt if there is no
+ * appropriate available token in the pool.
+ */
+ Optional<StorageToken> RequestReuse(StorageToken prototype) {
+ // Step 0. Sanity check: the prototype token is supposed not to be
allocated with actual storage
+ ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be
allocated before.";
+ // If the prototype has no reference at all, feel free to allocate new
storage.
+ // The unused binding can be removed by cleaning passes.
+ if (prototype->ref_counter == 0) {
+ return NullOpt;
+ }
+
+ // Step 1. Get the available pool of the token dtype.
+ std::multimap<int64_t, StorageToken>& pool =
available_pool_[prototype->dtype];
+
+ // Step 2. Get the range of memory blocks in [size / match_range_, size *
match_range_)
+ int64_t size = prototype->bytes;
+ auto begin = pool.lower_bound(size / match_range_);
+ auto mid = pool.lower_bound(size);
+ auto end = pool.upper_bound(size * match_range_);
+ // Step 3. Search for memory block that equals or is larger than the
requested size.
+ if (mid != end) {
+ StorageToken available_token = mid->second;
+ ICHECK_EQ(available_token->ref_counter, 0)
+ << "Available tokens are expected to have 0 reference.";
+ ICHECK_LE(size, available_token->bytes);
+ available_token->ref_counter = prototype->ref_counter;
+ pool.erase(mid);
+ return available_token;
+ }
+ // Step 4. Then search for memory block that is smaller than the requested
size.
+ if (mid != begin) {
+ --mid;
+ StorageToken available_token = mid->second;
+ ICHECK_EQ(available_token->ref_counter, 0)
+ << "Available tokens are expected to have 0 reference.";
+ ICHECK_GE(size, available_token->bytes);
+ // Enlarge the token size.
+ available_token->bytes = size;
+ available_token->ref_counter = prototype->ref_counter;
+ pool.erase(mid);
+ return available_token;
+ }
+ // Return `NullOpt` indicating that no satisfiable storage token is found
in the available pool.
+ return NullOpt;
+ }
+
+ /*!
+ * \brief Allocate a storage token for the input prototype token.
+ * \param prototype The prototype token.
+ * \param storage_id The id of this token.
+ */
+ StorageToken Alloc(StorageToken prototype, int storage_id) {
+ // Sanity check: the prototype token is supposed not to be allocated with
actual storage yet
+ ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be
allocated before.";
+ prototype->storage_id = storage_id;
+ full_pool_.push_back(prototype);
+ return prototype;
+ }
+
+ /*!
+ * \brief Release the input token, putting it into the available pool.
+ * \param token The token to be released.
+ */
+ void Release(StorageToken token) {
+ // Sanity check: the token has been allocated with actual storage, and
should have 0 reference.
+ ICHECK_GE(token->storage_id, 0)
+ << "The token to be released is expected to be allocated before";
+ ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected
to have 0 reference.";
+ available_pool_[token->dtype].insert({token->bytes, token});
+ }
+
+ private:
+ /*! \brief A constant scale representing the token search range. */
+ const int match_range_{16};
+ /*! \brief The pool of available storage tokens for each dtype. */
+ std::unordered_map<DataType, std::multimap<int64_t, StorageToken>>
available_pool_;
+ /*! \brief All the storage tokens that have been allocated with actual
storage. */
+ std::vector<StorageToken> full_pool_;
+};
+
+/*! \brief Check if the input op is "relax.reshape". */
+bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
+
+/*! \brief The base class for the storage allocation visitor. */
+class StorageAllocatorBaseVisitor : public ExprVisitor {
+ protected:
+ using ExprVisitor::VisitExpr_;
+
+ void VisitBindingBlock_(const BindingBlockNode* block) override {
+ // We maintain a block stack for token allocation-site and use-site check.
+ block_stack_.push_back(block);
+ ExprVisitor::VisitBindingBlock_(block);
+ ICHECK(!block_stack_.empty());
+ ICHECK(block_stack_.back() == block);
+ block_stack_.pop_back();
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ ExprVisitor::VisitBinding_(binding);
+ // The binding var has the same tokens as the binding value.
+ SetTokens(binding->var.get(), token_map_[binding->value.get()]);
+ }
+
+ void VisitExpr_(const TupleNode* tuple) final {
+ Array<Tokens> tokens;
+ tokens.reserve(tuple->fields.size());
+ for (const Expr& field : tuple->fields) {
+ Tokens field_tokens = GetTokens(field);
+ tokens.push_back(field_tokens);
+ }
+ SetTokens(tuple, Tokens(tokens));
+ }
+
+ void VisitExpr_(const TupleGetItemNode* tuple_item) final {
+ Tokens tokens = GetTokens(tuple_item->tuple);
+ // If the tuple has no token, every of its field has no token as well.
+ if (tokens.IsNull()) {
+ token_map_[tuple_item] = Tokens();
+ return;
+ }
+ ICHECK(tokens.IsNested());
+ Array<Tokens> field_tokens = tokens.NestedArray();
+ ICHECK_GT(static_cast<int>(field_tokens.size()), tuple_item->index);
+ ICHECK_GE(tuple_item->index, 0);
+ SetTokens(tuple_item, field_tokens[tuple_item->index]);
+ }
+
+ /******************** Utilities ********************/
+
+ Tokens GetTokens(const Expr& expr) {
+ this->VisitExpr(expr);
+ return token_map_[expr.get()];
+ }
+
+ virtual void SetTokens(const ExprNode* expr, Tokens tokens) {
token_map_[expr] = tokens; }
+
+ /*! \brief The mapping from each Expr to its corresponding storage tokens. */
+ std::unordered_map<const ExprNode*, Tokens> token_map_;
+ /*! \brief The binding block stack. */
+ std::vector<const BindingBlockNode*> block_stack_;
+};
+
+/*!
+ * \brief The visitor class for storage token initialization.
+ * \details It goes through the entire function to get the storage tokens
+ * used by each Expr. After the initialization, we
+ * - know the tokens that each Expr is using,
+ * - know the number of references for each token,
+ * - rule out the builtin alloc_tensors to which the planning does not apply.
+ */
+class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
+ public:
+ explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
+
+ /*!
+ * \brief The entry of the initialization.
+ * \return The mapping from each Expr to the token it uses.
+ */
+ std::unordered_map<const ExprNode*, Tokens> Initialize(const Function& func)
{
+ // Recurse into the function to get its tokens.
+ Tokens body_tokens = GetTokens(func->body);
+ // Discard the tokens used by the function return value, as they are
external referenced.
+ DiscardTokensIn(body_tokens);
+ return this->token_map_;
+ }
+
+ private:
+ using ExprVisitor::VisitExpr_;
+
+ void VisitExpr_(const CallNode* call) final {
+ static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ if (call->op == alloc_tensor_op) {
+ // Create a storage token for builtin alloc_tensor.
+ this->CreateToken(call);
+ return;
+ } else if (IsReshape(call->op)) {
+ // Reuse the input's token for builtin reshape.
+ SetTokens(call, GetTokens(call->args[0]));
+ return;
+ }
+
+ // - Increase the reference counters of the arguments when the callee is
+ // a PrimFunc of the context module.
+ // - Otherwise, discard the tokens used by the arguments, as there might be
+ // potential external reference.
+ if (IsPrimFuncGlobalVar(call->op)) {
+ ICHECK(!block_stack_.empty());
+ for (const Expr& arg : call->args) {
+ Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
+ ForEachLeaf(tokens, [](StorageToken token) { token->ref_counter += 1;
});
+ }
+ } else {
+ for (const Expr& arg : call->args) {
+ DiscardTokensIn(GetTokens(arg));
+ }
+ }
+ }
+
+ void VisitExpr_(const IfNode* if_node) final {
+ Tokens cond_tokens = GetTokens(if_node->cond);
+ Tokens then_tokens = GetTokens(if_node->true_branch);
+ Tokens else_tokens = GetTokens(if_node->false_branch);
+ // Discard the tokens used by the condition, then-body and else-body,
+ // as the planning works on block level.
+ DiscardTokensIn(cond_tokens);
+ DiscardTokensIn(then_tokens);
+ DiscardTokensIn(else_tokens);
+ }
+
+ void VisitExpr_(const SeqExprNode* seq) final {
+ for (const BindingBlock& binding_block : seq->blocks) {
+ this->VisitBindingBlock(binding_block);
+ }
+ Tokens body_tokens = GetTokens(seq->body);
+ // Discard the tokens used by the body, as the planning works on block
level.
+ DiscardTokensIn(body_tokens);
+ }
+
+ /******************** Utilities ********************/
+
+ /*!
+ * \brief Check if the input op is GlobalVar corresponding to a PrimFunc
inside the ctx module.
+ * \param op The op to be checked
+ * \return A boolean indicating if the input op corresponds to a PrimFunc.
+ */
+ bool IsPrimFuncGlobalVar(const Expr& op) {
+ const auto* global_var = op.as<GlobalVarNode>();
+ if (global_var == nullptr) {
+ return false;
+ }
+ auto func_it = ctx_mod_->functions.find(GetRef<GlobalVar>(global_var));
+ if (func_it == ctx_mod_->functions.end()) {
+ return false;
+ }
+ return (*func_it).second->IsInstance<tir::PrimFuncNode>();
+ }
+
+ /*!
+ * \brief Create a storage token for the builtin alloc_tensor call.
+ * \param call The call to be processed.
+ * \return The created token.
+ */
+ Tokens CreateToken(const CallNode* call) {
+ // Sanity checks about
+ // - the call return value is a Tensor;
+ // - the shape of the tensor is known, in the form of ShapeExpr;
+ // - the tensor has known dtype;
+ // - no storage token was created for this call before.
+ const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ ICHECK_NOTNULL(sinfo);
+ ICHECK_NOTNULL(shape);
+ ICHECK(!sinfo->IsUnknownDtype());
+ ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
+ ICHECK(!token_map_.count(call));
+
+ // No support for symbolic shape at this moment.
+ for (const PrimExpr& dim_len : shape->values) {
+ const auto* int_len = dim_len.as<IntImmNode>();
+ if (!int_len) {
+ token_map_[call] = Tokens();
+ return Tokens();
+ }
+ }
+
+ // Create and set token.
+ StorageToken token(shape->values, sinfo->dtype);
+
+ Tokens tokens(token);
+ SetTokens(call, tokens);
+ ICHECK(!block_stack_.empty());
+ token2block_[token.get()] = block_stack_.back();
+ return tokens;
+ }
+
+ /*!
+ * \brief Override the token setter in the base visitor.
+ * For each token, we keep record of all Expr that are using that token.
+ * When we want to discard one token, we use the records to remove the token
+ * from the Expr that are using it.
+ */
+ void SetTokens(const ExprNode* expr, Tokens tokens) final {
+ StorageAllocatorBaseVisitor::SetTokens(expr, tokens);
+ ForEachLeaf(tokens, [this, expr](StorageToken token) {
+ this->token2exprs_[token.get()].push_back(expr);
+ });
+ }
+
+ /*!
+ * \brief Token getter with allocation site check.
+ * We first get the tokens used by the input Expr, and check if the
allocation
+ * site of each token is the input current block.
+ * Since the planning works on block level, if some token's allocation site
+ * is not the current block, we discard the token so that it will not be
planned.
+ * \param expr The Expr whose tokens is to be got.
+ * \param cur_block The pointer to the current block.
+ * \return The tokens used by the input Expr.
+ */
+ Tokens GetTokensWithAllocSiteCheck(const Expr& expr, const BindingBlockNode*
cur_block) {
+ Tokens tokens = GetTokens(expr);
+ ForEachLeaf(tokens, [this, cur_block](StorageToken token) {
+ auto it = this->token2block_.find(token.get());
+ ICHECK(it != this->token2block_.end());
+ if (it->second != cur_block) {
+ this->DiscardToken(token);
+ }
+ });
+ return token_map_[expr.get()];
+ }
+
+ /*! \brief Discard the input tokens. */
+ void DiscardTokensIn(Tokens tokens) {
+ ForEachLeaf(tokens, [this](StorageToken token) {
this->DiscardToken(token); });
+ }
+
+ /*!
+ * \brief Discard the input token.
+ * For each Expr that is using the input token, remove the token from the
Expr's token set.
+ * \param token_to_discard The token to be discarded.
+ */
+ void DiscardToken(StorageToken token_to_discard) {
+ const std::vector<const ExprNode*>& exprs =
token2exprs_[token_to_discard.get()];
+ for (const ExprNode* expr : exprs) {
+ token_map_[expr] = MapNestedMsg(token_map_[expr],
[token_to_discard](StorageToken token) {
+ return token.same_as(token_to_discard) ? Tokens() : Tokens(token);
+ });
+ }
+ token2exprs_.erase(token_to_discard.get());
+ token2block_.erase(token_to_discard.get());
+ }
+
+ /*!
+ * \brief The context IRModule, used for checking if a callee function is
+ * a PrimFunc inside the IRModule.
+ */
+ const IRModule& ctx_mod_;
+ /*! \brief The mapping from each token to the binding block where it is
created. */
+ std::unordered_map<const StorageTokenNode*, const BindingBlockNode*>
token2block_;
+ /*! \brief The mapping from each token to the Exprs that are using this
token. */
+ std::unordered_map<const StorageTokenNode*, std::vector<const ExprNode*>>
token2exprs_;
+};
+
+/*!
+ * \brief The visitor class for storage token allocation planning.
+ * \details
+ * - For each builtin alloc_tensor whose token is not discarded in the
+ * initialization stage, we request a storage reuse or decide to allocate
+ * storage for this token, depending on if there is appropriate available
+ * token in the token pool we maintain.
+ * - For each VM builtin reshape, we reuse the input's tokens.
+ *
+ * After the allocation planning, we
+ * - know the token that each builtin alloc_tensor plans to use. Compared
+ * with the initialization, here the token is possibly a reuse of some
+ * previous token, rather than we having one token for each alloc_tensor.
+ * - know the last referenced site for each builtin alloc_tensor. This
+ * information is used for inserting kill_tensor in the rewrite stage.
+ * - know the tokens allocated in each binding block. This information
+ * is used for inserting kill_storage in the rewrite stage.
+ */
+class StorageAllocator : public StorageAllocatorBaseVisitor {
+ public:
+ explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens>
token_map) {
+ this->token_map_ = std::move(token_map);
+ }
+
+ /*!
+ * \brief The mapping from each `builtin.alloc_tensor` to its corresponding
+ * underlying storage token that it is using.
+ */
+ std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token;
+ /*! \brief The mapping from each Expr to the tensors that need to be killed
after it. */
+ std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors;
+ /*! \brief The mapping from each binding block to the storage tokens that
are create inside. */
+ std::unordered_map<const BindingBlockNode*, std::vector<const
StorageTokenNode*>> block2tokens;
+
+ private:
+ using ExprVisitor::VisitBinding_;
+ using ExprVisitor::VisitExpr_;
+
+ void VisitBindingBlock_(const BindingBlockNode* block) final {
+ StorageAllocatorBaseVisitor::VisitBindingBlock_(block);
+ // Sanity check: each token allocated inside the block should not be
+ // referenced by anyone at the end of the block.
+ for (const StorageTokenNode* token : block2tokens[block]) {
+ ICHECK_EQ(token->ref_counter, 0);
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
+ static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ if (call->op == alloc_tensor_op) {
+ auto it = token_map_.find(call);
+ ICHECK(it != token_map_.end());
+
+ if (it->second.IsNull()) {
+ // IsNull being true means the token was discarded, and this
alloc_tensor
+ // is not considered by the planning.
+ return;
+ }
+ ICHECK(it->second.IsLeaf());
+ StorageToken new_token =
this->RequestReuseOrAlloc(it->second.LeafValue());
+
+ // Record that this alloc_tensor is using the token.
+ alloc_tensor2token.insert({call, new_token});
+ token2cur_tensor_[new_token.get()].push_back(binding->var);
+ SetTokens(call, Tokens(new_token));
+ // Record that the token is allocated in the current block.
+ ICHECK(!block_stack_.empty());
+ std::vector<const StorageTokenNode*>& block_tokens =
block2tokens[block_stack_.back()];
+ if (std::find(block_tokens.begin(), block_tokens.end(), new_token.get())
==
+ block_tokens.end()) {
+ block_tokens.push_back(new_token.get());
+ }
+ return;
+ } else if (IsReshape(call->op)) {
+ Tokens tokens = GetTokens(call->args[0]);
+ ICHECK(!tokens.IsNested());
+ if (tokens.IsLeaf()) {
+ // If the input is using a token, record that the reshape uses the
token as well.
+ token2cur_tensor_[tokens.LeafValue().get()].push_back(binding->var);
+ SetTokens(call, tokens);
+ } else {
+ ICHECK(token_map_[call].IsNull());
+ }
+ return;
+ }
+
+ // Decrease the reference counter by one for each token that the arguments
use.
+ // Check if a token can be released (i.e., has no reference) after
decrease.
+ // And release it if so.
+ for (const Expr& arg : call->args) {
+ Tokens tokens = GetTokens(arg);
+ ForEachLeaf(tokens, [this, call](StorageToken token) {
+ ICHECK_GT(token->ref_counter, 0);
+ token->ref_counter -= 1;
+ this->CheckForRelease(token, call);
+ });
+ }
+ }
+
+ /*! \brief Request a storage reuse, or allocate storage if no appropriate
storage is reusable. */
+ StorageToken RequestReuseOrAlloc(StorageToken prototype) {
+ Optional<StorageToken> token = allocator_.RequestReuse(prototype);
+ if (!token.defined()) {
+ return allocator_.Alloc(prototype, this->n_storage_++);
+ } else {
+ return token.value();
+ }
+ }
+
+ /*!
+ * \brief Check if a token has no reference and thus can be released. And
release it if so.
+ * \param token The token to be checked.
+ * \param release_site The CallNode where the the input token is send for
release.
+ * If the token is checked to release here, we keep record of the release
site so that
+ * kill_tensor can be inserted here at the rewrite stage.
+ */
+ void CheckForRelease(StorageToken token, const CallNode* release_site) {
+ // Sanity check: the token was allocated before and has non-negative
reference.
+ ICHECK_GE(token->storage_id, 0);
+ ICHECK_GE(token->ref_counter, 0);
+
+ if (token->ref_counter == 0) {
+ allocator_.Release(token);
+ auto it = token2cur_tensor_.find(token.get());
+ ICHECK(it != token2cur_tensor_.end());
+ // Record that the tensors that are using this token will be killed
+ // immediately after the release site.
+ std::vector<Var>& killed_tensors = expr2killed_tensors[release_site];
+ killed_tensors.insert(killed_tensors.end(), it->second.begin(),
it->second.end());
+ token2cur_tensor_.erase(it);
+ }
+ }
+
+ /*! \brief Number of allocated storages. */
+ int n_storage_{0};
+ /*! \brief The 1D memory allocator. */
+ TokenAllocator1D allocator_;
+ /*! \brief The mapping from each token to the tensors that are currently
using it. */
+ std::unordered_map<const StorageTokenNode*, std::vector<Var>>
token2cur_tensor_;
+};
+
+/*!
+ * \brief The rewriter class based on the token allocation planning.
+ * \details
+ * - For each builtin alloc_tensor that was planned, substitute it with a
memory
+ * alloc_tensor. If no memory alloc_storage was created for it before, create
one.
+ * - Insert memory kill_tensor at the release site of each tensor.
+ * - Insert memory kill_storage at the end of each binding block, for the
tokens allocated in it.
+ */
+class StorageAllocationRewriter : public ExprMutator {
+ public:
+ explicit StorageAllocationRewriter(
+ std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token,
+ std::unordered_map<const ExprNode*, std::vector<Var>>
expr2killed_tensors,
+ std::unordered_map<const BindingBlockNode*, std::vector<const
StorageTokenNode*>>
+ block2tokens)
+ : alloc_tensor2token_(std::move(alloc_tensor2token)),
+ expr2killed_tensors_(std::move(expr2killed_tensors)),
+ block2tokens_(std::move(block2tokens)) {}
+
+ private:
+ using ExprMutator::VisitExpr_;
+
+ BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+ builder_->BeginBindingBlock();
+ for (Binding binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+
+ // Insert `memory.kill_storage` for the storage tokens allocated inside
this block.
+ for (const StorageTokenNode* token : block2tokens_[block]) {
+ ICHECK(token->storage.defined());
+ static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
+ this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}),
/*name_hint=*/"_");
+ }
+
+ BindingBlock new_block = builder_->EndBlock();
+ return new_block;
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ ExprMutator::VisitBinding_(binding);
+
+ // Insert `memory.kill_tensor` for the tensors that need to be killed
after this binding.
+ auto it = expr2killed_tensors_.find(binding->value.get());
+ if (it != expr2killed_tensors_.end()) {
+ for (const Var& var : it->second) {
+ static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor");
+ this->builder_->Emit(Call(mem_kill_tensor,
{Downcast<Var>(this->VisitExpr(var))}),
+ /*name_hint=*/"_");
+ }
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* call) final {
+ auto it = alloc_tensor2token_.find(call);
+ if (it != alloc_tensor2token_.end()) {
+ const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ ICHECK_NOTNULL(sinfo);
+ ICHECK_NOTNULL(sinfo->shape.as<ShapeExprNode>());
+ PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
+
+ // If the token is visited for the first time, create a storage variable
using
+ // `memory.alloc_storage` for it.
+ StorageToken token = it->second;
+ if (!token->storage.defined()) {
+ static const Op& mem_alloc_storage =
Op::Get("relax.memory.alloc_storage");
+ ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
+ PrimValue virtual_device_index = runtime_device_index;
+ std::string storage_scope = "global";
+ DataType dtype = token->dtype;
+ Call alloc_storage(
+ mem_alloc_storage,
+ {std::move(size), virtual_device_index, StringImm(storage_scope),
DataTypeImm(dtype)},
+ Attrs());
+ token->storage = builder_->Emit(alloc_storage, "storage");
+ }
+
+ // And always create a `memory.alloc_tensor` for the old
`builtin.alloc_tensor`.
+ static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
+ PrimValue offset = PrimValue::Int64(0);
+ DataType dtype = sinfo->dtype;
+ return Call(mem_alloc_tensor,
+ {token->storage.value(), offset, sinfo->shape.value(),
DataTypeImm(dtype)},
+ Attrs());
+ }
+
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ /*!
+ * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
+ its corresponding underlying storage token that it is using.
+ */
+ std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token_;
+ /*! \brief The mapping from each Expr to the tensors that need to be killed
after it. */
+ std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors_;
+ /*! \brief The mapping from each binding block to the storage tokens that
are create inside. */
+ std::unordered_map<const BindingBlockNode*, std::vector<const
StorageTokenNode*>> block2tokens_;
+};
+
+Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) {
+ // Step 1. Initialize.
+ StorageAllocatorInit initializer(ctx_mod);
+ std::unordered_map<const ExprNode*, Tokens> token_map =
initializer.Initialize(func);
+ // Step 2. Collect the memory allocation info.
+ StorageAllocator allocator(std::move(token_map));
+ allocator(func);
+ // Step 3. Rewrite the function.
+ StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token),
+ std::move(allocator.expr2killed_tensors),
+ std::move(allocator.block2tokens));
+ func = Downcast<Function>(rewriter(func));
+ return func;
+}
+
+namespace transform {
+
+Pass StaticPlanBlockMemory() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(StaticPlanBlockMemory(std::move(f), m));
+ };
+ return CreateFunctionPass(pass_func, /*opt_level=*/0,
"StaticPlanBlockMemory", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py
b/tests/python/relax/test_analysis_estimate_memory_usage.py
new file mode 100644
index 0000000000..3e6ba4499f
--- /dev/null
+++ b/tests/python/relax/test_analysis_estimate_memory_usage.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import relax as R, tir as T
+from tvm.relax.analysis import estimate_memory_usage
+
+
+def test_basic():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(
+ rxplaceholder: T.Buffer(T.int64(8), "float32"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ T_add: T.Buffer(T.int64(8), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
+ T_reshape: T.Buffer(T.int64(8), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(
+ rxplaceholder: T.Buffer(T.int64(8), "float32"), compute:
T.Buffer(T.int64(8), "float32")
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(
+ rxplaceholder: T.Buffer(T.int64(10), "float32"),
+ compute: T.Buffer(T.int64(10), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(
+ rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
+ compute: T.Buffer((T.int64(2), T.int64(4)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(
+ rxplaceholder: T.Buffer(T.int64(8), "float32"),
+ PadInput: T.Buffer(T.int64(10), "float32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([32]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(
+ storage, offset=0, shape=R.shape([2, 4]), dtype="float32"
+ )
+ _: R.Tuple() = exp(x, alloc)
+ lv: R.Tensor((2, 4), dtype="float32") = alloc
+ lv1: R.Tensor((8,), dtype="float32") = R.call_packed(
+ "vm.builtin.reshape", lv, R.shape([8]),
sinfo_args=[R.Tensor((8,), dtype="float32")]
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([40]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, offset=0, shape=R.shape([8]), dtype="float32"
+ )
+ _1: R.Tuple() = relu(lv1, alloc1)
+ _2: R.Tuple() = R.memory.kill_tensor(alloc)
+ _3: R.Tuple() = R.memory.kill_tensor(lv1)
+ lv2: R.Tensor((8,), dtype="float32") = alloc1
+ alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage, offset=0, shape=R.shape([8]), dtype="float32"
+ )
+ _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2)
+ _5: R.Tuple() = R.memory.kill_tensor(alloc1)
+ lv3: R.Tensor((8,), dtype="float32") = alloc2
+ alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, offset=0, shape=R.shape([10]), dtype="float32"
+ )
+ _6: R.Tuple() = pad(lv3, alloc3)
+ _7: R.Tuple() = R.memory.kill_tensor(alloc2)
+ lv4: R.Tensor((10,), dtype="float32") = alloc3
+ alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([10]), dtype="float32", runtime_device_index=0
+ )
+ _8: R.Tuple() = log(lv4, alloc4)
+ _9: R.Tuple() = R.memory.kill_tensor(alloc3)
+ gv5: R.Tensor((10,), dtype="float32") = alloc4
+ _11: R.Tuple() = R.memory.kill_storage(storage)
+ _10: R.Tuple() = R.memory.kill_storage(storage1)
+ return gv5
+
+ assert (
+ estimate_memory_usage(Module)
+ == r"""Memory usage estimation:
+- Function main:
+ * Without memory planning, there are 5 constant-size memory allocation(s)
with total size 1.639e-07 GB.
+ * With memory planning, there are 2 constant-size memory allocation(s) with
total size 6.706e-08 GB.
+ * Memory planning reduces constant memory size to 40.9%."""
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
new file mode 100644
index 0000000000..f11df58b26
--- /dev/null
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -0,0 +1,612 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R, tir as T
+
+
+def test_basic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer(T.int64(8), "float32"),
rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8),
"float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)),
"float32"), T_reshape: T.Buffer(T.int64(8), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute:
T.Buffer(T.int64(8), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute:
T.Buffer(T.int64(10), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput:
T.Buffer(T.int64(10), "float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0)
+ _: R.Tuple() = exp(x, alloc)
+ lv: R.Tensor((2, 4), dtype="float32") = alloc
+ lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,))
+ alloc1: R.Tensor((8,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0)
+ _1: R.Tuple() = relu(lv1, alloc1)
+ lv2: R.Tensor((8,), dtype="float32") = alloc1
+ alloc2: R.Tensor((8,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0)
+ _2: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((8,), dtype="float32") = alloc2
+ alloc3: R.Tensor((10,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _3: R.Tuple() = pad(lv3, alloc3)
+ lv4: R.Tensor((10,), dtype="float32") = alloc3
+ alloc4: R.Tensor((10,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _4: R.Tuple() = log(lv4, alloc4)
+ gv: R.Tensor((10,), dtype="float32") = alloc4
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer(T.int64(8), "float32"),
rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8),
"float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)),
"float32"), T_reshape: T.Buffer(T.int64(8), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute:
T.Buffer(T.int64(8), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute:
T.Buffer(T.int64(10), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput:
T.Buffer(T.int64(10), "float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]),
virtual_device_index=0, storage_scope="global", dtype="float32")
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), dtype="float32")
+ _: R.Tuple() = exp(x, alloc)
+ lv: R.Tensor((2, 4), dtype="float32") = alloc
+ lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([40]),
virtual_device_index=0, storage_scope="global", dtype="float32")
+ alloc1: R.Tensor((8,), dtype="float32") =
R.memory.alloc_tensor(storage1, 0, R.shape([8]), dtype="float32")
+ _1: R.Tuple() = relu(lv1, alloc1)
+ _2: R.Tuple() = R.memory.kill_tensor(alloc)
+ _3: R.Tuple() = R.memory.kill_tensor(lv1)
+ lv2: R.Tensor((8,), dtype="float32") = alloc1
+ alloc2: R.Tensor((8,), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([8]), dtype="float32")
+ _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2)
+ _5: R.Tuple() = R.memory.kill_tensor(alloc1)
+ lv3: R.Tensor((8,), dtype="float32") = alloc2
+ alloc3: R.Tensor((10,), dtype="float32") =
R.memory.alloc_tensor(storage1, 0, R.shape([10]), dtype="float32")
+ _6: R.Tuple() = pad(lv3, alloc3)
+ _7: R.Tuple() = R.memory.kill_tensor(alloc2)
+ lv4: R.Tensor((10,), dtype="float32") = alloc3
+ alloc4: R.Tensor((10,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _8: R.Tuple() = log(lv4, alloc4)
+ _9: R.Tuple() = R.memory.kill_tensor(alloc3)
+ gv5: R.Tensor((10,), dtype="float32") = alloc4
+ _11: R.Tuple() = R.memory.kill_storage(storage)
+ _10: R.Tuple() = R.memory.kill_storage(storage1)
+ return gv5
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_different_dtype():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def add1(
+ A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="int32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ gv: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="int32", runtime_device_index=0
+ )
+ _1: R.Tuple() = add1(y, y, alloc1)
+ gv1: R.Tensor((2, 3), dtype="int32") = alloc1
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @T.prim_func
+ def add1(
+ A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="int32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ _1: R.Tuple() = R.memory.kill_tensor(alloc)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="int32"
+ )
+ alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([2, 3]), dtype="int32"
+ )
+ _2: R.Tuple() = add1(y, y, alloc1)
+ _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+ gv12: R.Tensor((2, 3), dtype="int32") = alloc1
+ _5: R.Tuple() = R.memory.kill_storage(storage)
+ _4: R.Tuple() = R.memory.kill_storage(storage1)
+ return x
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_same_dtype():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ gv: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1: R.Tuple() = add(y, y, alloc1)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc1
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = add(x, x, alloc)
+ _1: R.Tuple() = R.memory.kill_tensor(alloc)
+ gv1: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _2: R.Tuple() = add(y, y, alloc1)
+ _3: R.Tuple() = R.memory.kill_tensor(alloc1)
+ gv12: R.Tensor((2, 3), dtype="float32") = alloc1
+ _4: R.Tuple() = R.memory.kill_storage(storage)
+ return x
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_if_cond():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def all_less_than_zero(A: T.Buffer((2, 3), "float32"), B: T.Buffer((),
"bool")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ alloc: R.Tensor((), dtype="bool") = R.builtin.alloc_tensor(
+ R.shape([]), dtype="bool", runtime_device_index=0
+ )
+ _: R.Tuple() = all_less_than_zero(x, alloc)
+ x1: R.Tensor((), dtype="bool") = alloc
+ if x1:
+ y: R.Tensor((2, 3), dtype="float32") = x
+ else:
+ alloc1: R.Tensor((2, 3), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1: R.Tuple() = exp(x, alloc1)
+ gv3: R.Tensor((2, 3), dtype="float32") = alloc1
+ y: R.Tensor((2, 3), dtype="float32") = gv3
+ return x
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+def test_if_then_else():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = exp(x, alloc)
+ y: R.Tensor((2, 3), dtype="float32") = alloc
+ if cond:
+ z: R.Tensor((2, 3), dtype="float32") = y
+ else:
+ z: R.Tensor((2, 3), dtype="float32") = y
+ return x
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+def test_cross_block_use():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = exp(x, alloc)
+ y: R.Tensor((2, 3), dtype="float32") = alloc
+ if cond:
+ alloc1: R.Tensor((2, 3), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1: R.Tuple() = exp(y, alloc1)
+ y2: R.Tensor((2, 3), dtype="float32") = alloc1
+ z: R.Tensor((2, 3), dtype="float32") = y2
+ else:
+ alloc2: R.Tensor((2, 3), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _2: R.Tuple() = exp(y, alloc2)
+ y2: R.Tensor((2, 3), dtype="float32") = alloc2
+ z: R.Tensor((2, 3), dtype="float32") = y2
+ return x
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+def test_nested_tuple():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = exp(x, alloc)
+ y1: R.Tensor((2, 3), dtype="float32") = alloc
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _1: R.Tuple() = exp(x, alloc1)
+ y2: R.Tensor((2, 3), dtype="float32") = alloc1
+ alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _2: R.Tuple() = exp(x, alloc2)
+ y3: R.Tensor((2, 3), dtype="float32") = alloc2
+ t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")) = (
+ y1,
+ y2,
+ )
+ nt: R.Tuple(
+ R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")),
+ R.Tensor((2, 3), dtype="float32"),
+ ) = (t, y3)
+ nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")) = nt[
+ 0
+ ]
+ y1_: R.Tensor((2, 3), dtype="float32") = nt0[0]
+ y2_: R.Tensor((2, 3), dtype="float32") = nt0[1]
+ y3_: R.Tensor((2, 3), dtype="float32") = nt[1]
+ alloc3: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _3: R.Tuple() = exp(y1_, alloc3)
+ z1: R.Tensor((2, 3), dtype="float32") = alloc3
+ alloc4: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _4: R.Tuple() = exp(y2_, alloc4)
+ z2: R.Tensor((2, 3), dtype="float32") = alloc4
+ alloc5: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _5: R.Tuple() = exp(y3_, alloc5)
+ z3: R.Tensor((2, 3), dtype="float32") = alloc5
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = exp(x, alloc)
+ y1: R.Tensor((2, 3), dtype="float32") = alloc
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _1: R.Tuple() = exp(x, alloc1)
+ y2: R.Tensor((2, 3), dtype="float32") = alloc1
+ storage2: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc2: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage2, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _2: R.Tuple() = exp(x, alloc2)
+ y3: R.Tensor((2, 3), dtype="float32") = alloc2
+ t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")) = (
+ y1,
+ y2,
+ )
+ nt: R.Tuple(
+ R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")),
+ R.Tensor((2, 3), dtype="float32"),
+ ) = (t, y3)
+ nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3),
dtype="float32")) = nt[
+ 0
+ ]
+ y1_: R.Tensor((2, 3), dtype="float32") = nt0[0]
+ y2_: R.Tensor((2, 3), dtype="float32") = nt0[1]
+ y3_: R.Tensor((2, 3), dtype="float32") = nt[1]
+ storage3: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc3: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage3, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _3: R.Tuple() = exp(y1_, alloc3)
+ _4: R.Tuple() = R.memory.kill_tensor(alloc)
+ _11: R.Tuple() = R.memory.kill_tensor(alloc3)
+ z1: R.Tensor((2, 3), dtype="float32") = alloc3
+ alloc4: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _41: R.Tuple() = exp(y2_, alloc4)
+ _21: R.Tuple() = R.memory.kill_tensor(alloc1)
+ _31: R.Tuple() = R.memory.kill_tensor(alloc4)
+ z2: R.Tensor((2, 3), dtype="float32") = alloc4
+ alloc5: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage3, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _5: R.Tuple() = exp(y3_, alloc5)
+ _42: R.Tuple() = R.memory.kill_tensor(alloc2)
+ _51: R.Tuple() = R.memory.kill_tensor(alloc5)
+ z3: R.Tensor((2, 3), dtype="float32") = alloc5
+ _9: R.Tuple() = R.memory.kill_storage(storage)
+ _7: R.Tuple() = R.memory.kill_storage(storage1)
+ _8: R.Tuple() = R.memory.kill_storage(storage2)
+ _6: R.Tuple() = R.memory.kill_storage(storage3)
+ return x
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_call_func_other_than_primfunc():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ _ = R.add(x, alloc)
+ y: R.Tensor((2, 3), dtype="float32") = alloc
+ return x
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+def test_symbolic_shape():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def exp(var_A: T.handle, var_B: T.handle):
+ m = T.var("int64")
+ n = T.var("int64")
+ A = T.match_buffer(var_A, (m, n), "float32")
+ B = T.match_buffer(var_B, (m, n), "float32")
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")):
+ m = T.var("int64")
+ n = T.var("int64")
+ alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([m, n]), dtype="float32", runtime_device_index=0
+ )
+ _ = exp(x, alloc)
+ y: R.Tensor((m, n), dtype="float32") = alloc
+ return x
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+def test_zero_reference():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")):
+ alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([2, 3]), dtype="float32", runtime_device_index=0
+ )
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")):
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([2, 3]), dtype="float32"
+ )
+ _: R.Tuple() = R.memory.kill_storage(storage)
+ return x
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_reshape_param():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(
+ A: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"),
+ C: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"),
+ ):
+ T.evaluate(0)
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 50), dtype="float32"), y: R.Tensor((100,),
dtype="float32")
+ ) -> R.Tensor((2, 25, 2), dtype="float32"):
+ lv: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(x, (2, 25,
2))
+ lv1: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(y, (2, 25,
2))
+ alloc: R.Tensor((2, 25, 2), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([2, 25, 2]), dtype="float32", runtime_device_index=0
+ )
+ _: R.Tuple() = add(lv, lv1, alloc)
+ gv: R.Tensor((2, 25, 2), dtype="float32") = alloc
+ return gv
+
+ # The pass does no change.
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()