junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713613369



##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1223,6 +1223,70 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
 
     ########## Schedule: Reduction ##########
 
+    def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
+        """Decompose a reduction block into init block and update block, where 
the newly generated
+        init block will be before the specified loop.
+
+        1) The block is a reduction block.
+
+        2) The loop is the ancestor of the block.
+
+        3) The loop is not lower than all the loops related to reduce block 
var.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The reduction block to be decomposed
+        loop : LoopRV
+            The position where init block is inserted
+
+        Examples
+        --------
+        Before decompose-reduction, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") 
as [vi, vj, vk]:
+                        with tir.init():
+                            C[vi, vj] = 0.0
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Create the schedule and do decompose-reduction with specified loop:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_decompose)
+            C = sch.get_block("C")
+            i, j, k = sch.get_loops(C)
+            sch.decompose_reduction(C, i)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying decompose-reduction, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i in tir.serial(128):
+                    for j in tir.serial(128):
+                        with tir.block([128, 128]) as [vi, vj]:
+                            C[vi, vj] = 0.0
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") 
as [vi, vj, vk]:
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        """
+        _ffi_api.ScheduleDecomposeReduction(self, block, loop)  # type: ignore 
# pylint: disable=no-member

Review comment:
       Looks like we forgot to return anything?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to