Lunderberg commented on code in PR #11287:
URL: https://github.com/apache/tvm/pull/11287#discussion_r879804939
##########
include/tvm/arith/iter_affine_map.h:
##########
@@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};
+/*! \brief Mapping level for iterators. */
+enum IterMapLevel {
+ // Require the mapping to be bijective.
+ Bijective = 0,
+ // Require the mapping to be subjective.
Review Comment:
Nit: subjective -> surjective
##########
src/arith/iter_affine_map.cc:
##########
@@ -1260,140 +1300,132 @@ IterSumExpr
IterMapRewriter::PreprocessDividend(IterMapExpr dividend) {
}
}
+PrimExpr NearLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b,
Analyzer* analyzer) {
+ auto fsplit = [](const PrimExpr& e) -> std::pair<PrimExpr, int64_t> {
+ if (const IntImmNode* imm = e.as<IntImmNode>()) {
+ return {1, imm->value};
+ }
+ PVar<PrimExpr> pv;
+ PVar<IntImm> pc;
+ if ((pv * pc).Match(e) || (pc * pv).Match(e)) {
+ return {pv.Eval(), pc.Eval()->value};
+ } else {
+ return {e, 1};
+ }
+ };
+
+ auto p1 = fsplit(a);
+ auto p2 = fsplit(b);
+ auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
+ if (analyzer->CanProveEqual(p1.first, p2.first)) {
+ return p1.first * const_lcm;
+ } else {
+ return (p1.first * p2.first) * const_lcm;
+ }
+}
+
std::pair<IterSplitExpr, PrimExpr>
IterMapRewriter::PadDividendToDivisor(IterSplitExpr split,
PrimExpr base,
PrimExpr divisor) {
// If FloorDiv: (((source//lower_factor) % extent) + base) // divisor
// If FloorMod: (((source//lower_factor) % extent) + base) % divisor
- PrimExpr lookup_key = split;
-
- auto modified_divisor = [&]() {
- if (update_iterator_padding_) {
- return divisor;
- }
-
- auto it = padded_iter_map_.find(lookup_key);
- if (it == padded_iter_map_.end()) {
- return divisor;
- }
-
- const std::vector<PrimExpr>& divisors = it->second.divisors;
- PrimExpr largest_divisor = divisor;
- for (const auto& other : divisors) {
- if (CanProveDivisible(other, largest_divisor)) {
- // New one is bigger, use it
- largest_divisor = other;
- } else if (CanProveDivisible(largest_divisor, other)) {
- // Current is bigger, keep it
- } else {
- ErrorLogger(this) << "Iterator appears in multiple terms with
incompatible divisors "
- << tvm::PrettyPrint(largest_divisor) << " and "
- << tvm::PrettyPrint(other);
- }
- }
- return largest_divisor;
- }();
-
- divisor = modified_divisor;
-
+ // Update current iteration split's padding.
// First, adding any padding that is on the lower side of a
// FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0
// when iter==0.
-
- PrimExpr left_pad;
-
- if (is_zero(base)) {
- // Padding on the left is unnecessary if base is known to be zero.
- left_pad = make_zero(base->dtype);
- } else {
- left_pad = analyzer_->Simplify(floormod(base, divisor));
- }
+ PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor));
// Next, adding any padding that is on the upper side of a
// FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad,
divisor) == 0
// when iter==extent.
-
PrimExpr right_edge = left_pad + split->extent;
PrimExpr right_pad;
-
if (CanProveDivisible(right_edge, divisor)) {
- // Padding on the right is unnecessary if the extent is a multiple of
- // the divisor.
right_pad = 0;
} else {
- right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
- }
-
- if (is_zero(left_pad) && is_zero(right_pad)) {
- return {split, left_pad};
+ right_pad = analyzer_->Simplify(floormod(-right_edge, divisor), 9);
}
if (update_iterator_padding_) {
+ IterMark mark = split->source;
+ auto& info = padded_iter_map_[mark];
+ info.padding_factor =
+ NearLeastCommonMultiple(info.padding_factor, divisor *
split->lower_factor, analyzer_);
+
+ if (is_zero(left_pad) && is_zero(right_pad)) {
+ return {split, 0};
+ }
+
// In the first pass, the primary goal is to collect all the divisors
// that may be used for padding. These will impact the divisor used
// to determine padding in the second pass.
- IterPaddingInfo& info = padded_iter_map_[lookup_key];
-
- info.divisors.push_back(divisor);
+ PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent +
right_pad);
- PrimExpr padded_extent = left_pad + split->extent + right_pad;
-
- IterSumExpr as_sum({split}, left_pad);
- IterMark mark(as_sum, padded_extent);
- IterSplitExpr new_split(mark);
-
- return {new_split, left_pad};
+ PrimExpr mark_left_pad = left_pad * split->lower_factor;
+ if (!is_zero(left_pad)) {
+ if (info.left_pad.defined()) {
+ info.left_pad = max(info.left_pad, mark_left_pad);
+ } else {
+ info.left_pad = mark_left_pad;
+ }
+ }
+ split.CopyOnWrite()->extent = padded_extent;
+ return {split, left_pad};
}
- // Any padding that is required during parsing should have been found
- // during the first pass that determines the GCD.
- auto it = padded_iter_map_.find(lookup_key);
+ // In the second pass, update iteration mark's to padded
+ const IterMark& mark = split->source;
+ auto it = padded_iter_map_.find(mark);
if (it == padded_iter_map_.end()) {
- ErrorLogger(this) << "Dividend has extent " <<
tvm::PrettyPrint(split->extent) << " and offset "
- << tvm::PrettyPrint(base) << ", which requires padding
for divisor "
- << tvm::PrettyPrint(divisor) << ".";
- return {IterSplitExpr(), left_pad};
+ return {split, left_pad};
}
- IterPaddingInfo& info = it->second;
-
- if (info.padded.defined()) {
- // A previous visit already applied padding to this iterator.
- // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`).
- ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad));
- ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad));
-
- return {info.padded, left_pad};
+ auto& info = it->second;
+ if (is_zero(info.left_pad.defined() ? info.left_pad : 0) &&
+ CanProveDivisible(mark->extent, info.padding_factor)) {
+ return {split, left_pad};
}
- // This is the first encounter with the iterator during the second pass.
- IterSumExpr as_sum({split}, left_pad);
- IterMark mark(as_sum, left_pad + split->extent + right_pad);
- info.padded = IterSplitExpr(mark);
- info.left_pad = left_pad;
- info.right_pad = right_pad;
-
- auto left_padding_introduced = (left_pad != 0);
- // Equivalent to (0 <= split < left_pad), but easier to simplify in
- // terms of the transformed variables.
- auto left_padding_predicate =
- left_padding_introduced && (floordiv(info.padded, divisor) ==
floordiv(base, divisor) &&
- floormod(info.padded, divisor) < left_pad);
-
- PrimExpr nparts = ceildiv(right_edge, divisor);
-
- auto right_padding_introduced = (right_pad != 0);
-
- // Equivalent to (right_edge <= split < right_edge+right_pad), but
- // easier to simplify in terms of the transformed variables.
- auto right_padding_predicate = right_padding_introduced &&
- (floordiv(info.padded, divisor) ==
floordiv(right_edge, divisor) &&
- floormod(info.padded, divisor) >=
floormod(right_edge, divisor));
-
- requires_padding_ = requires_padding_ || (left_padding_introduced ||
right_padding_introduced);
- padding_predicate_ = padding_predicate_ || (left_padding_predicate ||
right_padding_predicate);
-
- return {info.padded, left_pad};
+ if (!info.padded.defined()) {
+ PrimExpr mark_left_pad = info.left_pad.defined() ? info.left_pad : 0;
+ PrimExpr mark_right_pad;
+ if (CanProveDivisible(mark->extent + mark_left_pad, info.padding_factor)) {
+ mark_right_pad = 0;
+ } else {
+ mark_right_pad = floormod(-(mark->extent + mark_left_pad),
info.padding_factor);
+ }
+ PrimExpr padded_extent = analyzer_->Simplify(mark_left_pad + mark->extent
+ mark_right_pad);
+ info.right_pad = mark_right_pad;
+ info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad),
padded_extent);
+
+ auto left_padding_introduced = (mark_left_pad != 0);
+ PrimExpr divisor = info.padding_factor;
+ PrimExpr right_edge = mark_left_pad + mark->extent;
+
+ // Equivalent to (0 <= split < left_pad), but easier to simplify in
+ // terms of the transformed variables.
+ auto left_padding_predicate =
+ left_padding_introduced && (floordiv(info.padded->source, divisor) ==
0 &&
+ floormod(info.padded->source, divisor) <
mark_left_pad);
+
+ auto right_padding_introduced = (mark_right_pad != 0);
+
+ // Equivalent to (right_edge <= split < right_edge+right_pad), but
+ // easier to simplify in terms of the transformed variables.
+ auto right_padding_predicate =
+ right_padding_introduced &&
+ (floordiv(info.padded->source, divisor) == floordiv(right_edge,
divisor) &&
+ floormod(info.padded->source, divisor) >= floormod(right_edge,
divisor));
+
+ requires_padding_ = requires_padding_ || (left_padding_introduced ||
right_padding_introduced);
+ padding_predicate_ = padding_predicate_ || (left_padding_predicate ||
right_padding_predicate);
+ }
+ // ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor));
Review Comment:
Should these `// ICHECK` lines be either uncommented or removed?
##########
src/arith/iter_affine_map.cc:
##########
@@ -1062,58 +1099,59 @@ PaddedIterMapResult DetectPaddedIterMap(const
Array<PrimExpr>& indices,
[](const IterConstraint& a, const IterConstraint& b) { return
a.expr_size < b.expr_size; });
IterMapRewriter rewriter(analyzer, constrained_input_iters,
simplify_trivial_iterators,
- &result.errors);
+ &result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to
size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter,
constraint.lower_bound,
constraint.upper_bound);
- if (result.errors.size()) {
- return result;
+ if (result->errors.size()) {
+ return result_obj;
}
}
if (!rewriter.CheckConstraints()) {
- result.errors.push_back("Invalid constraints.");
- return result;
+ result->errors.push_back("Invalid constraints.");
+ return result_obj;
}
// Step0.1: Check each index to determine required padding
- bool allow_padding = !require_bijective;
+ bool allow_padding = check_level != IterMapLevel::Bijective;
Review Comment:
This would enable padding for `IterMapLevel::Surjective`, which I don't
think is correct. Since padding is any output value for which no input value
exists, any introduction of padding wouldn't be surjective.
##########
include/tvm/arith/iter_affine_map.h:
##########
@@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};
+/*! \brief Mapping level for iterators. */
+enum IterMapLevel {
Review Comment:
Do we want to allow the case where the map is neither surjective nor
injective? I don't have any existing use cases, but since non-surjective
mappings are useful to express padded layouts and non-injective mappings may be
useful to express circular buffering, I could imagine wanting to disable all of
the validation checks in order to have a padded circular buffer.
##########
python/tvm/arith/iter_affine_map.py:
##########
@@ -117,14 +117,14 @@ def detect_iter_map(
Returns
-------
- results : List[IterSumExpr]
+ results : IterMapResult
The iter map matching result.
- Empty array if no match can be found.
+ The result's .indices is empty array if no match can be found.
"""
return _ffi_api.DetectIterMap(
indices, input_iters, predicate, require_bijective,
simplify_trivial_iterators
- )
+ ).indices
Review Comment:
I think this line disagrees with the return type in the docstring. We
should either remove the `.indices` to return a `IterMapResult`, or revert the
docstring changes to describe returning a `IterMapResult`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]