liangW-intellif opened a new pull request, #13033: URL: https://github.com/apache/tvm/pull/13033
Hi, this PR ported the rolling_buffer primitive from TE schedule to TensorIR schedule, refer to [[RFC] Introducing a ‘rolling_buffer’ scheduling primitive](https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836). The primitive performs the following steps to transform the target buffer into a 'rolling buffer': 1. Collect bound overlaps on the target buffer, and select the outermost rollable axis appeared in the block's loop nest as the 'rolling axis'. 2. Append block predicate to the producer block of the target buffer to avoid recomputation. 3. Use modulo arithmetic to modify the target buffer's read and load indices to circularize the buffer along the rolling dimension. Note: The region_cover property of the consumer block of the target buffer will become false. ### Example - Before ```python def before_rolling_buffer( A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"] ) -> None: # body # with T.block("root") B = T.alloc_buffer([10, 10], dtype="int8") for i0, i1 in T.grid(2, 2): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): with T.block("B"): ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) B[ax0_1, ax1_1] = T.max(B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): with T.block("C"): ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1]) ``` - After `sch.rolling_buffer(sch.get_block("B"), buffer_index=0)` ```python @T.prim_func def after_rolling_buffer( A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"] ) -> None: # body # with T.block("root") B = T.alloc_buffer([6, 10], dtype="int8") for i0, i1 in T.grid(2, 2): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): with T.block("B"): T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1)) ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) B[ax0_1 % 6, ax1_1] = T.max(B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): with T.block("C"): ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1]) ``` ### Difference from TE rolling_buffer TIR rolling_buffer will only select a dimension with a positive bound overlap as rolling dimension, consider the following example, the collected bound overlap for buffer B is [0, 0, 2, 0]. ```python @T.prim_func def before( A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] ): B = T.alloc_buffer([1, 12, 14, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): for ax0, ax1, ax2 in T.grid(4, 6, 16): with T.block("B"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1) ax3 = T.axis.spatial(16, ax2) B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3] for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3): with T.block("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1) ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) rv0, rv1 = T.axis.remap("RR", [i4, i5]) C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3]) ``` For the logic of TE rolling_buffer, i1_0 will be selected as the rolling axis and its range will be folded to [0, 4] to compact and minimize the buffer size. But for TensorIR, buffer region compaction will be performed by CompactBufferAllocation pass, so the primitive will select i2_0 with a positive bound overlap to be the rolling axis to circularize the buffer. ```python @T.prim_func def after(A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"]) -> None: # body # with T.block("root") B = T.alloc_buffer([1, 12, 6, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): for ax0, ax1, ax2 in T.grid(4, 6, 16): with T.block("B"): T.where(i2_0 < 1 or 2 <= ax1) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) ax2_1 = T.axis.opaque(12, i2_0 * 4 + ax1) ax3 = T.axis.spatial(16, ax2) B[ax0_1, ax1_1, ax2_1 % 6, ax3] = T.max(B[ax0_1, ax1_1, ax2_1 % 6, ax3], A[ax0_1, ax1_1, ax2_1, ax3]) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3): with T.block("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) ax2 = T.axis.opaque(12, i2_0 * 4 + i2_1) ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) rv0, rv1 = T.axis.remap("RR", [i4, i5]) C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3]) ``` cc @wrongtest-intellif -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
