This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 870246a369 [LoopPartition] Fix a bug of LoopPartition in single point
scenarioes (#16104)
870246a369 is described below
commit 870246a369d4f51e8c4848f755c30c9c8f22920b
Author: lightzhan <[email protected]>
AuthorDate: Fri Dec 15 23:31:33 2023 +0800
[LoopPartition] Fix a bug of LoopPartition in single point scenarioes
(#16104)
Fix a bug of LoopPartition in single point scenarioes.
Co-authored-by: lightzhan-intellif <[email protected]>
---
src/tir/transforms/loop_partition.cc | 35 ++++
.../test_tir_transform_loop_partition.py | 228 +++++++++++++++++++++
2 files changed, 263 insertions(+)
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index 0d08852669..c1c2d4644e 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -588,12 +588,47 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var
var, PrimExpr min, Prim
}
}
+ bool all_singlepoints_outside = true;
+
+ // Check all partitions to see if they are single points and outside
`for_interval`
+ for (const auto& partition : finder.partitions) {
+ const auto& intset = partition.second;
+ // Only proceed if the interval set is a single point
+ if (intset.IsSinglePoint()) {
+ auto single_point = intset.PointValue();
+ // Check if the single point is outside the `for_interval`
+ bool is_inside = analyzer_.CanProve(single_point >=
for_interval.min()) &&
+ analyzer_.CanProve(single_point <=
for_interval.max());
+ if (is_inside) {
+ // If any single point is inside, this is an error condition
+ LOG(ERROR) << "unexpected case happened.";
+ all_singlepoints_outside = false;
+ break;
+ }
+ } else {
+ // If there is any intset that is not a single point, follow default
logic
+ // For now, we set all_singlepoints_outside to false to indicate
default logic was used
+ all_singlepoints_outside = false;
+ break;
+ }
+ }
+
+ if (all_singlepoints_outside) {
+ // If all single points are outside `for_interval`, return a nothing
interval and false
+ return {IntSet::Nothing(), ExpressionSet(), false};
+ }
+
// we couldn't find an interval in which the conditions are
// provably true or false. Therefore, we can't partition the loop
// based on those conds
return {{}, {}, std::nullopt};
}();
+ if (middle_interval.IsNothing() && opt_cond_value == false) {
+ // Return loop directly as it can be simplified.
+ return stmt;
+ }
+
if (!opt_cond_value.has_value()) {
if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
analyzer_.CanProve(max - min > 0)) {
diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py
b/tests/python/tir-transform/test_tir_transform_loop_partition.py
index aa11ae5a5f..2b3f73e24f 100644
--- a/tests/python/tir-transform/test_tir_transform_loop_partition.py
+++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
import tvm
import tvm.testing
from tvm import te
@@ -834,5 +835,232 @@ def test_loop_partition_with_unit_loop_in_condition():
assert tvm.ir.structural_equal(mod["main"],
after.with_attr("global_symbol", "main"))
[email protected]_func
+def concat_func_single_point(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+ for i0 in range(28):
+ for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+ if i1 > 63:
+ T_concat[i0, i1] = placeholder[i0, i1 - 64]
+ elif i1 == 63:
+ T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
+ else:
+ T_concat[i0, i1] = placeholder_2[i0, i1]
+
+
[email protected]_func
+def expected_partitioned_concat_single_point(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+):
+ for i0 in range(28):
+ T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+ for i1 in range(63):
+ placeholder_2_1 = T.Buffer((1764,), "int8",
data=placeholder_2.data)
+ T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
+ placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+ T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
+ for i1 in range(64):
+ placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+ T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+
+
[email protected]_func
+def concat_func_start_point_equality(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+ for i0 in range(28):
+ for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
+ if i1 == 0:
+ # Special case for i1 == 0
+ T_concat[i0, i1] = placeholder_1[i0, 0]
+ elif i1 < 64:
+ # Normal case for i1 in [1, 63]
+ T_concat[i0, i1] = placeholder_2[i0, i1]
+ else:
+ # Case for i1 in [64, 127]
+ T_concat[i0, i1] = placeholder[i0, i1 - 64]
+
+
[email protected]_func
+def concat_func_start_point_equality_expected(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+):
+ for i0 in range(28):
+ T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+ placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+ T_concat_1[i0 * 128] = placeholder_1_1[i0]
+ for i1 in range(63):
+ placeholder_2_1 = T.Buffer((1764,), "int8",
data=placeholder_2.data)
+ T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1]
+ for i1 in range(64):
+ placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+ T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+
+
[email protected]_func
+def concat_func_end_point_equality(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+) -> None:
+ for i0 in range(28):
+ for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
+ if i1 == 127:
+ # Explicit equality check for the end point i1 == 127
+ T_concat[i0, i1] = placeholder_1[i0, 0]
+ elif i1 >= 64:
+ # Case for i1 in [64, 126]
+ T_concat[i0, i1] = placeholder[i0, i1 - 64]
+ else:
+ # Case for i1 in [0, 63]
+ T_concat[i0, i1] = placeholder_2[i0, i1]
+
+
[email protected]_func
+def concat_func_end_point_equality_expected(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 63), "int8"),
+ T_concat: T.Buffer((28, 128), "int8"),
+):
+ for i0 in range(28):
+ T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
+ for i1 in range(64):
+ placeholder_2_1 = T.Buffer((1764,), "int8",
data=placeholder_2.data)
+ T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
+ for i1 in range(63):
+ placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+ T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
+ placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+ T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0]
+
+
[email protected]_func
+def concat_func_edge_equalities(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 1), "int8"),
+ T_concat: T.Buffer((28, 66), "int8"),
+) -> None:
+ for i0 in range(28):
+ for i1 in range(
+ 66, annotations={"pragma_loop_partition_hint": 1}
+ ): # Loop from 0 to 65 inclusive
+ if i1 == 0:
+ # Handle equality at the start of the range: i1 == 0
+ T_concat[i0, i1] = placeholder_2[i0, 0]
+ elif i1 == 65:
+ # Handle equality at the end of the range: i1 == 65
+ T_concat[i0, i1] = placeholder_1[i0, 0]
+ else:
+ # Copying from placeholder (from 0 to 63)
+ T_concat[i0, i1] = placeholder[i0, i1 - 1]
+
+
[email protected]_func
+def concat_func_edge_equalities_expected(
+ placeholder: T.Buffer((28, 64), "int8"),
+ placeholder_1: T.Buffer((28, 1), "int8"),
+ placeholder_2: T.Buffer((28, 1), "int8"),
+ T_concat: T.Buffer((28, 66), "int8"),
+):
+ for i0 in range(28):
+ T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data)
+ placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data)
+ T_concat_1[i0 * 66] = placeholder_2_1[i0]
+ for i1 in range(64):
+ placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
+ T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1]
+ placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
+ T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0]
+
+
[email protected]_func
+def concat_five_buffers_with_equalities(
+ buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
+ buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
+ buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
+ buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
+ buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
+ T_concat: T.Buffer((28, 129), "int8"),
+) -> None:
+ for i0 in range(28):
+ for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}):
+ if i1 == 0:
+ T_concat[i0, i1] = buffer_a[i0, 0]
+ elif i1 == 64:
+ T_concat[i0, i1] = buffer_c[i0, 0]
+ elif i1 == 129:
+ T_concat[i0, i1] = buffer_e[i0, 0]
+ elif i1 < 64:
+ T_concat[i0, i1] = buffer_b[i0, i1 - 1]
+ else: # i1 > 64 and i1 < 128
+ T_concat[i0, i1] = buffer_d[i0, i1 - 65]
+
+
[email protected]_func
+def concat_five_buffers_with_equalities_expected(
+ buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
+ buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
+ buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
+ buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
+ buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
+ T_concat: T.Buffer((28, 129), "int8"),
+):
+ for i0 in range(28):
+ T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data)
+ buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data)
+ T_concat_1[i0 * 129] = buffer_a_1[i0]
+ for i1 in range(63):
+ buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data)
+ T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1]
+ buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data)
+ T_concat_1[i0 * 129 + 64] = buffer_c_1[i0]
+ for i1 in range(64):
+ buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data)
+ T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1]
+ buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data)
+ T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]
+
+
[email protected](
+ "origin,expected",
+ [
+ (concat_func_single_point, expected_partitioned_concat_single_point),
+ (concat_func_start_point_equality,
concat_func_start_point_equality_expected),
+ (concat_func_end_point_equality,
concat_func_end_point_equality_expected),
+ (concat_func_edge_equalities, concat_func_edge_equalities_expected),
+ (concat_five_buffers_with_equalities,
concat_five_buffers_with_equalities_expected),
+ ],
+)
+def test_single_point_partition(origin, expected):
+ origin = origin.with_attr({"global_symbol": "main"})
+ expected = expected.with_attr({"global_symbol": "main"})
+ mod = partition_from_scheduled_tir(
+ origin,
+ {
+ "tir.LoopPartition": {
+ "partition_const_loop": True,
+ "unroll_loop_with_partition_hint_no_interval": True,
+ }
+ },
+ )
+ assert tvm.ir.structural_equal(mod["main"], expected)
+
+
if __name__ == "__main__":
tvm.testing.main()