This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b9aa3564dc [TIR] Revert #11428 and move loop dependent alloc extent
check after region union (#12019)
b9aa3564dc is described below
commit b9aa3564dcde74a4d8aa7b70cd6f09ed476cb67c
Author: wrongtest <[email protected]>
AuthorDate: Fri Jul 8 00:29:58 2022 +0800
[TIR] Revert #11428 and move loop dependent alloc extent check after region
union (#12019)
---
src/tir/transforms/compact_buffer_region.cc | 58 +++++++++++----------
.../test_tir_transform_compact_buffer_region.py | 60 ++++++++++++++++++++++
2 files changed, 90 insertions(+), 28 deletions(-)
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index 46f64d4edf..2844f1b35e 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -45,17 +45,36 @@ using support::NDIntSet;
* \brief simplify and return the region collected by NDIntSet. return the
original
* buffer shape if the int_set is empty.
*/
-Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
- const Array<PrimExpr>&
original_shape,
- arith::Analyzer* analyzer) {
+Region SimplifyAndNarrowBufferRegionFromNDIntSet(
+ const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape,
arith::Analyzer* analyzer,
+ const std::vector<const ForNode*>& ancestor_loops) {
Array<Range> result;
result.reserve(nd_int_set.size());
for (size_t i = 0; i < nd_int_set.size(); ++i) {
const arith::IntSet& int_set = nd_int_set[i];
Range range = int_set.CoverRange(Range(/*begin=*/0,
/*end=*/original_shape[i]));
- result.push_back(
- Range::FromMinExtent(analyzer->Simplify(max(0, range->min)),
- analyzer->Simplify(min(original_shape[i],
range->extent))));
+ PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
+ PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i],
range->extent));
+
+ // Check the buffer region is not loop dependent, since loop dependent
+ // allocation is not supported yet.
+ auto is_loop_var = [&ancestor_loops](const VarNode* v) {
+ return std::any_of(ancestor_loops.begin(), ancestor_loops.end(),
+ [v](const ForNode* n) { return n->loop_var.get() ==
v; });
+ };
+ if (UsesVar(extent, is_loop_var)) {
+ // try estimate a constant upperbound on region's extent
+ int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
+ if (upperbound != arith::ConstIntBound::kPosInf) {
+ extent = make_const(extent->dtype, upperbound);
+ } else {
+ // or else we have to fallback to full region
+ min = make_zero(original_shape[i]->dtype);
+ extent = original_shape[i];
+ }
+ }
+
+ result.push_back(Range::FromMinExtent(min, extent));
}
return result;
}
@@ -63,7 +82,6 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const
NDIntSet& nd_int_set,
/*! \brief a more constrained bound estimate for n-dimentional int set */
NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
const std::unordered_map<const VarNode*, arith::IntSet>&
dom_map,
- const std::vector<const VarNode*>& ancestor_loop_vars,
arith::Analyzer* analyzer) {
std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom;
for (const auto& it : dom_map) {
@@ -72,21 +90,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
- NDIntSet res(0);
- for (const auto& it : eval_res.value()) {
- PrimExpr extent = analyzer->Simplify(it.max() - it.min() + 1);
- // skip accurate region analysis result if there are outer loop
dependencies.
- if (UsesVar(extent, [&ancestor_loop_vars](const VarNode* v) {
- return std::find(ancestor_loop_vars.begin(),
ancestor_loop_vars.end(), v) !=
- ancestor_loop_vars.end();
- })) {
- break;
- }
- res.push_back(it);
- }
- if (res.size() == region.size()) {
- return res;
- }
+ return NDIntSet(eval_res.value().begin(), eval_res.value().end());
}
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}
@@ -247,8 +251,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
ICHECK(it != relaxed_accesses_.end())
<< buffer << " is allocated but not accessed within block scope";
const NDIntSet& nd_int_set = it->second;
- buffer_access_region_[buffer] =
- SimplifyAndNarrowBufferRegionFromNDIntSet(nd_int_set, buffer->shape,
&dom_analyzer_);
+ buffer_access_region_[buffer] =
SimplifyAndNarrowBufferRegionFromNDIntSet(
+ nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
}
}
@@ -270,7 +274,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
- std::vector<const VarNode*> ancestor_loop_vars(n_ancestor_loops);
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const ForNode* loop = ancestor_loops_[i];
const VarNode* v = loop->loop_var.get();
@@ -281,12 +284,11 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
ICHECK(dom_it != dom_map_.end())
<< "Could not find domain for loop variable " << v->name_hint;
non_relaxed[i] = dom_it->second;
- ancestor_loop_vars[i] = v;
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
- NDIntSet nd_int_set = NDIntSetEval(buffer_region->region,
predicate_in_scope, dom_map_,
- ancestor_loop_vars, &dom_analyzer_);
+ NDIntSet nd_int_set =
+ NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_,
&dom_analyzer_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index af206ef186..5d8b99e7d0 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -849,5 +849,65 @@ def test_compact_spatial_tiled_pad_and_pooling():
_check(spatial_tiled_pad_and_pooling,
compacted_spatial_tiled_pad_and_pooling)
+def test_complex_case_1():
+ """Meta-schedule matmul case for compact shared A, B matrix"""
+
+ # fmt: off
+ @T.prim_func
+ def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304),
"float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
+ for bx in T.thread_binding(144, thread="blockIdx.x"):
+ for vx in T.thread_binding(2, thread="vthread.x"):
+ for tx_p in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block():
+ for k_0 in T.serial(193):
+ with T.block():
+ A_shared = T.alloc_buffer([960, 770],
dtype="float32", scope="shared")
+ B_shared = T.alloc_buffer([770, 2304],
dtype="float32", scope="shared")
+ for _u in T.serial(1):
+ for tx in T.thread_binding(256,
thread="threadIdx.x"):
+ for vec in T.vectorized(3):
+ with T.block("A_shared"):
+ T.where(bx // 18 * 128 + ((_u
* 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4
< 770 and (_u * 256 + tx) * 3 + vec < 512)
+ A_shared[bx // 18 * 128 + (_u
* 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx //
18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec)
% 4]
+ for _u in T.serial(1):
+ for tx in T.thread_binding(256,
thread="threadIdx.x"):
+ for vec in T.vectorized(4):
+ with T.block("B_shared"):
+ T.where(k_0 * 4 + ((_u * 256 +
tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512)
+ B_shared[k_0 * 4 + (_u * 1024
+ tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] =
B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx
* 4 + vec) % 128]
+ for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1,
8, 1, 4, 2, 2):
+ with T.block("update_update"):
+ C[(((bx // 18 + 0) * 8 + tx_p // 32) *
8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
= C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx
% 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p
// 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4
+ k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
+
+ @T.prim_func
+ def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770,
2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
+ for bx in T.thread_binding(144, thread="blockIdx.x"):
+ for vx in T.thread_binding(2, thread="vthread.x"):
+ for tx_p in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block():
+ for k_0 in T.serial(193):
+ with T.block():
+ A_shared = T.alloc_buffer([128, 4],
dtype="float32", scope="shared")
+ B_shared = T.alloc_buffer([4, 128],
dtype="float32", scope="shared")
+ for v_u in T.serial(1):
+ for tx in T.thread_binding(256,
thread="threadIdx.x"):
+ for vec in T.vectorized(3):
+ with T.block("A_shared"):
+ T.where(bx // 18 * 128 + (tx *
3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec <
512)
+ A_shared[(tx * 3 + vec) // 4,
(tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3
+ vec) % 4]
+ for v_u in T.serial(1):
+ for tx in T.thread_binding(256,
thread="threadIdx.x"):
+ for vec in T.vectorized(4):
+ with T.block("B_shared"):
+ T.where(k_0 * 4 + tx // 32 <
770 and tx * 4 + vec < 512)
+ B_shared[tx // 32, tx % 32 * 4
+ vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec]
+ for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1,
8, 1, 4, 2, 2):
+ with T.block("update_update"):
+ C[bx // 18 * 128 + tx_p // 32 * 16 +
i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 *
128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2
+ j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64
+ tx_p % 32 * 2 + j_4]
+ # fmt: on
+
+ _check(func, compacted_func)
+
+
if __name__ == "__main__":
tvm.testing.main()