tqchen commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602342033
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -459,27 +665,107 @@ class IterMapRewriter : public ExprMutator {
}
};
+/*! \brief An internal struct to represent range extent on iterators(iter <
upper_bound). */
+struct IterConstraint {
+ // The expr of the iter
+ PrimExpr iter;
+ // The expr of the upper_bound
+ PrimExpr upper_bound;
+ // The size of the iter, which is the number of nodes
+ size_t expr_size = 0;
+
+ IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
+ : iter(std::move(iter)), upper_bound(std::move(upper_bound)),
expr_size(size) {}
+};
+
+/*!
+ * \brief Split the predicate into `(a < b) && (c < d) && ...`
+ * \param pred The predicate to be split.
+ * \return A list of pairs, each element of which are lhs and rhs of the '<'
sign,
+ * empty if the split failed.
+ */
+std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+ std::vector<IterConstraint> result;
+ arith::PVar<PrimExpr> lhs, rhs, rest;
+ for (;;) {
+ if ((rest && (lhs < rhs)).Match(pred)) {
+ result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+ pred = rest.Eval();
+ } else if ((lhs < rhs).Match(pred)) {
+ result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+ break;
+ } else {
+ return std::vector<IterConstraint>();
+ }
+ }
+ return result;
+}
+
+/*! \brief Count the size of the PrimExpr. */
+class PrimExprSizeCounter : public ExprVisitor {
Review comment:
move to src/analysis/expr_size.cc analysis.h and expose as a function
`size_t ExprSize(const PrimExpr& expr)`; document as number of expressions in
the child
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -381,31 +534,84 @@ class IterMapRewriter : public ExprMutator {
if (!base_scale) return NullOpt;
// check if it can be remapped into a fused pattern.
PrimExpr expected_scale = base_scale.value();
- for (size_t i = 0; i < expr->args.size(); ++i) {
+ for (size_t i = 0; i < expr->args.size();) {
+ // find j such that expr->args[j] has expected scale
size_t j = i == 0 ? base_index : 0;
for (; j < expr->args.size(); ++j) {
- if (!visited[j] && CanProveEqual(expr->args[j]->scale,
expected_scale)) break;
+ if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale,
expected_scale)) break;
}
- if (j == expr->args.size()) {
- return NullOpt;
+ if (j == expr->args.size()) return NullOpt;
+ // look for the longest constrained iter started from expr->args[j]
+ // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+ // predicate: j*2 + k < 9
+ // We need to match the predicate in expr and adjust the expected scale,
+ // otherwise we expect the scale of i to be 2*5=10
+ Optional<IterSumExpr> constraint_to_match;
+ for (const IterSumExpr& iter : constrained_iters_flattened_) {
+ if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
+ // find a predicate started from expr->args[j]
+ if (!constraint_to_match ||
+ constraint_to_match.value()->args.size() < iter->args.size()) {
+ constraint_to_match = iter;
+ }
+ }
}
- visited[j] = true;
- auto arg = expr->args[j];
- arg.CopyOnWrite()->scale = div(expr->args[j]->scale, base_scale.value());
- iters.push_back(arg);
- expected_scale *= expr->args[j]->extent;
- }
- // update the iterator to use the canonicalized form
- expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(),
iters.rend());
- auto it = sum_fuse_map_.find(expr);
- if (it != sum_fuse_map_.end()) return it->second;
- auto mark = IterMark(expr, div(expected_scale, base_scale.value()));
- IterSplitExpr split(mark, base_scale.value());
- sum_fuse_map_[expr] = split;
- return split;
+ if (constraint_to_match) {
+ // match the predicate and mark the iterators in the
constraint_to_match as visited
+ // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+ // predicate = j*2 + k < 9
+ // then j*2 + k matches the lower two splits of expr
+ for (auto it = constraint_to_match.value()->args.rbegin();
+ it != constraint_to_match.value()->args.rend(); ++it) {
+ size_t k = 0;
+ for (; k < expr->args.size(); ++k) {
+ if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
+ if (analyzer_->CanProveEqual((*it)->scale * expected_scale,
expr->args[k]->scale))
+ break;
+ }
+ }
+ if (k == expr->args.size()) return NullOpt;
+ visited[k] = true;
+ flattened_iters.push_back(expr->args[k]);
+ }
+ auto iter = sum_fuse_map_.find(constraint_to_match.value());
+ ICHECK(iter != sum_fuse_map_.end());
+ IterMark iter_matched = iter->second;
+ grouped_iters.emplace_back(iter_matched, expected_scale);
+ expected_scale *= iter_matched->extent;
+ // move forward
+ i += constraint_to_match.value()->args.size();
+ } else {
+ // constraint_to_match not found, skip this iterator
+ visited[j] = true;
+ flattened_iters.push_back(expr->args[j]);
+ grouped_iters.push_back(expr->args[j]);
+ expected_scale *= expr->args[j]->extent;
+ ++i;
+ }
+ }
+ // Get the flattened form and structured form
+ // both forms have splits from outermost to innermost
+ IterSumExpr structured_form = expr, flattened_form = expr;
+ flattened_form.CopyOnWrite()->args =
+ Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
+ structured_form.CopyOnWrite()->args =
+ Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
+ auto it = sum_fuse_map_.find(flattened_form);
+ if (it != sum_fuse_map_.end()) {
+ // old iter
+ return IterSplitExpr(it->second, base_scale.value());
+ } else {
+ // new iter, form a new mark
+ IterMark mark = IterMark(structured_form, div(expected_scale,
base_scale.value()));
+ sum_fuse_map_[flattened_form] = mark;
+ flattened_map_[structured_form] = flattened_form;
+ return IterSplitExpr(mark, base_scale.value());
+ }
}
bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
+ if (analyzer_->CanProveEqual(lhs, rhs)) return true;
Review comment:
this is not needed. Mainly because the const integer comparison is a
faster path while CanProveEqual contains a slower path
##########
File path: tests/python/unittest/test_arith_iter_affine_map.py
##########
@@ -107,32 +106,20 @@ def test_fuse():
res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)]))
assert len(res) == 0
- # simple stride pattern
- res = tvm.arith.detect_iter_map([x * 4 + y * 2], var_dom([(x, 3), (y, 2)]))
- assert len(res) == 1
Review comment:
check if these testcases are intentionally deleted, i see they are
moved, just want to make sure
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]