This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bd003e847b [Unity] Memory planning with TIR var upper bound (#14511)
bd003e847b is described below

commit bd003e847bdfef6655b5e1bbea7218c5f0864f67
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Apr 5 17:49:42 2023 -0400

    [Unity] Memory planning with TIR var upper bound (#14511)
    
    This PR enhances the memory planning pass with TIR var upper bound
    annotation, so that it supports dynamic shape cases when we have
    an estimate for the maximum values of the dynamic vars.
---
 include/tvm/relax/transform.h                      |  13 ++
 python/tvm/relax/analysis/estimate_memory_usage.py |   7 +-
 python/tvm/relax/transform/transform.py            |  14 ++
 src/relax/transform/static_plan_block_memory.cc    |  52 ++++++-
 .../test_transform_static_plan_block_memory.py     | 173 +++++++++++++++++++++
 5 files changed, 254 insertions(+), 5 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index f6acf80beb..7ebb48ef12 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -109,6 +109,19 @@ TVM_DLL Pass RewriteDataflowReshape();
  * The pass will reuse allocated memory to its best effort, in order to
  * reduce the total amount of allocated memory size.
  *
+ * The pass "supports" dynamic shape in the way of TIR variable upper bound
+ * annotation. We can optionally annotate the attribute "tir_var_upper_bound"
+ * to Relax functions. The attribute value is a dict from strings to integers,
+ * denoting the name of TIR variables to the upper bound values of the TIR 
vars.
+ * Note: The annotated upper bound attribute only applies to TIR vars in the
+ * function signature for clarity.
+ *
+ * For example, we can annotate a Relax function with
+ *   `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
+ * It means the maximum value of variable that names "n" in the function
+ * signature will have upper bound 1024. And we will use 1024 as its value
+ * during memory planning.
+ *
  * \return The pass.
  */
 TVM_DLL Pass StaticPlanBlockMemory();
diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py 
b/python/tvm/relax/analysis/estimate_memory_usage.py
index 55f82740ec..014a8e0d49 100644
--- a/python/tvm/relax/analysis/estimate_memory_usage.py
+++ b/python/tvm/relax/analysis/estimate_memory_usage.py
@@ -153,9 +153,10 @@ def estimate_memory_usage(mod: Union[IRModule, Function]) 
-> str:
                 "memory allocation(s) with total size "
                 "{0:.4} GB.\n".format(self.planned_alloc_mem / 2**30)
             )
-            est += " * Memory planning reduces constant memory size to " 
"{0:.1%}.".format(
-                self.planned_alloc_mem / self.total_alloc_tensor_mem
-            )
+            if self.total_alloc_tensor_mem != 0:
+                est += " * Memory planning reduces constant memory size to " 
"{0:.1%}.".format(
+                    self.planned_alloc_mem / self.total_alloc_tensor_mem
+                )
             return "- Function " + func_name + ":\n" + est
 
     if isinstance(mod, Function):
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 049ac2947f..85e36186c2 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -131,6 +131,20 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
     """The static memory planning pass on BindingBlock level.
     The pass will reuse allocated memory to its best effort, in order to
     reduce the total amount of allocated memory size.
+
+    The pass "supports" dynamic shape in the way of TIR variable upper bound
+    annotation. We can optionally annotate the attribute "tir_var_upper_bound"
+    to Relax functions. The attribute value is a dict from strings to integers,
+    denoting the name of TIR variables to the upper bound values of the TIR 
vars.
+    Note: The annotated upper bound attribute only applies to TIR vars in the
+    function signature for clarity.
+
+    For example, we can annotate a Relax function with
+      `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
+    It means the maximum value of variable that names "n" in the function
+    signature will have upper bound 1024. And we will use 1024 as its value
+    during memory planning.
+
     Returns
     -------
     ret : tvm.ir.transform.Pass
diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 952513db4c..84e69f3d47 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -48,11 +48,26 @@
  * - insert kill_storage at the end of each binding block, for all the storage
  * tokens that are allocated inside the binding block, as the memory planning
  * only works on block level.
+ *
+ * The memory planning pass "supports" dynamic shape in the way of TIR variable
+ * upper bound annotation. To be more specific, we can annotate the attribute
+ * "tir_var_upper_bound" to Relax functions. The attribute value is a dict from
+ * strings to integers, denoting the name of TIR variables to the upper bound
+ * values of the TIR vars. **The annotated upper bound attribute only applies
+ * to TIR vars in the function signature for clarity.**
+ *
+ * For example, we can annotate a Relax function with
+ *   `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
+ * It means the maximum value of variable that names "n" in the function
+ * signature will have upper bound 1024. And we will use 1024 as its value
+ * during memory planning.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/nested_msg.h>
 #include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
 
 #include <map>
 #include <set>
@@ -306,6 +321,23 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
   explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
 
   void VisitExpr_(const FunctionNode* func) final {
+    // Use the attribute-annotated TIR var upper bounds as the TIR var values 
for
+    // memory planning.
+    // NOTE: we only apply the annotated upper bounds to the TIR variables that
+    // appear in the **function signature**.
+    Map<String, IntImm> var_upper_bound_attr =
+        func->GetAttr<Map<String, 
IntImm>>("tir_var_upper_bound").value_or(Map<String, IntImm>());
+    Array<tir::Var> var_in_signature = 
TIRVarsInStructInfo(GetStructInfo(GetRef<Function>(func)));
+    var_upper_bound_.clear();
+    for (const tir::Var& tir_var : var_in_signature) {
+      auto it = var_upper_bound_attr.find(tir_var->name_hint);
+      if (it != var_upper_bound_attr.end()) {
+        ana_.Bind(tir_var, tvm::Range::FromMinExtent(
+                               tvm::IntImm(DataType::Int(64), 0),
+                               tvm::IntImm(DataType::Int(64), 
(*it).second->value + 1)));
+      }
+    }
+
     // Recurse into the function to get its tokens.
     Tokens body_tokens = GetTokens(func->body);
     // Discard the tokens used by the function return value, as they are 
external referenced.
@@ -401,8 +433,20 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
     ICHECK(!token_map_.count(call));
 
-    // No support for symbolic shape at this moment.
+    // Use the upper bounds of TIR vars as their values.
+    Array<PrimExpr> upper_bounded_shape;
+    upper_bounded_shape.reserve(shape->values.size());
     for (const PrimExpr& dim_len : shape->values) {
+      int64_t max_bound = ana_.const_int_bound(dim_len)->max_value;
+      if (max_bound == std::numeric_limits<int64_t>::max()) {
+        upper_bounded_shape.push_back(dim_len);
+      } else {
+        upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), 
max_bound));
+      }
+    }
+
+    // No support for TIR vars that are not bounded.
+    for (const PrimExpr& dim_len : upper_bounded_shape) {
       const auto* int_len = dim_len.as<IntImmNode>();
       if (!int_len) {
         token_map_[call] = Tokens();
@@ -411,7 +455,7 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     }
 
     // Create and set token.
-    StorageToken token(shape->values, sinfo->dtype);
+    StorageToken token(upper_bounded_shape, sinfo->dtype);
 
     Tokens tokens(token);
     SetTokens(call, tokens);
@@ -476,11 +520,15 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
     token2block_.erase(token_to_discard.get());
   }
 
+  /*! \brief The arithmetic analyzer. */
+  arith::Analyzer ana_;
   /*!
    * \brief The context IRModule, used for checking if a callee function is
    * a PrimFunc inside the IRModule.
    */
   const IRModule& ctx_mod_;
+  /*! \brief The mapping from TIR variables to their respective upper bound 
values. */
+  std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual> 
var_upper_bound_;
   /*! \brief The mapping from each token to the binding block where it is 
created. */
   std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> 
token2block_;
   /*! \brief The mapping from each token to the Exprs that are using this 
token. */
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 521fcc1924..06fdd04daa 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -836,5 +836,178 @@ def test_multiple_functions():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_tir_var_upper_bound():
+    # fmt: off
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            R.func_attr({"tir_var_upper_bound": {"n": 4}})
+            n = T.int64()
+            cls = Module
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _1: R.Tuple() = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", 
runtime_device_index=0)
+            _3: R.Tuple() = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+            _4: R.Tuple() = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 4}})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), 
R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, 
R.shape([2 * n]))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([40]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.relu(lv1, alloc1)
+            __1: R.Tuple = R.memory.kill_tensor(alloc)
+            _1_1: R.Tuple = R.memory.kill_tensor(lv1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+            _2_1: R.Tuple = R.memory.kill_tensor(alloc1)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n + 2]), 
R.dtype("float32"))
+            _3: R.Tuple = cls.pad(lv3, alloc3)
+            _3_1: R.Tuple = R.memory.kill_tensor(alloc2)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.log(lv4, alloc4)
+            _4_1: R.Tuple = R.memory.kill_tensor(alloc3)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            _5: R.Tuple = R.memory.kill_storage(storage)
+            _6: R.Tuple = R.memory.kill_storage(storage1)
+            return gv
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_tir_var_decreasing_monotone():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32")) 
-> R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32"):
+            n = T.int64()
+            m = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}})
+            cls = Module
+            alloc: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32"), 
R.prim_value(0))
+            _: R.Tuple = cls.tir_exp(x, alloc)
+            y: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc
+            alloc1: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32"), 
R.prim_value(0))
+            _1: R.Tuple = cls.tir_exp(y, alloc1)
+            z: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc1
+            alloc2: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32"), 
R.prim_value(0))
+            _2: R.Tuple = cls.tir_exp(z, alloc2)
+            r: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc2
+            return r
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32")) 
-> R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32"):
+            n = T.int64()
+            m = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([8000]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m, T.max(n - m, 
1)]), R.dtype("float32"))
+            _: R.Tuple = cls.tir_exp(x, alloc)
+            y: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc
+            storage1: R.Object = R.memory.alloc_storage(R.shape([8000]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m, T.max(n - m, 
1)]), R.dtype("float32"))
+            _1: R.Tuple = cls.tir_exp(y, alloc1)
+            __1: R.Tuple = R.memory.kill_tensor(alloc)
+            z: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc1
+            alloc2: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32"), 
R.prim_value(0))
+            _2: R.Tuple = cls.tir_exp(z, alloc2)
+            _1_1: R.Tuple = R.memory.kill_tensor(alloc1)
+            r: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = alloc2
+            _2_1: R.Tuple = R.memory.kill_storage(storage)
+            _3: R.Tuple = R.memory.kill_storage(storage1)
+            return r
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to