This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3cce9738bd [BugFix][TIR] Affine-binding check should not simplify
trivial iterators (#13203)
3cce9738bd is described below
commit 3cce9738bd4a6d94657d2979b47a20238956e5cd
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Oct 28 04:36:08 2022 -0400
[BugFix][TIR] Affine-binding check should not simplify trivial iterators
(#13203)
* Fix affine bindings
* Regression test and test update
---
src/tir/schedule/analysis/analysis.cc | 3 +-
.../test_tir_schedule_state_cached_flags.py | 30 ++++++++++++
...t_tir_transform_lower_cross_thread_reduction.py | 54 +++++++++++-----------
3 files changed, 60 insertions(+), 27 deletions(-)
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index d8b4f31f4c..a2c0bc7594 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -540,7 +540,8 @@ bool IsAffineBinding(const BlockRealize& realize, const
Map<Var, Range>& loop_va
/*input_iters=*/loop_var_ranges,
/*predicate=*/realize->predicate,
/*check_level=*/arith::IterMapLevel::Surjective,
- /*analyzer=*/analyzer);
+ /*analyzer=*/analyzer,
+ /*simplify_trivial_iterators=*/false);
if (res->indices.empty()) {
return false;
}
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index 9878217140..70935814ba 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -438,6 +438,25 @@ def matmul_relu_padding(A: T.Buffer[(127, 127),
"float16"], B: T.Buffer[(127, 12
compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
[email protected]_func
+def splitted_square_sum_with_predicate(
+ A: T.Buffer[(1, 7, 7, 512), "float32"], B: T.Buffer[(1, 1, 1, 512),
"float32"]
+) -> None:
+ for i0_i1_i2_i3_0_fused, ax0, ax1, ax2, ax3 in T.grid(2, 1, 1, 1, 256):
+ for ax4_ax5_fused_0, ax4_ax5_fused_1 in T.grid(1, 256):
+ with T.block("B"):
+ T.where(ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1 < 49)
+ ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ ax3_1 = T.axis.spatial(512, i0_i1_i2_i3_0_fused * 256 + ax3)
+ rv0 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 +
ax4_ax5_fused_1) // 7)
+ rv1 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 +
ax4_ax5_fused_1) % 7)
+ T.reads(A[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1])
+ T.writes(B[ax0_1, ax1_1, ax2_1, ax3_1])
+ with T.init():
+ B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
+ B[ax0_1, ax1_1, ax2_1, ax3_1] += A[ax0_1, ax1_1 * 7 + rv0,
ax2_1 * 7 + rv1, ax3_1]
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
# fmt: on
@@ -865,5 +884,16 @@ def test_matmul_relu_padding():
# pylint: enable=protected-access
+def test_splitted_square_sum_with_predicate():
+ s = tir.ScheduleState(splitted_square_sum_with_predicate, debug_mask="all")
+ # pylint: disable=protected-access
+ assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
+ affine_binding=True,
+ region_cover=True,
+ stage_pipeline=True,
+ )
+ # pylint: enable=protected-access
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 8c139b710e..9ae4f4cf86 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -341,8 +341,8 @@ def single_reduction_loop_with_block_predicate(
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
- i0_1 = T.axis.spatial(256, i0)
- k = T.axis.reduce(256, ax1_1)
+ i0_1 = T.axis.spatial(256, i0 + ax0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(A[i0_1, k])
T.writes(T_softmax_maxelem_shared[i0_1])
@@ -354,8 +354,8 @@ def single_reduction_loop_with_block_predicate(
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
- i0_2 = T.axis.spatial(256, i0)
- k = T.axis.reduce(256, ax1_1)
+ i0_2 = T.axis.spatial(256, i0 + ax0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2])
T.writes(T_softmax_expsum_shared[i0_2])
@@ -368,7 +368,7 @@ def single_reduction_loop_with_block_predicate(
for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_norm"):
i0_3 = T.axis.spatial(256, i0)
- i1 = T.axis.spatial(256, i1_1)
+ i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
T.where(i1_0 * 512 + i1_1 < 256)
T.reads(
A[i0_3, i1], T_softmax_maxelem_shared[i0_3],
T_softmax_expsum_shared[i0_3]
@@ -392,19 +392,20 @@ def lowered_single_reduction_loop_with_block_predicate(
cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
for i0 in T.serial(256):
- for ax0, ax1_0 in T.grid(1, 1):
+ for ax0 in T.serial(1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_maxelem_in_thread_init"):
T.reads()
T.writes(in_thread_0[0])
in_thread_0[0] = T.float32(-3.4028234663852886e38)
- with T.block("T_softmax_maxelem_in_thread"):
- i0_1 = T.axis.spatial(256, i0)
- k = T.axis.reduce(256, ax1_1)
- T.where(ax1_0 * 512 + ax1_1 < 256)
- T.reads(A[i0_1, k])
- T.writes(in_thread_0[0])
- in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
+ for ax1_0 in T.serial(1):
+ with T.block("T_softmax_maxelem_in_thread"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_1 = T.axis.spatial(256, i0 + ax0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_1, k])
+ T.writes(in_thread_0[0])
+ in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
with T.block("T_softmax_maxelem_cross_thread"):
T.reads(in_thread_0[0])
T.writes(cross_thread_0[0])
@@ -426,25 +427,26 @@ def lowered_single_reduction_loop_with_block_predicate(
)
)
with T.block("T_softmax_maxelem_write_back"):
- i0_2 = T.axis.spatial(256, i0)
+ i0_2 = T.axis.spatial(256, i0 + ax0)
T.reads(cross_thread_0[0])
T.writes(T_softmax_maxelem_shared[i0_2])
T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
- for ax0, ax1_0 in T.grid(1, 1):
+ for ax0 in T.serial(1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_expsum_in_thread_init"):
T.reads()
T.writes(in_thread_1[0])
in_thread_1[0] = T.float32(0)
- with T.block("T_softmax_expsum_in_thread"):
- i0_3 = T.axis.spatial(256, i0)
- k = T.axis.reduce(256, ax1_1)
- T.where(ax1_0 * 512 + ax1_1 < 256)
- T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3])
- T.writes(in_thread_1[0])
- in_thread_1[0] = in_thread_1[0] + T.exp(
- A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
dtype="float32"
- )
+ for ax1_0 in T.serial(1):
+ with T.block("T_softmax_expsum_in_thread"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_3 = T.axis.spatial(256, i0 + ax0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3])
+ T.writes(in_thread_1[0])
+ in_thread_1[0] = in_thread_1[0] + T.exp(
+ A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
dtype="float32"
+ )
with T.block("T_softmax_expsum_cross_thread"):
T.reads(in_thread_1[0])
T.writes(cross_thread_1[0])
@@ -464,7 +466,7 @@ def lowered_single_reduction_loop_with_block_predicate(
)
)
with T.block("T_softmax_expsum_write_back"):
- i0_4 = T.axis.spatial(256, i0)
+ i0_4 = T.axis.spatial(256, i0 + ax0)
T.reads(cross_thread_1[0])
T.writes(T_softmax_expsum_shared[i0_4])
T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
@@ -472,7 +474,7 @@ def lowered_single_reduction_loop_with_block_predicate(
for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_norm"):
i0_5 = T.axis.spatial(256, i0)
- i1 = T.axis.spatial(256, i1_1)
+ i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
T.where(i1_0 * 512 + i1_1 < 256)
T.reads(
A[i0_5, i1], T_softmax_maxelem_shared[i0_5],
T_softmax_expsum_shared[i0_5]