This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6a6093bc18 fold const or empty iter partition (#12080)
6a6093bc18 is described below
commit 6a6093bc180ed762b3e0d19eb37fcf10d97289c1
Author: wrongtest <[email protected]>
AuthorDate: Wed Jul 13 22:52:35 2022 +0800
fold const or empty iter partition (#12080)
---
src/tir/transforms/loop_partition.cc | 30 ++++++++++++++++--------------
1 file changed, 16 insertions(+), 14 deletions(-)
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index 59ac339006..677506889e 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -587,16 +587,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var
var, PrimExpr min, Prim
if (middle_interval_i->HasLowerBound()) {
body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
- PrimExpr cond = (body_begin - min >= 0);
- if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre
doubt loop";
- body_begin = Max(body_begin, min);
+ PrimExpr extent = analyzer_.Simplify(body_begin - min);
+ if (!analyzer_.CanProve(extent > 0)) {
+ body_begin = tvm::max(body_begin, min);
// stop recursing on this interval if we can't prove it has
non-negative length
pre_stmt_recurse = false;
}
- if (!partition_thread_scope) {
- Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
- pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
+ if (!analyzer_.CanProve(extent <= 0)) {
+ if (!partition_thread_scope) {
+ Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+ pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
+ }
}
}
} else {
@@ -612,16 +613,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var
var, PrimExpr min, Prim
post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
- PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
- if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond << ", when generating the
post doubt loop";
- post_doubt_begin = Min(post_doubt_begin, max + 1);
+ PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1);
+ if (!analyzer_.CanProve(extent > 0)) {
+ post_doubt_begin = tvm::min(post_doubt_begin, max + 1);
// stop recursing on this interval if we can't prove it has
non-negative length
post_stmt_recurse = false;
}
- if (!partition_thread_scope) {
- Stmt post_body = Substitute(body, {{Var{var}, var +
post_doubt_begin}});
- post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
+ if (!analyzer_.CanProve(extent <= 0)) {
+ if (!partition_thread_scope) {
+ Stmt post_body = Substitute(body, {{Var{var}, var +
post_doubt_begin}});
+ post_stmt = MakeFor(stmt.get(), extent, post_body);
+ }
}
}
} else {