This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab93b31d0d [ARITH][TensorIR] Improve CompactBufferRegion for symbolic 
shape (#14596)
ab93b31d0d is described below

commit ab93b31d0ddae1c35e285404c8c7fe7c8dc35ae7
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 12 13:41:30 2023 -0400

    [ARITH][TensorIR] Improve CompactBufferRegion for symbolic shape (#14596)
    
    This PR improves compact buffer region for symbolic shape programs.
    We introduce a new function CanProveLessEqualThanSymbolicShapeValue
    which can take the extra information such that rhs is a non-negative shape.
    
    The change is necessary to enable symbolic shape schedules.
    Test cases are added to cover the case
---
 include/tvm/arith/analyzer.h                       | 19 ++++-
 src/arith/analyzer.cc                              | 21 ++++++
 src/arith/iter_affine_map.cc                       | 50 +++++++++++--
 src/tir/transforms/compact_buffer_region.cc        | 14 +++-
 .../python/unittest/test_arith_iter_affine_map.py  | 11 ++-
 .../test_tir_transform_compact_buffer_region.py    | 85 +++++++++++++++++++++-
 6 files changed, 188 insertions(+), 12 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e64426aca3..9cad4ffd69 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -72,7 +72,7 @@ enum class ProofStrength : int {
   /*!
    * \brief Prove using symbolic bound analysis
    */
-  kSymbolicBound = 1
+  kSymbolicBound = 1,
 };
 
 /*!
@@ -668,6 +668,23 @@ class TVM_DLL Analyzer {
    * \note Analyzer will call into sub-analyzers to get the result.
    */
   bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
+  /*!
+   * \brief Whether we can prove lhs is smaller than possibly symbolic shape.
+   *
+   * By calling this function, the caller gives an extra hint that shape > 0,
+   * because it appeared in buffer shape.
+   *
+   * This is useful to prove condition such as 32 <= 32 * n where the 32 * n
+   * is known to be a shape. Use this routine to reduce the symbolic 
comparisons
+   * in buffer compaction.
+   *
+   * The underlying analyzer will use the kSymbolicBound proof.
+   *
+   * \param lhs The input lhs.
+   * \param shape The symbolic shape.
+   * \return Whether we can prove lhs <= shape.
+   */
+  bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, const 
PrimExpr& shape);
   /*!
    * \brief Whether can we prove condition.
    *
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 89dcb8301a..f744bed4f4 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -25,6 +25,8 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 
+#include "product_normal_form.h"
+
 namespace tvm {
 namespace arith {
 
@@ -115,6 +117,24 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const 
PrimExpr& rhs) {
   return CanProve(lhs - rhs == 0);
 }
 
+bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, 
const PrimExpr& shape) {
+  if (this->CanProve(lhs <= shape, ProofStrength::kSymbolicBound)) return true;
+  // no need to do further attempt if shape is already a constant.
+  if (tir::is_const_int(shape)) return false;
+  // collect constant scale and ignore symbolic part
+  // so 32 * n => cscale = 32
+  int64_t cscale = 1;
+  auto fcollect = [&](const PrimExpr& expr) {
+    if (auto* ptr = expr.as<IntImmNode>()) {
+      cscale *= ptr->value;
+    }
+  };
+  UnpackReduction<tir::MulNode>(shape, fcollect);
+  PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale));
+  if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound)) 
return true;
+  return false;
+}
+
 bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
   // Avoid potentially expensive simplification unless required.
   if (const auto* ptr = expr.as<IntImmNode>()) {
@@ -155,6 +175,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength 
strength) {
       }
     }
   }
+
   return false;
 }
 
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 18fbc75286..ed2c40da72 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -766,6 +766,7 @@ class IterMapRewriter : public ExprMutator {
     // use reverse search as usually smallest is ordered on the right
     int base_index = -1;
     int64_t min_const_scale = 0;
+
     for (int i = rbegin; i >= 0; --i) {
       if (skip_flag[i]) continue;
       if (match_source.defined() && 
!match_source.same_as(expr->args[i]->source)) continue;
@@ -773,6 +774,12 @@ class IterMapRewriter : public ExprMutator {
         if (base_index == -1 || op->value < min_const_scale) {
           min_const_scale = op->value;
           base_index = static_cast<int>(i);
+        } else if (op->value == min_const_scale) {
+          // for ties, we want to look into 1 extent trivial iters
+          // prioritize trivial iterators
+          if (is_one(expr->args[i]->extent) && 
!is_one(expr->args[base_index]->extent)) {
+            base_index = static_cast<int>(i);
+          }
         }
       }
     }
@@ -794,6 +801,22 @@ class IterMapRewriter : public ExprMutator {
     return base_index;
   }
 
+  /*!
+   * \brief Find the first possible location that have extent equals 1
+   *
+   * Unit extent can be rare in simplifications and not having them can
+   * help us do early exit in scale matching.
+   *
+   * This parameter is being used in FindIterWithExactScale N times.
+   * \param expr The input expression.
+   */
+  int FindFirstPossibleUnitExtentIndex(const IterSumExpr& expr) {
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      if (is_one(expr->args[i]->extent)) return static_cast<int>(i);
+    }
+    return static_cast<int>(expr->args.size());
+  }
+
   /*!
    * \brief Helper method to find iterator with exact the expected scale.
    * \param expr The expression.
@@ -801,14 +824,16 @@ class IterMapRewriter : public ExprMutator {
    * \param match_source Must match the same source.
    * \param expected_scale The expected_scale.
    * \param rbegin The last index to start reverse searching, -1 means 
everything.
+   * \param first_possible_unit_extent_pos The last possible locatin with 
split->extent == 1
    * \return -1 if not no match found, otherwise return the index.
    */
   int FindIterWithExactScale(const IterSumExpr& expr, const std::vector<bool>& 
skip_flag,
                              const PrimExpr& expected_scale, 
Optional<IterMark> match_source,
-                             int rbegin = -1) {
+                             int rbegin = -1, int 
first_possible_unit_extent_pos = 0) {
     if (rbegin == -1) {
       rbegin = static_cast<int>(expr->args.size()) - 1;
     }
+    int matched_pos = -1;
     // use reverse search, as smallest scale usually are near the end.
     for (int j = rbegin; j >= 0; --j) {
       if (skip_flag[j]) continue;
@@ -816,10 +841,18 @@ class IterMapRewriter : public ExprMutator {
       const PrimExpr& cur_scale = expr->args[j]->scale;
       // for bijective mapping, the matched scale must equal to expected scale
       if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
-        return j;
+        if (is_one(expr->args[j]->extent)) return j;
+        // if extent is not one and there is a possible extent=1 split
+        // further out, we need to extent the search
+        // extent=1 gets higher priority since they don't change the scale
+        // if there are multiple of them, we match the first encountered
+        if (matched_pos == -1) {
+          matched_pos = j;
+        }
+        if (j <= first_possible_unit_extent_pos) return matched_pos;
       }
     }
-    return -1;
+    return matched_pos;
   }
 
   /*!
@@ -867,6 +900,7 @@ class IterMapRewriter : public ExprMutator {
     // most iter map are small n < 5
     // so we can afford N^2 complexity
     bool has_overlap = false;
+
     for (size_t i = 0; i < expr->args.size(); ++i) {
       auto it = hit_count.find(expr->args[i]->source);
       if (it != hit_count.end()) {
@@ -880,6 +914,7 @@ class IterMapRewriter : public ExprMutator {
 
     std::vector<bool> visited(expr->args.size(), false);
     std::vector<IterSplitExpr> reverse_flattened_iters;
+    int first_possible_unit_extent_pos = 
FindFirstPossibleUnitExtentIndex(expr);
 
     // Start eliminating the iterators
     for (int rend = static_cast<int>(expr->args.size()) - 1; rend >= 0;) {
@@ -925,7 +960,8 @@ class IterMapRewriter : public ExprMutator {
       while (true) {
         // NOTE: mul order [lower_factor, extent, scale]
         PrimExpr lhs_scale = MulAndNormalize(rhs_iter->extent, 
rhs_iter->scale);
-        matched_index = FindIterWithExactScale(expr, visited, lhs_scale, 
rhs_iter->source, rend);
+        matched_index = FindIterWithExactScale(expr, visited, lhs_scale, 
rhs_iter->source, rend,
+                                               first_possible_unit_extent_pos);
         if (matched_index == -1) break;
         IterSplitExpr lhs_iter = expr->args[matched_index];
         ICHECK(rhs_iter->source.same_as(lhs_iter->source));
@@ -974,14 +1010,16 @@ class IterMapRewriter : public ExprMutator {
     PrimExpr expected_extra_base = 0;
     PrimExpr tail_extent = 0;
     PrimExpr expected_scale = base_scale;
+    int first_possible_unit_extent_pos = 
FindFirstPossibleUnitExtentIndex(expr);
 
     for (size_t i = 0; i < expr->args.size();) {
       PrimExpr matched_scale{nullptr};
       bool is_exact_match{false};
       // find position such that expr->args[j] match expected scale
       // if it is first step, we can simply start with base index
-      int matched_pos =
-          i == 0 ? base_index : FindIterWithExactScale(expr, visited, 
expected_scale, NullOpt);
+      int matched_pos = i == 0 ? base_index
+                               : FindIterWithExactScale(expr, visited, 
expected_scale, NullOpt, -1,
+                                                        
first_possible_unit_extent_pos);
       if (matched_pos != -1) {
         matched_scale = expected_scale;
         is_exact_match = true;
diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index 3cfb1a4740..a047191f16 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -54,7 +54,14 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(
     const arith::IntSet& int_set = nd_int_set[i];
     Range range = int_set.CoverRange(Range(/*begin=*/0, 
