junrushao commented on PR #14547:
URL: https://github.com/apache/tvm/pull/14547#issuecomment-1502377735
Interestingly, this PR happened to reveal a potential issue with the
existing async pipeline when scheduling with Hexagon targets, where arithmetic
analysis leads `FlattenBuffer` to generate wrong IRs.
The TIR before `FlattenBuffer`:
```python
@T.prim_func
def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64),
"uint8")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
a_buffer_global_vtcm = T.decl_buffer((2, 1, 64), "uint8",
scope="global.vtcm")
out_global_vtcm = T.decl_buffer((2, 1, 64), "uint8", scope="global.vtcm")
for i in T.unroll(2):
if i < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm[i % 2, i - i, ax0] = a_buffer[i, ax0]
if i == 1 and i - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm[(i - 1) % 2, i - 1 - (i - 1), j] =
a_buffer_global_vtcm[(i - 1) % 2, i - 1 - (i - 1), j] + T.uint8(1)
for i in range(6):
if i + 2 < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm[(i + 2) % 2, i + 2 - (i + 2), ax0] =
a_buffer[i + 2, ax0]
if i + 2 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm[(i - 1 + 2) % 2, i - 1 + 2 - (i - 1 + 2), j]
= a_buffer_global_vtcm[(i - 1 + 2) % 2, i - 1 + 2 - (i - 1 + 2), j] + T.uint8(1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#0: Note the indexing here. The number in
the middle is essentially 0.
if i + 2 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out[i - 2 + 2, ax0] = out_global_vtcm[(i - 2 + 2) % 2, i - 2
+ 2 - (i - 2 + 2), ax0]
for i in T.unroll(2):
if i + 8 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 0 - i)
for j in range(64):
out_global_vtcm[(i - 1 + 8) % 2, i - 1 + 8 - (i - 1 + 8), j]
= a_buffer_global_vtcm[(i - 1 + 8) % 2, i - 1 + 8 - (i - 1 + 8), j] + T.uint8(1)
if i + 8 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out[i - 2 + 8, ax0] = out_global_vtcm[(i - 2 + 8) % 2, i - 2
+ 8 - (i - 2 + 8), ax0]
```
The TIR after `FlattenBuffer`, before this PR:
```python
@I.ir_module
class Module:
@T.prim_func
def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64),
"uint8")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
a_buffer_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
out_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
a_buffer_global_vtcm_1 = T.Buffer((128,), "uint8",
data=a_buffer_global_vtcm, scope="global.vtcm")
a_buffer_1 = T.Buffer((512,), "uint8", data=a_buffer.data)
out_global_vtcm_1 = T.Buffer((128,), "uint8", data=out_global_vtcm,
scope="global.vtcm")
for i in T.unroll(2):
if i < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i
* 64 + ax0]
if i == 1 and i - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm_1[j + 64 - i % 2 * 64] =
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
out_1 = T.Buffer((512,), "uint8", data=out.data)
for i in range(6):
if i + 2 < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i
* 64 + ax0 + 128]
if i + 2 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm_1[j + 64 - i % 2 * 64] =
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
^^^^^^^^^^^^^^^^^^^
#1: The `-i % 2` thing here doesn't
look right
if i + 2 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out_1[i * 64 + ax0] = out_global_vtcm_1[i % 2 * 64 + ax0]
for i in T.unroll(2):
if i + 8 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 0 - i)
for j in range(64):
out_global_vtcm_1[j + 64 - i % 2 * 64] =
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
if i + 8 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out_1[i * 64 + ax0 + 384] = out_global_vtcm_1[i % 2 * 64
+ ax0]
```
This PR fixes this issue and generates the correct IR:
```python
@T.prim_func
def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64),
"uint8")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
a_buffer_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
out_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
a_buffer_global_vtcm_1 = T.Buffer((128,), "uint8",
data=a_buffer_global_vtcm, scope="global.vtcm")
a_buffer_1 = T.Buffer((512,), "uint8", data=a_buffer.data)
out_global_vtcm_1 = T.Buffer((128,), "uint8", data=out_global_vtcm,
scope="global.vtcm")
for i in T.unroll(2):
if i < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm_1[i * 64 + ax0] = a_buffer_1[i * 64 +
ax0]
if i == 1 and i - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm_1[j + 64 - i % 2 * 64] =
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
out_1 = T.Buffer((512,), "uint8", data=out.data)
for i in range(6):
if i + 2 < 8:
T.attr(0, "async_commit_queue_scope", 0)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i * 64
+ ax0 + 128]
if i + 2 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 1)
for j in range(64):
out_global_vtcm_1[i % 2 * 64 + j] = a_buffer_global_vtcm_1[i
% 2 * 64 + j] + T.uint8(1)
^^^^^^^^^^^^^^^^^^^
#2: This is the correct index after buffer
flattening
if i + 2 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out_1[i * 64 + ax0] = out_global_vtcm_1[i % 2 * 64 + ax0]
for i in T.unroll(2):
if i + 8 - 1 < 8:
T.attr(0, "async_wait_queue_scope", 0)
T.attr(0, "async_wait_inflight_count", 0 - i)
for j in range(64):
out_global_vtcm_1[i * 64 + j] = a_buffer_global_vtcm_1[i *
64 + j] + T.uint8(1)
if i + 8 - 2 < 8:
T.attr(0, "async_commit_queue_scope", 2)
T.attr(0, "async_scope", 1)
for ax0 in range(64):
out_1[i * 64 + ax0 + 384] = out_global_vtcm_1[i * 64 + ax0]
```
However, it seems that there might be extra orthogonal bugs in the lowering
pipeline that broke the Hexagon unittests. It would be great if @Lunderberg or
anyone else who's familiar with Hexagon pipeline to take a further look into
this issue.
Meanwhile, given this PR fixes existing issues and brings in extremely
valuable new functionalities, I would love to propose that @tqchen we should
add unittests to catch the regression, skip the hexagon tests to get this PR
merged, and recover the those tests once we have another fix.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]