junrushao commented on PR #14547:
URL: https://github.com/apache/tvm/pull/14547#issuecomment-1502377735

   Interestingly, this PR happened to reveal a potential issue with the 
existing async pipeline when scheduling with Hexagon targets, where arithmetic 
analysis leads `FlattenBuffer` to generate wrong IRs.
   
   The TIR before `FlattenBuffer`:
   
   ```python
   @T.prim_func
   def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64), 
"uint8")):
       T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
       a_buffer_global_vtcm = T.decl_buffer((2, 1, 64), "uint8", 
scope="global.vtcm")
       out_global_vtcm = T.decl_buffer((2, 1, 64), "uint8", scope="global.vtcm")
       for i in T.unroll(2):
           if i < 8:
               T.attr(0, "async_commit_queue_scope", 0)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   a_buffer_global_vtcm[i % 2, i - i, ax0] = a_buffer[i, ax0]
           if i == 1 and i - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 1)
               for j in range(64):
                   out_global_vtcm[(i - 1) % 2, i - 1 - (i - 1), j] = 
a_buffer_global_vtcm[(i - 1) % 2, i - 1 - (i - 1), j] + T.uint8(1)
       for i in range(6):
           if i + 2 < 8:
               T.attr(0, "async_commit_queue_scope", 0)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   a_buffer_global_vtcm[(i + 2) % 2, i + 2 - (i + 2), ax0] = 
a_buffer[i + 2, ax0]
           if i + 2 - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 1)
               for j in range(64):
                   out_global_vtcm[(i - 1 + 2) % 2, i - 1 + 2 - (i - 1 + 2), j] 
= a_buffer_global_vtcm[(i - 1 + 2) % 2, i - 1 + 2 - (i - 1 + 2), j] + T.uint8(1)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   #0: Note the indexing here. The number in 
the middle is essentially 0.
           if i + 2 - 2 < 8:
               T.attr(0, "async_commit_queue_scope", 2)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   out[i - 2 + 2, ax0] = out_global_vtcm[(i - 2 + 2) % 2, i - 2 
+ 2 - (i - 2 + 2), ax0]
       for i in T.unroll(2):
           if i + 8 - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 0 - i)
               for j in range(64):
                   out_global_vtcm[(i - 1 + 8) % 2, i - 1 + 8 - (i - 1 + 8), j] 
= a_buffer_global_vtcm[(i - 1 + 8) % 2, i - 1 + 8 - (i - 1 + 8), j] + T.uint8(1)
           if i + 8 - 2 < 8:
               T.attr(0, "async_commit_queue_scope", 2)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   out[i - 2 + 8, ax0] = out_global_vtcm[(i - 2 + 8) % 2, i - 2 
+ 8 - (i - 2 + 8), ax0]
   ```
   
   The TIR after `FlattenBuffer`, before this PR:
   
   ```python
   @I.ir_module
   class Module:
       @T.prim_func
       def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64), 
"uint8")):
           T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
           a_buffer_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
           out_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
           a_buffer_global_vtcm_1 = T.Buffer((128,), "uint8", 
data=a_buffer_global_vtcm, scope="global.vtcm")
           a_buffer_1 = T.Buffer((512,), "uint8", data=a_buffer.data)
           out_global_vtcm_1 = T.Buffer((128,), "uint8", data=out_global_vtcm, 
scope="global.vtcm")
           for i in T.unroll(2):
               if i < 8:
                   T.attr(0, "async_commit_queue_scope", 0)
                   T.attr(0, "async_scope", 1)
                   for ax0 in range(64):
                       a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i 
* 64 + ax0]
               if i == 1 and i - 1 < 8:
                   T.attr(0, "async_wait_queue_scope", 0)
                   T.attr(0, "async_wait_inflight_count", 1)
                   for j in range(64):
                       out_global_vtcm_1[j + 64 - i % 2 * 64] = 
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
           out_1 = T.Buffer((512,), "uint8", data=out.data)
           for i in range(6):
               if i + 2 < 8:
                   T.attr(0, "async_commit_queue_scope", 0)
                   T.attr(0, "async_scope", 1)
                   for ax0 in range(64):
                       a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i 
* 64 + ax0 + 128]
               if i + 2 - 1 < 8:
                   T.attr(0, "async_wait_queue_scope", 0)
                   T.attr(0, "async_wait_inflight_count", 1)
                   for j in range(64):
                       out_global_vtcm_1[j + 64 - i % 2 * 64] = 
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
                                         ^^^^^^^^^^^^^^^^^^^
                                         #1: The `-i % 2` thing here doesn't 
look right
               if i + 2 - 2 < 8:
                   T.attr(0, "async_commit_queue_scope", 2)
                   T.attr(0, "async_scope", 1)
                   for ax0 in range(64):
                       out_1[i * 64 + ax0] = out_global_vtcm_1[i % 2 * 64 + ax0]
           for i in T.unroll(2):
               if i + 8 - 1 < 8:
                   T.attr(0, "async_wait_queue_scope", 0)
                   T.attr(0, "async_wait_inflight_count", 0 - i)
                   for j in range(64):
                       out_global_vtcm_1[j + 64 - i % 2 * 64] = 
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
               if i + 8 - 2 < 8:
                   T.attr(0, "async_commit_queue_scope", 2)
                   T.attr(0, "async_scope", 1)
                   for ax0 in range(64):
                       out_1[i * 64 + ax0 + 384] = out_global_vtcm_1[i % 2 * 64 
+ ax0]
   ```
   
   This PR fixes this issue and generates the correct IR:
   
   ```python
   @T.prim_func
   def main(a_buffer: T.Buffer((8, 64), "uint8"), out: T.Buffer((8, 64), 
"uint8")):
       T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
       a_buffer_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
       out_global_vtcm = T.allocate([128], "uint8", "global.vtcm")
       a_buffer_global_vtcm_1 = T.Buffer((128,), "uint8", 
data=a_buffer_global_vtcm, scope="global.vtcm")
       a_buffer_1 = T.Buffer((512,), "uint8", data=a_buffer.data)
       out_global_vtcm_1 = T.Buffer((128,), "uint8", data=out_global_vtcm, 
scope="global.vtcm")
       for i in T.unroll(2):
           if i < 8:
               T.attr(0, "async_commit_queue_scope", 0)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   a_buffer_global_vtcm_1[i * 64 + ax0] = a_buffer_1[i * 64 + 
ax0]
           if i == 1 and i - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 1)
               for j in range(64):
                   out_global_vtcm_1[j + 64 - i % 2 * 64] = 
a_buffer_global_vtcm_1[j + 64 - i % 2 * 64] + T.uint8(1)
       out_1 = T.Buffer((512,), "uint8", data=out.data)
       for i in range(6):
           if i + 2 < 8:
               T.attr(0, "async_commit_queue_scope", 0)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   a_buffer_global_vtcm_1[i % 2 * 64 + ax0] = a_buffer_1[i * 64 
+ ax0 + 128]
           if i + 2 - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 1)
               for j in range(64):
                   out_global_vtcm_1[i % 2 * 64 + j] = a_buffer_global_vtcm_1[i 
% 2 * 64 + j] + T.uint8(1)
                                     ^^^^^^^^^^^^^^^^^^^
                                     #2: This is the correct index after buffer 
flattening
           if i + 2 - 2 < 8:
               T.attr(0, "async_commit_queue_scope", 2)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   out_1[i * 64 + ax0] = out_global_vtcm_1[i % 2 * 64 + ax0]
       for i in T.unroll(2):
           if i + 8 - 1 < 8:
               T.attr(0, "async_wait_queue_scope", 0)
               T.attr(0, "async_wait_inflight_count", 0 - i)
               for j in range(64):
                   out_global_vtcm_1[i * 64 + j] = a_buffer_global_vtcm_1[i * 
64 + j] + T.uint8(1)
           if i + 8 - 2 < 8:
               T.attr(0, "async_commit_queue_scope", 2)
               T.attr(0, "async_scope", 1)
               for ax0 in range(64):
                   out_1[i * 64 + ax0 + 384] = out_global_vtcm_1[i * 64 + ax0]
   ```
   
   However, it seems that there might be extra orthogonal bugs in the lowering 
pipeline that broke the Hexagon unittests. It would be great if @Lunderberg or 
anyone else who's familiar with Hexagon pipeline to take a further look into 
this issue.
   
   Meanwhile, given this PR fixes existing issues and brings in extremely 
valuable new functionalities, I would love to propose that @tqchen we should 
add unittests to catch the regression, skip the hexagon tests to get this PR 
merged, and recover the those tests once we have another fix.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to