This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 03ef29e8ad [TIR][Schedule] Derive Nonnegative Bounds from Shape Var 
(#15210)
03ef29e8ad is described below

commit 03ef29e8ad9c3073881422db3f4f935a9538370d
Author: Junru Shao <[email protected]>
AuthorDate: Mon Jul 3 08:08:11 2023 -0700

    [TIR][Schedule] Derive Nonnegative Bounds from Shape Var (#15210)
    
    This PR enhance the arithmetic analysis used in compute-at to further
    help symbolic bound simplification.
    
    Previously, when a variable `n` appears in the shape of an input buffer
    `T.Buffer((n * 32), "float32")`, we could safely assume that `n` is
    nonnegative as it is part of the shape. This could help us simplify some
    bounds during scheduling as well as lowering.
    
    For example, for integers `n` and `bx` where `bx` has a symbolic bound
    `[0, 32 * n)`, if `n` is nonnegative, we could simplify the following
    expressions to True:
    
    ```
    0 <= floordiv(bx, n) < 32
    0 <= floormod(bx, n) < n
    ```
    
    This PR depends on #15193 to provide an interface that hints analyzer.
---
 src/tir/schedule/primitive/compute_at.cc           | 15 +++++
 .../unittest/test_tir_schedule_compute_at.py       | 69 ++++++++++++++++++++++
 2 files changed, 84 insertions(+)

diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 8210274960..45d0c81050 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -680,6 +680,20 @@ void CalculateProvidedRequiredRegions(
 
 /******** Main Implementation ********/
 
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+                       arith::Analyzer* analyzer) {
+  while (sref->parent != nullptr) {
+    sref = sref->parent;
+  }
+  const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
+  for (const auto& kv : f->buffer_map) {
+    const Buffer& buffer = kv.second;
+    for (const PrimExpr& e : buffer->shape) {
+      analyzer->MarkGlobalNonNegValue(e);
+    }
+  }
+}
+
 template <bool is_compute_at>
 void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& 
block_sref,
                                      const StmtSRef& loop_sref, bool 
preserve_unit_loops,
@@ -692,6 +706,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, 
const StmtSRef& block_s
   StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
                                           /*require_stage_pipeline=*/true);
   Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
+  AddShapeVarBounds(self, scope_root_sref.get(), analyzer);
   BlockScope scope = self->GetBlockScope(scope_root_sref);
   Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
   Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py 
b/tests/python/unittest/test_tir_schedule_compute_at.py
index a1b4cf1559..2e44776a0f 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1823,5 +1823,74 @@ def test_compute_inline_allocate_const(use_decl_buffer):
     verify_trace_roundtrip(sch=sch, mod=before)
 
 
+def test_shape_var_as_bound():
+    # fmt: off
+    @T.prim_func
+    def before(a: T.handle, b: T.handle, c: T.handle):
+        n = T.int32()
+        A = T.match_buffer(a, (32, 1, 128))
+        B = T.match_buffer(b, (32, n, 128))
+        C = T.match_buffer(c, (32, 1, n))
+        # with T.block("root"):
+        C_rf = T.alloc_buffer((128, 32, 1, n))
+        for ax0_ax1_fused, ax2_fused_1, ax2_fused_0 in T.grid(n * 32, 128, 1):
+            with T.block("NT_matmul_rf"):
+                vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
+                v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+                v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
+                T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1, 
vax2_fused_0 * 128 + vax2_fused_1])
+                T.writes(C_rf[vax2_fused_1, v0, 0, v1])
+                with T.init():
+                    C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
+                C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] 
+ A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + 
vax2_fused_1]
+        for ax0_ax1_fused, ax2_fused_1 in T.grid(n * 32, 128):
+            with T.block("NT_matmul"):
+                vax2_fused_1 = T.axis.reduce(128, ax2_fused_1)
+                v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+                v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                T.reads(C_rf[vax2_fused_1, v0, 0, v1])
+                T.writes(C[v0, 0, v1])
+                with T.init():
+                    C[v0, 0, v1] = T.float32(0)
+                C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]
+
+    @T.prim_func
+    def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: 
T.handle):
+        n = T.int32()
+        B = T.match_buffer(b, (32, n, 128))
+        C = T.match_buffer(c, (32, 1, n))
+        # with T.block("root"):
+        C_rf = T.alloc_buffer((128, 32, 1, n))
+        for ax0_ax1_fused in range(n * 32):
+            for ax2_fused_1, ax2_fused_0 in T.grid(128, 1):
+                with T.block("NT_matmul_rf"):
+                    vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
+                    v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
+                    T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, 
v1, vax2_fused_0 * 128 + vax2_fused_1])
+                    T.writes(C_rf[vax2_fused_1, v0, 0, v1])
+                    with T.init():
+                        C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
+                    C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, 
v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 
128 + vax2_fused_1]
+            for ax0, ax1, ax2 in T.grid(128, 1, 1):
+                with T.block("NT_matmul"):
+                    vax2_fused_1 = T.axis.reduce(128, ax0)
+                    v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax1)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax2)
+                    T.reads(C_rf[vax2_fused_1, v0, 0, v1])
+                    T.writes(C[v0, 0, v1])
+                    with T.init():
+                        C[v0, 0, v1] = T.float32(0)
+                    C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]
+    # fmt: on
+    sch = tir.Schedule(before, debug_mask="all")
+    block = sch.get_block("NT_matmul")
+    loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf"))
+    sch.reverse_compute_at(block, loop, preserve_unit_loops=True)
+    tvm.ir.assert_structural_equal(sch.mod["main"], expected, True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to