This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 10fb8c52d9 [MetaSchedule] Introduce Async Pipeline in MultiLevelTiling 
(#14009)
10fb8c52d9 is described below

commit 10fb8c52d9de47693dc45c3013666bad0547f52b
Author: Tian Xia <[email protected]>
AuthorDate: Sat Feb 25 22:21:21 2023 +0800

    [MetaSchedule] Introduce Async Pipeline in MultiLevelTiling (#14009)
    
    This PR introduces async pipeline in the current TVM's MultiLevelTiling 
Rules. This PR is based on apache/tvm#13966, which is already merged. This is 
because some conv2d workload will use `tir.if_then_else` to pad the input to 
the correct size, and this PR uses async copy in such copy statement.
    
    1. Add a subrule in 
`src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async 
copy for mlt in supported arch (>= sm80).
    
    In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d 
test cases and 1T ~ 2T in most GEMM test cases.
    
    All generated codes, scripts, and traces are available at 
https://github.com/Rainy-Memory/tvm-async-rule-benchmark.
    
    Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in 
conv2d NHWC workload, with a RTX 3080 GPU.
    
    **Notice: given the stochastic nature of evolutionary search, perfromance 
might become worse if enable this PR.**
    
    Workload: Conv2d NHWC
    
    |Shape|Mainline TVM|Mainline TVM with Async|Performance Boost|
    |-|-|-|-|
    
|N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452|6.141343581679319%|
    
|N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553|3.9936140067192905%|
    
|N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249|12.91442839038028%|
    
|N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499|9.736540600527816%|
    
|N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089|8.938178277203573%|
    
|N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634|21.606469175684712%|
    
|N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405|5.38669535070061%|
    
|N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657|0.9929043414276753%|
    
|N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712|21.039199780135142%|
    
|N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142|10.240086132713124%|
    
|N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054|11.805599824451525%|
    
|N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279|16.90569456283206%|
    
|N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379|16.144646629759908%|
    
|N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938|40.65865373586739%|
    
|N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068|24.873706981617943%|
    
|N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217|14.862815953549454%|
    
|N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014|36.766909259541826%|
    
    Workload: GEMM NN
    
    |Shape|Mainline TVM|Mainline TVM with Async|Performance Boost|
    |-|-|-|-|
    |M=512_N=256_K=640|8678.46|10607.37|22.226408832903555%|
    |M=512_N=384_K=256|8109.13|10290.72|26.902886006267003%|
    |M=512_N=512_K=512|11419.83|14000.86|22.601299669084398%|
    |M=512_N=3072_K=768|19709.39|18351.61|-6.8890006235606425%|
    |M=512_N=768_K=3072|12844.59|13730.88|6.90010346768561%|
    |M=896_N=896_K=896|16149.91|16131.39|-0.11467556165947945%|
    |M=1024_N=1024_K=1024|18842.11|19662.8|4.355616223448428%|
    |M=1152_N=1152_K=1152|15386.79|16736.1|8.769275462913303%|
    |M=1536_N=1536_K=1536|18522.67|18872.06|1.88628313304725%|
    |M=2048_N=2048_K=2048|19515.42|18874.85|-3.282378754851291%|
    |M=3072_N=3072_K=3072|19233.9|19291.42|0.2990553137948975%|
    |M=4096_N=4096_K=4096|17122.17|19259.01|12.479960191961652%|
---
 .../schedule_rule/multi_level_tiling.cc            |  56 ++++
 .../schedule_rule/multi_level_tiling.h             |   4 +
 .../test_meta_schedule_schedule_rule_mlt.py        |   6 +-
 .../unittest/test_meta_schedule_space_cuda.py      |   2 +-
 .../test_meta_schedule_space_cuda_async.py         | 340 +++++++++++++++++++++
 .../test_meta_schedule_space_cuda_winograd.py      |   2 +-
 6 files changed, 405 insertions(+), 5 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 324eedafb9..54407c46c8 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -87,6 +87,23 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const 
TuneContext& context)
       TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined 
in the target";
     }
   }
+  if (Optional<String> opt_sm = 
context->target.value()->GetAttr<String>("arch")) {
+    std::string sm = opt_sm.value();
+    if (support::StartsWith(sm, "sm_")) {
+      sm = sm.substr(3);
+      try {
+        // only sm_80 or higher supports async memcopy
+        if (std::stoi(sm) >= 80) {
+          // only stage = 4 & 5 is tested. all integer that is bigger than 2
+          // is theoretically feasible, but no guarantee for great performance.
+          this->stages.insert(this->stages.end(), {4, 5});
+        }
+      } catch (const std::invalid_argument& e) {
+        LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
+                     << ". Details: " << e.what();
+      }
+    }
+  }
   logger = context->logger;
 }
 
@@ -115,6 +132,8 @@ std::vector<State> 
MultiLevelTilingNode::ApplySubRules(std::vector<State> states
   states = SubRule(std::move(states), [&](State state) { return 
TileLoopNest(std::move(state)); });
   states = SubRule(std::move(states), [&](State state) { return 
AddWriteReuse(std::move(state)); });
   states = SubRule(std::move(states), [&](State state) { return 
AddReadReuse(std::move(state)); });
+  states =
+      SubRule(std::move(states), [&](State state) { return 
AddAsyncPipeline(std::move(state)); });
   return states;
 }
 
@@ -280,6 +299,43 @@ std::vector<State> 
MultiLevelTilingNode::AddReadReuse(State state) const {
   return results;
 }
 
+std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
+  // For arch that does not support async pipeline, this->stages will be an 
empty vector
+  if (r_indices_.size() < 1 || this->stages.empty()) {
+    return {state};
+  }
+  // Current only support default config used by ScheduleRule::DefaultCUDA
+  // @see src/meta_schedule/schedule_rule/schedule_rule.cc
+  // check the reduce loop contains exactly 3 for loops
+  // therefore it matches the notation array size in the following code
+  tir::StmtSRef r_loop_sref = 
state->sch->GetSRef(state->tiles[r_indices_[0]].back());
+  const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
+  Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
+  if (seq.size() != 3) {
+    return {state};
+  }
+  for (auto& stmt : seq) {
+    if (!stmt.as<tir::ForNode>()) {
+      return {state};
+    }
+  }
+
+  std::vector<State> ret;
+  ret.push_back(state);
+  for (int stage : this->stages) {
+    State new_state = state->Copy();
+    LoopRV r_loop_fused = 
new_state->sch->Fuse(new_state->tiles[r_indices_[0]]);
+    new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
+                             Array<Integer>{0, 0, stage - 2});
+    new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
+                             Array<Integer>{0, 1, 2});
+    new_state->sch->Annotate(r_loop_fused, 
tir::attr::software_pipeline_async_stages,
+                             Array<Integer>{0});
+    ret.push_back(std::move(new_state));
+  }
+  return ret;
+}
+
 void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
                                                        const tir::BlockRV& 
block) const {
   // Filter out invalid vector lanes according to the data type.
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h 
b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index d8725a3060..ff38756ff0 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   std::vector<State> TileLoopNest(State state) const;
   // SubRule 3. add read cache
   std::vector<State> AddReadReuse(State state) const;
+  // SubRule 4. add async pipeline
+  std::vector<State> AddAsyncPipeline(State state) const;
 
   // Do nothing; Inherited from ScheduleRuleNode
   void InitializeWithTuneContext(const TuneContext& context) final;
@@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   int thread_warp_size_;
   /*! \brief The maximum number of threads to be used size of a thread warp */
   int max_threads_per_block_;
+  /*! \brief All available async pipeline stages. */
+  std::vector<int> stages;
   /*! \brief The logging function */
   PackedFunc logger;
   /*! \brief The function to overwrite the default condition for applying 
MultiLevelTiling. */
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index 66eb819122..497915cd65 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -365,7 +365,7 @@ def test_cuda_matmul():
     actual = generate_design_space(
         kind="cuda",
         mod=mod,
-        target=Target("nvidia/geforce-rtx-3080"),
+        target=Target("nvidia/geforce-rtx-2080"),  # disable async trace using 
sm75
         types=ms.schedule_rule.MultiLevelTiling,
     )
     check_sketches(
@@ -483,7 +483,7 @@ def test_cuda_matmul_relu():
     actual = generate_design_space(
         kind="cuda",
         mod=mod,
-        target=Target("nvidia/geforce-rtx-3080"),
+        target=Target("nvidia/geforce-rtx-2080"),  # disable async trace using 
sm75
         types=ms.schedule_rule.MultiLevelTiling,
     )
     check_sketches(
@@ -723,7 +723,7 @@ def test_cache_read_specify_consumer():
     space = generate_design_space(
         kind="cuda",
         mod=mod,
-        target=Target("nvidia/geforce-rtx-3080"),
+        target=Target("nvidia/geforce-rtx-2080"),  # disable async trace using 
sm75
         types=ms.schedule_rule.MultiLevelTiling,
     )
     check_sketches(
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py 
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 241fe63e1d..bc674064d1 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -27,7 +27,7 @@ from tvm.target import Target
 
 
 def _target():
-    return Target("nvidia/geforce-rtx-3070")
+    return Target("nvidia/geforce-rtx-2080")  # disable async trace using sm75
 
 
 def _design_space(mod):
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_async.py 
b/tests/python/unittest/test_meta_schedule_space_cuda_async.py
new file mode 100644
index 0000000000..d31d626696
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_space_cuda_async.py
@@ -0,0 +1,340 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for MetaSchedule search space on CUDA"""
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import (
+    check_sketches,
+    generate_design_space,
+    print_sketches,
+)
+from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _target():
+    return Target("nvidia/geforce-rtx-3070")
+
+
+def _design_space(mod):
+    return generate_design_space(
+        kind="cuda",
+        mod=mod,
+        target=_target(),
+        types=ms.ScheduleRule,
+    )
+
+
+def get_c2d_prim_func(stage: int):
+    if stage == 0:
+        # fmt: off
+        @T.prim_func
+        def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: 
T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), 
"float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T.block_attr({"meta_schedule.unroll_explicit": 1024})
+                conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), 
scope="local")
+                PadInput_shared = T.alloc_buffer((1, 230, 230, 3), 
scope="shared")
+                weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared")
+                for n_0_h_0_w_0_co_0_fused in T.thread_binding(112, 
thread="blockIdx.x"):
+                    for n_1_h_1_w_1_co_1_fused in T.thread_binding(8, 
thread="vthread.x"):
+                        for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, 
thread="threadIdx.x"):
+                            for rh_0, rw_0, rc_0 in T.grid(1, 1, 3):
+                                for ax0_ax1_ax2_ax3_fused in range(693):
+                                    with T.block("PadInput_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(230, 
n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33)
+                                        v2 = T.axis.spatial(230, 
n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33)
+                                        v3 = T.axis.spatial(3, rc_0)
+                                        T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
+                                        T.writes(PadInput_shared[v0, v1, v2, 
v3])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        PadInput_shared[v0, v1, v2, v3] = 
T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 
3, v2 - 3, v3], T.float32(0))
+                                for ax0_ax1_ax2_ax3_fused in range(3136):
+                                    with T.block("weight_shared"):
+                                        v0 = T.axis.spatial(7, 
ax0_ax1_ax2_ax3_fused // 448)
+                                        v1 = T.axis.spatial(7, 
ax0_ax1_ax2_ax3_fused % 448 // 64)
+                                        v2 = T.axis.spatial(3, rc_0)
+                                        v3 = T.axis.spatial(64, 
ax0_ax1_ax2_ax3_fused % 64)
+                                        T.reads(weight[v0, v1, v2, v3])
+                                        T.writes(weight_shared[v0, v1, v2, v3])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 3})
+                                        weight_shared[v0, v1, v2, v3] = 
weight[v0, v1, v2, v3]
+                                for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, 
rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 
1, 1, 1, 1):
+                                    with T.block("conv2d_nhwc"):
+                                        v_n = T.axis.spatial(1, n_3 + n_4)
+                                        v_h = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + 
n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4)
+                                        v_w = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4)
+                                        v_co = T.axis.spatial(64, co_3 + co_4 
+ n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16)
+                                        v_rh = T.axis.reduce(7, rh_0 * 7 + 
rh_1 + rh_2)
+                                        v_rw = T.axis.reduce(7, rw_0 * 7 + 
rw_1 * 7 + rw_2)
+                                        v_rc = T.axis.reduce(3, rc_1 + rc_2 + 
rc_0)
+                                        T.reads(PadInput_shared[v_n, v_h * 2 + 
v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, 
v_co])
+                                        T.writes(conv2d_nhwc_local[v_n, v_h, 
v_w, v_co])
+                                        
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, 
"meta_schedule.thread_extent_low_inclusive": 32, 
"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                        with T.init():
+                                            conv2d_nhwc_local[v_n, v_h, v_w, 
v_co] = T.float32(0)
+                                        conv2d_nhwc_local[v_n, v_h, v_w, v_co] 
= conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, 
v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co]
+                            for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1):
+                                with T.block("conv2d_nhwc_local"):
+                                    v0 = T.axis.spatial(1, ax0)
+                                    v1 = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + 
n_2_h_2_w_2_co_2_fused // 16 + ax1)
+                                    v2 = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2)
+                                    v3 = T.axis.spatial(64, 
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3)
+                                    T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
+                                    T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                    conv2d_nhwc[v0, v1, v2, v3] = 
conv2d_nhwc_local[v0, v1, v2, v3]
+
+        # fmt: on
+    else:
+        # fmt: off
+        @T.prim_func
+        def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: 
T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), 
"float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T.block_attr({"meta_schedule.unroll_explicit": 1024})
+                conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), 
scope="local")
+                PadInput_shared = T.alloc_buffer((1, 230, 230, 3), 
scope="shared")
+                weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared")
+                for n_0_h_0_w_0_co_0_fused in T.thread_binding(112, 
thread="blockIdx.x"):
+                    for n_1_h_1_w_1_co_1_fused in T.thread_binding(8, 
thread="vthread.x"):
+                        for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, 
thread="threadIdx.x"):
+                            for rh_0_rw_0_rc_0_fused in T.serial(3, 
annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": 
[0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}):
+                                for ax0_ax1_ax2_ax3_fused in range(693):
+                                    with T.block("PadInput_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(230, 
n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33)
+                                        v2 = T.axis.spatial(230, 
n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33)
+                                        v3 = T.axis.spatial(3, 
rh_0_rw_0_rc_0_fused)
+                                        T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
+                                        T.writes(PadInput_shared[v0, v1, v2, 
v3])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        PadInput_shared[v0, v1, v2, v3] = 
T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 
3, v2 - 3, v3], T.float32(0))
+                                for ax0_ax1_ax2_ax3_fused in range(3136):
+                                    with T.block("weight_shared"):
+                                        v0 = T.axis.spatial(7, 
ax0_ax1_ax2_ax3_fused // 448)
+                                        v1 = T.axis.spatial(7, 
ax0_ax1_ax2_ax3_fused % 448 // 64)
+                                        v2 = T.axis.spatial(3, 
rh_0_rw_0_rc_0_fused)
+                                        v3 = T.axis.spatial(64, 
ax0_ax1_ax2_ax3_fused % 64)
+                                        T.reads(weight[v0, v1, v2, v3])
+                                        T.writes(weight_shared[v0, v1, v2, v3])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 3})
+                                        weight_shared[v0, v1, v2, v3] = 
weight[v0, v1, v2, v3]
+                                for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, 
rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 
1, 1, 1, 1):
+                                    with T.block("conv2d_nhwc"):
+                                        v_n = T.axis.spatial(1, n_4 + n_3)
+                                        v_h = T.axis.spatial(112, h_3 + h_4 + 
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + 
n_2_h_2_w_2_co_2_fused // 16)
+                                        v_w = T.axis.spatial(112, w_4 + 
n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3)
+                                        v_co = T.axis.spatial(64, co_4 + 
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3)
+                                        v_rh = T.axis.reduce(7, rh_2 + rh_1)
+                                        v_rw = T.axis.reduce(7, rw_1 * 7 + 
rw_2)
+                                        v_rc = T.axis.reduce(3, rc_2 + 
rh_0_rw_0_rc_0_fused + rc_1)
+                                        T.reads(PadInput_shared[v_n, v_h * 2 + 
v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, 
v_co])
+                                        T.writes(conv2d_nhwc_local[v_n, v_h, 
v_w, v_co])
+                                        
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, 
"meta_schedule.thread_extent_low_inclusive": 32, 
"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                        with T.init():
+                                            conv2d_nhwc_local[v_n, v_h, v_w, 
v_co] = T.float32(0)
+                                        conv2d_nhwc_local[v_n, v_h, v_w, v_co] 
= conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, 
v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co]
+                            for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1):
+                                with T.block("conv2d_nhwc_local"):
+                                    v0 = T.axis.spatial(1, ax0)
+                                    v1 = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + 
n_2_h_2_w_2_co_2_fused // 16 + ax1)
+                                    v2 = T.axis.spatial(112, 
n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2)
+                                    v3 = T.axis.spatial(64, 
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3)
+                                    T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
+                                    T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                    conv2d_nhwc[v0, v1, v2, v3] = 
conv2d_nhwc_local[v0, v1, v2, v3]
+        # fmt: on
+    return c2d
+
+
+def test_cuda_c2d():
+    c2d_decision = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [14, 2, 4, 1, 1]),
+        ("SamplePerfectTile", [8, 1, 1, 14, 1]),
+        ("SamplePerfectTile", [1, 4, 16, 1, 1]),
+        ("SamplePerfectTile", [1, 7, 1]),
+        ("SamplePerfectTile", [1, 1, 7]),
+        ("SamplePerfectTile", [3, 1, 1]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 4),
+    ]
+
+    mod = create_te_workload("C2D", 0)
+    actual = _design_space(mod)
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[
+            get_c2d_prim_func(stage=0),
+            get_c2d_prim_func(stage=4),
+            get_c2d_prim_func(stage=5),
+        ],
+        expected_decisions=[c2d_decision, c2d_decision, c2d_decision],
+    )
+
+
+def get_gmm_prim_func(stage: int):
+    if stage == 0:
+        # fmt: off
+        @T.prim_func
+        def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 
1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T.block_attr({"meta_schedule.unroll_explicit": 16})
+                Y_local = T.alloc_buffer((1, 1024, 1024), scope="local")
+                A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+                B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+                for b_0_i_0_j_0_fused in T.thread_binding(256, 
thread="blockIdx.x"):
+                    for b_1_i_1_j_1_fused in T.thread_binding(32, 
thread="vthread.x"):
+                        for b_2_i_2_j_2_fused in T.thread_binding(64, 
thread="threadIdx.x"):
+                            for k_0 in range(64):
+                                for ax0_ax1_ax2_fused in range(1024):
+                                    with T.block("A_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16)
+                                        v2 = T.axis.spatial(1024, k_0 * 16 + 
ax0_ax1_ax2_fused % 16)
+                                        T.reads(A[v0, v1, v2])
+                                        T.writes(A_shared[v0, v1, v2])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        A_shared[v0, v1, v2] = A[v0, v1, v2]
+                                for ax0_ax1_ax2_fused in range(1024):
+                                    with T.block("B_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(1024, k_0 * 16 + 
ax0_ax1_ax2_fused // 64)
+                                        v2 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64)
+                                        T.reads(B[v0, v1, v2])
+                                        T.writes(B_shared[v0, v1, v2])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        B_shared[v0, v1, v2] = B[v0, v1, v2]
+                                for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in 
T.grid(2, 1, 1, 1, 8, 1, 1, 2):
+                                    with T.block("Y"):
+                                        v_b = T.axis.spatial(1, b_4 + b_3)
+                                        v_i = T.axis.spatial(1024, 
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused 
// 8 + i_3 + i_4)
+                                        v_j = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 
8 * 2 + j_3 * 2 + j_4)
+                                        v_k = T.axis.reduce(1024, k_0 * 16 + 
k_1 * 8 + k_2)
+                                        T.reads(A_shared[v_b, v_i, v_k], 
B_shared[v_b, v_k, v_j])
+                                        T.writes(Y_local[v_b, v_i, v_j])
+                                        
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, 
"meta_schedule.thread_extent_low_inclusive": 32, 
"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                        with T.init():
+                                            Y_local[v_b, v_i, v_j] = 
T.float32(0)
+                                        Y_local[v_b, v_i, v_j] = Y_local[v_b, 
v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j]
+                            for ax0, ax1, ax2 in T.grid(1, 1, 2):
+                                with T.block("Y_local"):
+                                    v0 = T.axis.spatial(1, ax0)
+                                    v1 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused 
// 8 + ax1)
+                                    v2 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 
8 * 2 + ax2)
+                                    T.reads(Y_local[v0, v1, v2])
+                                    T.writes(Y[v0, v1, v2])
+                                    Y[v0, v1, v2] = Y_local[v0, v1, v2]
+
+        # fmt: on
+    else:
+        # fmt: off
+        @T.prim_func
+        def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 
1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T.block_attr({"meta_schedule.unroll_explicit": 16})
+                Y_local = T.alloc_buffer((1, 1024, 1024), scope="local")
+                A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+                B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+                for b_0_i_0_j_0_fused in T.thread_binding(256, 
thread="blockIdx.x"):
+                    for b_1_i_1_j_1_fused in T.thread_binding(32, 
thread="vthread.x"):
+                        for b_2_i_2_j_2_fused in T.thread_binding(64, 
thread="threadIdx.x"):
+                            for k_0_fused in T.serial(64, 
annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": 
[0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}):
+                                for ax0_ax1_ax2_fused in range(1024):
+                                    with T.block("A_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16)
+                                        v2 = T.axis.spatial(1024, k_0_fused * 
16 + ax0_ax1_ax2_fused % 16)
+                                        T.reads(A[v0, v1, v2])
+                                        T.writes(A_shared[v0, v1, v2])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        A_shared[v0, v1, v2] = A[v0, v1, v2]
+                                for ax0_ax1_ax2_fused in range(1024):
+                                    with T.block("B_shared"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(1024, k_0_fused * 
16 + ax0_ax1_ax2_fused // 64)
+                                        v2 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64)
+                                        T.reads(B[v0, v1, v2])
+                                        T.writes(B_shared[v0, v1, v2])
+                                        
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                        B_shared[v0, v1, v2] = B[v0, v1, v2]
+                                for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in 
T.grid(2, 1, 1, 1, 8, 1, 1, 2):
+                                    with T.block("Y"):
+                                        v_b = T.axis.spatial(1, b_3 + b_4)
+                                        v_i = T.axis.spatial(1024, i_3 + i_4 + 
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused 
// 8)
+                                        v_j = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 
8 * 2 + j_3 * 2 + j_4)
+                                        v_k = T.axis.reduce(1024, k_0_fused * 
16 + k_1 * 8 + k_2)
+                                        T.reads(A_shared[v_b, v_i, v_k], 
B_shared[v_b, v_k, v_j])
+                                        T.writes(Y_local[v_b, v_i, v_j])
+                                        
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, 
"meta_schedule.thread_extent_low_inclusive": 32, 
"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                        with T.init():
+                                            Y_local[v_b, v_i, v_j] = 
T.float32(0)
+                                        Y_local[v_b, v_i, v_j] = Y_local[v_b, 
v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j]
+                            for ax0, ax1, ax2 in T.grid(1, 1, 2):
+                                with T.block("Y_local"):
+                                    v0 = T.axis.spatial(1, ax0)
+                                    v1 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused 
// 8 + ax1)
+                                    v2 = T.axis.spatial(1024, 
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 
8 * 2 + ax2)
+                                    T.reads(Y_local[v0, v1, v2])
+                                    T.writes(Y[v0, v1, v2])
+                                    Y[v0, v1, v2] = Y_local[v0, v1, v2]
+
+        # fmt: on
+    return gmm
+
+
+def test_cuda_gmm():
+    gmm_decision = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [16, 8, 8, 1, 1]),
+        ("SamplePerfectTile", [16, 4, 8, 1, 2]),
+        ("SamplePerfectTile", [64, 2, 8]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 1),
+    ]
+
+    mod = create_te_workload("GMM", 3)
+    actual = _design_space(mod)
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[
+            get_gmm_prim_func(stage=0),
+            get_gmm_prim_func(stage=4),
+            get_gmm_prim_func(stage=5),
+        ],
+        expected_decisions=[gmm_decision, gmm_decision, gmm_decision],
+    )
+
+
+if __name__ == "__main__":
+    test_cuda_c2d()
+    test_cuda_gmm()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py 
b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
index 87a8fcac98..e8ed3bb8b2 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
@@ -27,7 +27,7 @@ from tvm.target import Target
 
 
 def _target():
-    return Target("nvidia/geforce-rtx-3070")
+    return Target("nvidia/geforce-rtx-2080")  # disable async trace using sm75
 
 
 def _design_space(mod):


Reply via email to