This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f622e7f180 [ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) 
simplify (#14571)
f622e7f180 is described below

commit f622e7f180bdb95c9d3121bb1f78fe459a57812a
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 11 16:59:29 2023 -0400

    [ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) simplify (#14571)
    
    This PR fixes a previous bug introduced in itermap detection.
    
    Specifically, y - (x % 2) were simplified to y + (x % 2) - 1.
    Which is wrong. The working rule is  y + ((x + 1) % 2) - 1,
    but that rule will change the base iterator which is not desirable here.
    
    We also removed the rule that simplifies  (x + 1) % 2 => 1 - x % 2
    as benefit is minimal and it introduces extra negative co-efficients
    that hurts analysis in general (as negative co-efficients are
    harder in many cases).
---
 src/arith/iter_affine_map.cc                       |  5 --
 src/arith/rewrite_simplify.cc                      | 17 +++-
 .../unittest/test_arith_canonical_simplify.py      |  7 ++
 .../python/unittest/test_arith_iter_affine_map.py  | 10 +--
 .../python/unittest/test_arith_rewrite_simplify.py | 22 ++++--
 .../test_tir_transform_inject_software_pipeline.py | 90 +++++++++++-----------
 6 files changed, 86 insertions(+), 65 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e7fc4f2663..05af5b4070 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator {
   PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
 
   static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
-    if (sign < 0 && is_const_int(rhs->extent, 2)) {
-      lhs->base -= rhs->scale;
-      sign = 1;
-    }
-
     tir::ExprDeepEqual equal;
     for (size_t i = 0; i < lhs->args.size(); ++i) {
       IterSplitExpr lvalue = lhs->args[i];
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 528a272ef4..acd74b7031 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
AddNode* op) {
 
     TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 
2));
 
+    // Simplify (x + 1) % 2 + x % 2 => 1
+    // NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though
+    // mainly because introducing extra negative signs to expression can harm 
itertaor
+    // analysis which usually relies on positive itertator co-efficients.
+    TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2), 
OneWithTypeLike(x),
+                       floormod(c1.Eval()->value, 2) == 1);
+    TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2), 
OneWithTypeLike(x),
+                       floormod(c1.Eval()->value, 2) == 1);
+
     // canonicalization rule
     // will try rewrite again after canonicalization.
 
@@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
     TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) 
+ y, c2),
                        c2.Eval()->value > 0);
 
-    TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 
1,
-                                 floormod(c1.Eval()->value, 2) == 1);
-    TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
-                       c2.Eval()->value > 0 && c1.Eval()->value % 
c2.Eval()->value == 0);
+    // (x + 5) % 2 -> (x + 1) %2,  (x + 3) % 3 => x
+    TVM_TRY_REWRITE_IF(
+        floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
+        c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || 
c1.Eval()->value < 0));
 
     TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, 
c2), c2),
                        c2.Eval()->value > 0);
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py 
b/tests/python/unittest/test_arith_canonical_simplify.py
index 1a4277d924..c1d7587f43 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -415,5 +415,12 @@ def test_proddiv_simplify():
     ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
 
 
+def test_floormod_two():
+    ck = CanonicalChecker()
+    flm = tvm.te.floormod
+    x, y = te.var("x"), te.var("y")
+    ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 0bb4c98b7b..5ce7296045 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -199,14 +199,14 @@ def test_compound():
     assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, 
var_dom([(x, 10), (y, 9)]))
 
 
-def test_compound_floormod_two():
+def test_compound_floormod_two_regression():
     x = tvm.tir.Var("x", "int32")
     fld = tvm.tir.floordiv
     flm = tvm.tir.floormod
-
-    # extent of 2 are normalized to positive scale
-    assert_iter_sum_pattern(
-        expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
+    # regression
+    # extent of 2 of negative scale cannot be normalized
+    assert_iter_sum_failure(
+        [fld(x, 2) * 2 - flm(x, 2) + 1],
         dom_map=var_dom([(x, 8)]),
     )
 
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py 
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 7ecc34c385..46ac0f9751 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare):
         TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)),
         TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1),
         TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)),
-        TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6),
-        TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)),
+        TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6),
+        TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)),
         TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z),
         TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)),
         TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z),
@@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare):
         TestCase(flm(x + 10, 2), flm(x, 2)),
         TestCase(flm(x + y * 10, 2), flm(x, 2)),
         TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
-        TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
         TestCase(flm(x * (-10), 2), 0),
         TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
         TestCase(flm(x + (-10), 2), flm(x, 2)),
         TestCase(flm(x + y * (-10), 2), flm(x, 2)),
         TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
         TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
+        # NOTE: the followng case is covered by canonical simplify
+        # long range simplifcation in general can be covered by canonical 
simplify
+        # TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
     )
 
 
