This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 870246a369 [LoopPartition] Fix a bug of LoopPartition in single point 
scenarioes (#16104)
870246a369 is described below

commit 870246a369d4f51e8c4848f755c30c9c8f22920b
Author: lightzhan <[email protected]>
AuthorDate: Fri Dec 15 23:31:33 2023 +0800

    [LoopPartition] Fix a bug of LoopPartition in single point scenarioes 
(#16104)
    
    Fix a bug of LoopPartition in single point scenarioes.
    
    Co-authored-by: lightzhan-intellif <[email protected]>
---
 src/tir/transforms/loop_partition.cc               |  35 ++++
 .../test_tir_transform_loop_partition.py           | 228 +++++++++++++++++++++
 2 files changed, 263 insertions(+)

diff --git a/src/tir/transforms/loop_partition.cc 
b/src/tir/transforms/loop_partition.cc
index 0d08852669..c1c2d4644e 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -588,12 +588,47 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var 
var, PrimExpr min, Prim
       }
     }
 
+    bool all_singlepoints_outside = true;
+
+    // Check all partitions to see if they are single points and outside 
`for_interval`
+    for (const auto& partition : finder.partitions) {
+      const auto& intset = partition.second;
+      // Only proceed if the interval set is a single point
+      if (intset.IsSinglePoint()) {
+        auto single_point = intset.PointValue();
+        // Check if the single point is outside the `for_interval`
+        bool is_inside = analyzer_.CanProve(single_point >= 
for_interval.min()) &&
+                         analyzer_.CanProve(single_point <= 
for_interval.max());
+        if (is_inside) {
+          // If any single point is inside, this is an error condition
+          LOG(ERROR) << "unexpected case happened.";
+          all_singlepoints_outside = false;
+          break;
+        }
+      } else {
+        // If there is any intset that is not a single point, follow default 
logic
+        // For now, we set all_singlepoints_outside to false to indicate 
default logic was used
+        all_singlepoints_outside = false;
+        break;
+      }
+    }
+
+    if (all_singlepoints_outside) {
+      // If all single points are outside `for_interval`, return a nothing 
interval and false
+      return {IntSet::Nothing(), ExpressionSet(), false};
+    }
+
     // we couldn't find an interval in which the conditions are
     // provably true or false.  Therefore, we can't partition the loop
     // based on those conds
     return {{}, {}, std::nullopt};
   }();
 
