This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cf14eddebb [Unity][Transform] Memory planning for dynamic-shape func 
return (#16111)
cf14eddebb is described below

commit cf14eddebb2b4121792a3449cf069cc2cd0c7e69
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jan 14 22:38:03 2024 -0500

    [Unity][Transform] Memory planning for dynamic-shape func return (#16111)
    
    This PR enhances the static block memory planning pass.
    Prior to this PR, the memory planning only works on memory
    allocation that is not externally referenced. In dynamic
    shape settings, such memory allocation is not fully static
    and may lead to memory fragmentation.
    
    This PR enhances the behavior, so that for such memory
    allocation, we first allocate a storage with regard to its
    estimated upper bound (when known), and then allocate the
    tensor with the actual dynamic shape out from the storage.
    This will ensure the static memory allocation and avoid
    memory fragmentation.
---
 src/relax/transform/static_plan_block_memory.cc    | 183 ++++++++++++++-------
 .../test_transform_static_plan_block_memory.py     |  62 +++++++
 2 files changed, 190 insertions(+), 55 deletions(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 4ac41d92c0..e8e65c6ecc 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
   std::vector<const BindingBlockNode*> block_stack_;
 };
 
+/*!
+ * \brief Set the upper bound of the TIR variables that appear in
+ * the input function signature in the analyzer.
+ * \param func The function to be analyzed.
+ * \param ana The analyzer which contains the TIR var upper bounds.
+ */
+void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
+  // Use the attribute-annotated TIR var upper bounds as the TIR var values for
+  // memory planning.
+  // NOTE: we only apply the annotated upper bounds to the TIR variables that
+  // appear in the **function signature**.
+  Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
+      func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
+          .value_or(Map<ObjectRef, ObjectRef>());
+  std::unordered_map<String, IntImm> var_upper_bound_attr;
+  // We manually check the value type to ensure the values are all positive 
IntImm.
+  for (auto it : var_upper_bound_attr_raw) {
+    const auto* key = it.first.as<StringObj>();
+    const auto* value = it.second.as<IntImmNode>();
+    CHECK(key != nullptr)
+        << "The entry key of attr `tir_var_upper_bound` should be string. 
However "
+        << it.first->GetTypeKey() << " is got.";
+    CHECK(value != nullptr)
+        << "The entry value of attr `tir_var_upper_bound` should be integer. 
However "
+        << it.second->GetTypeKey() << " is got.";
+    CHECK_GT(value->value, 0)
+        << "The entry value of attr `tir_var_upper_bound` should be a positive 
integer, while "
+        << value->value << " is got.";
+    var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
+  }
+  Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
+  for (const tir::Var& tir_var : var_in_signature) {
+    auto it = var_upper_bound_attr.find(tir_var->name_hint);
+    if (it != var_upper_bound_attr.end()) {
+      ana->Bind(tir_var,
+                tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
+                                          tvm::IntImm(DataType::Int(64), 
(*it).second->value + 1)));
+    }
+  }
+}
+
+/*!
+ * \brief Use the upper bounds of TIR vars to compute the upper
+ * bound of a given shape.
+ * \param shape The input shape to be computed.
+ * \param ana The arithmetic analyzer that contains the upper bounds
+ * of TIR variables
+ * \return The upper-bounded shape. When a dimension's upper bound
+ * cannot be determined, we keep the dimension unchanged.
+ */
+Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* 
ana) {
+  // Use the upper bounds of TIR vars as their values.
+  Array<PrimExpr> upper_bounded_shape;
+  upper_bounded_shape.reserve(shape.size());
+  for (const PrimExpr& dim_len : shape) {
+    int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
+    if (max_bound == std::numeric_limits<int64_t>::max()) {
+      upper_bounded_shape.push_back(dim_len);
+    } else {
+      upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
+    }
+  }
+  return upper_bounded_shape;
+}
+
+/*! \brief Check if a shape is static (a.k.a., has no TIR variable). */
+bool IsStaticShape(Array<PrimExpr> shape) {
+  for (const PrimExpr& dim : shape) {
+    const auto* int_len = dim.as<IntImmNode>();
+    if (!int_len) {
+      return false;
+    }
+  }
+  return true;
+}
+
 /*!
  * \brief The visitor class for storage token initialization.
  * \details It goes through the entire function to get the storage tokens
@@ -330,40 +406,8 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
   explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
 
   void VisitExpr_(const FunctionNode* func) final {
-    // Use the attribute-annotated TIR var upper bounds as the TIR var values 
for
-    // memory planning.
-    // NOTE: we only apply the annotated upper bounds to the TIR variables that
-    // appear in the **function signature**.
-    Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
-        func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
-            .value_or(Map<ObjectRef, ObjectRef>());
-    std::unordered_map<String, IntImm> var_upper_bound_attr;
-    // We manually check the value type to ensure the values are all positive 
IntImm.
-    for (auto it : var_upper_bound_attr_raw) {
-      const auto* key = it.first.as<StringObj>();
-      const auto* value = it.second.as<IntImmNode>();
-      CHECK(key != nullptr)
-          << "The entry key of attr `tir_var_upper_bound` should be string. 
However "
-          << it.first->GetTypeKey() << " is got.";
-      CHECK(value != nullptr)
-          << "The entry value of attr `tir_var_upper_bound` should be integer. 
However "
-          << it.second->GetTypeKey() << " is got.";
-      CHECK_GT(value->value, 0)
-          << "The entry value of attr `tir_var_upper_bound` should be a 
positive integer, while "
-          << value->value << " is got.";
-      var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
-    }
-    Array<tir::Var> var_in_signature = 
TIRVarsInStructInfo(GetStructInfo(GetRef<Function>(func)));
-    var_upper_bound_.clear();
-    for (const tir::Var& tir_var : var_in_signature) {
-      auto it = var_upper_bound_attr.find(tir_var->name_hint);
-      if (it != var_upper_bound_attr.end()) {
-        ana_.Bind(tir_var, tvm::Range::FromMinExtent(
-                               tvm::IntImm(DataType::Int(64), 0),
-                               tvm::IntImm(DataType::Int(64), 
(*it).second->value + 1)));
-      }
-    }
-
+    // Set the upper bound of TIR variables in the analyzer.
+    SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
     // Recurse into the function to get its tokens.
     Tokens body_tokens = GetTokens(func->body);
     // Discard the tokens used by the function return value, as they are 
external referenced.
@@ -457,32 +501,20 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     // - the tensor has known dtype;
     // - no storage token was created for this call before.
     const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
-    const auto* shape = sinfo->shape.as<ShapeExprNode>();
     ICHECK_NOTNULL(sinfo);
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
     ICHECK_NOTNULL(shape);
     ICHECK(!sinfo->IsUnknownDtype());
     ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
     ICHECK(!token_map_.count(call));
 
     // Use the upper bounds of TIR vars as their values.
-    Array<PrimExpr> upper_bounded_shape;
-    upper_bounded_shape.reserve(shape->values.size());
-    for (const PrimExpr& dim_len : shape->values) {
-      int64_t max_bound = ana_.const_int_bound(dim_len)->max_value;
-      if (max_bound == std::numeric_limits<int64_t>::max()) {
-        upper_bounded_shape.push_back(dim_len);
-      } else {
-        upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), 
max_bound));
-      }
-    }
+    Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, 
&ana_);
 
     // No support for TIR vars that are not bounded.
-    for (const PrimExpr& dim_len : upper_bounded_shape) {
-      const auto* int_len = dim_len.as<IntImmNode>();
-      if (!int_len) {
-        token_map_[call] = Tokens();
-        return Tokens();
-      }
+    if (!IsStaticShape(upper_bounded_shape)) {
+      token_map_[call] = Tokens();
+      return Tokens();
     }
 
     // Create and set token.
@@ -558,8 +590,6 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
    * a PrimFunc inside the IRModule.
    */
   const IRModule& ctx_mod_;
