This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eb7cf7051d Revert "support overlapped itersum (#12039)" (#12137)
eb7cf7051d is described below

commit eb7cf7051dd239ef86db13875687ef6d5ebb9418
Author: Florin Blanaru <[email protected]>
AuthorDate: Wed Jul 20 00:36:38 2022 +0100

    Revert "support overlapped itersum (#12039)" (#12137)
    
    This reverts commit 3e7a2ad9568a79fb775c0ca9d09a3fa2f51f792f.
---
 src/arith/iter_affine_map.cc                       | 91 ++++++----------------
 tests/python/unittest/test_arith_intset.py         |  7 +-
 .../python/unittest/test_arith_iter_affine_map.py  | 58 +-------------
 .../unittest/test_meta_schedule_space_cpu.py       | 26 +++----
 .../unittest/test_meta_schedule_space_cuda.py      | 12 +--
 tests/python/unittest/test_tir_schedule_reorder.py | 30 +------
 .../unittest/test_tir_schedule_split_fuse.py       |  8 +-
 .../test_tir_schedule_state_cached_flags.py        |  2 +-
 8 files changed, 58 insertions(+), 176 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 83e2821c98..d2aa16ded1 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -177,12 +177,8 @@ class IterMapRewriter : public ExprMutator {
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& 
input_iters,
-                           IterMapLevel check_level, bool 
simplify_trivial_iterators,
-                           Array<String>* errors)
-      : analyzer_(analyzer),
-        check_level_(check_level),
-        errors_(*errors),
-        padding_predicate_(const_false()) {
+                           bool simplify_trivial_iterators, Array<String>* 
errors)
+      : analyzer_(analyzer), errors_(*errors), 
padding_predicate_(const_false()) {
     for (auto kv : input_iters) {
       const Var& var = kv.first;
       const Range& vrng = kv.second;
@@ -423,8 +419,6 @@ class IterMapRewriter : public ExprMutator {
 
   // Internal analyzer
   Analyzer* analyzer_;
-  // Iter map check level
-  IterMapLevel check_level_;
   // Error messages for each unresolved expression.
   Array<String>& errors_;
   // The var map
@@ -657,7 +651,7 @@ class IterMapRewriter : public ExprMutator {
       if (predicate_induced_max.defined())
         predicate_induced_max = predicate_induced_max.value() - base;
     }
-    Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
+    Optional<IterSumExpr> opt = TryFuseIters(expr);
     ICHECK(!opt.defined() || opt.value()->args.size() == 1);
     // scale should be 1
     if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
@@ -708,7 +702,7 @@ class IterMapRewriter : public ExprMutator {
   IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
     // We are normalizing a regular iter
     if (expr->args.size() < 1) return expr;
-    Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
+    Optional<IterSumExpr> opt = TryFuseIters(expr);
     if (opt.defined()) {
       return opt.value();
     } else {
@@ -741,10 +735,9 @@ class IterMapRewriter : public ExprMutator {
    *    return a corresponding IterSumExpr with extra offset if needed.
    *    Try to normalize IterSum into a fused IterMark
    * \param expr The input sum.
-   * \param check_level The check level if iter mapping.
    * \return The sum with the fused IterMark and extra offset if succeed.
    */
-  Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel 
check_level) {
+  Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
     // select the iterators in order
     std::vector<bool> visited(expr->args.size(), false);
     std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -765,42 +758,14 @@ class IterMapRewriter : public ExprMutator {
     }
     // check if it can be remapped into a fused pattern.
     PrimExpr expected_extra_base = 0;
-    PrimExpr tail_extent = 0;
     PrimExpr expected_scale = base_scale.value();
     for (size_t i = 0; i < expr->args.size();) {
-      // find position such that expr->args[j] match expected scale
-      int j = i == 0 ? base_index : expr->args.size() - 1;
-
-      size_t matched_pos = expr->args.size();
-      PrimExpr matched_scale{nullptr};
-      bool is_exact_match{false};
-
-      for (; j >= 0; --j) {
-        if (visited[j]) {
-          continue;
-        }
-        const PrimExpr& cur_scale = expr->args[j]->scale;
-
-        // for bijective mapping, the matched scale must equal to expected 
scale
-        if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
-          matched_pos = j;
-          matched_scale = cur_scale;
-          is_exact_match = true;
-          break;
-        }
-        if (check_level != IterMapLevel::Bijective && 
base_scale.value()->value == 1) {
-          // find the closest scale which is less or equal to expected scale
-          if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
-              analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
-            if (matched_pos == expr->args.size() ||
-                analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
-              matched_pos = j;
-              matched_scale = cur_scale;
-            }
-          }
-        }
+      // find j such that expr->args[j] has expected scale
+      size_t j = i == 0 ? base_index : 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, 
expected_scale)) break;
       }
-      if (matched_pos == expr->args.size()) {
+      if (j == expr->args.size()) {
         return NullOpt;
       }
       // look for the longest constrained iter started from expr->args[j]
@@ -810,8 +775,8 @@ class IterMapRewriter : public ExprMutator {
       // otherwise we expect the scale of i to be 2*5=10
       Optional<IterSumExpr> constraint_to_match;
       for (const IterSumExpr& iter : constrained_iters_flattened_) {
-        if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) 
{
-          // find a predicate started from match position
+        if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
+          // find a predicate started from expr->args[j]
           if (!constraint_to_match ||
               constraint_to_match.value()->args.size() < iter->args.size()) {
             constraint_to_match = iter;
@@ -828,7 +793,7 @@ class IterMapRewriter : public ExprMutator {
           size_t k = 0;
           for (; k < expr->args.size(); ++k) {
             if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
-              if (analyzer_->CanProveEqual((*it)->scale * matched_scale, 
expr->args[k]->scale))
+              if (analyzer_->CanProveEqual((*it)->scale * expected_scale, 
expr->args[k]->scale))
                 break;
             }
           }
@@ -841,25 +806,20 @@ class IterMapRewriter : public ExprMutator {
         auto iter = sum_fuse_map_.find(constraint_to_match.value());
         ICHECK(iter != sum_fuse_map_.end());
         const IterMarkWithOffset& iter_matched = iter->second;
-        grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, 
base_scale.value()));
-        expected_extra_base += iter_matched.offset * matched_scale;
-        if (!is_exact_match) {
-          tail_extent += expected_scale - matched_scale;
-        }
-        expected_scale = matched_scale * iter_matched.mark->extent;
+        grouped_iters.emplace_back(iter_matched.mark, expected_scale);
+        expected_extra_base += iter_matched.offset * expected_scale;
+        expected_scale *= iter_matched.mark->extent;
         // move forward
         i += constraint_to_match.value()->args.size();
       } else {
         // constraint_to_match not found, skip this iterator
-        visited[matched_pos] = true;
-        IterSplitExpr arg = expr->args[matched_pos];
-        arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, 
base_scale.value()));
+        visited[j] = true;
+        IterSplitExpr arg = expr->args[j];
+        arg.CopyOnWrite()->scale =
+            analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
         flattened_iters.push_back(arg);
         grouped_iters.push_back(arg);
-        if (!is_exact_match) {
-          tail_extent += expected_scale - matched_scale;
-        }
-        expected_scale = matched_scale * expr->args[matched_pos]->extent;
+        expected_scale *= expr->args[j]->extent;
         ++i;
       }
     }
@@ -883,8 +843,7 @@ class IterMapRewriter : public ExprMutator {
                          expr->base + expected_extra_base);
     } else {
       // new iter, form a new mark
-      IterMark mark =
-          IterMark(structured_form, div(expected_scale, base_scale.value()) + 
tail_extent);
+      IterMark mark = IterMark(structured_form, div(expected_scale, 
base_scale.value()));
       sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
       flattened_map_[structured_form] = flattened_form;
       return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
@@ -1127,8 +1086,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& 
indices, const Map<Var, Range
       constraints.begin(), constraints.end(),
       [](const IterConstraint& a, const IterConstraint& b) { return 
a.expr_size < b.expr_size; });
 
-  IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
-                           simplify_trivial_iterators, &result->errors);
+  IterMapRewriter rewriter(analyzer, constrained_input_iters, 
simplify_trivial_iterators,
+                           &result->errors);
   // Step0.0: rewrite constraints in the order from size-small ones to 
size-big ones
   for (const IterConstraint& constraint : constraints) {
     auto res = rewriter.RewriteIterConstraint(constraint.iter, 
constraint.lower_bound,
@@ -1322,7 +1281,7 @@ IterSumExpr 
IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
     } else if (sum->args.size() == 1) {
       return sum;
     }
-    auto opt_fused = TryFuseIters(sum, check_level_);
+    auto opt_fused = TryFuseIters(sum);
     if (!opt_fused) {
       ErrorLogger(this) << "Dividend  " << tvm::PrettyPrint(original_dividend)
                         << ", can't be written as a single fused IterSum";
diff --git a/tests/python/unittest/test_arith_intset.py 
b/tests/python/unittest/test_arith_intset.py
index 74b53442ec..ca9d1077fe 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -323,6 +323,10 @@ def test_region_lower_bound_for_non_perfect_tile():
 
 
 def test_region_lower_bound_unfusable():
+    # This test is designed to trigger an error in DetectIterMap,
+    # resulting from a numerator which required multiple input
+    # variables.  The bug resulted in an exception being thrown,
+    # rather than a return value of None.
     var_dom = {
         tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
         tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
@@ -332,8 +336,7 @@ def test_region_lower_bound_unfusable():
         tvm.ir.Range.from_min_extent((i + j) // 2, 1),
     ]
     result = tvm.arith.estimate_region_lower_bound(region, var_dom, 
predicate=True)
-    assert result[0].min_value == 0
-    assert result[0].max_value == 5
+    assert result is None
 
 
 def test_union_lower_bound():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 6a2fdbbb3f..7bc5ead298 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -61,6 +61,7 @@ def assert_iter_sum_pattern(
     )
     indices = res.indices
     assert len(indices) == len(keys), res.errors
+    print(indices)
     for i, input_iter in enumerate(keys):
         spec = expect_dict[input_iter]
         (
@@ -445,13 +446,6 @@ def test_predicate():
         predicate=xo * 129 + xi < 128,
     )
 
-    # strided iteration predicate
-    assert_iter_sum_pattern(
-        {xo * 16 + xi * 4: (10, 0, 4)},
-        var_dom([(xo, 3), (xi, 4)]),
-        predicate=xo * 4 + xi < 10,
-    )
-
 
 def convert_division(divisions):
     if divisions is None or len(divisions) == 0:
@@ -1016,55 +1010,5 @@ def test_padding():
     assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
 
 
-def test_overlapped_fuse():
-    x = tvm.tir.Var("x", "int32")
-    y = tvm.tir.Var("y", "int32")
-    z = tvm.tir.Var("z", "int32")
-    a = tvm.tir.Var("x", "int32")
-    b = tvm.tir.Var("y", "int32")
-
-    # non-bijective fuse of two
-    assert_iter_sum_pattern(
-        {
-            x * 7 + y: (22, 0, 1),
-        },
-        var_dom([(x, 3), (y, 8)]),
-        check_level="surjective",
-    )
-    assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), 
check_level="bijective")
-
-    # non-bijective fuse of three
-    assert_iter_sum_pattern(
-        {
-            x * 18 + y * 7 + z: (40, 0, 1),
-        },
-        var_dom([(x, 2), (y, 3), (z, 8)]),
-        check_level="surjective",
-    )
-    assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), 
check_level="bijective")
-
-    # negative scale fusion is not allowed
-    assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), 
check_level="surjective")
-    assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), 
check_level="surjective")
-
-    # with predicate
-    assert_iter_sum_pattern(
-        {
-            a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
-        },
-        var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
-        predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
-        check_level="surjective",
-    )
-
-    # stride=1 kernel
-    assert_iter_sum_pattern(
-        {x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), 
check_level="surjective"
-    )
-
-    # do not allow both strided and overlapped
-    assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), 
check_level="surjective")
-
-
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py 
b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 7895fb376e..36f365e732 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -48,11 +48,11 @@ def test_cpu_c1d():
             for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 