/*end=*/original_shape[i]));
     PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
-    PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i], 
range->extent));
+    PrimExpr extent = range->extent;
+
+    // Apply stronger symbolic proof to help us remove symbolic min here.
+    if (!analyzer->CanProveLessEqualThanSymbolicShapeValue(range->extent, 
original_shape[i])) {
+      extent = tvm::min(original_shape[i], range->extent);
+    }
+
+    extent = analyzer->Simplify(extent);
 
     // Check the buffer region is not loop dependent, since loop dependent
     // allocation is not supported yet.
@@ -89,6 +96,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
   }
   Optional<Array<arith::IntSet>> eval_res =
       arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
+
   if (eval_res.defined()) {
     return NDIntSet(eval_res.value().begin(), eval_res.value().end());
   }
@@ -203,8 +211,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
   }
 
   void VisitStmt_(const BlockNode* op) final {
-    // Step 0. Check there is no init part.
+    // Step 0. Check there is no init part and block is opaque
     ICHECK(!op->init.defined());
+    ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on 
opaque blocks";
     // Step 1. Record and update current read/write region annotations
     std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, 
ObjectPtrEqual>
         cur_access_annotations;
@@ -281,6 +290,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       // Step 2. Relax the access region
       NDIntSet nd_int_set =
           NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, 
&dom_analyzer_);
+
       // Step 3. Restore the non-relaxed ancestor loops domain
       for (size_t i = 0; i < n_ancestor_loops; ++i) {
         const VarNode* v = ancestor_loops_[i]->loop_var.get();
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 45ec5f1e27..cbca1bb325 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from xml import dom
 import tvm
 import tvm.testing
 from tvm.tir import floormod, floordiv
@@ -1185,7 +1184,7 @@ def test_iter_map_simplify_unit_loop_order():
 
     # Even with simplifcation, it should follow the original order
     assert_iter_map_simplfy(
-        {x + y + (z // 4) * 4 + z % 4: x + y + z},
+        {x + y + (z // 4) * 4 + z % 4: z + x + y},
         var_dom([(x, 1), (y, 1), (z, 32)]),
         simplify_trivial_iterators=False,
     )
@@ -1196,6 +1195,14 @@ def test_iter_map_simplify_unit_loop_order():
         simplify_trivial_iterators=False,
     )
 
+    # When we have iterators that have same scale but one of them come
+    # with unit extent, we should prioritize unit extent
+    assert_iter_map_simplfy(
+        {x // 128 + y + z: y + x // 128 + z},
+        var_dom([(x, 128), (y, 128), (z, 1)]),
+        simplify_trivial_iterators=False,
+    )
+
 
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index 1a2a47a170..e90539f3ef 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -223,7 +223,7 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: 
T.int32) -> None:
         with T.block():
             T.reads(A[i * 8 : i * 8 + 8])
             T.writes(C[i * 8 : i * 8 + 8])
-            B = T.alloc_buffer((T.min(n, 1) * 8,), "float32")
+            B = T.alloc_buffer((8,), "float32")
             for j in range(0, 8):
                 with T.block():
                     T.reads(A[i * 8 + j])
@@ -1009,5 +1009,88 @@ def 
test_compact_dependent_buffer_indices_of_packed_matmul():
     _check(nonuniform_packed_matmul_write_cache, 
nonuniform_packed_matmul_write_cache_compacted)
 
 
+def test_compact_symbolic_bound0():
+    """Test symbolic bound that get compacted to constant"""
+
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+            for i, k_0 in T.grid(T.int64(8), n):
+                with T.block(""):
+                    X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+                    for ax0 in range(T.int64(32)):
+                        with T.block("X_global"):
+                            X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 * 
T.int64(32) + ax0]
+                    for k_1 in range(T.int64(32)):
+                        with T.block("Y"):
+                            Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 * 
T.int64(32) + k_1]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+            for i, k_0 in T.grid(T.int64(8), n):
+                with T.block(""):
+                    X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+                    for ax0 in range(T.int64(32)):
+                        with T.block("X_global"):
+                            X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32) 
+ ax0]
+                    for k_1 in range(T.int64(32)):
+                        with T.block("Y"):
+                            Y[i, k_0 * T.int64(32) + k_1] = 
X_global[T.int64(0), k_1]
+
+    mod = Before
+    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+    after = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_compact_symbolic_bound1():
+    """Test symbolic bound that get compacted to constant"""
+
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+            for i, k_0 in T.grid(T.int64(8), n):
+                with T.block(""):
+                    X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+                    with T.block("X_global"):
+                        for x0 in range(T.int64(32)):
+                            X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 * 
T.int64(32) + x0]
+                    with T.block("Y"):
+                        for x1 in range(T.int64(32)):
+                            Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 * 
T.int64(32) + x1]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+            # with T.block("root"):
+            for i, k_0 in T.grid(T.int64(8), n):
+                with T.block(""):
+                    X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+                    with T.block("X_global"):
+                        for x0 in range(T.int64(32)):
+                            X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32) 
+ x0]
+                    with T.block("Y"):
+                        for x1 in range(T.int64(32)):
+                            Y[i, k_0 * T.int64(32) + x1] = 
X_global[T.int64(0), x1]
+
+    mod = Before
+    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+    after = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to