This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b2f641b52 [Unity] Support storage reuse for dynamic shapes (#16500)
5b2f641b52 is described below

commit 5b2f641b52f28230d8393c45913fd4085a4274d6
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Feb 2 06:31:57 2024 -0800

    [Unity] Support storage reuse for dynamic shapes (#16500)
    
    Before this PR, dynamic shapes require upper bound of
    variables to be provided in order to use storage planning.
    We can relax this requirement, for shapes with unknown bound,
    we can look up other tensors with the same symbolic
    shapes. This can be helpful for deep learning models
    where the layers with the same configurations are usually
    repeated since there are many objects with the same shapes.
    
    This PR changed the `StorageToken` to use `PrimExpr`
    bytes which can be integer or symbolic. For symbolic
    shapes, we put the tokens into a special buckets for looking up.
---
 src/relax/transform/static_plan_block_memory.cc    | 99 +++++++++++++++-------
 tests/python/relax/test_dataflow_pattern.py        | 35 ++++----
 .../test_transform_static_plan_block_memory.py     | 47 ++++++++--
 3 files changed, 125 insertions(+), 56 deletions(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index da1e706efe..2d8990d90b 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -99,12 +99,23 @@ class StorageTokenNode : public Object {
   /*! \brief Reference counter. */
   int ref_counter{0};
   /*! \brief Number of bytes that this token requires. */
-  int64_t bytes;
+  PrimExpr bytes;
   /*! \brief The dtype of this token. */
   DataType dtype;
   /*! \brief The storage id, reserved for debug and demo use. */
   int storage_id{-1};
 
+  /*! \brief Get the constant number of bytes that this token requires, or -1 
if the number of bytes
+   * is symbolic */
+  int64_t const_bytes() const {
+    const int64_t* const_val = tir::as_const_int(bytes);
+    if (const_val) {
+      return *const_val;
+    } else {
+      return -1;
+    }
+  }
+
   static constexpr const char* _type_key = "relax.transform.StorageToken";
   TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
 };
@@ -117,19 +128,22 @@ class StorageToken : public ObjectRef {
  public:
   explicit StorageToken(Array<PrimExpr> shape, DataType dtype) {
     // Compute the tensor size from the shape.
-    int64_t size = 1;
+    int64_t const_coeff = dtype.bytes() * dtype.lanes();
+    PrimExpr size = tir::make_const(DataType::Int(64), 1);
     for (const PrimExpr& dim_len : shape) {
-      const auto* int_len = dim_len.as<IntImmNode>();
-      ICHECK_NOTNULL(int_len);
-      size *= int_len->value;
+      if (const IntImmNode* const_dim_len = dim_len.as<IntImmNode>()) {
+        const_coeff *= const_dim_len->value;
+      } else {
+        size *= dim_len;
+      }
     }
+    size = tir::make_const(DataType::Int(64), const_coeff) * size;
 
     ObjectPtr<StorageTokenNode> n = make_object<StorageTokenNode>();
-    n->bytes = size * dtype.bytes() * dtype.lanes();
+    n->bytes = size;
     n->dtype = dtype;
     data_ = std::move(n);
   }
-
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, 
StorageTokenNode);
 };
 
@@ -143,6 +157,8 @@ using Tokens = NestedMsg<StorageToken>;
  */
 class TokenAllocator1D {
  public:
+  explicit TokenAllocator1D(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
+
   /*!
    * \brief Request a storage token from the available token pool for a
    * given prototype, or report no appropriate available token in the pool.
@@ -162,8 +178,24 @@ class TokenAllocator1D {
     // Step 1. Get the available pool of the token dtype.
     std::multimap<int64_t, StorageToken>& pool = 
available_pool_[prototype->dtype];
 
+    int64_t size = prototype->const_bytes();
+    if (size == -1) {
+      // Handle the case where the prototype token has dynamic size. Currently 
it requires the
+      // symbolic size to be the same as the prototype token in order to reuse 
the storage.
+      auto [begin, end] = pool.equal_range(size);
+      for (; begin != end; ++begin) {
+        StorageToken available_token = begin->second;
+        if (analyzer_->CanProveEqual(prototype->bytes, 
available_token->bytes)) {
+          ICHECK_EQ(available_token->ref_counter, 0)
+              << "Available tokens are expected to have 0 reference.";
+          available_token->ref_counter = prototype->ref_counter;
+          pool.erase(begin);
+          return available_token;
+        }
+      }
+      return NullOpt;
+    }
     // Step 2. Get the range of memory blocks in [size / match_range_, size * 
match_range_)
-    int64_t size = prototype->bytes;
     auto begin = pool.lower_bound(size / match_range_);
     auto mid = pool.lower_bound(size);
     auto end = pool.upper_bound(size * match_range_);
@@ -172,7 +204,7 @@ class TokenAllocator1D {
       StorageToken available_token = mid->second;
       ICHECK_EQ(available_token->ref_counter, 0)
           << "Available tokens are expected to have 0 reference.";
-      ICHECK_LE(size, available_token->bytes);
+      ICHECK_LE(size, available_token->const_bytes());
       available_token->ref_counter = prototype->ref_counter;
       pool.erase(mid);
       return available_token;
@@ -181,11 +213,13 @@ class TokenAllocator1D {
     if (mid != begin) {
       --mid;
       StorageToken available_token = mid->second;
+      int64_t available_size = available_token->const_bytes();
       ICHECK_EQ(available_token->ref_counter, 0)
           << "Available tokens are expected to have 0 reference.";
-      ICHECK_GE(size, available_token->bytes);
+      ICHECK_GE(available_size, 0);
+      ICHECK_GE(size, available_size);
       // Enlarge the token size.
-      available_token->bytes = size;
+      available_token->bytes = tir::make_const(DataType::Int(64), size);
       available_token->ref_counter = prototype->ref_counter;
       pool.erase(mid);
       return available_token;
@@ -216,7 +250,7 @@ class TokenAllocator1D {
     ICHECK_GE(token->storage_id, 0)
         << "The token to be released is expected to be allocated before";
     ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected 
to have 0 reference.";
-    available_pool_[token->dtype].insert({token->bytes, token});
+    available_pool_[token->dtype].insert({token->const_bytes(), token});
   }
 
   /*! \brief Clear the allocator. */
@@ -226,6 +260,8 @@ class TokenAllocator1D {
   }
 
  private:
+  /*! \brief The arithmetic analyzer. */
+  arith::Analyzer* analyzer_;
   /*! \brief A constant scale representing the token search range. */
   const int match_range_{16};
   /*! \brief The pool of available storage tokens for each dtype. */
@@ -385,10 +421,12 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
   /*!
    * \brief The entry of the initialization.
    * \param mod The IRModule to be planned
+   * \param analyzer The arithmetic analyzer.
    * \return The mapping from each Expr to the token it uses.
    */
-  static std::unordered_map<const ExprNode*, Tokens> Initialize(const 
IRModule& mod) {
-    StorageAllocatorInit initializer(mod);
+  static std::unordered_map<const ExprNode*, Tokens> Initialize(const 
IRModule& mod,
+                                                                
arith::Analyzer* analyzer) {
+    StorageAllocatorInit initializer(mod, analyzer);
 
     for (auto it : mod->functions) {
       const auto* func = it.second.as<FunctionNode>();
@@ -403,11 +441,12 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
  private:
   using ExprVisitor::VisitExpr_;
 
-  explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
+  explicit StorageAllocatorInit(const IRModule& ctx_mod, arith::Analyzer* 
analyzer)
+      : ctx_mod_(ctx_mod), analyzer_(analyzer) {}
 
   void VisitExpr_(const FunctionNode* func) final {
     // Set the upper bound of TIR variables in the analyzer.
-    SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
+    SetTIRVarUpperBound(GetRef<Function>(func), analyzer_);
     // Recurse into the function to get its tokens.
     Tokens body_tokens = GetTokens(func->body);
     // Discard the tokens used by the function return value, as they are 
external referenced.
@@ -508,14 +547,9 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
     ICHECK(!token_map_.count(call));
 
-    // Use the upper bounds of TIR vars as their values.
-    Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, 
&ana_);
-
-    // No support for TIR vars that are not bounded.
-    if (!IsStaticShape(upper_bounded_shape)) {
-      token_map_[call] = Tokens();
-      return Tokens();
-    }
+    // Use the upper bounds of TIR vars as their values. The upper bound shape 
can still be dynamic
+    // if the upper bounds of some variables are not provided.
+    Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, 
analyzer_);
 
     // Create and set token.
     StorageToken token(upper_bounded_shape, sinfo->dtype);
@@ -583,13 +617,13 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     token2block_.erase(token_to_discard.get());
   }
 
-  /*! \brief The arithmetic analyzer. */
-  arith::Analyzer ana_;
   /*!
    * \brief The context IRModule, used for checking if a callee function is
    * a PrimFunc inside the IRModule.
    */
   const IRModule& ctx_mod_;
+  /*! \brief The arithmetic analyzer. */
+  arith::Analyzer* analyzer_;
   /*! \brief The mapping from each token to the binding block where it is 
created. */
   std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> 
token2block_;
   /*! \brief The mapping from each token to the Exprs that are using this 
token. */
@@ -612,7 +646,9 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
  */
 class StorageAllocator : public StorageAllocatorBaseVisitor {
  public:
-  explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> 
token_map) {
+  explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> 
token_map,
+                            arith::Analyzer* analyzer)
+      : allocator_(analyzer) {
     this->token_map_ = std::move(token_map);
   }
 
@@ -797,7 +833,7 @@ class StorageAllocationRewriter : public ExprMutator {
       Var storage_var{nullptr};
       auto it_token = token2storage_var_.find(token.get());
       if (it_token == token2storage_var_.end()) {
-        ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
+        ShapeExpr size({token->bytes});
         PrimValue virtual_device_index = runtime_device_index;
         std::string storage_scope = "global";
         DataType dtype = token->dtype;
@@ -868,10 +904,13 @@ class StorageAllocationRewriter : public ExprMutator {
 };
 
 IRModule StaticPlanBlockMemory(IRModule mod) {
+  arith::Analyzer ana;
+
   // Step 1. Initialize.
-  std::unordered_map<const ExprNode*, Tokens> token_map = 
StorageAllocatorInit::Initialize(mod);
+  std::unordered_map<const ExprNode*, Tokens> token_map =
+      StorageAllocatorInit::Initialize(mod, &ana);
   // Step 2. Collect the memory allocation info.
-  StorageAllocator allocator(std::move(token_map));
+  StorageAllocator allocator(std::move(token_map), &ana);
   allocator.Allocate(mod);
   // Step 3. Rewrite the function.
   StorageAllocationRewriter rewriter(std::move(mod),  //
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 39b3c40c26..cf2a0cde84 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1241,29 +1241,27 @@ def test_combine_transposed_matmul_twice():
             lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, 
axes=None)
             lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, 
lv1, out_dtype="void")
             lv3: R.Tuple(
-                R.Tensor((2, 640, 1280), dtype="float32"),
-                R.Tensor((2, 384, 1280), dtype="float32"),
-                R.Tensor((2, 0, 1280), dtype="float32"),
-            ) = R.split(lv2, indices_or_sections=[640, 1280], axis=1)
-            lv0: R.Tensor((2, 640, 1280), dtype="float32") = lv3[0]
-            lv1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3[1]
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+            ) = R.split(lv2, indices_or_sections=[640], axis=-1)
+            lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
+            lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
             lv_1: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), 
axis=0)
             lv1_2: R.Tensor((640, 1280), dtype="float32") = 
R.permute_dims(lv_1, axes=None)
             lv2_1: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(
                 x2, lv1_2, out_dtype="void"
             )
             lv3_1: R.Tuple(
-                R.Tensor((2, 640, 1280), dtype="float32"),
-                R.Tensor((2, 384, 1280), dtype="float32"),
-                R.Tensor((2, 0, 1280), dtype="float32"),
-            ) = R.split(lv2_1, indices_or_sections=[640, 1280], axis=1)
-            lv2_1_1: R.Tensor((2, 640, 1280), dtype="float32") = lv3_1[0]
-            lv3_1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3_1[1]
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+            ) = R.split(lv2_1, indices_or_sections=[640], axis=-1)
+            lv2_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[0]
+            lv3_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[1]
             out: R.Tuple(
-                R.Tensor((2, 640, 1280), dtype="float32"),
-                R.Tensor((2, 384, 1280), dtype="float32"),
-                R.Tensor((2, 640, 1280), dtype="float32"),
-                R.Tensor((2, 384, 1280), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
             ) = (lv0, lv1_1, lv2_1_1, lv3_1_1)
             R.output(out)
         return out
@@ -1282,9 +1280,9 @@ def test_combine_transposed_matmul_twice():
 
             concat = R.concat([w1, w2], axis=0)
             matmul = R.matmul(inp, R.permute_dims(concat))
-            sections = [w1.struct_info.shape[0], w1.struct_info.shape[0] + 
w2.struct_info.shape[0]]
+            sections = [w1.struct_info.shape[0]]
 
-            chunks = R.split(matmul, sections, 1)
+            chunks = R.split(matmul, sections, -1)
 
             return {
                 matchings[matmul1]: chunks[0],
@@ -1297,6 +1295,7 @@ def test_combine_transposed_matmul_twice():
         # make sure it builds
         mod = tvm.IRModule()
         mod["main"] = rewritten
+        print(mod)
 
         rx.build(mod, target="llvm")
 
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 0398e9cab8..83eff854a4 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -701,9 +701,34 @@ def test_symbolic_shape():
             y: R.Tensor((m, n), dtype="float32") = alloc
             return x
 
-    # The pass does no change.
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def exp(var_A: T.handle, var_B: T.handle):
+            m = T.int64()
+            n = T.int64()
+            A = T.match_buffer(var_A, (m, n), "float32")
+            B = T.match_buffer(var_B, (m, n), "float32")
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([4 * (m * n)]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            alloc: R.Tensor((m, n), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([m, n]), R.dtype("float32")
+            )
+            _: R.Tuple = cls.exp(x, alloc)
+            y: R.Tensor((m, n), dtype="float32") = alloc
+            return x
+
     mod = relax.transform.StaticPlanBlockMemory()(Module)
-    tvm.ir.assert_structural_equal(mod, Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
 
 
 def test_zero_reference():
@@ -1198,7 +1223,10 @@ def test_call_tir_dyn_plan_partially_dynamic():
             alloc2: R.Tensor((n, m), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
             _2: R.Tuple = cls.tir_exp(lv2, alloc2)
             lv3: R.Tensor((n, m), dtype="float32") = alloc2
-            return lv3
+            alloc3: R.Tensor((n, m), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
+            _3: R.Tuple = cls.tir_exp(lv3, alloc3)
+            lv4: R.Tensor((n, m), dtype="float32") = alloc3
+            return lv4
 
     @I.ir_module
     class Expected:
@@ -1216,19 +1244,22 @@ def test_call_tir_dyn_plan_partially_dynamic():
             m = T.int64()
             R.func_attr({"relax.force_pure": True, "tir_var_upper_bound": 
{"n": 20}})
             cls = Expected
-            storage: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            storage: R.Object = R.memory.alloc_storage(R.shape([80 * m]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
             alloc: R.Tensor((n, m), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), 
R.dtype("float32"))
             _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, 
m])))
             full: R.Tensor((n, m), dtype="float32") = alloc
-            storage1: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([80 * m]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
             alloc1: R.Tensor((n, m), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m]), 
R.dtype("float32"))
             _1: R.Tuple = cls.tir_exp(full, alloc1)
             lv2: R.Tensor((n, m), dtype="float32") = alloc1
-            storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
-            alloc2: R.Tensor((n, m), dtype="float32") = 
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), 
R.dtype("float32"))
+            alloc2: R.Tensor((n, m), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), 
R.dtype("float32"))
             _2: R.Tuple = cls.tir_exp(lv2, alloc2)
             lv3: R.Tensor((n, m), dtype="float32") = alloc2
-            return lv3
+            storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc3: R.Tensor((n, m), dtype="float32") = 
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), 
R.dtype("float32"))
+            _3: R.Tuple = cls.tir_exp(lv3, alloc3)
+            lv4 = alloc3
+            return lv4
     # fmt: on
 
     mod = relax.transform.StaticPlanBlockMemory()(Module)

Reply via email to