This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f28fcd1239 [TensorIR] Fix ComputeAt with perfect symbolic bound 
(#14592)
f28fcd1239 is described below

commit f28fcd1239315059b66c3b61badc99ffc8181c92
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 11 18:01:08 2023 -0400

    [TensorIR] Fix ComputeAt with perfect symbolic bound (#14592)
    
    This PR fixes a case where we have perfect symbolic bound
    in compute at and can be eliminated.
    
    Testcases are added.
---
 src/tir/schedule/primitive/compute_at.cc           |  9 +++--
 .../unittest/test_tir_schedule_compute_at.py       | 41 ++++++++++++++++++++++
 2 files changed, 48 insertions(+), 2 deletions(-)

diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 75ea308de8..b161bf954d 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -271,13 +271,18 @@ class ScopeReconstructor : private StmtMutator {
       }
       const arith::IntSet& pred_bound = iter_doms[i].bound;
       if (!pred_bound.IsNothing()) {
+        // NOTE: Apply strong analyzer proofs to get rid of symbolic bound
         if (pred_bound.HasLowerBound()) {
           PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
-          predicate = predicate && lower_bound;
+          if (!analyzer->CanProve(lower_bound, 
arith::ProofStrength::kSymbolicBound)) {
+            predicate = predicate && lower_bound;
+          }
         }
         if (pred_bound.HasUpperBound()) {
           PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
-          predicate = predicate && upper_bound;
+          if (!analyzer->CanProve(upper_bound, 
arith::ProofStrength::kSymbolicBound)) {
+            predicate = predicate && upper_bound;
+          }
         }
       }
     }
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py 
b/tests/python/unittest/test_tir_schedule_compute_at.py
index 364a43acda..0623fb02f3 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1282,6 +1282,47 @@ def 
test_compute_at_simplify_static_bound(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=static_bound)
 
 
+def test_compute_at_simplify_symbolic_predicate():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * 32), "float32")
+            Y = T.match_buffer(y, (T.int64(8), n * 32), "float32")
+            for i, k in T.grid(T.int64(8), n * 32):
+                with T.block("Y"):
+                    vi, vk = T.axis.remap("SS", [i, k])
+                    Y[vi, vk] = X[vi, vk]
+
+    @tvm.script.ir_module
+    class After:
+        @T.prim_func
+        def main(x: T.handle, y: T.handle, n: T.int64):
+            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+            X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+
+            for i, k_0 in T.grid(T.int64(8), n):
+                for ax0 in range(T.int64(32)):
+                    with T.block("X_global"):
+                        v0 = T.axis.spatial(T.int64(8), i)
+                        v1 = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) 
+ ax0)
+                        X_global[v0, v1] = X[v0, v1]
+                for k_1 in range(T.int64(32)):
+                    with T.block("Y"):
+                        vi = T.axis.spatial(T.int64(8), i)
+                        vk = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) 
+ k_1)
+                        Y[vi, vk] = X_global[vi, vk]
+
+    sch = tir.Schedule(Before, debug_mask="all")
+    block = sch.get_block("Y")
+    i, k = sch.get_loops(sch.get_block("Y"))
+    ko, ki = sch.split(k, [None, 32])
+    XX = sch.cache_read(block, 0, "global")
+    sch.compute_at(XX, ko)
+    tvm.ir.assert_structural_equal(sch.mod, After)
+
+
 def test_compute_at_non_perfect_channel_group(use_block_name):
     @T.prim_func
     def grouped_channel_bias(

Reply via email to