This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c1c6d93b09 [ARITH] NormalizeToIterSum (#15120)
c1c6d93b09 is described below

commit c1c6d93b09e887235f32caf67f9d5cbdb2abb3a7
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Jun 21 12:22:23 2023 -0400

    [ARITH] NormalizeToIterSum (#15120)
---
 include/tvm/arith/iter_affine_map.h                | 19 +++++
 python/tvm/arith/__init__.py                       |  1 +
 python/tvm/arith/iter_affine_map.py                | 31 +++++++
 src/arith/iter_affine_map.cc                       | 98 ++++++++++++++++++++++
 src/arith/product_normal_form.h                    |  9 +-
 .../python/unittest/test_arith_iter_affine_map.py  | 68 +++++++++++++++
 6 files changed, 218 insertions(+), 8 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
index d89d3126a6..53c5b32dd2 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -420,6 +420,25 @@ Array<Array<IterMark>> SubspaceDivide(const 
Array<PrimExpr>& bindings,
  */
 PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);
 
+/*!
+ * \brief Rewrite index as IterSumExpr
+ *
+ * ((i0 // b0) % a0) * s0 + ((i0 // b1) % a1) * s1 ... + base
+ *
+ * The iterators are ordered such that s0 > s1 ...
+ * if we can prove the relation.
+ *
+ * Note that base may contain expressions that cannot be detected
+ * as the right pattern.
+ *
+ * \param index The input index
+ * \param input_iters The input iterators.
+ * \param analyzer The input analyzer.
+ * \note This function is useful to detect iterator stride patterns.
+ */
+IterSumExpr NormalizeToIterSum(PrimExpr index, const Map<Var, Range>& 
input_iters,
+                               arith::Analyzer* analyzer);
+
 }  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITH_ITER_AFFINE_MAP_H_
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index e2f6127e29..30fd86b037 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -32,6 +32,7 @@ from .iter_affine_map import (
     detect_iter_map,
     iter_map_simplify,
     normalize_iter_map_to_expr,
+    normalize_to_iter_sum,
     subspace_divide,
     inverse_affine_iter_map,
 )
diff --git a/python/tvm/arith/iter_affine_map.py 
b/python/tvm/arith/iter_affine_map.py
index 34487d00f0..f19dd0a1ba 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -156,6 +156,37 @@ def detect_iter_map(
     )
 
 
+def normalize_to_iter_sum(index, input_iters):
+    """Normalize expr to iter sum.
+
+    The normalized result ensures that
+    each scale is in the form of (symbol_prod) * cscale
+    It will also sort in desc order by cscale then len(symbol_prod).
+
+    Parameters
+    ----------
+    index : PrimExpr
+        The input index
+
+    input_iters : Map[Var, Range]
+        The domain of each input iterators.
+
+    Returns
+    -------
+    iter_sum: IterSumExpr
+        The result iter sum
+
+    Note
+    ----
+    This function does best effort detection, so some undetected
+    part can go into iter_sum.base
+
+    This function is useful to decide the stride multiplier and
+    division factor in buffer access patterns.
+    """
+    return _ffi_api.NormalizeToIterSum(index, input_iters)
+
+
 def iter_map_simplify(
     indices,
     input_iters,
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 377f8bb7c9..47782a91aa 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -224,6 +224,16 @@ class IterMapRewriter : public ExprMutator {
                                       predicate_induced_max);
   }
 
+  /**
+   * Rewrite expr to iter sum pattern
+   * \parma expr The input expression
+   * \return The rewritten iter sum pattern
+   * \note The result base may contain items that is not
+   */
+  IterSumExpr RewriteToNormalizedIterSum(const PrimExpr& expr) {
+    return NormalizeToIterSum(ToIterSumExpr(DirectMutate(expr)));
+  }
+
   /*!
    * \brief If require bijective mapping, this function checks two conditions:
    *   - C0: Each iter mark should be fully covered by non-overlapping splits.
@@ -735,6 +745,72 @@ class IterMapRewriter : public ExprMutator {
     }
   }
 
+  /*!
+   * \brief Normalize expr to iter sum.
+   *
+   * The normalized result ensures that
+   * each scale is in the form of (symbol_prod) * cscale
+   *
+   * It will also sort entries in desc order by cscale then len(symbol_prod).
+   *
+   * This is a best effort sorting since some scale can be symbolic.
+   * We first order them by the constant factors, then the number of symbols
+   * involved in a multiply
+   *
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterSum(IterSumExpr expr) {
+    // We are normalizing a regular iter
+    if (expr->args.size() < 1) return expr;
+    if (auto opt = TryCombineSplitFromSameSource(expr)) {
+      expr = opt.value();
+      if (expr->args.size() < 1) return expr;
+    }
+    struct Item {
+      int64_t cscale;
+      int64_t symbol_prod_count;
+      IterSplitExpr split;
+    };
+
+    std::vector<Item> items;
+
+    for (IterSplitExpr split : expr->args) {
+      int64_t symbol_prod_count = 0;
+      int64_t cscale = 1;
+      PrimExpr res = tir::make_const(split.dtype(), 1);
+      auto fcollect = [&](PrimExpr val) {
+        if (const auto* intimm = val.as<IntImmNode>()) {
+          cscale *= intimm->value;
+        } else {
+          res = res * val;
+          ++symbol_prod_count;
+        }
+      };
+      UnpackReduction<tir::MulNode>(split->scale, fcollect);
+      if (cscale != 1) {
+        res = res * tir::make_const(res.dtype(), cscale);
+      }
+      split.CopyOnWrite()->scale = res;
+      items.emplace_back(Item{cscale, symbol_prod_count, split});
+    }
+
+    std::stable_sort(items.begin(), items.end(), [](const Item& lhs, const 
Item& rhs) {
+      if (lhs.cscale > rhs.cscale) return true;
+      if (lhs.cscale < rhs.cscale) return false;
+      return lhs.symbol_prod_count > rhs.symbol_prod_count;
+    });
+
+    Array<IterSplitExpr> args;
+    for (const Item& item : items) {
+      args.push_back(item.split);
+    }
+
+    expr.CopyOnWrite()->args = args;
+    expr.CopyOnWrite()->base = NormalizeIterMapToExpr(expr->base);
+    return expr;
+  }
+
   /*!
    * \brief Create a IterSumExpr from expr.
    * \param expr The input expr.
@@ -1426,6 +1502,28 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap")
                            simplify_trivial_iterators);
     });
 
+IterSumExpr NormalizeToIterSum(PrimExpr index, const Map<Var, Range>& 
input_iters,
+                               arith::Analyzer* analyzer) {
+  IterMapResult result;
+  ICHECK(IterRangeSanityCheck(input_iters))
+      << "Invalid iterators.  Iterators may not be expressions of each other.";
+
+  // we skip constraint check as the most important thing here is only the 
pattern
+  std::vector<IterConstraint> constraints;
+  IterMapLevel check_level = IterMapLevel::NoCheck;
+  bool simplify_trivial_iterators = true;
+  IterMapRewriter rewriter(analyzer, input_iters, check_level, 
simplify_trivial_iterators,
+                           &result->errors);
+
+  return rewriter.RewriteToNormalizedIterSum(index);
+}
+
+TVM_REGISTER_GLOBAL("arith.NormalizeToIterSum")
+    .set_body_typed([](PrimExpr index, const Map<Var, Range>& input_iters) {
+      arith::Analyzer ana;
+      return NormalizeToIterSum(index, input_iters, &ana);
+    });
+
 PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
   auto var = GetRef<Var>(op);
   auto it = var_map_.find(var);
diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h
index 178b372b28..768a3a2b8b 100644
--- a/src/arith/product_normal_form.h
+++ b/src/arith/product_normal_form.h
@@ -54,14 +54,7 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf 
fleaf) {
  *
  * NOTE on multiplication order: when have have shape (s[0], s[1], s[2]),
  * we prefer to multiple in order of s[0] * s[1] * s[2]
- *
- * That means when we are looking at the pattern of split iterator:
- *
- * - result = (source // lower_factor) % extent * scale
- *
- * We should take the order of lower_factor, extent, scale.
- * Please do best keeping this order to make future simplifcation easy.
- *
+
  * \param lhs The lhs iterator
  * \param rhs The rhs iterator
  * \return the result.
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 640d7592ad..594dec73ea 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1217,5 +1217,73 @@ def test_iter_map_simplify_unit_loop_order():
     )
 
 
+def assert_normalize_to_iter_sum(index, input_iters, args, base):
+    res = tvm.arith.normalize_to_iter_sum(index, input_iters)
+
+    assert isinstance(res, tvm.arith.IterSumExpr)
+    assert len(res.args) == len(args)
+    for split, item in zip(res.args, args):
+        tvm.testing.assert_prim_expr_equal(split.scale, item[1])
+        tvm.testing.assert_prim_expr_equal(
+            tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
+        )
+    tvm.testing.assert_prim_expr_equal(res.base, base)
+
+
+def test_normalize_to_iter_sum():
+    x = tvm.tir.Var("x", "int64")
+    y = tvm.tir.Var("y", "int64")
+    z = tvm.tir.Var("z", "int64")
+    a = tvm.tir.Var("a", "int64")
+    n = tvm.tir.Var("n", "int64")
+
+    assert_normalize_to_iter_sum(
+        z + ((y + x * 4 + 2) * n) + 3,
+        var_dom([(x, 9), (y, 4), (z, 3)]),
+        [(x, n * 4), (y, n), (z, 1)],
+        2 * n + 3,
+    )
+
+    # max cannot detected so it goes into base
+    assert_normalize_to_iter_sum(
+        tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3,
+        var_dom([(x, 9), (y, 4), (z, 3)]),
+        [(x, n * 4), (y, n)],
+        tvm.tir.max(z, a) + 2 * n + 3,
+    )
+
+    # order by symbolc prod
+    assert_normalize_to_iter_sum(
+        z + ((y * 4 * a + x * 4 + 2) * n) + 3,
+        var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+        [(y, a * n * 4), (x, n * 4), (z, 1)],
+        2 * n + 3,
+    )
+
+    # order by cscale
+    assert_normalize_to_iter_sum(
+        z + 2 * y * 3 + 4 * x,
+        var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+        [(y, 6), (x, 4), (z, 1)],
+        0,
+    )
+
+    # split pattern
+    assert_normalize_to_iter_sum(
+        z + 2 * y * 3 + 4 * (x // 2),
+        var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+        [(y, 6), (x // 2, 4), (z, 1)],
+        0,
+    )
+
+    # iter simplify
+    assert_normalize_to_iter_sum(
+        z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),
+        var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
+        [(y, 6), (z, 2), (x, 1)],
+        0,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to