+  if (middle_interval.IsNothing() && opt_cond_value == false) {
+    // Return loop directly as it can be simplified.
+    return stmt;
+  }
+
   if (!opt_cond_value.has_value()) {
     if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
         analyzer_.CanProve(max - min > 0)) {
diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py 
b/tests/python/tir-transform/test_tir_transform_loop_partition.py
index aa11ae5a5f..2b3f73e24f 100644
--- a/tests/python/tir-transform/test_tir_transform_loop_partition.py
+++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
 import tvm
 import tvm.testing
 from tvm import te
@@ -834,5 +835,232 @@ def test_loop_partition_with_unit_loop_in_condition():
     assert tvm.ir.structural_equal(mod["main"], 
after.with_attr("global_symbol", "main"))
 
 
[email protected]_func
+def concat_func_single_point(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+    for i0 in range(28):
+        for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+            if i1 > 63:
+                T_concat[i0, i1] = placeholder[i0, i1 - 64]
+            elif i1 == 63:
+                T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
+            else:
+                T_concat[i0, i1] = placeholder_2[i0, i1]
+
+
[email protected]_func
+def expected_partitioned_concat_single_point(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+):
+    for i0 in range(28):
+        T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+        for i1 in range(63):
+            placeholder_2_1 = T.Buffer((1764,), "int8", 
data=placeholder_2.data)
+            T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
+        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+        T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
+        for i1 in range(64):
+            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+            T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+
+
[email protected]_func
+def concat_func_start_point_equality(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+    for i0 in range(28):
+        for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
+            if i1 == 0:
+                # Special case for i1 == 0
+                T_concat[i0, i1] = placeholder_1[i0, 0]
+            elif i1 < 64:
+                # Normal case for i1 in [1, 63]
+                T_concat[i0, i1] = placeholder_2[i0, i1]
+            else:
+                # Case for i1 in [64, 127]
+                T_concat[i0, i1] = placeholder[i0, i1 - 64]
+
+
[email protected]_func
+def concat_func_start_point_equality_expected(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+):
+    for i0 in range(28):
+        T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+        T_concat_1[i0 * 128] = placeholder_1_1[i0]
+        for i1 in range(63):
+            placeholder_2_1 = T.Buffer((1764,), "int8", 
data=placeholder_2.data)
+            T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1]
+        for i1 in range(64):
+            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+            T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+
+
[email protected]_func
+def concat_func_end_point_equality(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+    for i0 in range(28):
+        for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
+            if i1 == 127:
+                # Explicit equality check for the end point i1 == 127
+                T_concat[i0, i1] = placeholder_1[i0, 0]
+            elif i1 >= 64:
+                # Case for i1 in [64, 126]
+                T_concat[i0, i1] = placeholder[i0, i1 - 64]
+            else:
+                # Case for i1 in [0, 63]
+                T_concat[i0, i1] = placeholder_2[i0, i1]
+
+
[email protected]_func
+def concat_func_end_point_equality_expected(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 63), "int8"),
+    T_concat: T.Buffer((28, 128), "int8"),
+):
+    for i0 in range(28):
+        T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+        for i1 in range(64):
+            placeholder_2_1 = T.Buffer((1764,), "int8", 
data=placeholder_2.data)
+            T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
+        for i1 in range(63):
+            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+            T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+        T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0]
+
+
[email protected]_func
+def concat_func_edge_equalities(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 1), "int8"),
+    T_concat: T.Buffer((28, 66), "int8"),
+) -> None:
+    for i0 in range(28):
+        for i1 in range(
+            66, annotations={"pragma_loop_partition_hint": 1}
+        ):  # Loop from 0 to 65 inclusive
+            if i1 == 0:
+                # Handle equality at the start of the range: i1 == 0
+                T_concat[i0, i1] = placeholder_2[i0, 0]
+            elif i1 == 65:
+                # Handle equality at the end of the range: i1 == 65
+                T_concat[i0, i1] = placeholder_1[i0, 0]
+            else:
+                # Copying from placeholder (from 0 to 63)
+                T_concat[i0, i1] = placeholder[i0, i1 - 1]
+
+
[email protected]_func
+def concat_func_edge_equalities_expected(
+    placeholder: T.Buffer((28, 64), "int8"),
+    placeholder_1: T.Buffer((28, 1), "int8"),
+    placeholder_2: T.Buffer((28, 1), "int8"),
+    T_concat: T.Buffer((28, 66), "int8"),
+):
+    for i0 in range(28):
+        T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data)
+        placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data)
+        T_concat_1[i0 * 66] = placeholder_2_1[i0]
+        for i1 in range(64):
+            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+            T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1]
+        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+        T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0]
+
+
[email protected]_func
+def concat_five_buffers_with_equalities(
+    buffer_a: T.Buffer((28, 1), "int8"),  # Used for i1 == 0
+    buffer_b: T.Buffer((28, 63), "int8"),  # Fills i1 from 1 to 63
+    buffer_c: T.Buffer((28, 1), "int8"),  # Used for i1 == 64
+    buffer_d: T.Buffer((28, 63), "int8"),  # Fills i1 from 65 to 128
+    buffer_e: T.Buffer((28, 1), "int8"),  # Used for i1 == 129
+    T_concat: T.Buffer((28, 129), "int8"),
+) -> None:
+    for i0 in range(28):
+        for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}):
+            if i1 == 0:
+                T_concat[i0, i1] = buffer_a[i0, 0]
+            elif i1 == 64:
+                T_concat[i0, i1] = buffer_c[i0, 0]
+            elif i1 == 129:
+                T_concat[i0, i1] = buffer_e[i0, 0]
+            elif i1 < 64:
+                T_concat[i0, i1] = buffer_b[i0, i1 - 1]
+            else:  # i1 > 64 and i1 < 128
+                T_concat[i0, i1] = buffer_d[i0, i1 - 65]
+
+
[email protected]_func
+def concat_five_buffers_with_equalities_expected(
+    buffer_a: T.Buffer((28, 1), "int8"),  # Used for i1 == 0
+    buffer_b: T.Buffer((28, 63), "int8"),  # Fills i1 from 1 to 63
+    buffer_c: T.Buffer((28, 1), "int8"),  # Used for i1 == 64
+    buffer_d: T.Buffer((28, 63), "int8"),  # Fills i1 from 65 to 128
+    buffer_e: T.Buffer((28, 1), "int8"),  # Used for i1 == 129
+    T_concat: T.Buffer((28, 129), "int8"),
+):
+    for i0 in range(28):
+        T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data)
+        buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data)
+        T_concat_1[i0 * 129] = buffer_a_1[i0]
+        for i1 in range(63):
+            buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data)
+            T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1]
+        buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data)
+        T_concat_1[i0 * 129 + 64] = buffer_c_1[i0]
+        for i1 in range(64):
+            buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data)
+            T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1]
+        buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data)
+        T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]
+
+
[email protected](
+    "origin,expected",
+    [
+        (concat_func_single_point, expected_partitioned_concat_single_point),
+        (concat_func_start_point_equality, 
concat_func_start_point_equality_expected),
+        (concat_func_end_point_equality, 
concat_func_end_point_equality_expected),
+        (concat_func_edge_equalities, concat_func_edge_equalities_expected),
+        (concat_five_buffers_with_equalities, 
concat_five_buffers_with_equalities_expected),
+    ],
+)
+def test_single_point_partition(origin, expected):
+    origin = origin.with_attr({"global_symbol": "main"})
+    expected = expected.with_attr({"global_symbol": "main"})
+    mod = partition_from_scheduled_tir(
+        origin,
+        {
+            "tir.LoopPartition": {
+                "partition_const_loop": True,
+                "unroll_loop_with_partition_hint_no_interval": True,
+            }
+        },
+    )
+    assert tvm.ir.structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to