This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f28fcd1239 [TensorIR] Fix ComputeAt with perfect symbolic bound
(#14592)
f28fcd1239 is described below
commit f28fcd1239315059b66c3b61badc99ffc8181c92
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 11 18:01:08 2023 -0400
[TensorIR] Fix ComputeAt with perfect symbolic bound (#14592)
This PR fixes a case where we have perfect symbolic bound
in compute at and can be eliminated.
Testcases are added.
---
src/tir/schedule/primitive/compute_at.cc | 9 +++--
.../unittest/test_tir_schedule_compute_at.py | 41 ++++++++++++++++++++++
2 files changed, 48 insertions(+), 2 deletions(-)
diff --git a/src/tir/schedule/primitive/compute_at.cc
b/src/tir/schedule/primitive/compute_at.cc
index 75ea308de8..b161bf954d 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -271,13 +271,18 @@ class ScopeReconstructor : private StmtMutator {
}
const arith::IntSet& pred_bound = iter_doms[i].bound;
if (!pred_bound.IsNothing()) {
+ // NOTE: Apply strong analyzer proofs to get rid of symbolic bound
if (pred_bound.HasLowerBound()) {
PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
- predicate = predicate && lower_bound;
+ if (!analyzer->CanProve(lower_bound,
arith::ProofStrength::kSymbolicBound)) {
+ predicate = predicate && lower_bound;
+ }
}
if (pred_bound.HasUpperBound()) {
PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
- predicate = predicate && upper_bound;
+ if (!analyzer->CanProve(upper_bound,
arith::ProofStrength::kSymbolicBound)) {
+ predicate = predicate && upper_bound;
+ }
}
}
}
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py
b/tests/python/unittest/test_tir_schedule_compute_at.py
index 364a43acda..0623fb02f3 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1282,6 +1282,47 @@ def
test_compute_at_simplify_static_bound(use_block_name):
verify_trace_roundtrip(sch=sch, mod=static_bound)
+def test_compute_at_simplify_symbolic_predicate():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * 32), "float32")
+ Y = T.match_buffer(y, (T.int64(8), n * 32), "float32")
+ for i, k in T.grid(T.int64(8), n * 32):
+ with T.block("Y"):
+ vi, vk = T.axis.remap("SS", [i, k])
+ Y[vi, vk] = X[vi, vk]
+
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+
+ for i, k_0 in T.grid(T.int64(8), n):
+ for ax0 in range(T.int64(32)):
+ with T.block("X_global"):
+ v0 = T.axis.spatial(T.int64(8), i)
+ v1 = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32)
+ ax0)
+ X_global[v0, v1] = X[v0, v1]
+ for k_1 in range(T.int64(32)):
+ with T.block("Y"):
+ vi = T.axis.spatial(T.int64(8), i)
+ vk = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32)
+ k_1)
+ Y[vi, vk] = X_global[vi, vk]
+
+ sch = tir.Schedule(Before, debug_mask="all")
+ block = sch.get_block("Y")
+ i, k = sch.get_loops(sch.get_block("Y"))
+ ko, ki = sch.split(k, [None, 32])
+ XX = sch.cache_read(block, 0, "global")
+ sch.compute_at(XX, ko)
+ tvm.ir.assert_structural_equal(sch.mod, After)
+
+
def test_compute_at_non_perfect_channel_group(use_block_name):
@T.prim_func
def grouped_channel_bias(