This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 03ef29e8ad [TIR][Schedule] Derive Nonnegative Bounds from Shape Var
(#15210)
03ef29e8ad is described below
commit 03ef29e8ad9c3073881422db3f4f935a9538370d
Author: Junru Shao <[email protected]>
AuthorDate: Mon Jul 3 08:08:11 2023 -0700
[TIR][Schedule] Derive Nonnegative Bounds from Shape Var (#15210)
This PR enhance the arithmetic analysis used in compute-at to further
help symbolic bound simplification.
Previously, when a variable `n` appears in the shape of an input buffer
`T.Buffer((n * 32), "float32")`, we could safely assume that `n` is
nonnegative as it is part of the shape. This could help us simplify some
bounds during scheduling as well as lowering.
For example, for integers `n` and `bx` where `bx` has a symbolic bound
`[0, 32 * n)`, if `n` is nonnegative, we could simplify the following
expressions to True:
```
0 <= floordiv(bx, n) < 32
0 <= floormod(bx, n) < n
```
This PR depends on #15193 to provide an interface that hints analyzer.
---
src/tir/schedule/primitive/compute_at.cc | 15 +++++
.../unittest/test_tir_schedule_compute_at.py | 69 ++++++++++++++++++++++
2 files changed, 84 insertions(+)
diff --git a/src/tir/schedule/primitive/compute_at.cc
b/src/tir/schedule/primitive/compute_at.cc
index 8210274960..45d0c81050 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -680,6 +680,20 @@ void CalculateProvidedRequiredRegions(
/******** Main Implementation ********/
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+ arith::Analyzer* analyzer) {
+ while (sref->parent != nullptr) {
+ sref = sref->parent;
+ }
+ const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
+ for (const auto& kv : f->buffer_map) {
+ const Buffer& buffer = kv.second;
+ for (const PrimExpr& e : buffer->shape) {
+ analyzer->MarkGlobalNonNegValue(e);
+ }
+ }
+}
+
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef&
block_sref,
const StmtSRef& loop_sref, bool
preserve_unit_loops,
@@ -692,6 +706,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self,
const StmtSRef& block_s
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
+ AddShapeVarBounds(self, scope_root_sref.get(), analyzer);
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py
b/tests/python/unittest/test_tir_schedule_compute_at.py
index a1b4cf1559..2e44776a0f 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1823,5 +1823,74 @@ def test_compute_inline_allocate_const(use_decl_buffer):
verify_trace_roundtrip(sch=sch, mod=before)
+def test_shape_var_as_bound():
+ # fmt: off
+ @T.prim_func
+ def before(a: T.handle, b: T.handle, c: T.handle):
+ n = T.int32()
+ A = T.match_buffer(a, (32, 1, 128))
+ B = T.match_buffer(b, (32, n, 128))
+ C = T.match_buffer(c, (32, 1, n))
+ # with T.block("root"):
+ C_rf = T.alloc_buffer((128, 32, 1, n))
+ for ax0_ax1_fused, ax2_fused_1, ax2_fused_0 in T.grid(n * 32, 128, 1):
+ with T.block("NT_matmul_rf"):
+ vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
+ v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+ vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
+ T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1,
vax2_fused_0 * 128 + vax2_fused_1])
+ T.writes(C_rf[vax2_fused_1, v0, 0, v1])
+ with T.init():
+ C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
+ C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1]
+ A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 +
vax2_fused_1]
+ for ax0_ax1_fused, ax2_fused_1 in T.grid(n * 32, 128):
+ with T.block("NT_matmul"):
+ vax2_fused_1 = T.axis.reduce(128, ax2_fused_1)
+ v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+ T.reads(C_rf[vax2_fused_1, v0, 0, v1])
+ T.writes(C[v0, 0, v1])
+ with T.init():
+ C[v0, 0, v1] = T.float32(0)
+ C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]
+
+ @T.prim_func
+ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c:
T.handle):
+ n = T.int32()
+ B = T.match_buffer(b, (32, n, 128))
+ C = T.match_buffer(c, (32, 1, n))
+ # with T.block("root"):
+ C_rf = T.alloc_buffer((128, 32, 1, n))
+ for ax0_ax1_fused in range(n * 32):
+ for ax2_fused_1, ax2_fused_0 in T.grid(128, 1):
+ with T.block("NT_matmul_rf"):
+ vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
+ v0 = T.axis.spatial(32, ax0_ax1_fused // n)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+ vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
+ T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0,
v1, vax2_fused_0 * 128 + vax2_fused_1])
+ T.writes(C_rf[vax2_fused_1, v0, 0, v1])
+ with T.init():
+ C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
+ C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0,
v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 *
128 + vax2_fused_1]
+ for ax0, ax1, ax2 in T.grid(128, 1, 1):
+ with T.block("NT_matmul"):
+ vax2_fused_1 = T.axis.reduce(128, ax0)
+ v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax1)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax2)
+ T.reads(C_rf[vax2_fused_1, v0, 0, v1])
+ T.writes(C[v0, 0, v1])
+ with T.init():
+ C[v0, 0, v1] = T.float32(0)
+ C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]
+ # fmt: on
+ sch = tir.Schedule(before, debug_mask="all")
+ block = sch.get_block("NT_matmul")
+ loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf"))
+ sch.reverse_compute_at(block, loop, preserve_unit_loops=True)
+ tvm.ir.assert_structural_equal(sch.mod["main"], expected, True)
+
+
if __name__ == "__main__":
tvm.testing.main()