liangW-intellif opened a new pull request, #13033:
URL: https://github.com/apache/tvm/pull/13033

   Hi, this PR ported the rolling_buffer primitive from TE schedule to TensorIR 
schedule, refer to [[RFC] Introducing a ‘rolling_buffer’ scheduling 
primitive](https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836).
   The primitive performs the following steps to transform the target buffer 
into a 'rolling buffer':
   1. Collect bound overlaps on the target buffer, and select the outermost 
rollable axis appeared in the block's loop nest as the 'rolling axis'.
   2. Append block predicate to the producer block of the target buffer to 
avoid recomputation.
   3. Use modulo arithmetic to modify the target buffer's read and load indices 
to circularize the buffer along the rolling dimension.
   
   Note:  The region_cover property of the consumer block of the target buffer 
will become false.
   ### Example
   - Before
   ```python
   def before_rolling_buffer(
       A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"]
   ) -> None:
       # body
       # with T.block("root")
       B = T.alloc_buffer([10, 10], dtype="int8")
       for i0, i1 in T.grid(2, 2):
           for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
               with T.block("B"):
                   ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                   ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                   rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                   B[ax0_1, ax1_1] = T.max(B[ax0_1, ax1_1], A[ax0_1 + rv0, 
ax1_1 + rv1])
           for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
               with T.block("C"):
                   ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                   ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                   rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                   C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 + rv0, 
ax1_1 + rv1])
   
   ```
   - After `sch.rolling_buffer(sch.get_block("B"), buffer_index=0)`
   ```python
   @T.prim_func
   def after_rolling_buffer(
       A: T.Buffer[(12, 12), "int8"],
       C: T.Buffer[(8, 8), "int8"]
   ) -> None:
       # body
       # with T.block("root")
       B = T.alloc_buffer([6, 10], dtype="int8")
       for i0, i1 in T.grid(2, 2):
           for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
               with T.block("B"):
                   T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1))
                   ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                   ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                   rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                   B[ax0_1 % 6, ax1_1] = T.max(B[ax0_1 % 6, ax1_1], A[ax0_1 + 
rv0, ax1_1 + rv1])
           for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
               with T.block("C"):
                   ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                   ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                   rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                   C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, 
ax1_1 + rv1])
   ```
   ### Difference from TE rolling_buffer
   TIR rolling_buffer will only select a dimension with a positive bound 
overlap as rolling dimension, consider the following example, the collected 
bound overlap for buffer B is [0, 0, 2, 0].
   ```python
   @T.prim_func
   def before(
       A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), 
"int8"]
   ):
       B = T.alloc_buffer([1, 12, 14, 16], dtype="int8")
       for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1):
           for ax0, ax1, ax2 in T.grid(4, 6, 16):
               with T.block("B"):
                   ax0_1 = T.axis.spatial(1, 0)
                   ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0)
                   ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1)
                   ax3 = T.axis.spatial(16, ax2)
                   B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3]
           for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3):
               with T.block("C"):
                   ax0 = T.axis.spatial(1, i0_0 + i0_1)
                   ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1)
                   ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1)
                   ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1)
                   rv0, rv1 = T.axis.remap("RR", [i4, i5])
                   C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, 
ax1 + rv0, ax2 + rv1, ax3])
   ```
   For the logic of TE rolling_buffer, i1_0 will be selected as the rolling 
axis and its range will be folded to [0, 4] to compact and minimize the buffer 
size. But for TensorIR, buffer region compaction will be performed by 
CompactBufferAllocation pass, so the primitive will select i2_0 with a positive 
bound overlap to be the rolling axis to circularize the buffer.
   ```python
   @T.prim_func
   def after(A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), 
"int8"]) -> None:
       # body
       # with T.block("root")
       B = T.alloc_buffer([1, 12, 6, 16], dtype="int8")
       for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1):
           for ax0, ax1, ax2 in T.grid(4, 6, 16):
               with T.block("B"):
                   T.where(i2_0 < 1 or 2 <= ax1)
                   ax0_1 = T.axis.spatial(1, 0)
                   ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0)
                   ax2_1 = T.axis.opaque(12, i2_0 * 4 + ax1)
                   ax3 = T.axis.spatial(16, ax2)
                   B[ax0_1, ax1_1, ax2_1 % 6, ax3] = T.max(B[ax0_1, ax1_1, 
ax2_1 % 6, ax3], A[ax0_1, ax1_1, ax2_1, ax3])
           for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3):
               with T.block("C"):
                   ax0 = T.axis.spatial(1, i0_0 + i0_1)
                   ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1)
                   ax2 = T.axis.opaque(12, i2_0 * 4 + i2_1)
                   ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1)
                   rv0, rv1 = T.axis.remap("RR", [i4, i5])
                   C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, 
ax1 + rv0, (ax2 + rv1) % 6, ax3])
   ```
   cc @wrongtest-intellif 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to