This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f622e7f180 [ARITH][BUGFIX] Fix a bug of iter map floormod(x,2)
simplify (#14571)
f622e7f180 is described below
commit f622e7f180bdb95c9d3121bb1f78fe459a57812a
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 11 16:59:29 2023 -0400
[ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) simplify (#14571)
This PR fixes a previous bug introduced in itermap detection.
Specifically, y - (x % 2) were simplified to y + (x % 2) - 1.
Which is wrong. The working rule is y + ((x + 1) % 2) - 1,
but that rule will change the base iterator which is not desirable here.
We also removed the rule that simplifies (x + 1) % 2 => 1 - x % 2
as benefit is minimal and it introduces extra negative co-efficients
that hurts analysis in general (as negative co-efficients are
harder in many cases).
---
src/arith/iter_affine_map.cc | 5 --
src/arith/rewrite_simplify.cc | 17 +++-
.../unittest/test_arith_canonical_simplify.py | 7 ++
.../python/unittest/test_arith_iter_affine_map.py | 10 +--
.../python/unittest/test_arith_rewrite_simplify.py | 22 ++++--
.../test_tir_transform_inject_software_pipeline.py | 90 +++++++++++-----------
6 files changed, 86 insertions(+), 65 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e7fc4f2663..05af5b4070 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
- if (sign < 0 && is_const_int(rhs->extent, 2)) {
- lhs->base -= rhs->scale;
- sign = 1;
- }
-
tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 528a272ef4..acd74b7031 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
AddNode* op) {
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1,
2));
+ // Simplify (x + 1) % 2 + x % 2 => 1
+ // NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though
+ // mainly because introducing extra negative signs to expression can harm
itertaor
+ // analysis which usually relies on positive itertator co-efficients.
+ TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2),
OneWithTypeLike(x),
+ floormod(c1.Eval()->value, 2) == 1);
+ TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2),
OneWithTypeLike(x),
+ floormod(c1.Eval()->value, 2) == 1);
+
// canonicalization rule
// will try rewrite again after canonicalization.
@@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2)
+ y, c2),
c2.Eval()->value > 0);
- TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) +
1,
- floormod(c1.Eval()->value, 2) == 1);
- TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
- c2.Eval()->value > 0 && c1.Eval()->value %
c2.Eval()->value == 0);
+ // (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x
+ TVM_TRY_REWRITE_IF(
+ floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
+ c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value ||
c1.Eval()->value < 0));
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1,
c2), c2),
c2.Eval()->value > 0);
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py
b/tests/python/unittest/test_arith_canonical_simplify.py
index 1a4277d924..c1d7587f43 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -415,5 +415,12 @@ def test_proddiv_simplify():
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
+def test_floormod_two():
+ ck = CanonicalChecker()
+ flm = tvm.te.floormod
+ x, y = te.var("x"), te.var("y")
+ ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 0bb4c98b7b..5ce7296045 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -199,14 +199,14 @@ def test_compound():
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)},
var_dom([(x, 10), (y, 9)]))
-def test_compound_floormod_two():
+def test_compound_floormod_two_regression():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
-
- # extent of 2 are normalized to positive scale
- assert_iter_sum_pattern(
- expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
+ # regression
+ # extent of 2 of negative scale cannot be normalized
+ assert_iter_sum_failure(
+ [fld(x, 2) * 2 - flm(x, 2) + 1],
dom_map=var_dom([(x, 8)]),
)
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 7ecc34c385..46ac0f9751 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare):
TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)),
TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1),
TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)),
- TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6),
- TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)),
+ TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6),
+ TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)),
TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z),
TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)),
TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z),
@@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare):
TestCase(flm(x + 10, 2), flm(x, 2)),
TestCase(flm(x + y * 10, 2), flm(x, 2)),
TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
- TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
TestCase(flm(x * (-10), 2), 0),
TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
TestCase(flm(x + (-10), 2), flm(x, 2)),
TestCase(flm(x + y * (-10), 2), flm(x, 2)),
TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
+ # NOTE: the followng case is covered by canonical simplify
+ # long range simplifcation in general can be covered by canonical
simplify
+ # TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
)
@@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare):
require identifying more related terms in order to apply.
(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
+
+ We should not introduce extra negative coeficient to iterators
+ however during simplification
"""
x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
- TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
- TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
# Difference of floordiv yields floormod
@@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare):
# Sum of floordiv and floormod to yield floordiv
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
- # Removal of floormod where possible
- TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
+ # regression: although we can rewrite (x + 1) %2 => 1 - x%2
+ # doing so would introduce negative co-efficient to iterators
+ # which makes later iter map detection harder, in principle we
+ # should not introduce additional negative signs of iterator in
rewriting
+ TestCase(flm(x + 1, 2), flm(x + 1, 2)),
+ TestCase(flm(x + 5, 2), flm(x + 1, 2)),
+ TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]),
)
diff --git
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 7e59172bdd..b9f35ed553 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -139,8 +139,8 @@ def transformed_simple_compute(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1]])
- T.writes([B[1 - i % 2, tx, 0]])
- B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+ T.writes([B[(i + 1) % 2, tx, 0]])
+ B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
@@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation(
):
with T.block():
T.reads([A[tx, i + 1]])
- T.writes([B[1 - i % 2, tx, 0]])
- B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+ T.writes([B[(i + 1) % 2, tx, 0]])
+ B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
@@ -266,7 +266,7 @@ def transformed_three_stage_compute(
T.where(i == 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
- C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
+ C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] +
T.float32(2)
with T.block():
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
@@ -278,7 +278,7 @@ def transformed_three_stage_compute(
with T.block():
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
- C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] +
T.float32(2)
+ C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] +
T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i])
@@ -291,7 +291,7 @@ def transformed_three_stage_compute(
T.where(i < 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
- C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
+ C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] +
T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i + 14])
@@ -391,12 +391,12 @@ def transformed_dag_interleaving(
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
- T.writes(AL[1 - i % 2, 0, 0])
- AL[1 - i % 2, 0, 0] = AS[tx, 0]
+ T.writes(AL[(i + 1) % 2, 0, 0])
+ AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
- T.writes(BL[1 - i % 2, 0, 0])
- BL[1 - i % 2, 0, 0] = BS[tx, 0]
+ T.writes(BL[(i + 1) % 2, 0, 0])
+ BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
T.writes(C[tx, i])
@@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1, 0:16]])
- T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A[tx, i + 1, j]])
- T.writes([A_shared[1 - i % 2, tx, 0, j]])
- A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
+ T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+ A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1,
j]
with T.block():
T.reads([A_shared[i % 2, tx, i, 0]])
T.writes([B[0, tx, i, 0]])
@@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[i % 2, tx, i, j + 1]])
- T.writes([B[1 - j % 2, tx, i, 0]])
- B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx,
0, j + 1] * T.float32(
- 2
- )
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_shared[
+ i % 2, tx, 0, j + 1
+ ] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
@@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[1, tx, 15, j + 1]])
- T.writes([B[1 - j % 2, tx, 15, 0]])
- B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j +
1] * T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j +
1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
@@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1, 0:16]])
- T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A[tx, i + 1, j]])
- T.writes([A_shared[1 - i % 2, tx, 0, j]])
- A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
+ T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+ A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1,
j]
with T.block():
T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i,
0]])
T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[i % 2, tx, i, j + 1]])
- T.writes([B[1 - j % 2, tx, i, 0]])
- B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx,
0, j + 1] * T.float32(
- 2
- )
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_shared[
+ i % 2, tx, 0, j + 1
+ ] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
with T.block():
- T.reads([A_shared[1 - i % 2, tx, i + 1, 0]])
+ T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
T.writes([B[0, tx, i + 1, 0]])
- B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] *
T.float32(2)
+ B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] *
T.float32(2)
with T.block():
T.reads([B[1, tx, i, 0]])
T.writes([C[tx, i, 15]])
@@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[1, tx, 15, j + 1]])
- T.writes([B[1 - j % 2, tx, 15, 0]])
- B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j +
1] * T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j +
1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
@@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[tx, i, j + 1]])
- T.writes([B[1 - j % 2, tx, i, 0]])
- B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1]
* T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j +
1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
@@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[tx, 15, j + 1]])
- T.writes([B[1 - j % 2, tx, 15, 0]])
- B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] *
T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] *
T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
@@ -929,25 +929,27 @@ def transformed_nested_pipeline_double_buffer(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[i % 2, tx, i, j + 1]])
- T.writes([B[1 - j % 2, tx, i, 0]])
- B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0,
j + 1] * T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0,
0, j + 1] * T.float32(
+ 2
+ )
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
with T.block():
T.reads([A_shared[tx, 0, 0:16]])
- T.writes([A_local[1 - i % 2, 0, 0, 0:16]])
+ T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A_shared[tx, 0, j]])
- T.writes([A_local[1 - i % 2, 0, 0, j]])
+ T.writes([A_local[(i + 1) % 2, 0, 0, j]])
T.block_attr({"double_buffer_scope": 0})
- A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i +
1, j]
+ A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i
+ 1, j]
with T.block():
- T.reads([A_local[1 - i % 2, tx, i + 1, 0]])
+ T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
T.writes([B[0, tx, i + 1, 0]])
- B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] *
T.float32(2)
+ B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] *
T.float32(2)
with T.block():
T.reads([B[1, tx, i, 0]])
T.writes([C[tx, i, 15]])
@@ -961,8 +963,8 @@ def transformed_nested_pipeline_double_buffer(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[1, tx, 15, j + 1]])
- T.writes([B[1 - j % 2, tx, 15, 0]])
- B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1]
* T.float32(2)
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j +
1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])