This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 26ad703338 [MetaSchedule] Handle 'warp_execution' in 
RewriteCooperativeFetch (#11955)
26ad703338 is described below

commit 26ad70333875c55eec438b840d004a5fb9255572
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jun 30 16:16:58 2022 -0700

    [MetaSchedule] Handle 'warp_execution' in RewriteCooperativeFetch (#11955)
    
    Updated `RewriteCooperativeFetch` to handle 'warp_execution' annotation 
when the extend of `threadIdx.x` is not specified
---
 .../postproc/rewrite_cooperative_fetch.cc          |  33 ++++-
 ..._schedule_postproc_rewrite_cooperative_fetch.py | 151 ++++++++++++++++++++-
 2 files changed, 182 insertions(+), 2 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc 
b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
index 798f00423f..d111bdb42a 100644
--- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
+++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
@@ -65,6 +65,23 @@ Optional<BlockRV> ParseAnnotate(const Schedule& sch, const 
Instruction& inst,
   return Downcast<BlockRV>(inst->inputs[0]);
 }
 
+/*!
+ * \brief Parse instruction: sch.annotate(..., attr::warp_execution)
+ * \param sch The schedule
+ * \param inst The instruction to be parsed
+ * \return Whether ths parsing is successful
+ */
+bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
+  static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
+  if (!inst->kind.same_as(inst_kind_annotate)) {
+    return false;
+  }
+  ICHECK_EQ(inst->inputs.size(), 2);
+  ICHECK_EQ(inst->attrs.size(), 1);
+  String ann_key = Downcast<String>(inst->attrs[0]);
+  return ann_key == attr::warp_execution;
+}
+
 }  // namespace tir
 
 namespace meta_schedule {
@@ -76,7 +93,14 @@ namespace meta_schedule {
 class RewriteCooperativeFetchNode : public PostprocNode {
  public:
   // Inherited from PostprocNode
-  void InitializeWithTuneContext(const TuneContext& context) final {}
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    if (Optional<Integer> v = 
context->target.value()->GetAttr<Integer>("thread_warp_size")) {
+      this->thread_warp_size_ = v.value()->value;
+    } else {
+      TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not 
defined in the target";
+    }
+  }
+
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final;
 
@@ -84,6 +108,9 @@ class RewriteCooperativeFetchNode : public PostprocNode {
 
   static constexpr const char* _type_key = 
"meta_schedule.RewriteCooperativeFetch";
   TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);
+
+ private:
+  int thread_warp_size_ = -1;
 };
 
 bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
@@ -101,6 +128,10 @@ bool RewriteCooperativeFetchNode::Apply(const 
tir::Schedule& sch) {
       thread_extent_y = new_thread_extent.value()->value;
       continue;
     }
+    if (tir::ParseWarpExecutionAnn(sch, inst)) {
+      thread_extent_x = thread_warp_size_;
+      continue;
+    }
     Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, 
&vector_lane);
     if (!opt_block_rv.defined()) {
       continue;
diff --git 
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py
 
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py
index 5460c59009..e55f693e72 100644
--- 
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py
+++ 
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py
@@ -17,6 +17,7 @@
 # pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 
 import tvm
+import tvm.testing
 from tvm import tir
 from tvm.meta_schedule import TuneContext
 from tvm.meta_schedule.postproc import RewriteCooperativeFetch
@@ -99,6 +100,108 @@ class AfterRewrite0:
                             C[v0, v1] = C_local[v0, v1]
 
 
[email protected]_module
+class WarpExecutionAfterRewrite:
+    @T.prim_func
+    def main(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        C: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
+        A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
+            for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
+                for i0_2_i1_2_fused in T.thread_binding(0, 8, 
thread="threadIdx.y"):
+                    for i2_0 in T.serial(0, 1):
+                        for ax0_ax1_fused_0 in T.serial(0, 1024):
+                            for ax0_ax1_fused_1 in T.thread_binding(0, 8, 
thread="threadIdx.y"):
+                                for ax0_ax1_fused_2 in T.thread_binding(
+                                    0, 32, thread="threadIdx.x"
+                                ):
+                                    with T.block("A_shared"):
+                                        v0 = T.axis.spatial(
+                                            512,
+                                            (
+                                                ax0_ax1_fused_0 * 256
+                                                + ax0_ax1_fused_1 * 32
+                                                + ax0_ax1_fused_2
+                                            )
+                                            // 512,
+                                        )
+                                        v1 = T.axis.spatial(
+                                            512,
+                                            (
+                                                ax0_ax1_fused_0 * 256
+                                                + ax0_ax1_fused_1 * 32
+                                                + ax0_ax1_fused_2
+                                            )
+                                            % 512,
+                                        )
+                                        T.reads([A[v0, v1]])
+                                        T.writes([A_shared[v0, v1]])
+                                        A_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused_0 in T.serial(0, 32):
+                            for ax0_ax1_fused_1 in T.thread_binding(0, 8, 
thread="threadIdx.y"):
+                                for ax0_ax1_fused_2 in T.thread_binding(
+                                    0, 32, thread="threadIdx.x"
+                                ):
+                                    for ax0_ax1_fused_3 in T.vectorized(0, 2):
+                                        with T.block("B_shared"):
+                                            v0 = T.axis.spatial(
+                                                512,
+                                                (
+                                                    ax0_ax1_fused_0 * 512
+                                                    + ax0_ax1_fused_1 * 64
+                                                    + ax0_ax1_fused_2 * 2
+                                                    + ax0_ax1_fused_3
+                                                )
+                                                // 32,
+                                            )
+                                            v1 = T.axis.spatial(
+                                                512,
+                                                i0_0_i1_0_fused * 32
+                                                + (
+                                                    ax0_ax1_fused_0 * 512
+                                                    + ax0_ax1_fused_1 * 64
+                                                    + ax0_ax1_fused_2 * 2
+                                                    + ax0_ax1_fused_3
+                                                )
+                                                % 32,
+                                            )
+                                            T.reads([B[v0, v1]])
+                                            T.writes([B_shared[v0, v1]])
+                                            B_shared[v0, v1] = B[v0, v1]
+                        for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 
2, 2, 32, 16, 2):
+                            with T.block("C"):
+                                i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + 
i0_3 * 16 + i0_4)
+                                j = T.axis.spatial(
+                                    512,
+                                    i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 
+ i1_3 * 2 + i1_4,
+                                )
+                                k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 
+ i2_2)
+                                T.reads([A_shared[i, k], B_shared[k, j]])
+                                T.writes([C_local[i, j]])
+                                T.block_attr({"warp_execution": 1})
+                                with T.init():
+                                    C_local[i, j] = T.float32(0)
+                                C_local[i, j] = C_local[i, j] + A_shared[i, k] 
* B_shared[k, j]
+                    for ax0, ax1 in T.grid(32, 4):
+                        with T.block("C_local"):
+                            v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + 
ax0)
+                            v1 = T.axis.spatial(
+                                512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 
4 + ax1
+                            )
+                            T.reads([C_local[v0, v1]])
+                            T.writes([C[v0, v1]])
+                            C[v0, v1] = C_local[v0, v1]
+
+
 # pylint: 
enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 # fmt: on
 
@@ -147,5 +250,51 @@ def test_rewrite_cooperative_fetch():
     tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)
 
 
+def test_rewrite_warp_execution():
+    mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512))
+    target = _target()
+    ctx = _create_context(mod, target)
+
+    sch = tir.Schedule(mod, debug_mask="all")
+    # fmt: off
+    # pylint: disable=line-too-long,invalid-name
+    b0 = sch.get_block(name="C", func_name="main")
+    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")
+    l2, l3, l4 = sch.get_loops(block=b0)
+    sch.annotate(b0, "warp_execution", 1)
+    v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, 
max_innermost_factor=64, decision=[1, 16, 1, 2, 16])
+    l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])
+    v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, 
max_innermost_factor=64, decision=[16, 1, 8, 2, 2])
+    l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, 
v19])
+    v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, 
max_innermost_factor=64, decision=[1, 16, 32])
+    l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])
+    sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, 
l24)
+    l31 = sch.fuse(l10, l20)
+    sch.bind(loop=l31, thread_axis="blockIdx.x")
+    l32 = sch.fuse(l11, l21)
+    sch.bind(loop=l32, thread_axis="vthread.x")
+    l33 = sch.fuse(l12, l22)
+    sch.bind(loop=l33, thread_axis="threadIdx.y")
+    b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")
+    sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)
+    _, _, _, _, l39, l40 = sch.get_loops(block=b34)
+    l41 = sch.fuse(l39, l40)
+    _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, 
decision=[262144, 1])
+    sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", 
ann_val=v43)
+    b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")
+    sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)
+    _, _, _, _, l49, l50 = sch.get_loops(block=b44)
+    l51 = sch.fuse(l49, l50)
+    _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, 
decision=[8192, 2])
+    sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", 
ann_val=v53)
+    sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)
+    # pylint: enable=line-too-long,invalid-name
+    # fmt: on
+    sch.enter_postproc()
+    assert ctx.postprocs[0].apply(sch)
+    print(sch.mod["main"].script())
+    tvm.ir.assert_structural_equal(sch.mod, WarpExecutionAfterRewrite)
+
+
 if __name__ == "__main__":
-    test_rewrite_cooperative_fetch()
+    tvm.testing.main()

Reply via email to