-  /*! \brief The mapping from TIR variables to their respective upper bound 
values. */
-  std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual> 
var_upper_bound_;
   /*! \brief The mapping from each token to the binding block where it is 
created. */
   std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> 
token2block_;
   /*! \brief The mapping from each token to the Exprs that are using this 
token. */
@@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator {
       if (func_ == nullptr) {
         continue;
       }
+      constexpr static const char* plan_dyn_attr_ = 
"relax.memory_plan_dynamic_func_output";
+      plan_dynamic_output_ = static_cast<bool>(
+          
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 
0))->value);
+      if (plan_dynamic_output_) {
+        SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
+      }
       token2storage_var_.clear();
       Function func = Downcast<Function>(this->VisitExpr_(func_));
+      if (plan_dynamic_output_) {
+        func = WithoutAttr(func, plan_dyn_attr_);
+      }
       builder_->UpdateFunction(gv, func);
     }
     return builder_->GetContextIRModule();
@@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator {
   using ExprMutator::VisitExpr_;
 
   Expr VisitExpr_(const CallNode* call) final {
+    static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+    static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
+    static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
     auto it = alloc_tensor2token_.find(call);
     if (it != alloc_tensor2token_.end()) {
+      // Case 1. This `alloc_tensor` is planned for memory reuse.
+      ICHECK_EQ(call->op, alloc_tensor_op);
       const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
       ICHECK_NOTNULL(sinfo);
       ICHECK_NOTNULL(sinfo->shape.as<ShapeExprNode>());
@@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator {
       Var storage_var{nullptr};
       auto it_token = token2storage_var_.find(token.get());
       if (it_token == token2storage_var_.end()) {
-        static const Op& mem_alloc_storage = 
Op::Get("relax.memory.alloc_storage");
         ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
         PrimValue virtual_device_index = runtime_device_index;
         std::string storage_scope = "global";
@@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator {
       }
 
       // And always create a `memory.alloc_tensor` for the old 
`builtin.alloc_tensor`.
-      static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
       PrimValue offset = PrimValue::Int64(0);
       DataType dtype = sinfo->dtype;
       return Call(mem_alloc_tensor, {storage_var, offset, 
sinfo->shape.value(), DataTypeImm(dtype)},
                   Attrs());
+    } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) {
+      // Case 2. For a `alloc_tensor` that is not planned for memory reuse,
+      // we would still like to allocate **static** memory for the tensor.
+      // So in case the tensor shape is dynamic but has an upper bound
+      // estimation, we allocate a storage to its upper bound size, and
+      // allocate a tensor out from it with the actual symbolic shape.
+
+      const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
+      ICHECK_NOTNULL(sinfo);
+      const auto* shape = sinfo->shape.as<ShapeExprNode>();
+      ICHECK_NOTNULL(shape);
+      Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, 
&ana_);
+      if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) 
{
+        ICHECK(!sinfo->IsUnknownDtype());
+        ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
+        StorageToken token(upper_bounded_shape, sinfo->dtype);
+        Call alloc_storage(mem_alloc_storage,
+                           {/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), 
token->bytes)}),
+                            
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
+                            /*storage_scope=*/StringImm("global"),  //
+                            /*dtype=*/DataTypeImm(token->dtype)});
+        Var storage = builder_->Emit(alloc_storage, "storage");
+        return Call(mem_alloc_tensor, {storage,  //
+                                       /*offset=*/PrimValue::Int64(0),
+                                       /*shape=*/GetRef<ShapeExpr>(shape),  //
+                                       /*dtype=*/DataTypeImm(sinfo->dtype)});
+      }
     }
 
     return ExprMutator::VisitExpr_(call);
   }
 
