This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3cce9738bd [BugFix][TIR] Affine-binding check should not simplify 
trivial iterators (#13203)
3cce9738bd is described below

commit 3cce9738bd4a6d94657d2979b47a20238956e5cd
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Oct 28 04:36:08 2022 -0400

    [BugFix][TIR] Affine-binding check should not simplify trivial iterators 
(#13203)
    
    * Fix affine bindings
    
    * Regression test and test update
---
 src/tir/schedule/analysis/analysis.cc              |  3 +-
 .../test_tir_schedule_state_cached_flags.py        | 30 ++++++++++++
 ...t_tir_transform_lower_cross_thread_reduction.py | 54 +++++++++++-----------
 3 files changed, 60 insertions(+), 27 deletions(-)

diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index d8b4f31f4c..a2c0bc7594 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -540,7 +540,8 @@ bool IsAffineBinding(const BlockRealize& realize, const 
Map<Var, Range>& loop_va
       /*input_iters=*/loop_var_ranges,
       /*predicate=*/realize->predicate,
       /*check_level=*/arith::IterMapLevel::Surjective,
-      /*analyzer=*/analyzer);
+      /*analyzer=*/analyzer,
+      /*simplify_trivial_iterators=*/false);
   if (res->indices.empty()) {
     return false;
   }
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py 
b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index 9878217140..70935814ba 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -438,6 +438,25 @@ def matmul_relu_padding(A: T.Buffer[(127, 127), 
"float16"], B: T.Buffer[(127, 12
             compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
 
 
[email protected]_func
+def splitted_square_sum_with_predicate(
+    A: T.Buffer[(1, 7, 7, 512), "float32"], B: T.Buffer[(1, 1, 1, 512), 
"float32"]
+) -> None:
+    for i0_i1_i2_i3_0_fused, ax0, ax1, ax2, ax3 in T.grid(2, 1, 1, 1, 256):
+        for ax4_ax5_fused_0, ax4_ax5_fused_1 in T.grid(1, 256):
+            with T.block("B"):
+                T.where(ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1 < 49)
+                ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                ax3_1 = T.axis.spatial(512, i0_i1_i2_i3_0_fused * 256 + ax3)
+                rv0 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 + 
ax4_ax5_fused_1) // 7)
+                rv1 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 + 
ax4_ax5_fused_1) % 7)
+                T.reads(A[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1])
+                T.writes(B[ax0_1, ax1_1, ax2_1, ax3_1])
+                with T.init():
+                    B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
+                B[ax0_1, ax1_1, ax2_1, ax3_1] += A[ax0_1, ax1_1 * 7 + rv0, 
ax2_1 * 7 + rv1, ax3_1]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 # fmt: on
 
@@ -865,5 +884,16 @@ def test_matmul_relu_padding():
     # pylint: enable=protected-access
 
 
+def test_splitted_square_sum_with_predicate():
+    s = tir.ScheduleState(splitted_square_sum_with_predicate, debug_mask="all")
+    # pylint: disable=protected-access
+    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
+        affine_binding=True,
+        region_cover=True,
+        stage_pipeline=True,
+    )
+    # pylint: enable=protected-access
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py 
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 8c139b710e..9ae4f4cf86 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -341,8 +341,8 @@ def single_reduction_loop_with_block_predicate(
         for ax0, ax1_0 in T.grid(1, 1):
             for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_maxelem"):
-                    i0_1 = T.axis.spatial(256, i0)
-                    k = T.axis.reduce(256, ax1_1)
+                    i0_1 = T.axis.spatial(256, i0 + ax0)
+                    k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
                     T.where(ax1_0 * 512 + ax1_1 < 256)
                     T.reads(A[i0_1, k])
                     T.writes(T_softmax_maxelem_shared[i0_1])
@@ -354,8 +354,8 @@ def single_reduction_loop_with_block_predicate(
         for ax0, ax1_0 in T.grid(1, 1):
             for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_expsum"):
-                    i0_2 = T.axis.spatial(256, i0)
-                    k = T.axis.reduce(256, ax1_1)
+                    i0_2 = T.axis.spatial(256, i0 + ax0)
+                    k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
                     T.where(ax1_0 * 512 + ax1_1 < 256)
                     T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2])
                     T.writes(T_softmax_expsum_shared[i0_2])
