spectrometerHBH commented on a change in pull request #9699:
URL: https://github.com/apache/tvm/pull/9699#discussion_r772807552
##########
File path: tests/python/unittest/test_arith_intset.py
##########
@@ -213,6 +214,56 @@ def test_region_lower_bound_multiple_variables():
assert k_int_set.max_value.value == 31
+def test_region_lower_bound_for_non_perfect_tile():
+ h1 = tvm.tir.Var("h1", "int32")
+ h2 = tvm.tir.Var("h2", "int32")
+ h3 = tvm.tir.Var("h3", "int32")
+ # h1, h2 are bounded, h3 is free
+ var_dom = {
+ h2: tvm.ir.Range(begin=0, end=2),
+ h1: tvm.ir.Range(begin=0, end=5),
+ }
+ analyzer = tvm.arith.Analyzer()
+
+ def do_test_point_access(point, predicates, expect):
+ regions = tvm.arith.estimate_region_lower_bound(
+ region=[
+ tvm.ir.Range.from_min_extent(min_value=point, extent=1),
+ ],
+ var_dom=var_dom,
+ predicate=tvm.tir.all(*predicates),
+ )
+ if expect is None: # expect a failure
+ assert regions is None
+ else:
+ assert len(regions) == 1
+ assert structural_equal(
+ analyzer.simplify(expect[0], 3),
analyzer.simplify(regions[0].min_value, 3)
+ )
+ assert structural_equal(
+ analyzer.simplify(expect[1], 3),
analyzer.simplify(regions[0].max_value, 3)
+ )
+
+ # normal case of a non-uniform tiling
+ # h3 == 0: region is [1, 9]
+ # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
+ # h3 > 26: region is [h3 * 8, 223]
+ do_test_point_access(
Review comment:
Would be great if we can add 3 more test cases for each case of h3 :)
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -347,11 +372,24 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k,
scale=1)),
// extent=9)
// scale=1))
- std::unordered_map<IterSumExpr, IterMark, IterSumHash, IterSumEqual>
sum_fuse_map_;
+ // Example(2): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+ // predicate: 1 <= j*2 + k < 9
+ // Then, flattened form = IterSum(IterSplit(i, scale=9),
+ // IterSplit(j, scale=2),
+ // IterSplit(k, scale=1))
+ // normal form = IterSum(IterSplit(i, scale=9),
Review comment:
I don't think it's a good idea to rewrite `i*9 + j*2 + k` into such a
normal form `(i*9 + (j*2 + k - 1) + 1)`.
The reason is that `j*2 + k - 1` has extent 8, it's weird that the scale of
i is 9.
Actually `i*9 + j*2 + k` with predicate `1 <= j*2 + k < 9` doesn't
correspond to an iter, since the value of it is not continuous. It's value is
`[1, 8] \union [10, 17], ...`.
So it looks to me that we don't allow lower bound constraints on
intermediate iters. We can only bound the final iter. In the above case, to
bound `i*9 + j*2 + k >= 1` is legal. while `j*2 + k >= 1` is not acceptable.
##########
File path: tests/python/unittest/test_arith_iter_affine_map.py
##########
@@ -199,8 +199,58 @@ def test_predicate():
x = tvm.tir.Var("x", "int32"), 13
y = tvm.tir.Var("y", "int32"), 10
+ # available contraints
Review comment:
Would be great if we add a complex test case where includes combinations
of split/fuse and predicates(min <= iter < max). There are some interesting
boundary test cases, such as
1. limiting the lower bound of some iter to be split, e.g. floordiv(i, 18),
where i >= 2.
2. Try to fuse several iters, each with a lower bound constraint.
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -407,62 +445,105 @@ class IterMapRewriter : public ExprMutator {
}
if (j == splits.size()) {
// we do not allow incomplete split if the bindings should be bijective
- if (require_bijective) return Array<IterSplitExpr>();
+ if (require_bijective) {
+ diag_ctx_.Emit(
+ Diagnostic::Error(mark->source->span)
+ << "Do not allow incomplete split in bijective checking,
expected_lower_factor="
+ << expected_lower_factor);
+ return Array<IterSplitExpr>();
+ }
// look for the next split skipping this lower factor
// For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
// It is valid to only have [y / 6, y % 2] if bijective is not required
// We can skip (y / 2) % 6
j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
// split not found
- if (j == splits.size()) return Array<IterSplitExpr>();
+ if (j == splits.size()) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "Fail to find split skipping the lower factor in
bijective-free "
+ "checking, expected_lower_factor="
+ << expected_lower_factor);
+ return Array<IterSplitExpr>();
+ }
}
used[j] = true;
iters.push_back(splits[j]);
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}
+
// Case 1. bijective is required.
// We check the extent we calculate is consistent with the extent
of the mark
// Case 2. bijective is not required.
// We check the extent we calculate is a factor of the extent of
the mark
// For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y
\in [0, 25) is not.
if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor,
mark->extent)) ||
(!require_bijective && !CanProveDivisible(mark->extent,
expected_lower_factor))) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "Mark extent of " << mark
+ << " is not compatible with expected_lower_factor=" <<
expected_lower_factor);
return Array<IterSplitExpr>();
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
}
/*!
- * \brief Normalize the left hand side of iter constraint(expr <
predicate_induced_extent)
- * \param expr The left hand side of iter constraint.
- * \param predicate_induced_extent Extent from iter constraint.
+ * \brief Normalize the iter expression with constraint (min <= expr < max)
+ * \param expr The iter expression.
+ * \param predicate_induced_min Closed lower bound from iter constraint,
maybe undefined.
+ * \param predicate_induced_max Open upper bound from iter constraint, maybe
undefined.
* \return The Normalized expression.
*/
- IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr,
- const PrimExpr&
predicate_induced_extent) {
- // We are normalizing the left hand side of iter constraint(iter <
predicate_induced_extent)
- Optional<IterSplitExpr> opt = TryFuseIters(expr);
+ IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr
predicate_induced_min,
+ PrimExpr predicate_induced_max) {
+ // remove base temporarily since `TryFuseIters` require zero base iter sum
+ PrimExpr base = expr->base;
+ if (!is_zero(base)) {
+ expr.CopyOnWrite()->base = 0;
+ if (predicate_induced_min.defined()) predicate_induced_min =
predicate_induced_min - base;
Review comment:
Looks like you need type `Optional<PrimExpr>` for predicate_induce_min
and predcate_induced_max?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]