This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b9aa3564dc [TIR] Revert #11428 and move loop dependent alloc extent 
check after region union (#12019)
b9aa3564dc is described below

commit b9aa3564dcde74a4d8aa7b70cd6f09ed476cb67c
Author: wrongtest <[email protected]>
AuthorDate: Fri Jul 8 00:29:58 2022 +0800

    [TIR] Revert #11428 and move loop dependent alloc extent check after region 
union (#12019)
---
 src/tir/transforms/compact_buffer_region.cc        | 58 +++++++++++----------
 .../test_tir_transform_compact_buffer_region.py    | 60 ++++++++++++++++++++++
 2 files changed, 90 insertions(+), 28 deletions(-)

diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index 46f64d4edf..2844f1b35e 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -45,17 +45,36 @@ using support::NDIntSet;
  * \brief simplify and return the region collected by NDIntSet. return the 
original
  * buffer shape if the int_set is empty.
  */
-Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
-                                                 const Array<PrimExpr>& 
original_shape,
-                                                 arith::Analyzer* analyzer) {
+Region SimplifyAndNarrowBufferRegionFromNDIntSet(
+    const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape, 
arith::Analyzer* analyzer,
+    const std::vector<const ForNode*>& ancestor_loops) {
   Array<Range> result;
   result.reserve(nd_int_set.size());
   for (size_t i = 0; i < nd_int_set.size(); ++i) {
     const arith::IntSet& int_set = nd_int_set[i];
     Range range = int_set.CoverRange(Range(/*begin=*/0, 
/*end=*/original_shape[i]));
-    result.push_back(
-        Range::FromMinExtent(analyzer->Simplify(max(0, range->min)),
-                             analyzer->Simplify(min(original_shape[i], 
range->extent))));
+    PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
+    PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i], 
range->extent));
+
+    // Check the buffer region is not loop dependent, since loop dependent
+    // allocation is not supported yet.
+    auto is_loop_var = [&ancestor_loops](const VarNode* v) {
+      return std::any_of(ancestor_loops.begin(), ancestor_loops.end(),
+                         [v](const ForNode* n) { return n->loop_var.get() == 
v; });
+    };
+    if (UsesVar(extent, is_loop_var)) {
+      // try estimate a constant upperbound on region's extent
+      int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
+      if (upperbound != arith::ConstIntBound::kPosInf) {
+        extent = make_const(extent->dtype, upperbound);
+      } else {
+        // or else we have to fallback to full region
+        min = make_zero(original_shape[i]->dtype);
+        extent = original_shape[i];
+      }
+    }
+
+    result.push_back(Range::FromMinExtent(min, extent));
   }
   return result;
 }
@@ -63,7 +82,6 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const 
NDIntSet& nd_int_set,
 /*! \brief a more constrained bound estimate for n-dimentional int set */
 NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
                       const std::unordered_map<const VarNode*, arith::IntSet>& 
dom_map,
-                      const std::vector<const VarNode*>& ancestor_loop_vars,
                       arith::Analyzer* analyzer) {
   std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom;
   for (const auto& it : dom_map) {
@@ -72,21 +90,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
   Optional<Array<arith::IntSet>> eval_res =
       arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
   if (eval_res.defined()) {
-    NDIntSet res(0);
-    for (const auto& it : eval_res.value()) {
-      PrimExpr extent = analyzer->Simplify(it.max() - it.min() + 1);
-      // skip accurate region analysis result if there are outer loop 
dependencies.
-      if (UsesVar(extent, [&ancestor_loop_vars](const VarNode* v) {
-            return std::find(ancestor_loop_vars.begin(), 
ancestor_loop_vars.end(), v) !=
-                   ancestor_loop_vars.end();
-          })) {
-        break;
-      }
-      res.push_back(it);
-    }
-    if (res.size() == region.size()) {
-      return res;
-    }
+    return NDIntSet(eval_res.value().begin(), eval_res.value().end());
   }
   return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
 }
@@ -247,8 +251,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       ICHECK(it != relaxed_accesses_.end())
           << buffer << " is allocated but not accessed within block scope";
       const NDIntSet& nd_int_set = it->second;
-      buffer_access_region_[buffer] =
-          SimplifyAndNarrowBufferRegionFromNDIntSet(nd_int_set, buffer->shape, 
&dom_analyzer_);
+      buffer_access_region_[buffer] = 
SimplifyAndNarrowBufferRegionFromNDIntSet(
+          nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
     }
   }
 
@@ -270,7 +274,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       // Step 1. Stop ancestor loop vars out of the allocation block from
       // being relaxed unless NeedRelaxThread() is true.
       std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
