junrushao1994 commented on a change in pull request #8716:
URL: https://github.com/apache/tvm/pull/8716#discussion_r687320417



##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -444,6 +444,233 @@ def after_split(a: ty.handle, b: ty.handle) -> None:
 
     ########## Schedule: Manipulate ForKind ##########
 
+    def parallel(self, loop: LoopRV) -> None:
+        """Parallelize the input loop. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, the loop can only be contained in 
data-parallel block
+        iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be parallelized
+
+        Examples
+        --------
+
+        Before parallel, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_parallel(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, i)
+                        tir.bind(vj, j)
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do parallel:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_parallel)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.parallel(i)
+
+        After applying parallel, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_parallel(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i in tir.parallel(0, 128):
+                    for j in tir.serial(0, 128):
+                        with tir.block([128, 128], "B") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        _ffi_api.ScheduleParallel(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    def vectorize(self, loop: LoopRV) -> None:
+        """Vectorize the input loop. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, the loop can only be contained in 
data-parallel block
+        iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be vectorized
+
+        Examples
+        --------
+
+        Before vectorize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_vectorize(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, i)
+                        tir.bind(vj, j)
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do vectorize:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_vectorize)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.vectorize(j)
+
+        After applying vectorize, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_vectorize(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i in tir.serial(0, 128):
+                    for j in tir.vectorized(0, 128):
+                        with tir.block([128, 128], "B") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        _ffi_api.ScheduleVectorize(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    def bind(self, loop: LoopRV, thread_axis: str) -> None:
+        """Bind the input loop to the given thread axis. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, if the thread axis starts with 
"threadIdx`, the loop can
+        only be contained in data-parallel block iter and reduction block 
iters' bindings. Otherwise
+        the loop can only be contained in data-parallel block iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be bound to the thread axis
+        thread_axis : str
+            The thread axis to be bound to the loop. Possible candidates:
+            - blockIdx.x/y/z
+            - threadIdx.x/y/z
+            - vthread
+            - vthread.x/y/z

Review comment:
       There is no hardware correspondence to vthread either, that’s just 
definition :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to