This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b1594c [BugFix] Fix NeedsMultiLevelTiling by skipping trivial block
iterators (#10804)
6b1594c is described below
commit 6b1594c852154882d41b15d7ceba2c3a7a6322c7
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Mar 29 11:31:16 2022 +0800
[BugFix] Fix NeedsMultiLevelTiling by skipping trivial block iterators
(#10804)
This PR fixes a bug of `NeedsMultiLevelTiling`, which didn't consider the
effect of trivial block iterators (iterators whose domains are `[0, 1)`). Such
iterators impacts the following analysis by overlargely counting the number of
iterators that are not used to index the block read regions, and might lead to
the application of multi-level tiling where the rule is supposed not to apply.
To fix the problem, we simply skip such trivial block iterators.
---
src/tir/schedule/analysis/analysis.cc | 9 +++++--
...ta_schedule_schedule_rule_multi_level_tiling.py | 31 +++++++++++++++++++++-
2 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index f3aa250..4358704 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1834,11 +1834,16 @@ bool NeedsMultiLevelTiling(const ScheduleState& self,
const StmtSRef& block_sref
return false;
}
const BufferNode* write_buffer = block->writes[0]->buffer.get();
- // Step 1. Sort out spatial block variables
+ // Step 1. Sort out spatial block variables. Skip the block iters of domain
[0, 1), since such
+ // block iters distracts the following check of the unused block iters.
std::vector<const VarNode*> spatial_block_vars;
spatial_block_vars.reserve(block->iter_vars.size());
for (const IterVar& block_var : block->iter_vars) {
- if (block_var->iter_type == IterVarType::kDataPar) {
+ const int64_t* dom_min = as_const_int(block_var->dom->min);
+ const int64_t* dom_extent = as_const_int(block_var->dom->extent);
+ bool has_trivial_dom =
+ dom_min != nullptr && dom_extent != nullptr && *dom_min == 0 &&
*dom_extent == 1;
+ if (block_var->iter_type == IterVarType::kDataPar && !has_trivial_dom) {
spatial_block_vars.push_back(block_var->var.get());
}
}
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 52218e6..555a1a8 100644
---
a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++
b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -17,13 +17,14 @@
# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.schedule_rule import (
multi_level_tiling,
)
from tvm.meta_schedule.testing.space_generation import check_trace
from tvm.meta_schedule.tune_context import TuneContext
+from tvm.script import tir as T
from tvm.te import create_prim_func
-from tvm.meta_schedule.testing import te_workload
from tvm.target import Target
@@ -273,8 +274,36 @@ def test_cuda_matmul_relu():
check_trace(spaces, expected)
+def test_cuda_sum_with_trivial_block_iter():
+ @T.prim_func
+ def sum_with_trivial_block_iter(
+ A: T.Buffer[(1, 64, 768), "float32"], B: T.Buffer[(1, 64, 1),
"float32"]
+ ) -> None:
+ for i0, i1, i2, i3 in T.grid(1, 64, 1, 768):
+ with T.block("sum"):
+ ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3])
+ T.reads(A[ax0, ax1, k2])
+ T.writes(B[ax0, ax1, ax2])
+ with T.init():
+ B[ax0, ax1, ax2] = T.float32(0)
+ B[ax0, ax1, ax2] = B[ax0, ax1, ax2] + A[ax0, ax1, k2]
+
+ # Expect nothing to happen - the rule is not supposed to be applied in
this case
+ expected = [[]]
+ target = Target("cuda", host="llvm")
+ ctx = _create_context(
+ sum_with_trivial_block_iter,
+ target=target,
+ rule=multi_level_tiling(target=target),
+ )
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 1
+ check_trace(spaces, expected)
+
+
if __name__ == "__main__":
test_cpu_matmul()
test_cpu_matmul_relu()
test_cuda_matmul()
test_cuda_matmul_relu()
+ test_cuda_sum_with_trivial_block_iter()