This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c1c6d93b09 [ARITH] NormalizeToIterSum (#15120)
c1c6d93b09 is described below
commit c1c6d93b09e887235f32caf67f9d5cbdb2abb3a7
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Jun 21 12:22:23 2023 -0400
[ARITH] NormalizeToIterSum (#15120)
---
include/tvm/arith/iter_affine_map.h | 19 +++++
python/tvm/arith/__init__.py | 1 +
python/tvm/arith/iter_affine_map.py | 31 +++++++
src/arith/iter_affine_map.cc | 98 ++++++++++++++++++++++
src/arith/product_normal_form.h | 9 +-
.../python/unittest/test_arith_iter_affine_map.py | 68 +++++++++++++++
6 files changed, 218 insertions(+), 8 deletions(-)
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index d89d3126a6..53c5b32dd2 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -420,6 +420,25 @@ Array<Array<IterMark>> SubspaceDivide(const
Array<PrimExpr>& bindings,
*/
PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);
+/*!
+ * \brief Rewrite index as IterSumExpr
+ *
+ * ((i0 // b0) % a0) * s0 + ((i0 // b1) % a1) * s1 ... + base
+ *
+ * The iterators are ordered such that s0 > s1 ...
+ * if we can prove the relation.
+ *
+ * Note that base may contain expressions that cannot be detected
+ * as the right pattern.
+ *
+ * \param index The input index
+ * \param input_iters The input iterators.
+ * \param analyzer The input analyzer.
+ * \note This function is useful to detect iterator stride patterns.
+ */
+IterSumExpr NormalizeToIterSum(PrimExpr index, const Map<Var, Range>&
input_iters,
+ arith::Analyzer* analyzer);
+
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index e2f6127e29..30fd86b037 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -32,6 +32,7 @@ from .iter_affine_map import (
detect_iter_map,
iter_map_simplify,
normalize_iter_map_to_expr,
+ normalize_to_iter_sum,
subspace_divide,
inverse_affine_iter_map,
)
diff --git a/python/tvm/arith/iter_affine_map.py
b/python/tvm/arith/iter_affine_map.py
index 34487d00f0..f19dd0a1ba 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -156,6 +156,37 @@ def detect_iter_map(
)
+def normalize_to_iter_sum(index, input_iters):
+ """Normalize expr to iter sum.
+
+ The normalized result ensures that
+ each scale is in the form of (symbol_prod) * cscale
+ It will also sort in desc order by cscale then len(symbol_prod).
+
+ Parameters
+ ----------
+ index : PrimExpr
+ The input index
+
+ input_iters : Map[Var, Range]
+ The domain of each input iterators.
+
+ Returns
+ -------
+ iter_sum: IterSumExpr
+ The result iter sum
+
+ Note
+ ----
+ This function does best effort detection, so some undetected
+ part can go into iter_sum.base
+
+ This function is useful to decide the stride multiplier and
+ division factor in buffer access patterns.
+ """
+ return _ffi_api.NormalizeToIterSum(index, input_iters)
+
+
def iter_map_simplify(
indices,
input_iters,
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 377f8bb7c9..47782a91aa 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -224,6 +224,16 @@ class IterMapRewriter : public ExprMutator {
predicate_induced_max);
}
+ /**
+ * Rewrite expr to iter sum pattern
+ * \parma expr The input expression
+ * \return The rewritten iter sum pattern
+ * \note The result base may contain items that is not
+ */
+ IterSumExpr RewriteToNormalizedIterSum(const PrimExpr& expr) {
+ return NormalizeToIterSum(ToIterSumExpr(DirectMutate(expr)));
+ }
+
/*!
* \brief If require bijective mapping, this function checks two conditions:
* - C0: Each iter mark should be fully covered by non-overlapping splits.
@@ -735,6 +745,72 @@ class IterMapRewriter : public ExprMutator {
}
}
+ /*!
+ * \brief Normalize expr to iter sum.
+ *
+ * The normalized result ensures that
+ * each scale is in the form of (symbol_prod) * cscale
+ *
+ * It will also sort entries in desc order by cscale then len(symbol_prod).
+ *
+ * This is a best effort sorting since some scale can be symbolic.
+ * We first order them by the constant factors, then the number of symbols
+ * involved in a multiply
+ *
+ * \param expr The input expression.
+ * \return The Normalized expression.
+ */
+ IterSumExpr NormalizeToIterSum(IterSumExpr expr) {
+ // We are normalizing a regular iter
+ if (expr->args.size() < 1) return expr;
+ if (auto opt = TryCombineSplitFromSameSource(expr)) {
+ expr = opt.value();
+ if (expr->args.size() < 1) return expr;
+ }
+ struct Item {
+ int64_t cscale;
+ int64_t symbol_prod_count;
+ IterSplitExpr split;
+ };
+
+ std::vector<Item> items;
+
+ for (IterSplitExpr split : expr->args) {
+ int64_t symbol_prod_count = 0;
+ int64_t cscale = 1;
+ PrimExpr res = tir::make_const(split.dtype(), 1);
+ auto fcollect = [&](PrimExpr val) {
+ if (const auto* intimm = val.as<IntImmNode>()) {
+ cscale *= intimm->value;
+ } else {
+ res = res * val;
+ ++symbol_prod_count;
+ }
+ };
+ UnpackReduction<tir::MulNode>(split->scale, fcollect);
+ if (cscale != 1) {
+ res = res * tir::make_const(res.dtype(), cscale);
+ }
+ split.CopyOnWrite()->scale = res;
+ items.emplace_back(Item{cscale, symbol_prod_count, split});
+ }
+
+ std::stable_sort(items.begin(), items.end(), [](const Item& lhs, const
Item& rhs) {
+ if (lhs.cscale > rhs.cscale) return true;
+ if (lhs.cscale < rhs.cscale) return false;
+ return lhs.symbol_prod_count > rhs.symbol_prod_count;
+ });
+
+ Array<IterSplitExpr> args;
+ for (const Item& item : items) {
+ args.push_back(item.split);
+ }
+
+ expr.CopyOnWrite()->args = args;
+ expr.CopyOnWrite()->base = NormalizeIterMapToExpr(expr->base);
+ return expr;
+ }
+
/*!
* \brief Create a IterSumExpr from expr.
* \param expr The input expr.
@@ -1426,6 +1502,28 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap")
simplify_trivial_iterators);
});
+IterSumExpr NormalizeToIterSum(PrimExpr index, const Map<Var, Range>&
input_iters,
+ arith::Analyzer* analyzer) {
+ IterMapResult result;
+ ICHECK(IterRangeSanityCheck(input_iters))
+ << "Invalid iterators. Iterators may not be expressions of each other.";
+
+ // we skip constraint check as the most important thing here is only the
pattern
+ std::vector<IterConstraint> constraints;
+ IterMapLevel check_level = IterMapLevel::NoCheck;
+ bool simplify_trivial_iterators = true;
+ IterMapRewriter rewriter(analyzer, input_iters, check_level,
simplify_trivial_iterators,
+ &result->errors);
+
+ return rewriter.RewriteToNormalizedIterSum(index);
+}
+
+TVM_REGISTER_GLOBAL("arith.NormalizeToIterSum")
+ .set_body_typed([](PrimExpr index, const Map<Var, Range>& input_iters) {
+ arith::Analyzer ana;
+ return NormalizeToIterSum(index, input_iters, &ana);
+ });
+
PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
auto var = GetRef<Var>(op);
auto it = var_map_.find(var);
diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h
index 178b372b28..768a3a2b8b 100644
--- a/src/arith/product_normal_form.h
+++ b/src/arith/product_normal_form.h
@@ -54,14 +54,7 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf
fleaf) {
*
* NOTE on multiplication order: when have have shape (s[0], s[1], s[2]),
* we prefer to multiple in order of s[0] * s[1] * s[2]
- *
- * That means when we are looking at the pattern of split iterator:
- *
- * - result = (source // lower_factor) % extent * scale
- *
- * We should take the order of lower_factor, extent, scale.
- * Please do best keeping this order to make future simplifcation easy.
- *
+
* \param lhs The lhs iterator
* \param rhs The rhs iterator
* \return the result.
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 640d7592ad..594dec73ea 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1217,5 +1217,73 @@ def test_iter_map_simplify_unit_loop_order():
)
+def assert_normalize_to_iter_sum(index, input_iters, args, base):
+ res = tvm.arith.normalize_to_iter_sum(index, input_iters)
+
+ assert isinstance(res, tvm.arith.IterSumExpr)
+ assert len(res.args) == len(args)
+ for split, item in zip(res.args, args):
+ tvm.testing.assert_prim_expr_equal(split.scale, item[1])
+ tvm.testing.assert_prim_expr_equal(
+ tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
+ )
+ tvm.testing.assert_prim_expr_equal(res.base, base)
+
+
+def test_normalize_to_iter_sum():
+ x = tvm.tir.Var("x", "int64")
+ y = tvm.tir.Var("y", "int64")
+ z = tvm.tir.Var("z", "int64")
+ a = tvm.tir.Var("a", "int64")
+ n = tvm.tir.Var("n", "int64")
+
+ assert_normalize_to_iter_sum(
+ z + ((y + x * 4 + 2) * n) + 3,
+ var_dom([(x, 9), (y, 4), (z, 3)]),
+ [(x, n * 4), (y, n), (z, 1)],
+ 2 * n + 3,
+ )
+
+ # max cannot detected so it goes into base
+ assert_normalize_to_iter_sum(
+ tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3,
+ var_dom([(x, 9), (y, 4), (z, 3)]),
+ [(x, n * 4), (y, n)],
+ tvm.tir.max(z, a) + 2 * n + 3,
+ )
+
+ # order by symbolc prod
+ assert_normalize_to_iter_sum(
+ z + ((y * 4 * a + x * 4 + 2) * n) + 3,
+ var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+ [(y, a * n * 4), (x, n * 4), (z, 1)],
+ 2 * n + 3,
+ )
+
+ # order by cscale
+ assert_normalize_to_iter_sum(
+ z + 2 * y * 3 + 4 * x,
+ var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+ [(y, 6), (x, 4), (z, 1)],
+ 0,
+ )
+
+ # split pattern
+ assert_normalize_to_iter_sum(
+ z + 2 * y * 3 + 4 * (x // 2),
+ var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+ [(y, 6), (x // 2, 4), (z, 1)],
+ 0,
+ )
+
+ # iter simplify
+ assert_normalize_to_iter_sum(
+ z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),
+ var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+ [(y, 6), (z, 2), (x, 1)],
+ 0,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()