This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eb7cf7051d Revert "support overlapped itersum (#12039)" (#12137)
eb7cf7051d is described below
commit eb7cf7051dd239ef86db13875687ef6d5ebb9418
Author: Florin Blanaru <[email protected]>
AuthorDate: Wed Jul 20 00:36:38 2022 +0100
Revert "support overlapped itersum (#12039)" (#12137)
This reverts commit 3e7a2ad9568a79fb775c0ca9d09a3fa2f51f792f.
---
src/arith/iter_affine_map.cc | 91 ++++++----------------
tests/python/unittest/test_arith_intset.py | 7 +-
.../python/unittest/test_arith_iter_affine_map.py | 58 +-------------
.../unittest/test_meta_schedule_space_cpu.py | 26 +++----
.../unittest/test_meta_schedule_space_cuda.py | 12 +--
tests/python/unittest/test_tir_schedule_reorder.py | 30 +------
.../unittest/test_tir_schedule_split_fuse.py | 8 +-
.../test_tir_schedule_state_cached_flags.py | 2 +-
8 files changed, 58 insertions(+), 176 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 83e2821c98..d2aa16ded1 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -177,12 +177,8 @@ class IterMapRewriter : public ExprMutator {
using Parent = ExprMutator;
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>&
input_iters,
- IterMapLevel check_level, bool
simplify_trivial_iterators,
- Array<String>* errors)
- : analyzer_(analyzer),
- check_level_(check_level),
- errors_(*errors),
- padding_predicate_(const_false()) {
+ bool simplify_trivial_iterators, Array<String>*
errors)
+ : analyzer_(analyzer), errors_(*errors),
padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
@@ -423,8 +419,6 @@ class IterMapRewriter : public ExprMutator {
// Internal analyzer
Analyzer* analyzer_;
- // Iter map check level
- IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
@@ -657,7 +651,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
- Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
+ Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
@@ -708,7 +702,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
- Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
+ Optional<IterSumExpr> opt = TryFuseIters(expr);
if (opt.defined()) {
return opt.value();
} else {
@@ -741,10 +735,9 @@ class IterMapRewriter : public ExprMutator {
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
- * \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
- Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel
check_level) {
+ Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -765,42 +758,14 @@ class IterMapRewriter : public ExprMutator {
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
- PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
- // find position such that expr->args[j] match expected scale
- int j = i == 0 ? base_index : expr->args.size() - 1;
-
- size_t matched_pos = expr->args.size();
- PrimExpr matched_scale{nullptr};
- bool is_exact_match{false};
-
- for (; j >= 0; --j) {
- if (visited[j]) {
- continue;
- }
- const PrimExpr& cur_scale = expr->args[j]->scale;
-
- // for bijective mapping, the matched scale must equal to expected
scale
- if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
- matched_pos = j;
- matched_scale = cur_scale;
- is_exact_match = true;
- break;
- }
- if (check_level != IterMapLevel::Bijective &&
base_scale.value()->value == 1) {
- // find the closest scale which is less or equal to expected scale
- if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
- analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
- if (matched_pos == expr->args.size() ||
- analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
- matched_pos = j;
- matched_scale = cur_scale;
- }
- }
- }
+ // find j such that expr->args[j] has expected scale
+ size_t j = i == 0 ? base_index : 0;
+ for (; j < expr->args.size(); ++j) {
+ if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale,
expected_scale)) break;
}
- if (matched_pos == expr->args.size()) {
+ if (j == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
@@ -810,8 +775,8 @@ class IterMapRewriter : public ExprMutator {
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
- if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false))
{
- // find a predicate started from match position
+ if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
+ // find a predicate started from expr->args[j]
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
@@ -828,7 +793,7 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
- if (analyzer_->CanProveEqual((*it)->scale * matched_scale,
expr->args[k]->scale))
+ if (analyzer_->CanProveEqual((*it)->scale * expected_scale,
expr->args[k]->scale))
break;
}
}
@@ -841,25 +806,20 @@ class IterMapRewriter : public ExprMutator {
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
- grouped_iters.emplace_back(iter_matched.mark, div(matched_scale,
base_scale.value()));
- expected_extra_base += iter_matched.offset * matched_scale;
- if (!is_exact_match) {
- tail_extent += expected_scale - matched_scale;
- }
- expected_scale = matched_scale * iter_matched.mark->extent;
+ grouped_iters.emplace_back(iter_matched.mark, expected_scale);
+ expected_extra_base += iter_matched.offset * expected_scale;
+ expected_scale *= iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
- visited[matched_pos] = true;
- IterSplitExpr arg = expr->args[matched_pos];
- arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale,
base_scale.value()));
+ visited[j] = true;
+ IterSplitExpr arg = expr->args[j];
+ arg.CopyOnWrite()->scale =
+ analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
- if (!is_exact_match) {
- tail_extent += expected_scale - matched_scale;
- }
- expected_scale = matched_scale * expr->args[matched_pos]->extent;
+ expected_scale *= expr->args[j]->extent;
++i;
}
}
@@ -883,8 +843,7 @@ class IterMapRewriter : public ExprMutator {
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
- IterMark mark =
- IterMark(structured_form, div(expected_scale, base_scale.value()) +
tail_extent);
+ IterMark mark = IterMark(structured_form, div(expected_scale,
base_scale.value()));
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
@@ -1127,8 +1086,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>&
indices, const Map<Var, Range
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return
a.expr_size < b.expr_size; });
- IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
- simplify_trivial_iterators, &result->errors);
+ IterMapRewriter rewriter(analyzer, constrained_input_iters,
simplify_trivial_iterators,
+ &result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to
size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter,
constraint.lower_bound,
@@ -1322,7 +1281,7 @@ IterSumExpr
IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
} else if (sum->args.size() == 1) {
return sum;
}
- auto opt_fused = TryFuseIters(sum, check_level_);
+ auto opt_fused = TryFuseIters(sum);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
diff --git a/tests/python/unittest/test_arith_intset.py
b/tests/python/unittest/test_arith_intset.py
index 74b53442ec..ca9d1077fe 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -323,6 +323,10 @@ def test_region_lower_bound_for_non_perfect_tile():
def test_region_lower_bound_unfusable():
+ # This test is designed to trigger an error in DetectIterMap,
+ # resulting from a numerator which required multiple input
+ # variables. The bug resulted in an exception being thrown,
+ # rather than a return value of None.
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
@@ -332,8 +336,7 @@ def test_region_lower_bound_unfusable():
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
]
result = tvm.arith.estimate_region_lower_bound(region, var_dom,
predicate=True)
- assert result[0].min_value == 0
- assert result[0].max_value == 5
+ assert result is None
def test_union_lower_bound():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 6a2fdbbb3f..7bc5ead298 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -61,6 +61,7 @@ def assert_iter_sum_pattern(
)
indices = res.indices
assert len(indices) == len(keys), res.errors
+ print(indices)
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
@@ -445,13 +446,6 @@ def test_predicate():
predicate=xo * 129 + xi < 128,
)
- # strided iteration predicate
- assert_iter_sum_pattern(
- {xo * 16 + xi * 4: (10, 0, 4)},
- var_dom([(xo, 3), (xi, 4)]),
- predicate=xo * 4 + xi < 10,
- )
-
def convert_division(divisions):
if divisions is None or len(divisions) == 0:
@@ -1016,55 +1010,5 @@ def test_padding():
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
-def test_overlapped_fuse():
- x = tvm.tir.Var("x", "int32")
- y = tvm.tir.Var("y", "int32")
- z = tvm.tir.Var("z", "int32")
- a = tvm.tir.Var("x", "int32")
- b = tvm.tir.Var("y", "int32")
-
- # non-bijective fuse of two
- assert_iter_sum_pattern(
- {
- x * 7 + y: (22, 0, 1),
- },
- var_dom([(x, 3), (y, 8)]),
- check_level="surjective",
- )
- assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]),
check_level="bijective")
-
- # non-bijective fuse of three
- assert_iter_sum_pattern(
- {
- x * 18 + y * 7 + z: (40, 0, 1),
- },
- var_dom([(x, 2), (y, 3), (z, 8)]),
- check_level="surjective",
- )
- assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="bijective")
-
- # negative scale fusion is not allowed
- assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]),
check_level="surjective")
- assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]),
check_level="surjective")
-
- # with predicate
- assert_iter_sum_pattern(
- {
- a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
- },
- var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
- predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
- check_level="surjective",
- )
-
- # stride=1 kernel
- assert_iter_sum_pattern(
- {x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]),
check_level="surjective"
- )
-
- # do not allow both strided and overlapped
- assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]),
check_level="surjective")
-
-
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py
b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 7895fb376e..36f365e732 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -48,11 +48,11 @@ def test_cpu_c1d():
for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1,
1, 8):
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3
in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
- l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 +
i1_2 * 2 + i1_3)
- co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8
+ i2_2)
+ n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+ l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 +
i1_2 * 2 + i1_3)
+ co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 +
i2_2 + i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_1 + i4_0)
+ rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc],
weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -89,11 +89,11 @@ def test_cpu_c1d():
PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and
i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3,
i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
- l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 +
i1_2 * 2 + i1_3)
- co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 *
8 + i2_2)
+ n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+ l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 +
i1_2 * 2 + i1_3)
+ co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 +
i2_2 + i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_1 + i4_0)
+ rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 +
rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -107,7 +107,7 @@ def test_cpu_c1d():
T.reads(conv1d_nlc_global[v0, v1, v2])
T.writes(conv1d_nlc[v0, v1, v2])
conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]
-
+
@T.prim_func
def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3,
64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
@@ -119,11 +119,11 @@ def test_cpu_c1d():
T.block_attr({"meta_schedule.parallel":288,
"meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2,
i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8,
3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
- l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2
+ i1_3)
- co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 +
i2_2)
+ n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+ l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2
+ i1_3)
+ co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 +
i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_1 + i4_0)
+ rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc],
weight[rl, rc, co])
T.writes(conv1d_nlc[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 86edb373ec..b8723e286a 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -47,7 +47,7 @@ def test_cuda_c1d():
for ax0_ax1_ax2_fused in T.serial(260):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(258,
i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
+ v1 = T.axis.spatial(258,
i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
v2 = T.axis.spatial(64, i4_0 * 4 +
ax0_ax1_ax2_fused % 4)
T.reads(inputs[v0, v1 - 1, v2])
T.writes(PadInput_shared[v0, v1, v2])
@@ -64,11 +64,11 @@ def test_cuda_c1d():
weight_shared[v0, v1, v2] = weight[v0, v1,
v2]
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2,
i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_4 + i0_3)
- l = T.axis.spatial(128,
i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
- co = T.axis.spatial(128,
i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
- rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3
+ i3_2)
- rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2
+ i4_2)
+ n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0
+ 0)
+ l = T.axis.spatial(128,
(i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4
+ i1_4)
+ co = T.axis.spatial(128, (((0 * 2 +
i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 +
i2_4)
+ rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 +
i3_2)
+ rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) *
2 + i4_2)
T.reads(PadInput_shared[n, l * 2 + rl, co
// 128 * 64 + rc], weight_shared[rl, rc, co])
T.writes(conv1d_nlc_local[n, l, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024,
"meta_schedule.thread_extent_low_inclusive":32,
"meta_schedule.tiling_structure":"SSSRRSRS"})
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py
b/tests/python/unittest/test_tir_schedule_reorder.py
index b859b655ef..4351fe5b63 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -214,9 +214,9 @@ def test_reorder_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)
-def test_reorder_overlapped_access():
+def test_reorder_with_partial_affineness():
@T.prim_func
- def overlapped_access(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14,
4), "float32"]):
+ def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4),
"float32"]):
# example to write first axis multiple times
for v0, v1, v2 in T.grid(6, 4, 4):
with T.block("block"):
@@ -225,7 +225,7 @@ def test_reorder_overlapped_access():
B[i, j] = A[i, j] + 1.0
@T.prim_func
- def overlapped_access_reorder(A: T.Buffer[(14, 4), "float32"], B:
T.Buffer[(14, 4), "float32"]):
+ def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B:
T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v2, v1 in T.grid(6, 4, 4):
with T.block("block"):
@@ -233,30 +233,6 @@ def test_reorder_overlapped_access():
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0
- sch = tir.Schedule(overlapped_access, debug_mask="all")
- v0, v1, v2 = sch.get_loops(sch.get_block("block"))
- sch.reorder(v0, v2, v1)
- tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"])
- verify_trace_roundtrip(sch=sch, mod=overlapped_access)
-
-
-def test_reorder_with_partial_affineness():
- @T.prim_func
- def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4),
"float32"]):
- for v0, v1, v2 in T.grid(6, 4, 4):
- with T.block("block"):
- i = T.axis.spatial(14, v0 * v0 + v1)
- j = T.axis.spatial(4, v2)
- B[i, j] = A[i, j] + 1.0
-
- @T.prim_func
- def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B:
T.Buffer[(14, 4), "float32"]):
- for v0, v2, v1 in T.grid(6, 4, 4):
- with T.block("block"):
- i = T.axis.spatial(14, v0 * v0 + v1)
- j = T.axis.spatial(4, v2)
- B[i, j] = A[i, j] + 1.0
-
sch = tir.Schedule(non_affine_func, debug_mask="all")
v0, v1, v2 = sch.get_loops(sch.get_block("block"))
with pytest.raises(tvm.tir.ScheduleError):
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py
b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 3ae88e0abb..9fd678174d 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -177,7 +177,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) ->
None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8):
with T.block("B"):
- vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
+ vi = T.axis.S(128, (i1 + i2) * 64 + i3)
vj = T.axis.S(128, j1 * 32 + j2)
vk = T.axis.S(128, k1 * 8 + k2)
T.reads([A[vi, vj, vk]])
@@ -191,9 +191,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) ->
None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1,
64):
with T.block("B"):
- vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
- vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3)
- vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3)
+ vi = T.axis.S(128, (i1 + i2) * 64 + i3)
+ vj = T.axis.S(128, (j1 + j2) * 64 + j3)
+ vk = T.axis.S(128, (k1 + k2) * 64 + k3)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index bbeb8d8760..1b4c34973f 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -758,7 +758,7 @@ def test_non_perfect_tiling_cache():
s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
# pylint: disable=protected-access
assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
- affine_binding=True,
+ affine_binding=False,
region_cover=True,
stage_pipeline=True,
)