1, 8):
                 for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 
in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
                     with T.block("conv1d_nlc"):
-                        n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
-                        l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + 
i1_2 * 2 + i1_3)
-                        co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 
+ i2_2)
+                        n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+                        l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + 
i1_2 * 2 + i1_3)
+                        co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + 
i2_2 + i2_3)
                         rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                        rc = T.axis.reduce(64, i4_1 + i4_0)
+                        rc = T.axis.reduce(64, i4_0 + i4_1)
                         T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], 
weight[rl, rc, co])
                         T.writes(conv1d_nlc_global[n, l, co])
                         
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -89,11 +89,11 @@ def test_cpu_c1d():
                             PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and 
i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
                     for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, 
i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
                         with T.block("conv1d_nlc"):
-                            n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
-                            l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + 
i1_2 * 2 + i1_3)
-                            co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 
8 + i2_2)
+                            n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                            l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + 
i1_2 * 2 + i1_3)
+                            co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + 
i2_2 + i2_3)
                             rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                            rc = T.axis.reduce(64, i4_1 + i4_0)
+                            rc = T.axis.reduce(64, i4_0 + i4_1)
                             T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + 
rc], weight[rl, rc, co])
                             T.writes(conv1d_nlc_global[n, l, co])
                             
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -107,7 +107,7 @@ def test_cpu_c1d():
                         T.reads(conv1d_nlc_global[v0, v1, v2])
                         T.writes(conv1d_nlc[v0, v1, v2])
                         conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]