+  /*! \brief The arithmetic analyzer. */
+  arith::Analyzer ana_;
+  /*! \brief A boolean indicating whether to plan dynamic-shape function 
output tensors. */
+  bool plan_dynamic_output_;
   /*!
    * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
    its corresponding underlying storage token that it is using.
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 9d1fe4fd40..783f18ee98 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1109,6 +1109,68 @@ def test_call_tir_dyn():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_call_tir_dyn_plan_dynamic_func_output():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def tir_full(var_full: T.handle, n: T.int64):
+            T.evaluate(0)
+
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": 
True, "relax.memory_plan_dynamic_func_output": True})
+            cls = Module
+            alloc: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+            full: R.Tensor((n,), dtype="float32") = alloc
+            alloc1: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _1: R.Tuple = cls.tir_exp(full, alloc1)
+            lv2: R.Tensor((n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
+            _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+            lv3: R.Tensor((n,), dtype="float32") = alloc2
+            return lv3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def tir_full(var_full: T.handle, n: T.int64):
+            T.evaluate(0)
+
+        @R.function
+        def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": 
True})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([80]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), 
R.dtype("float32"))
+            _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
+            full: R.Tensor((n,), dtype="float32") = alloc
+            storage1: R.Object = R.memory.alloc_storage(R.shape([80]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.tir_exp(full, alloc1)
+            lv2: R.Tensor((n,), dtype="float32") = alloc1
+            storage2: R.Object = R.memory.alloc_storage(R.shape([80]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc2: R.Tensor((n,), dtype="float32") = 
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n]), 
R.dtype("float32"))
+            _2: R.Tuple = cls.tir_exp(lv2, alloc2)
+            lv3: R.Tensor((n,), dtype="float32") = alloc2
+            return lv3
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_function_independence():
     # fmt: off
     @tvm.script.ir_module

Reply via email to