This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cf14eddebb [Unity][Transform] Memory planning for dynamic-shape func
return (#16111)
cf14eddebb is described below
commit cf14eddebb2b4121792a3449cf069cc2cd0c7e69
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jan 14 22:38:03 2024 -0500
[Unity][Transform] Memory planning for dynamic-shape func return (#16111)
This PR enhances the static block memory planning pass.
Prior to this PR, the memory planning only works on memory
allocation that is not externally referenced. In dynamic
shape settings, such memory allocation is not fully static
and may lead to memory fragmentation.
This PR enhances the behavior, so that for such memory
allocation, we first allocate a storage with regard to its
estimated upper bound (when known), and then allocate the
tensor with the actual dynamic shape out from the storage.
This will ensure the static memory allocation and avoid
memory fragmentation.
---
src/relax/transform/static_plan_block_memory.cc | 183 ++++++++++++++-------
.../test_transform_static_plan_block_memory.py | 62 +++++++
2 files changed, 190 insertions(+), 55 deletions(-)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 4ac41d92c0..e8e65c6ecc 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
std::vector<const BindingBlockNode*> block_stack_;
};
+/*!
+ * \brief Set the upper bound of the TIR variables that appear in
+ * the input function signature in the analyzer.
+ * \param func The function to be analyzed.
+ * \param ana The analyzer which contains the TIR var upper bounds.
+ */
+void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
+ // Use the attribute-annotated TIR var upper bounds as the TIR var values for
+ // memory planning.
+ // NOTE: we only apply the annotated upper bounds to the TIR variables that
+ // appear in the **function signature**.
+ Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
+ func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
+ .value_or(Map<ObjectRef, ObjectRef>());
+ std::unordered_map<String, IntImm> var_upper_bound_attr;
+ // We manually check the value type to ensure the values are all positive
IntImm.
+ for (auto it : var_upper_bound_attr_raw) {
+ const auto* key = it.first.as<StringObj>();
+ const auto* value = it.second.as<IntImmNode>();
+ CHECK(key != nullptr)
+ << "The entry key of attr `tir_var_upper_bound` should be string.
However "
+ << it.first->GetTypeKey() << " is got.";
+ CHECK(value != nullptr)
+ << "The entry value of attr `tir_var_upper_bound` should be integer.
However "
+ << it.second->GetTypeKey() << " is got.";
+ CHECK_GT(value->value, 0)
+ << "The entry value of attr `tir_var_upper_bound` should be a positive
integer, while "
+ << value->value << " is got.";
+ var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
+ }
+ Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
+ for (const tir::Var& tir_var : var_in_signature) {
+ auto it = var_upper_bound_attr.find(tir_var->name_hint);
+ if (it != var_upper_bound_attr.end()) {
+ ana->Bind(tir_var,
+ tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
+ tvm::IntImm(DataType::Int(64),
(*it).second->value + 1)));
+ }
+ }
+}
+
+/*!
+ * \brief Use the upper bounds of TIR vars to compute the upper
+ * bound of a given shape.
+ * \param shape The input shape to be computed.
+ * \param ana The arithmetic analyzer that contains the upper bounds
+ * of TIR variables
+ * \return The upper-bounded shape. When a dimension's upper bound
+ * cannot be determined, we keep the dimension unchanged.
+ */
+Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer*
ana) {
+ // Use the upper bounds of TIR vars as their values.
+ Array<PrimExpr> upper_bounded_shape;
+ upper_bounded_shape.reserve(shape.size());
+ for (const PrimExpr& dim_len : shape) {
+ int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
+ if (max_bound == std::numeric_limits<int64_t>::max()) {
+ upper_bounded_shape.push_back(dim_len);
+ } else {
+ upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
+ }
+ }
+ return upper_bounded_shape;
+}
+
+/*! \brief Check if a shape is static (a.k.a., has no TIR variable). */
+bool IsStaticShape(Array<PrimExpr> shape) {
+ for (const PrimExpr& dim : shape) {
+ const auto* int_len = dim.as<IntImmNode>();
+ if (!int_len) {
+ return false;
+ }
+ }
+ return true;
+}
+
/*!
* \brief The visitor class for storage token initialization.
* \details It goes through the entire function to get the storage tokens
@@ -330,40 +406,8 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
void VisitExpr_(const FunctionNode* func) final {
- // Use the attribute-annotated TIR var upper bounds as the TIR var values
for
- // memory planning.
- // NOTE: we only apply the annotated upper bounds to the TIR variables that
- // appear in the **function signature**.
- Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
- func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
- .value_or(Map<ObjectRef, ObjectRef>());
- std::unordered_map<String, IntImm> var_upper_bound_attr;
- // We manually check the value type to ensure the values are all positive
IntImm.
- for (auto it : var_upper_bound_attr_raw) {
- const auto* key = it.first.as<StringObj>();
- const auto* value = it.second.as<IntImmNode>();
- CHECK(key != nullptr)
- << "The entry key of attr `tir_var_upper_bound` should be string.
However "
- << it.first->GetTypeKey() << " is got.";
- CHECK(value != nullptr)
- << "The entry value of attr `tir_var_upper_bound` should be integer.
However "
- << it.second->GetTypeKey() << " is got.";
- CHECK_GT(value->value, 0)
- << "The entry value of attr `tir_var_upper_bound` should be a
positive integer, while "
- << value->value << " is got.";
- var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
- }
- Array<tir::Var> var_in_signature =
TIRVarsInStructInfo(GetStructInfo(GetRef<Function>(func)));
- var_upper_bound_.clear();
- for (const tir::Var& tir_var : var_in_signature) {
- auto it = var_upper_bound_attr.find(tir_var->name_hint);
- if (it != var_upper_bound_attr.end()) {
- ana_.Bind(tir_var, tvm::Range::FromMinExtent(
- tvm::IntImm(DataType::Int(64), 0),
- tvm::IntImm(DataType::Int(64),
(*it).second->value + 1)));
- }
- }
-
+ // Set the upper bound of TIR variables in the analyzer.
+ SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are
external referenced.
@@ -457,32 +501,20 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
// - the tensor has known dtype;
// - no storage token was created for this call before.
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
- const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(sinfo);
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
ICHECK(!sinfo->IsUnknownDtype());
ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
ICHECK(!token_map_.count(call));
// Use the upper bounds of TIR vars as their values.
- Array<PrimExpr> upper_bounded_shape;
- upper_bounded_shape.reserve(shape->values.size());
- for (const PrimExpr& dim_len : shape->values) {
- int64_t max_bound = ana_.const_int_bound(dim_len)->max_value;
- if (max_bound == std::numeric_limits<int64_t>::max()) {
- upper_bounded_shape.push_back(dim_len);
- } else {
- upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64),
max_bound));
- }
- }
+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values,
&ana_);
// No support for TIR vars that are not bounded.
- for (const PrimExpr& dim_len : upper_bounded_shape) {
- const auto* int_len = dim_len.as<IntImmNode>();
- if (!int_len) {
- token_map_[call] = Tokens();
- return Tokens();
- }
+ if (!IsStaticShape(upper_bounded_shape)) {
+ token_map_[call] = Tokens();
+ return Tokens();
}
// Create and set token.
@@ -558,8 +590,6 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
* a PrimFunc inside the IRModule.
*/
const IRModule& ctx_mod_;
- /*! \brief The mapping from TIR variables to their respective upper bound
values. */
- std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual>
var_upper_bound_;
/*! \brief The mapping from each token to the binding block where it is
created. */
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*>
token2block_;
/*! \brief The mapping from each token to the Exprs that are using this
token. */
@@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator {
if (func_ == nullptr) {
continue;
}
+ constexpr static const char* plan_dyn_attr_ =
"relax.memory_plan_dynamic_func_output";
+ plan_dynamic_output_ = static_cast<bool>(
+
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32),
0))->value);
+ if (plan_dynamic_output_) {
+ SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
+ }
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
+ if (plan_dynamic_output_) {
+ func = WithoutAttr(func, plan_dyn_attr_);
+ }
builder_->UpdateFunction(gv, func);
}
return builder_->GetContextIRModule();
@@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator {
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call) final {
+ static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
+ static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
auto it = alloc_tensor2token_.find(call);
if (it != alloc_tensor2token_.end()) {
+ // Case 1. This `alloc_tensor` is planned for memory reuse.
+ ICHECK_EQ(call->op, alloc_tensor_op);
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
ICHECK_NOTNULL(sinfo);
ICHECK_NOTNULL(sinfo->shape.as<ShapeExprNode>());
@@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator {
Var storage_var{nullptr};
auto it_token = token2storage_var_.find(token.get());
if (it_token == token2storage_var_.end()) {
- static const Op& mem_alloc_storage =
Op::Get("relax.memory.alloc_storage");
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
PrimValue virtual_device_index = runtime_device_index;
std::string storage_scope = "global";
@@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator {
}
// And always create a `memory.alloc_tensor` for the old
`builtin.alloc_tensor`.
- static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
PrimValue offset = PrimValue::Int64(0);
DataType dtype = sinfo->dtype;
return Call(mem_alloc_tensor, {storage_var, offset,
sinfo->shape.value(), DataTypeImm(dtype)},
Attrs());
+ } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) {
+ // Case 2. For a `alloc_tensor` that is not planned for memory reuse,
+ // we would still like to allocate **static** memory for the tensor.
+ // So in case the tensor shape is dynamic but has an upper bound
+ // estimation, we allocate a storage to its upper bound size, and
+ // allocate a tensor out from it with the actual symbolic shape.
+
+ const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ ICHECK_NOTNULL(sinfo);
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ ICHECK_NOTNULL(shape);
+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values,
&ana_);
+ if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape))
{
+ ICHECK(!sinfo->IsUnknownDtype());
+ ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
+ StorageToken token(upper_bounded_shape, sinfo->dtype);
+ Call alloc_storage(mem_alloc_storage,
+ {/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64),
token->bytes)}),
+
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
+ /*storage_scope=*/StringImm("global"), //
+ /*dtype=*/DataTypeImm(token->dtype)});
+ Var storage = builder_->Emit(alloc_storage, "storage");
+ return Call(mem_alloc_tensor, {storage, //
+ /*offset=*/PrimValue::Int64(0),
+ /*shape=*/GetRef<ShapeExpr>(shape), //
+ /*dtype=*/DataTypeImm(sinfo->dtype)});
+ }
}
return ExprMutator::VisitExpr_(call);
}
+ /*! \brief The arithmetic analyzer. */
+ arith::Analyzer ana_;
+ /*! \brief A boolean indicating whether to plan dynamic-shape function
output tensors. */
+ bool plan_dynamic_output_;
/*!
* \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
its corresponding underlying storage token that it is using.
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 9d1fe4fd40..783f18ee98 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1109,6 +1109,68 @@ def test_call_tir_dyn():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_call_tir_dyn_plan_dynamic_func_output():
+ # fmt: off
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def tir_full(var_full: T.handle, n: T.int64):
+ T.evaluate(0)
+
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure":
True, "relax.memory_plan_dynamic_func_output": True})
+ cls = Module
+ alloc: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+ full: R.Tensor((n,), dtype="float32") = alloc
+ alloc1: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _1: R.Tuple = cls.tir_exp(full, alloc1)
+ lv2: R.Tensor((n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+ _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+ lv3: R.Tensor((n,), dtype="float32") = alloc2
+ return lv3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def tir_full(var_full: T.handle, n: T.int64):
+ T.evaluate(0)
+
+ @R.function
+ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure":
True})
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(R.shape([80]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc: R.Tensor((n,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]),
R.dtype("float32"))
+ _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+ full: R.Tensor((n,), dtype="float32") = alloc
+ storage1: R.Object = R.memory.alloc_storage(R.shape([80]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc1: R.Tensor((n,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]),
R.dtype("float32"))
+ _1: R.Tuple = cls.tir_exp(full, alloc1)
+ lv2: R.Tensor((n,), dtype="float32") = alloc1
+ storage2: R.Object = R.memory.alloc_storage(R.shape([80]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc2: R.Tensor((n,), dtype="float32") =
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n]),
R.dtype("float32"))
+ _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+ lv3: R.Tensor((n,), dtype="float32") = alloc2
+ return lv3
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_function_independence():
# fmt: off
@tvm.script.ir_module