-
+                        
     @T.prim_func
     def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 
64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
         # function attr dict
@@ -119,11 +119,11 @@ def test_cpu_c1d():
             T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
             for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, 
i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 
3, 1, 1, 2, 1):
                 with T.block("conv1d_nlc"):
-                    n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
-                    l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 
+ i1_3)
-                    co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + 
i2_2)
+                    n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                    l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 
+ i1_3)
+                    co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + 
i2_3)
                     rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                    rc = T.axis.reduce(64, i4_1 + i4_0)
+                    rc = T.axis.reduce(64, i4_0 + i4_1)
                     T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], 
weight[rl, rc, co])
                     T.writes(conv1d_nlc[n, l, co])
                     T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py 
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 86edb373ec..b8723e286a 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -47,7 +47,7 @@ def test_cuda_c1d():
                             for ax0_ax1_ax2_fused in T.serial(260):
                                 with T.block("PadInput_shared"):
                                     v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(258, 
i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
+                                    v1 = T.axis.spatial(258, 
i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
                                     v2 = T.axis.spatial(64, i4_0 * 4 + 
ax0_ax1_ax2_fused % 4)
                                     T.reads(inputs[v0, v1 - 1, v2])
                                     T.writes(PadInput_shared[v0, v1, v2])
@@ -64,11 +64,11 @@ def test_cuda_c1d():
                                     weight_shared[v0, v1, v2] = weight[v0, v1, 
v2]
                             for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, 
i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
                                 with T.block("conv1d_nlc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3)
-                                    l = T.axis.spatial(128, 
i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
-                                    co = T.axis.spatial(128, 
i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
-                                    rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 
+ i3_2)
-                                    rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 
+ i4_2)
+                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 
+ 0)
+                                    l = T.axis.spatial(128, 
(i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 
+ i1_4)
+                                    co = T.axis.spatial(128, (((0 * 2 + 
i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + 
i2_4)
+                                    rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + 
i3_2)
+                                    rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 
2 + i4_2)
                                     T.reads(PadInput_shared[n, l * 2 + rl, co 
// 128 * 64 + rc], weight_shared[rl, rc, co])
                                     T.writes(conv1d_nlc_local[n, l, co])
                                     
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, 
"meta_schedule.thread_extent_low_inclusive":32, 
"meta_schedule.tiling_structure":"SSSRRSRS"})
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py 
b/tests/python/unittest/test_tir_schedule_reorder.py
index b859b655ef..4351fe5b63 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -214,9 +214,9 @@ def test_reorder_with_opaque_access():
     verify_trace_roundtrip(sch=sch, mod=opaque_access)
 
 
-def test_reorder_overlapped_access():
+def test_reorder_with_partial_affineness():
     @T.prim_func
-    def overlapped_access(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 
4), "float32"]):
+    def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), 
"float32"]):
         # example to write first axis multiple times
         for v0, v1, v2 in T.grid(6, 4, 4):
             with T.block("block"):