@@ -368,7 +368,7 @@ def single_reduction_loop_with_block_predicate(
             for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_norm"):
                     i0_3 = T.axis.spatial(256, i0)
-                    i1 = T.axis.spatial(256, i1_1)
+                    i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
                     T.where(i1_0 * 512 + i1_1 < 256)
                     T.reads(
                         A[i0_3, i1], T_softmax_maxelem_shared[i0_3], 
T_softmax_expsum_shared[i0_3]
@@ -392,19 +392,20 @@ def lowered_single_reduction_loop_with_block_predicate(
     cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
     in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
     for i0 in T.serial(256):
-        for ax0, ax1_0 in T.grid(1, 1):
+        for ax0 in T.serial(1):
             for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_maxelem_in_thread_init"):
                     T.reads()
                     T.writes(in_thread_0[0])
                     in_thread_0[0] = T.float32(-3.4028234663852886e38)
-                with T.block("T_softmax_maxelem_in_thread"):
-                    i0_1 = T.axis.spatial(256, i0)
-                    k = T.axis.reduce(256, ax1_1)
-                    T.where(ax1_0 * 512 + ax1_1 < 256)
-                    T.reads(A[i0_1, k])
-                    T.writes(in_thread_0[0])
-                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
+                for ax1_0 in T.serial(1):
+                    with T.block("T_softmax_maxelem_in_thread"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_1 = T.axis.spatial(256, i0 + ax0)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_1, k])
+                        T.writes(in_thread_0[0])
+                        in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
                 with T.block("T_softmax_maxelem_cross_thread"):
                     T.reads(in_thread_0[0])
                     T.writes(cross_thread_0[0])
@@ -426,25 +427,26 @@ def lowered_single_reduction_loop_with_block_predicate(
                         )
                     )
                 with T.block("T_softmax_maxelem_write_back"):
-                    i0_2 = T.axis.spatial(256, i0)
+                    i0_2 = T.axis.spatial(256, i0 + ax0)
                     T.reads(cross_thread_0[0])
                     T.writes(T_softmax_maxelem_shared[i0_2])
                     T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
-        for ax0, ax1_0 in T.grid(1, 1):
+        for ax0 in T.serial(1):
             for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_expsum_in_thread_init"):
                     T.reads()
                     T.writes(in_thread_1[0])
                     in_thread_1[0] = T.float32(0)
-                with T.block("T_softmax_expsum_in_thread"):
-                    i0_3 = T.axis.spatial(256, i0)
-                    k = T.axis.reduce(256, ax1_1)
-                    T.where(ax1_0 * 512 + ax1_1 < 256)
-                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3])
-                    T.writes(in_thread_1[0])
-                    in_thread_1[0] = in_thread_1[0] + T.exp(
-                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3], 
dtype="float32"
-                    )
+                for ax1_0 in T.serial(1):
+                    with T.block("T_softmax_expsum_in_thread"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_3 = T.axis.spatial(256, i0 + ax0)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3])
+                        T.writes(in_thread_1[0])
+                        in_thread_1[0] = in_thread_1[0] + T.exp(
+                            A[i0_3, k] - T_softmax_maxelem_shared[i0_3], 
dtype="float32"
+                        )
                 with T.block("T_softmax_expsum_cross_thread"):
                     T.reads(in_thread_1[0])
                     T.writes(cross_thread_1[0])
@@ -464,7 +466,7 @@ def lowered_single_reduction_loop_with_block_predicate(
                         )
                     )
                 with T.block("T_softmax_expsum_write_back"):
-                    i0_4 = T.axis.spatial(256, i0)
+                    i0_4 = T.axis.spatial(256, i0 + ax0)
                     T.reads(cross_thread_1[0])
                     T.writes(T_softmax_expsum_shared[i0_4])
                     T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
@@ -472,7 +474,7 @@ def lowered_single_reduction_loop_with_block_predicate(
             for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                 with T.block("T_softmax_norm"):
                     i0_5 = T.axis.spatial(256, i0)
-                    i1 = T.axis.spatial(256, i1_1)
+                    i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
                     T.where(i1_0 * 512 + i1_1 < 256)
                     T.reads(
                         A[i0_5, i1], T_softmax_maxelem_shared[i0_5], 
T_softmax_expsum_shared[i0_5]

Reply via email to