This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab93b31d0d [ARITH][TensorIR] Improve CompactBufferRegion for symbolic
shape (#14596)
ab93b31d0d is described below
commit ab93b31d0ddae1c35e285404c8c7fe7c8dc35ae7
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 12 13:41:30 2023 -0400
[ARITH][TensorIR] Improve CompactBufferRegion for symbolic shape (#14596)
This PR improves compact buffer region for symbolic shape programs.
We introduce a new function CanProveLessEqualThanSymbolicShapeValue
which can take the extra information such that rhs is a non-negative shape.
The change is necessary to enable symbolic shape schedules.
Test cases are added to cover the case
---
include/tvm/arith/analyzer.h | 19 ++++-
src/arith/analyzer.cc | 21 ++++++
src/arith/iter_affine_map.cc | 50 +++++++++++--
src/tir/transforms/compact_buffer_region.cc | 14 +++-
.../python/unittest/test_arith_iter_affine_map.py | 11 ++-
.../test_tir_transform_compact_buffer_region.py | 85 +++++++++++++++++++++-
6 files changed, 188 insertions(+), 12 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e64426aca3..9cad4ffd69 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -72,7 +72,7 @@ enum class ProofStrength : int {
/*!
* \brief Prove using symbolic bound analysis
*/
- kSymbolicBound = 1
+ kSymbolicBound = 1,
};
/*!
@@ -668,6 +668,23 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
+ /*!
+ * \brief Whether we can prove lhs is smaller than possibly symbolic shape.
+ *
+ * By calling this function, the caller gives an extra hint that shape > 0,
+ * because it appeared in buffer shape.
+ *
+ * This is useful to prove condition such as 32 <= 32 * n where the 32 * n
+ * is known to be a shape. Use this routine to reduce the symbolic
comparisons
+ * in buffer compaction.
+ *
+ * The underlying analyzer will use the kSymbolicBound proof.
+ *
+ * \param lhs The input lhs.
+ * \param shape The symbolic shape.
+ * \return Whether we can prove lhs <= shape.
+ */
+ bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, const
PrimExpr& shape);
/*!
* \brief Whether can we prove condition.
*
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 89dcb8301a..f744bed4f4 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -25,6 +25,8 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+#include "product_normal_form.h"
+
namespace tvm {
namespace arith {
@@ -115,6 +117,24 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const
PrimExpr& rhs) {
return CanProve(lhs - rhs == 0);
}
+bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs,
const PrimExpr& shape) {
+ if (this->CanProve(lhs <= shape, ProofStrength::kSymbolicBound)) return true;
+ // no need to do further attempt if shape is already a constant.
+ if (tir::is_const_int(shape)) return false;
+ // collect constant scale and ignore symbolic part
+ // so 32 * n => cscale = 32
+ int64_t cscale = 1;
+ auto fcollect = [&](const PrimExpr& expr) {
+ if (auto* ptr = expr.as<IntImmNode>()) {
+ cscale *= ptr->value;
+ }
+ };
+ UnpackReduction<tir::MulNode>(shape, fcollect);
+ PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale));
+ if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound))
return true;
+ return false;
+}
+
bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// Avoid potentially expensive simplification unless required.
if (const auto* ptr = expr.as<IntImmNode>()) {
@@ -155,6 +175,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength
strength) {
}
}
}
+
return false;
}
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 18fbc75286..ed2c40da72 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -766,6 +766,7 @@ class IterMapRewriter : public ExprMutator {
// use reverse search as usually smallest is ordered on the right
int base_index = -1;
int64_t min_const_scale = 0;
+
for (int i = rbegin; i >= 0; --i) {
if (skip_flag[i]) continue;
if (match_source.defined() &&
!match_source.same_as(expr->args[i]->source)) continue;
@@ -773,6 +774,12 @@ class IterMapRewriter : public ExprMutator {
if (base_index == -1 || op->value < min_const_scale) {
min_const_scale = op->value;
base_index = static_cast<int>(i);
+ } else if (op->value == min_const_scale) {
+ // for ties, we want to look into 1 extent trivial iters
+ // prioritize trivial iterators
+ if (is_one(expr->args[i]->extent) &&
!is_one(expr->args[base_index]->extent)) {
+ base_index = static_cast<int>(i);
+ }
}
}
}
@@ -794,6 +801,22 @@ class IterMapRewriter : public ExprMutator {
return base_index;
}
+ /*!
+ * \brief Find the first possible location that have extent equals 1
+ *
+ * Unit extent can be rare in simplifications and not having them can
+ * help us do early exit in scale matching.
+ *
+ * This parameter is being used in FindIterWithExactScale N times.
+ * \param expr The input expression.
+ */
+ int FindFirstPossibleUnitExtentIndex(const IterSumExpr& expr) {
+ for (size_t i = 0; i < expr->args.size(); ++i) {
+ if (is_one(expr->args[i]->extent)) return static_cast<int>(i);
+ }
+ return static_cast<int>(expr->args.size());
+ }
+
/*!
* \brief Helper method to find iterator with exact the expected scale.
* \param expr The expression.
@@ -801,14 +824,16 @@ class IterMapRewriter : public ExprMutator {
* \param match_source Must match the same source.
* \param expected_scale The expected_scale.
* \param rbegin The last index to start reverse searching, -1 means
everything.
+ * \param first_possible_unit_extent_pos The last possible locatin with
split->extent == 1
* \return -1 if not no match found, otherwise return the index.
*/
int FindIterWithExactScale(const IterSumExpr& expr, const std::vector<bool>&
skip_flag,
const PrimExpr& expected_scale,
Optional<IterMark> match_source,
- int rbegin = -1) {
+ int rbegin = -1, int
first_possible_unit_extent_pos = 0) {
if (rbegin == -1) {
rbegin = static_cast<int>(expr->args.size()) - 1;
}
+ int matched_pos = -1;
// use reverse search, as smallest scale usually are near the end.
for (int j = rbegin; j >= 0; --j) {
if (skip_flag[j]) continue;
@@ -816,10 +841,18 @@ class IterMapRewriter : public ExprMutator {
const PrimExpr& cur_scale = expr->args[j]->scale;
// for bijective mapping, the matched scale must equal to expected scale
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
- return j;
+ if (is_one(expr->args[j]->extent)) return j;
+ // if extent is not one and there is a possible extent=1 split
+ // further out, we need to extent the search
+ // extent=1 gets higher priority since they don't change the scale
+ // if there are multiple of them, we match the first encountered
+ if (matched_pos == -1) {
+ matched_pos = j;
+ }
+ if (j <= first_possible_unit_extent_pos) return matched_pos;
}
}
- return -1;
+ return matched_pos;
}
/*!
@@ -867,6 +900,7 @@ class IterMapRewriter : public ExprMutator {
// most iter map are small n < 5
// so we can afford N^2 complexity
bool has_overlap = false;
+
for (size_t i = 0; i < expr->args.size(); ++i) {
auto it = hit_count.find(expr->args[i]->source);
if (it != hit_count.end()) {
@@ -880,6 +914,7 @@ class IterMapRewriter : public ExprMutator {
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> reverse_flattened_iters;
+ int first_possible_unit_extent_pos =
FindFirstPossibleUnitExtentIndex(expr);
// Start eliminating the iterators
for (int rend = static_cast<int>(expr->args.size()) - 1; rend >= 0;) {
@@ -925,7 +960,8 @@ class IterMapRewriter : public ExprMutator {
while (true) {
// NOTE: mul order [lower_factor, extent, scale]
PrimExpr lhs_scale = MulAndNormalize(rhs_iter->extent,
rhs_iter->scale);
- matched_index = FindIterWithExactScale(expr, visited, lhs_scale,
rhs_iter->source, rend);
+ matched_index = FindIterWithExactScale(expr, visited, lhs_scale,
rhs_iter->source, rend,
+ first_possible_unit_extent_pos);
if (matched_index == -1) break;
IterSplitExpr lhs_iter = expr->args[matched_index];
ICHECK(rhs_iter->source.same_as(lhs_iter->source));
@@ -974,14 +1010,16 @@ class IterMapRewriter : public ExprMutator {
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale;
+ int first_possible_unit_extent_pos =
FindFirstPossibleUnitExtentIndex(expr);
for (size_t i = 0; i < expr->args.size();) {
PrimExpr matched_scale{nullptr};
bool is_exact_match{false};
// find position such that expr->args[j] match expected scale
// if it is first step, we can simply start with base index
- int matched_pos =
- i == 0 ? base_index : FindIterWithExactScale(expr, visited,
expected_scale, NullOpt);
+ int matched_pos = i == 0 ? base_index
+ : FindIterWithExactScale(expr, visited,
expected_scale, NullOpt, -1,
+
first_possible_unit_extent_pos);
if (matched_pos != -1) {
matched_scale = expected_scale;
is_exact_match = true;
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index 3cfb1a4740..a047191f16 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -54,7 +54,14 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(
const arith::IntSet& int_set = nd_int_set[i];
Range range = int_set.CoverRange(Range(/*begin=*/0,
/*end=*/original_shape[i]));
PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
- PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i],
range->extent));
+ PrimExpr extent = range->extent;
+
+ // Apply stronger symbolic proof to help us remove symbolic min here.
+ if (!analyzer->CanProveLessEqualThanSymbolicShapeValue(range->extent,
original_shape[i])) {
+ extent = tvm::min(original_shape[i], range->extent);
+ }
+
+ extent = analyzer->Simplify(extent);
// Check the buffer region is not loop dependent, since loop dependent
// allocation is not supported yet.
@@ -89,6 +96,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
}
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
+
if (eval_res.defined()) {
return NDIntSet(eval_res.value().begin(), eval_res.value().end());
}
@@ -203,8 +211,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
void VisitStmt_(const BlockNode* op) final {
- // Step 0. Check there is no init part.
+ // Step 0. Check there is no init part and block is opaque
ICHECK(!op->init.defined());
+ ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on
opaque blocks";
// Step 1. Record and update current read/write region annotations
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash,
ObjectPtrEqual>
cur_access_annotations;
@@ -281,6 +290,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
// Step 2. Relax the access region
NDIntSet nd_int_set =
NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_,
&dom_analyzer_);
+
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 45ec5f1e27..cbca1bb325 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from xml import dom
import tvm
import tvm.testing
from tvm.tir import floormod, floordiv
@@ -1185,7 +1184,7 @@ def test_iter_map_simplify_unit_loop_order():
# Even with simplifcation, it should follow the original order
assert_iter_map_simplfy(
- {x + y + (z // 4) * 4 + z % 4: x + y + z},
+ {x + y + (z // 4) * 4 + z % 4: z + x + y},
var_dom([(x, 1), (y, 1), (z, 32)]),
simplify_trivial_iterators=False,
)
@@ -1196,6 +1195,14 @@ def test_iter_map_simplify_unit_loop_order():
simplify_trivial_iterators=False,
)
+ # When we have iterators that have same scale but one of them come
+ # with unit extent, we should prioritize unit extent
+ assert_iter_map_simplfy(
+ {x // 128 + y + z: y + x // 128 + z},
+ var_dom([(x, 128), (y, 128), (z, 1)]),
+ simplify_trivial_iterators=False,
+ )
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index 1a2a47a170..e90539f3ef 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -223,7 +223,7 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n:
T.int32) -> None:
with T.block():
T.reads(A[i * 8 : i * 8 + 8])
T.writes(C[i * 8 : i * 8 + 8])
- B = T.alloc_buffer((T.min(n, 1) * 8,), "float32")
+ B = T.alloc_buffer((8,), "float32")
for j in range(0, 8):
with T.block():
T.reads(A[i * 8 + j])
@@ -1009,5 +1009,88 @@ def
test_compact_dependent_buffer_indices_of_packed_matmul():
_check(nonuniform_packed_matmul_write_cache,
nonuniform_packed_matmul_write_cache_compacted)
+def test_compact_symbolic_bound0():
+ """Test symbolic bound that get compacted to constant"""
+
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+ for ax0 in range(T.int64(32)):
+ with T.block("X_global"):
+ X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 *
T.int64(32) + ax0]
+ for k_1 in range(T.int64(32)):
+ with T.block("Y"):
+ Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 *
T.int64(32) + k_1]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+ for ax0 in range(T.int64(32)):
+ with T.block("X_global"):
+ X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32)
+ ax0]
+ for k_1 in range(T.int64(32)):
+ with T.block("Y"):
+ Y[i, k_0 * T.int64(32) + k_1] =
X_global[T.int64(0), k_1]
+
+ mod = Before
+ mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+ after = tvm.tir.transform.Simplify()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_compact_symbolic_bound1():
+ """Test symbolic bound that get compacted to constant"""
+
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+ with T.block("X_global"):
+ for x0 in range(T.int64(32)):
+ X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 *
T.int64(32) + x0]
+ with T.block("Y"):
+ for x1 in range(T.int64(32)):
+ Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 *
T.int64(32) + x1]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ # with T.block("root"):
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+ with T.block("X_global"):
+ for x0 in range(T.int64(32)):
+ X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32)
+ x0]
+ with T.block("Y"):
+ for x1 in range(T.int64(32)):
+ Y[i, k_0 * T.int64(32) + x1] =
X_global[T.int64(0), x1]
+
+ mod = Before
+ mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+ after = tvm.tir.transform.Simplify()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()