@@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare):
     require identifying more related terms in order to apply.
 
     (x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
+
+    We should not introduce extra negative coeficient to iterators
+    however during simplification
     """
 
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     test_case = tvm.testing.parameter(
         # Removing offsets from floormod
-        TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
-        TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
         TestCase(flm(x, 2) + flm(x + 1, 2), 1),
         TestCase(flm(x + 1, 2) + flm(x, 2), 1),
         # Difference of floordiv yields floormod
@@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare):
         # Sum of floordiv and floormod to yield floordiv
         TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
         TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
-        # Removal of floormod where possible
-        TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
+        # regression: although we can rewrite (x + 1) %2 => 1 - x%2
+        # doing so would introduce negative co-efficient to iterators
+        # which makes later iter map detection harder, in principle we
+        # should not introduce additional negative signs of iterator in 
rewriting
+        TestCase(flm(x + 1, 2), flm(x + 1, 2)),
+        TestCase(flm(x + 5, 2), flm(x + 1, 2)),
+        TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]),
     )
 
 
diff --git 
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py 
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 7e59172bdd..b9f35ed553 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -139,8 +139,8 @@ def transformed_simple_compute(
                 for i in T.serial(0, 15):
                     with T.block():
                         T.reads([A[tx, i + 1]])
-                        T.writes([B[1 - i % 2, tx, 0]])
-                        B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                        T.writes([B[(i + 1) % 2, tx, 0]])
+                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
                     with T.block():
                         T.reads([B[i % 2, tx, 0]])
                         T.writes([C[tx, i]])
@@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation(
                 ):
                     with T.block():
                         T.reads([A[tx, i + 1]])
-                        T.writes([B[1 - i % 2, tx, 0]])
-                        B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                        T.writes([B[(i + 1) % 2, tx, 0]])
+                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
                     with T.block():
                         T.reads([B[i % 2, tx, 0]])
                         T.writes([C[tx, i]])
@@ -266,7 +266,7 @@ def transformed_three_stage_compute(
                         T.where(i == 1)
                         T.reads(B[0:2, tx, 0])
                         T.writes(C[0:2, tx, 0])
-                        C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
+                        C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + 
T.float32(2)
             with T.block():
                 T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
                 T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
@@ -278,7 +278,7 @@ def transformed_three_stage_compute(
                     with T.block():
                         T.reads(B[0:2, tx, 0])
                         T.writes(C[0:2, tx, 0])
-                        C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + 
T.float32(2)
+                        C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + 
T.float32(2)
                     with T.block():
                         T.reads(C[0:2, tx, 0])
                         T.writes(D[tx, i])
@@ -291,7 +291,7 @@ def transformed_three_stage_compute(
                         T.where(i < 1)
                         T.reads(B[0:2, tx, 0])
                         T.writes(C[0:2, tx, 0])
-                        C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
+                        C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + 
T.float32(2)
                     with T.block():
                         T.reads(C[0:2, tx, 0])
                         T.writes(D[tx, i + 14])
@@ -391,12 +391,12 @@ def transformed_dag_interleaving(
                         BS[tx, 0] = B[tx, i + 1] + T.float32(2)
                     with T.block():
                         T.reads(AS[tx, 0])
-                        T.writes(AL[1 - i % 2, 0, 0])
-                        AL[1 - i % 2, 0, 0] = AS[tx, 0]
+                        T.writes(AL[(i + 1) % 2, 0, 0])
+                        AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
                     with T.block():
                         T.reads(BS[tx, 0])
-                        T.writes(BL[1 - i % 2, 0, 0])
-                        BL[1 - i % 2, 0, 0] = BS[tx, 0]
+                        T.writes(BL[(i + 1) % 2, 0, 0])
+                        BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
                     with T.block():
                         T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
                         T.writes(C[tx, i])
@@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple(
                 for i in T.serial(0, 15):
                     with T.block():
                         T.reads([A[tx, i + 1, 0:16]])
-                        T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
+                        T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
                         for j in T.serial(0, 16):
                             with T.block():
                                 T.reads([A[tx, i + 1, j]])
-                                T.writes([A_shared[1 - i % 2, tx, 0, j]])
-                                A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
+                                T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+                                A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, 
j]
                     with T.block():
                         T.reads([A_shared[i % 2, tx, i, 0]])
                         T.writes([B[0, tx, i, 0]])
@@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple(
                         for j in T.serial(0, 15):
                             with T.block():
                                 T.reads([A_shared[i % 2, tx, i, j + 1]])
-                                T.writes([B[1 - j % 2, tx, i, 0]])
-                                B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 
0, j + 1] * T.float32(
-                                    2
-                                )
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_shared[
+                                    i % 2, tx, 0, j + 1
+                                ] * T.float32(2)
                             with T.block():
                                 T.reads([B[j % 2, tx, i, 0]])
                                 T.writes([C[tx, i, j]])
@@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple(
                     for j in T.serial(0, 15):
                         with T.block():
                             T.reads([A_shared[1, tx, 15, j + 1]])
-                            T.writes([B[1 - j % 2, tx, 15, 0]])
-                            B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 
1] * T.float32(2)
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 
1] * T.float32(2)
                         with T.block():
                             T.reads([B[j % 2, tx, 15, 0]])
                             T.writes([C[tx, 15, j]])
@@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner(
                 for i in T.serial(0, 15):
                     with T.block():
                         T.reads([A[tx, i + 1, 0:16]])
-                        T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
+                        T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
                         for j in T.serial(0, 16):
                             with T.block():
                                 T.reads([A[tx, i + 1, j]])
-                                T.writes([A_shared[1 - i % 2, tx, 0, j]])
-                                A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
+                                T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+                                A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, 
j]
                     with T.block():
                         T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 
0]])
                         T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
                         for j in T.serial(0, 15):
                             with T.block():
                                 T.reads([A_shared[i % 2, tx, i, j + 1]])
-                                T.writes([B[1 - j % 2, tx, i, 0]])
-                                B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 
0, j + 1] * T.float32(
-                                    2
-                                )
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_shared[
+                                    i % 2, tx, 0, j + 1
+                                ] * T.float32(2)
                             with T.block():
                                 T.reads([B[j % 2, tx, i, 0]])
                                 T.writes([C[tx, i, j]])
                                 C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
                     with T.block():
-                        T.reads([A_shared[1 - i % 2, tx, i + 1, 0]])
+                        T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
                         T.writes([B[0, tx, i + 1, 0]])
-                        B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * 
T.float32(2)
+                        B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * 
T.float32(2)
                     with T.block():
                         T.reads([B[1, tx, i, 0]])
                         T.writes([C[tx, i, 15]])
@@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner(
                     for j in T.serial(0, 15):
                         with T.block():
                             T.reads([A_shared[1, tx, 15, j + 1]])
-                            T.writes([B[1 - j % 2, tx, 15, 0]])
-                            B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 
1] * T.float32(2)
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 
1] * T.float32(2)
                         with T.block():
                             T.reads([B[j % 2, tx, 15, 0]])
                             T.writes([C[tx, 15, j]])
@@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving(
                         for j in T.serial(0, 15):
                             with T.block():
                                 T.reads([A_local[tx, i, j + 1]])
-                                T.writes([B[1 - j % 2, tx, i, 0]])
-                                B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] 
* T.float32(2)
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 
1] * T.float32(2)
                             with T.block():
                                 T.reads([B[j % 2, tx, i, 0]])
                                 T.writes([C[tx, i, j]])
@@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving(
                     for j in T.serial(0, 15):
                         with T.block():
                             T.reads([A_local[tx, 15, j + 1]])
-                            T.writes([B[1 - j % 2, tx, 15, 0]])
-                            B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * 
T.float32(2)
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * 
T.float32(2)
                         with T.block():
                             T.reads([B[j % 2, tx, 15, 0]])
                             T.writes([C[tx, 15, j]])
@@ -929,25 +929,27 @@ def transformed_nested_pipeline_double_buffer(
                         for j in T.serial(0, 15):
                             with T.block():
                                 T.reads([A_local[i % 2, tx, i, j + 1]])
-                                T.writes([B[1 - j % 2, tx, i, 0]])
-                                B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, 
j + 1] * T.float32(2)
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 
0, j + 1] * T.float32(
+                                    2
+                                )
                             with T.block():
                                 T.reads([B[j % 2, tx, i, 0]])
                                 T.writes([C[tx, i, j]])
                                 C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
                     with T.block():
                         T.reads([A_shared[tx, 0, 0:16]])
-                        T.writes([A_local[1 - i % 2, 0, 0, 0:16]])
+                        T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
                         for j in T.serial(0, 16):
                             with T.block():
                                 T.reads([A_shared[tx, 0, j]])
-                                T.writes([A_local[1 - i % 2, 0, 0, j]])
+                                T.writes([A_local[(i + 1) % 2, 0, 0, j]])
                                 T.block_attr({"double_buffer_scope": 0})
-                                A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 
1, j]
+                                A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i 
+ 1, j]
                     with T.block():
-                        T.reads([A_local[1 - i % 2, tx, i + 1, 0]])
+                        T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
                         T.writes([B[0, tx, i + 1, 0]])
-                        B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * 
T.float32(2)
+                        B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * 
T.float32(2)
                     with T.block():
                         T.reads([B[1, tx, i, 0]])
                         T.writes([C[tx, i, 15]])
@@ -961,8 +963,8 @@ def transformed_nested_pipeline_double_buffer(
                     for j in T.serial(0, 15):
                         with T.block():
                             T.reads([A_local[1, tx, 15, j + 1]])
-                            T.writes([B[1 - j % 2, tx, 15, 0]])
-                            B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] 
* T.float32(2)
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 
1] * T.float32(2)
                         with T.block():
                             T.reads([B[j % 2, tx, 15, 0]])
                             T.writes([C[tx, 15, j]])

Reply via email to