zxybazh commented on code in PR #13481:
URL: https://github.com/apache/tvm/pull/13481#discussion_r1031806526
##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -1505,5 +1505,56 @@ def main_reverse_compute_at(
tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
+def test_reverse_compute_at_with_unit_loop():
+ @T.prim_func
+ def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1),
"float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ C = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0, i_1 in T.grid(8, 8, 16):
+ for j_1 in T.serial(16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for ax0, ax1, ax2 in T.grid(1, 2, 1):
+ with T.block("D"):
+ v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(B[v0, v1])
+ T.writes(D[v0, v1, v2])
+ D[v0, v1, v2] = B[v0, v1] + T.float32(1)
+
+ @T.prim_func
+ def main_reverse_compute_at(
+ A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
+ ):
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ C = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0, i_1 in T.grid(8, 8, 16):
+ for j_1 in T.serial(16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for ax0, ax1, ax2 in T.grid(1, 16, 1):
+ with T.block("D"):
+ T.where(i_0 * 16 + i_1 < 1 and j_0 * 16 + ax1 < 2)
+ v0 = T.axis.spatial(1, i_0 * 16 + i_1 + ax0)
+ v1 = T.axis.spatial(2, j_0 * 16 + ax1)
+ v2 = T.axis.spatial(1, ax2)
+ T.reads(B[v0, v1])
+ T.writes(D[v0, v1, v2])
+ D[v0, v1, v2] = B[v0, v1] + T.float32(1)
Review Comment:
```suggestion
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1),
"float32"]) -> None:
B = T.alloc_buffer([128, 128], dtype="float32")
for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
for j_1 in T.serial(T.int64(16)):
with T.block("B"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) +
i_1)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) +
j_1)
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(1)):
with T.block("D"):
v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(B[v0, v1])
T.writes(D[v0, v1, v2])
D[v0, v1, v2] = B[v0, v1] + T.float32(1)
@T.prim_func
def main_reverse_compute_at(
A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
):
B = T.alloc_buffer([128, 128], dtype="float32")
for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
for j_1 in T.serial(T.int64(16)):
with T.block("B"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) +
i_1)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) +
j_1)
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(16), T.int64(1)):
with T.block("D"):
T.where(
i_0 * T.int64(16) + i_1 < T.int64(1)
and j_0 * T.int64(16) + ax1 < T.int64(2)
)
v0 = T.axis.spatial(T.int64(1), i_0 * T.int64(16) + i_1
+ ax0)
v1 = T.axis.spatial(T.int64(2), j_0 * T.int64(16) + ax1)
v2 = T.axis.spatial(T.int64(1), ax2)
T.reads(B[v0, v1])
T.writes(D[v0, v1, v2])
D[v0, v1, v2] = B[v0, v1] + T.float32(1)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]