-      std::vector<const VarNode*> ancestor_loop_vars(n_ancestor_loops);
       for (size_t i = 0; i < n_ancestor_loops; ++i) {
         const ForNode* loop = ancestor_loops_[i];
         const VarNode* v = loop->loop_var.get();
@@ -281,12 +284,11 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
         ICHECK(dom_it != dom_map_.end())
             << "Could not find domain for loop variable " << v->name_hint;
         non_relaxed[i] = dom_it->second;
-        ancestor_loop_vars[i] = v;
         dom_map_.erase(dom_it);
       }
       // Step 2. Relax the access region
-      NDIntSet nd_int_set = NDIntSetEval(buffer_region->region, 
predicate_in_scope, dom_map_,
-                                         ancestor_loop_vars, &dom_analyzer_);
+      NDIntSet nd_int_set =
+          NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, 
&dom_analyzer_);
       // Step 3. Restore the non-relaxed ancestor loops domain
       for (size_t i = 0; i < n_ancestor_loops; ++i) {
         const VarNode* v = ancestor_loops_[i]->loop_var.get();
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index af206ef186..5d8b99e7d0 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -849,5 +849,65 @@ def test_compact_spatial_tiled_pad_and_pooling():
     _check(spatial_tiled_pad_and_pooling, 
compacted_spatial_tiled_pad_and_pooling)
 
 
+def test_complex_case_1():
+    """Meta-schedule matmul case for compact shared A, B matrix"""
+
+    # fmt: off
+    @T.prim_func
+    def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), 
"float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
+        for bx in T.thread_binding(144, thread="blockIdx.x"):
+            for vx in T.thread_binding(2, thread="vthread.x"):
+                for tx_p in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block():
+                        for k_0 in T.serial(193):
+                            with T.block():
+                                A_shared = T.alloc_buffer([960, 770], 
dtype="float32", scope="shared")
+                                B_shared = T.alloc_buffer([770, 2304], 
dtype="float32", scope="shared")
+                                for _u in T.serial(1):
+                                    for tx in T.thread_binding(256, 
thread="threadIdx.x"):
+                                        for vec in T.vectorized(3):
+                                            with T.block("A_shared"):
+                                                T.where(bx // 18 * 128 + ((_u 
* 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 
< 770 and (_u * 256 + tx) * 3 + vec < 512)
+                                                A_shared[bx // 18 * 128 + (_u 
* 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 
18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) 
% 4]
+                                for _u in T.serial(1):
+                                    for tx in T.thread_binding(256, 
thread="threadIdx.x"):
+                                        for vec in T.vectorized(4):
+                                            with T.block("B_shared"):
+                                                T.where(k_0 * 4 + ((_u * 256 + 
tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512)
+                                                B_shared[k_0 * 4 + (_u * 1024 
+ tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = 
B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx 
* 4 + vec) % 128]
+                                for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 
8, 1, 4, 2, 2):
+                                    with T.block("update_update"):
+                                        C[(((bx // 18 + 0) * 8 + tx_p // 32) * 
8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] 
= C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx 
% 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p 
// 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 
+ k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
+    
+    @T.prim_func
+    def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 
2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
+        for bx in T.thread_binding(144, thread="blockIdx.x"):
+            for vx in T.thread_binding(2, thread="vthread.x"):
+                for tx_p in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block():
+                        for k_0 in T.serial(193):
+                            with T.block():
+                                A_shared = T.alloc_buffer([128, 4], 
dtype="float32", scope="shared")
+                                B_shared = T.alloc_buffer([4, 128], 
dtype="float32", scope="shared")
+                                for v_u in T.serial(1):
+                                    for tx in T.thread_binding(256, 
thread="threadIdx.x"):
+                                        for vec in T.vectorized(3):
+                                            with T.block("A_shared"):
+                                                T.where(bx // 18 * 128 + (tx * 
3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 
512)
+                                                A_shared[(tx * 3 + vec) // 4, 
(tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 
+ vec) % 4]
+                                for v_u in T.serial(1):
+                                    for tx in T.thread_binding(256, 
thread="threadIdx.x"):
+                                        for vec in T.vectorized(4):
+                                            with T.block("B_shared"):
+                                                T.where(k_0 * 4 + tx // 32 < 
770 and tx * 4 + vec < 512)
+                                                B_shared[tx // 32, tx % 32 * 4 
+ vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec]
+                                for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 
8, 1, 4, 2, 2):
+                                    with T.block("update_update"):
+                                        C[bx // 18 * 128 + tx_p // 32 * 16 + 
i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 
128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 
+ j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 
+ tx_p % 32 * 2 + j_4]
+    # fmt: on
+
+    _check(func, compacted_func)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to