Hzfengsy commented on a change in pull request #10340:
URL: https://github.com/apache/tvm/pull/10340#discussion_r811586894
##########
File path: tests/python/unittest/test_tir_transform_loop_partition.py
##########
@@ -565,6 +566,78 @@ def test_explicit_partition_hint():
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)
[email protected]_func
+def partitioned_concat_3(
+ placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+ placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+ placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+ T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+ for i1, i2, i3 in T.grid(64, 28, 28):
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3,
+ T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+ True,
+ )
+ for i1, i2, i3 in T.grid(32, 28, 28):
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3 + 50176,
+ T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3),
+ True,
+ )
+ for i1, i2, i3 in T.grid(32, 28, 28):
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3 + 75264,
+ T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3),
+ True,
+ )
+
+
[email protected]_func
+def concat_func_3(
+ placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+ placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+ placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+ T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+ for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+ for i2, i3 in T.grid(28, 28):
+ if 96 <= i1:
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3,
+ T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3
- 75264),
+ True,
+ )
+ if 64 <= i1 and i1 < 96:
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3,
+ T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3
- 50176),
+ True,
+ )
+ if i1 < 64:
+ T.store(
+ T_concat.data,
+ i1 * 784 + i2 * 28 + i3,
+ T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+ True,
+ )
+
+
+def test_condition_mutually_exclusive():
+ mod = IRModule.from_expr(concat_func_3)
+ with tvm.transform.PassContext(config={"tir.LoopPartition":
{"partition_const_loop": True}}):
+ mod = tvm.tir.transform.FlattenBuffer()(mod)
Review comment:
Do we really need FlattenBuffer in this testcase?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]