Hzfengsy commented on a change in pull request #8716:
URL: https://github.com/apache/tvm/pull/8716#discussion_r686552320



##########
File path: tests/python/unittest/test_tir_schedule_for_kind.py
##########
@@ -0,0 +1,320 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review comment:
       Please add test cases for `bind` and `parallel` working on a block with 
`predicate`

##########
File path: tests/python/unittest/test_tir_schedule_for_kind.py
##########
@@ -0,0 +1,320 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]
+def element_wise(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128, 128))
+
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def element_wise_parallelized(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    for i0 in tir.parallel(0, 128):
+        for i1 in tir.serial(0, 128):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i0)
+                tir.bind(vj, i1)
+                B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def element_wise_i_bound(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    for i0 in tir.thread_binding(0, 128, thread="threadIdx.x"):
+        for i1 in tir.serial(0, 128):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i0)
+                tir.bind(vj, i1)
+                B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def element_wise_compute_at_split(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    for i in tir.serial(0, 128):
+        for j0 in tir.serial(0, 128):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i)
+                tir.bind(vj, j0)
+                B[vi, vj] = A[vi, vj] * 2.0
+        for j1o, j1i in tir.grid(32, 4):
+            with tir.block([128, 128], "C") as [vi, vj]:
+                tir.bind(vi, i)
+                tir.bind(vj, j1o * 4 + j1i)
+                C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_compute_at_split_vectorized(a: ty.handle, c: ty.handle) -> 
None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    for i in tir.serial(0, 128):
+        for j0 in tir.serial(0, 128):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i)
+                tir.bind(vj, j0)
+                B[vi, vj] = A[vi, vj] * 2.0
+        for j1o in tir.serial(0, 32):
+            for j1i in tir.vectorized(0, 4):
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    tir.bind(vi, i)
+                    tir.bind(vj, j1o * 4 + j1i)
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_compute_at_split_j0_j1o_bound(a: ty.handle, c: ty.handle) -> 
None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    for i in tir.serial(0, 128):
+        for j0 in tir.thread_binding(0, 128, thread="threadIdx.x"):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i)
+                tir.bind(vj, j0)
+                B[vi, vj] = A[vi, vj] * 2.0
+        for j1o in tir.thread_binding(0, 32, thread="threadIdx.x"):
+            for j1i in tir.serial(0, 4):
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    tir.bind(vi, i)
+                    tir.bind(vj, j1o * 4 + j1i)
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+
+    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def rowsum(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128,))
+
+    with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
+        with tir.init():
+            B[vi] = 0.0
+        B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]
+def rowsum_unrolled(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128,))
+    for i0 in tir.unroll(0, 128):
+        for i1 in tir.serial(0, 128):
+            with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
+                tir.bind(vi, i0)
+                tir.bind(vk, i1)
+                with tir.init():
+                    B[vi] = 0.0
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]
+def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128,))
+
+    for i, k in tir.grid(128, 16):
+        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
+            tir.bind(vi, i)
+            tir.bind(vk, tir.floordiv(k * k, 2))
+            with tir.init():
+                B[vi] = 0.0
+            B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]
+def rowsum_not_compact_data_flow(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128,))
+
+    with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
+        with tir.init():
+            B[vk] = 0.0
+        B[vk] = B[vk] + A[vi, vk]
+
+
[email protected]
+def rowsum_cross_thread_reduction(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128,))
+    for i0 in tir.serial(0, 128):
+        for i1 in tir.thread_binding(0, 128, thread="threadIdx.x"):
+            with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
+                tir.bind(vi, i0)
+                tir.bind(vk, i1)
+                with tir.init():
+                    B[vi] = 0.0
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]
+def opaque_block(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16,))
+    for i in tir.serial(0, 15):
+        with tir.block([], "opaque"):
+            A[i + 1] = A[i + 1] + A[i]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+
+def test_parallel():
+    s = tir.Schedule(element_wise, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    s.parallel(i)
+    tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized)
+    verify_trace_roundtrip(s, mod=element_wise)
+
+
+def test_parallel_reduction_block_iter():
+    s = tir.Schedule(matmul, debug_mask="all")
+    _, _, k = s.get_loops(s.get_block("C"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.parallel(k)
+
+
+def test_parallel_not_quasi_affine():
+    s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.parallel(i)
+
+
+def test_parallel_not_compact_data_flow():
+    s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.parallel(i)
+
+
+def test_vectorize():
+    s = tir.Schedule(element_wise_compute_at_split, debug_mask="all")
+    _, _, j1i = s.get_loops(s.get_block("C"))
+    s.vectorize(j1i)
+    tvm.ir.assert_structural_equal(s.mod["main"], 
element_wise_compute_at_split_vectorized)
+    verify_trace_roundtrip(s, mod=element_wise_compute_at_split)
+
+
+def test_vectorize_opaque_block():
+    s = tir.Schedule(opaque_block, debug_mask="all")
+    (i,) = s.get_loops(s.get_block("opaque"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.vectorize(i)
+
+
+def test_unroll():
+    s = tir.Schedule(rowsum, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    s.unroll(i)
+    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled)
+    verify_trace_roundtrip(s, mod=rowsum)
+
+
+def test_unroll_after_bind():
+    s = tir.Schedule(rowsum, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    s.bind(i, "blockIdx.x")
+    s.unroll(i)
+    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled)
+    verify_trace_roundtrip(s, mod=rowsum)
+
+
+def test_bind1():
+    s = tir.Schedule(element_wise, debug_mask="all")
+    i, _ = s.get_loops(s.get_block("B"))
+    s.bind(i, "threadIdx.x")
+    tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound)
+    verify_trace_roundtrip(s, mod=element_wise)
+
+
+def test_bind2():
+    s = tir.Schedule(element_wise_compute_at_split, debug_mask="all")
+    _, j0 = s.get_loops(s.get_block("B"))
+    _, j1o, _ = s.get_loops(s.get_block("C"))
+    s.bind(j0, "threadIdx.x")
+    s.bind(j1o, "threadIdx.x")
+    tvm.ir.assert_structural_equal(s.mod["main"], 
element_wise_compute_at_split_j0_j1o_bound)
+    verify_trace_roundtrip(s, mod=element_wise_compute_at_split)
+
+
+def test_bind_cross_thread_reduction():
+    s = tir.Schedule(rowsum, debug_mask="all")
+    _, k = s.get_loops(s.get_block("B"))
+    s.bind(k, "threadIdx.x")
+    tvm.ir.assert_structural_equal(s.mod["main"], 
rowsum_cross_thread_reduction)
+    verify_trace_roundtrip(s, mod=rowsum)
+
+
+def test_bind_not_cross_thread_reduction():
+    s = tir.Schedule(rowsum, debug_mask="all")
+    _, k = s.get_loops(s.get_block("B"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.bind(k, "blockIdx.x")
+
+
+def test_bind_after_parallel():

Review comment:
       This test is a bit tricky. Since `parallel` is usually for CPU while 
`bind` is for GPU. 
   I agree that we do not need to check it. But please remove it from the 
testcase, which may confuse users.

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -444,6 +444,233 @@ def after_split(a: ty.handle, b: ty.handle) -> None:
 
     ########## Schedule: Manipulate ForKind ##########
 
+    def parallel(self, loop: LoopRV) -> None:
+        """Parallelize the input loop. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, the loop can only be contained in 
data-parallel block
+        iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be parallelized
+
+        Examples
+        --------
+
+        Before parallel, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_parallel(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, i)
+                        tir.bind(vj, j)
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do parallel:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_parallel)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.parallel(i)
+
+        After applying parallel, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_parallel(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i in tir.parallel(0, 128):
+                    for j in tir.serial(0, 128):
+                        with tir.block([128, 128], "B") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        _ffi_api.ScheduleParallel(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    def vectorize(self, loop: LoopRV) -> None:
+        """Vectorize the input loop. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, the loop can only be contained in 
data-parallel block
+        iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be vectorized
+
+        Examples
+        --------
+
+        Before vectorize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_vectorize(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, i)
+                        tir.bind(vj, j)
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do vectorize:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_vectorize)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.vectorize(j)
+
+        After applying vectorize, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_vectorize(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i in tir.serial(0, 128):
+                    for j in tir.vectorized(0, 128):
+                        with tir.block([128, 128], "B") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        _ffi_api.ScheduleVectorize(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    def bind(self, loop: LoopRV, thread_axis: str) -> None:
+        """Bind the input loop to the given thread axis. It requires:
+        1) The scope block that the loop is in should have stage-pipeline 
property
+        2) All the blocks under the loop are complete blocks or reduction 
blocks, and have affine
+        bindings
+        3) For each block under the loop, if the thread axis starts with 
"threadIdx`, the loop can
+        only be contained in data-parallel block iter and reduction block 
iters' bindings. Otherwise
+        the loop can only be contained in data-parallel block iters' bindings
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be bound to the thread axis
+        thread_axis : str
+            The thread axis to be bound to the loop. Possible candidates:
+            - blockIdx.x/y/z
+            - threadIdx.x/y/z
+            - vthread
+            - vthread.x/y/z

Review comment:
       ```suggestion
               - vthread.x/y/z (not supported now)
   ```

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -345,7 +346,40 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& 
loop_rv,
   return CreateRV<LoopRV>(results);
 }
 
-/******** Schedule: compute location ********/
+/******** Schedule: Manipulate ForKind ********/
+
+void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::Parallel(state_, this->GetSRef(loop_rv));
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_);
+}
+
+void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::Vectorize(state_, this->GetSRef(loop_rv));
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_);
+}
+
+void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& 
thread_axis) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::Bind(state_, this->GetSRef(loop_rv),
+            IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), 
/*iter_type=*/kThreadIndex,

Review comment:
       It's OK for now. But it would be great to make `range` optional in the 
future.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to