junrushao1994 commented on a change in pull request #9041: URL: https://github.com/apache/tvm/pull/9041#discussion_r713613369
########## File path: python/tvm/tir/schedule/schedule.py ########## @@ -1223,6 +1223,70 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: ########## Schedule: Reduction ########## + def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: + """Decompose a reduction block into init block and update block, where the newly generated + init block will be before the specified loop. + + 1) The block is a reduction block. + + 2) The loop is the ancestor of the block. + + 3) The loop is not lower than all the loops related to reduce block var. + + Parameters + ---------- + block : BlockRV + The reduction block to be decomposed + loop : LoopRV + The position where init block is inserted + + Examples + -------- + Before decompose-reduction, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_decompose(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + Create the schedule and do decompose-reduction with specified loop: + + .. code-block:: python + + sch = tir.Schedule(before_decompose) + C = sch.get_block("C") + i, j, k = sch.get_loops(C) + sch.decompose_reduction(C, i) + print(tvm.script.asscript(sch.mod["main"])) + + After applying decompose-reduction, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_decompose(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i in tir.serial(128): + for j in tir.serial(128): + with tir.block([128, 128]) as [vi, vj]: + C[vi, vj] = 0.0 + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + """ + _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member Review comment: Looks like we forgot to return anything? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org