@@ -225,7 +225,7 @@ def test_reorder_overlapped_access():
                 B[i, j] = A[i, j] + 1.0
 
     @T.prim_func
-    def overlapped_access_reorder(A: T.Buffer[(14, 4), "float32"], B: 
T.Buffer[(14, 4), "float32"]):
+    def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: 
T.Buffer[(14, 4), "float32"]):
         # example to write first axis multiple times
         for v0, v2, v1 in T.grid(6, 4, 4):
             with T.block("block"):
@@ -233,30 +233,6 @@ def test_reorder_overlapped_access():
                 j = T.axis.spatial(4, v2)
                 B[i, j] = A[i, j] + 1.0
 
-    sch = tir.Schedule(overlapped_access, debug_mask="all")
-    v0, v1, v2 = sch.get_loops(sch.get_block("block"))
-    sch.reorder(v0, v2, v1)
-    tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"])
-    verify_trace_roundtrip(sch=sch, mod=overlapped_access)
-
-
-def test_reorder_with_partial_affineness():
-    @T.prim_func
-    def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), 
"float32"]):
-        for v0, v1, v2 in T.grid(6, 4, 4):
-            with T.block("block"):
-                i = T.axis.spatial(14, v0 * v0 + v1)
-                j = T.axis.spatial(4, v2)
-                B[i, j] = A[i, j] + 1.0
-
-    @T.prim_func
-    def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: 
T.Buffer[(14, 4), "float32"]):
-        for v0, v2, v1 in T.grid(6, 4, 4):
-            with T.block("block"):
-                i = T.axis.spatial(14, v0 * v0 + v1)
-                j = T.axis.spatial(4, v2)
-                B[i, j] = A[i, j] + 1.0
-
     sch = tir.Schedule(non_affine_func, debug_mask="all")
     v0, v1, v2 = sch.get_loops(sch.get_block("block"))
     with pytest.raises(tvm.tir.ScheduleError):
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py 
b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 3ae88e0abb..9fd678174d 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -177,7 +177,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> 
None:
     B = T.match_buffer(b, [128, 128, 128])
     for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8):
         with T.block("B"):
-            vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
+            vi = T.axis.S(128, (i1 + i2) * 64 + i3)
             vj = T.axis.S(128, j1 * 32 + j2)
             vk = T.axis.S(128, k1 * 8 + k2)
             T.reads([A[vi, vj, vk]])
@@ -191,9 +191,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> 
None:
     B = T.match_buffer(b, [128, 128, 128])
     for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 
64):
         with T.block("B"):
-            vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
-            vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3)
-            vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3)
+            vi = T.axis.S(128, (i1 + i2) * 64 + i3)
+            vj = T.axis.S(128, (j1 + j2) * 64 + j3)
+            vk = T.axis.S(128, (k1 + k2) * 64 + k3)
             T.reads([A[vi, vj, vk]])
             T.writes([B[vi, vj, vk]])
             B[vi, vj, vk] = A[vi, vj, vk] * 2.0
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py 
b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index bbeb8d8760..1b4c34973f 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -758,7 +758,7 @@ def test_non_perfect_tiling_cache():
     s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
     # pylint: disable=protected-access
     assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
-        affine_binding=True,
+        affine_binding=False,
         region_cover=True,
         stage_pipeline=True,
     )

Reply via email to