This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 10fb8c52d9 [MetaSchedule] Introduce Async Pipeline in MultiLevelTiling
(#14009)
10fb8c52d9 is described below
commit 10fb8c52d9de47693dc45c3013666bad0547f52b
Author: Tian Xia <[email protected]>
AuthorDate: Sat Feb 25 22:21:21 2023 +0800
[MetaSchedule] Introduce Async Pipeline in MultiLevelTiling (#14009)
This PR introduces async pipeline in the current TVM's MultiLevelTiling
Rules. This PR is based on apache/tvm#13966, which is already merged. This is
because some conv2d workload will use `tir.if_then_else` to pad the input to
the correct size, and this PR uses async copy in such copy statement.
1. Add a subrule in
`src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async
copy for mlt in supported arch (>= sm80).
In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d
test cases and 1T ~ 2T in most GEMM test cases.
All generated codes, scripts, and traces are available at
https://github.com/Rainy-Memory/tvm-async-rule-benchmark.
Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in
conv2d NHWC workload, with a RTX 3080 GPU.
**Notice: given the stochastic nature of evolutionary search, perfromance
might become worse if enable this PR.**
Workload: Conv2d NHWC
|Shape|Mainline TVM|Mainline TVM with Async|Performance Boost|
|-|-|-|-|
|N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452|6.141343581679319%|
|N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553|3.9936140067192905%|
|N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249|12.91442839038028%|
|N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499|9.736540600527816%|
|N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089|8.938178277203573%|
|N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634|21.606469175684712%|
|N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405|5.38669535070061%|
|N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657|0.9929043414276753%|
|N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712|21.039199780135142%|
|N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142|10.240086132713124%|
|N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054|11.805599824451525%|
|N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279|16.90569456283206%|
|N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379|16.144646629759908%|
|N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938|40.65865373586739%|
|N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068|24.873706981617943%|
|N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217|14.862815953549454%|
|N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014|36.766909259541826%|
Workload: GEMM NN
|Shape|Mainline TVM|Mainline TVM with Async|Performance Boost|
|-|-|-|-|
|M=512_N=256_K=640|8678.46|10607.37|22.226408832903555%|
|M=512_N=384_K=256|8109.13|10290.72|26.902886006267003%|
|M=512_N=512_K=512|11419.83|14000.86|22.601299669084398%|
|M=512_N=3072_K=768|19709.39|18351.61|-6.8890006235606425%|
|M=512_N=768_K=3072|12844.59|13730.88|6.90010346768561%|
|M=896_N=896_K=896|16149.91|16131.39|-0.11467556165947945%|
|M=1024_N=1024_K=1024|18842.11|19662.8|4.355616223448428%|
|M=1152_N=1152_K=1152|15386.79|16736.1|8.769275462913303%|
|M=1536_N=1536_K=1536|18522.67|18872.06|1.88628313304725%|
|M=2048_N=2048_K=2048|19515.42|18874.85|-3.282378754851291%|
|M=3072_N=3072_K=3072|19233.9|19291.42|0.2990553137948975%|
|M=4096_N=4096_K=4096|17122.17|19259.01|12.479960191961652%|
---
.../schedule_rule/multi_level_tiling.cc | 56 ++++
.../schedule_rule/multi_level_tiling.h | 4 +
.../test_meta_schedule_schedule_rule_mlt.py | 6 +-
.../unittest/test_meta_schedule_space_cuda.py | 2 +-
.../test_meta_schedule_space_cuda_async.py | 340 +++++++++++++++++++++
.../test_meta_schedule_space_cuda_winograd.py | 2 +-
6 files changed, 405 insertions(+), 5 deletions(-)
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 324eedafb9..54407c46c8 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -87,6 +87,23 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const
TuneContext& context)
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined
in the target";
}
}
+ if (Optional<String> opt_sm =
context->target.value()->GetAttr<String>("arch")) {
+ std::string sm = opt_sm.value();
+ if (support::StartsWith(sm, "sm_")) {
+ sm = sm.substr(3);
+ try {
+ // only sm_80 or higher supports async memcopy
+ if (std::stoi(sm) >= 80) {
+ // only stage = 4 & 5 is tested. all integer that is bigger than 2
+ // is theoretically feasible, but no guarantee for great performance.
+ this->stages.insert(this->stages.end(), {4, 5});
+ }
+ } catch (const std::invalid_argument& e) {
+ LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
+ << ". Details: " << e.what();
+ }
+ }
+ }
logger = context->logger;
}
@@ -115,6 +132,8 @@ std::vector<State>
MultiLevelTilingNode::ApplySubRules(std::vector<State> states
states = SubRule(std::move(states), [&](State state) { return
TileLoopNest(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return
AddWriteReuse(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return
AddReadReuse(std::move(state)); });
+ states =
+ SubRule(std::move(states), [&](State state) { return
AddAsyncPipeline(std::move(state)); });
return states;
}
@@ -280,6 +299,43 @@ std::vector<State>
MultiLevelTilingNode::AddReadReuse(State state) const {
return results;
}
+std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
+ // For arch that does not support async pipeline, this->stages will be an
empty vector
+ if (r_indices_.size() < 1 || this->stages.empty()) {
+ return {state};
+ }
+ // Current only support default config used by ScheduleRule::DefaultCUDA
+ // @see src/meta_schedule/schedule_rule/schedule_rule.cc
+ // check the reduce loop contains exactly 3 for loops
+ // therefore it matches the notation array size in the following code
+ tir::StmtSRef r_loop_sref =
state->sch->GetSRef(state->tiles[r_indices_[0]].back());
+ const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
+ Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
+ if (seq.size() != 3) {
+ return {state};
+ }
+ for (auto& stmt : seq) {
+ if (!stmt.as<tir::ForNode>()) {
+ return {state};
+ }
+ }
+
+ std::vector<State> ret;
+ ret.push_back(state);
+ for (int stage : this->stages) {
+ State new_state = state->Copy();
+ LoopRV r_loop_fused =
new_state->sch->Fuse(new_state->tiles[r_indices_[0]]);
+ new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
+ Array<Integer>{0, 0, stage - 2});
+ new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
+ Array<Integer>{0, 1, 2});
+ new_state->sch->Annotate(r_loop_fused,
tir::attr::software_pipeline_async_stages,
+ Array<Integer>{0});
+ ret.push_back(std::move(new_state));
+ }
+ return ret;
+}
+
void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV&
block) const {
// Filter out invalid vector lanes according to the data type.
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h
b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index d8725a3060..ff38756ff0 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
std::vector<State> TileLoopNest(State state) const;
// SubRule 3. add read cache
std::vector<State> AddReadReuse(State state) const;
+ // SubRule 4. add async pipeline
+ std::vector<State> AddAsyncPipeline(State state) const;
// Do nothing; Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final;
@@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
int thread_warp_size_;
/*! \brief The maximum number of threads to be used size of a thread warp */
int max_threads_per_block_;
+ /*! \brief All available async pipeline stages. */
+ std::vector<int> stages;
/*! \brief The logging function */
PackedFunc logger;
/*! \brief The function to overwrite the default condition for applying
MultiLevelTiling. */
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index 66eb819122..497915cd65 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -365,7 +365,7 @@ def test_cuda_matmul():
actual = generate_design_space(
kind="cuda",
mod=mod,
- target=Target("nvidia/geforce-rtx-3080"),
+ target=Target("nvidia/geforce-rtx-2080"), # disable async trace using
sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
@@ -483,7 +483,7 @@ def test_cuda_matmul_relu():
actual = generate_design_space(
kind="cuda",
mod=mod,
- target=Target("nvidia/geforce-rtx-3080"),
+ target=Target("nvidia/geforce-rtx-2080"), # disable async trace using
sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
@@ -723,7 +723,7 @@ def test_cache_read_specify_consumer():
space = generate_design_space(
kind="cuda",
mod=mod,
- target=Target("nvidia/geforce-rtx-3080"),
+ target=Target("nvidia/geforce-rtx-2080"), # disable async trace using
sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 241fe63e1d..bc674064d1 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -27,7 +27,7 @@ from tvm.target import Target
def _target():
- return Target("nvidia/geforce-rtx-3070")
+ return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75
def _design_space(mod):
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_async.py
b/tests/python/unittest/test_meta_schedule_space_cuda_async.py
new file mode 100644
index 0000000000..d31d626696
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_space_cuda_async.py
@@ -0,0 +1,340 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for MetaSchedule search space on CUDA"""
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import (
+ check_sketches,
+ generate_design_space,
+ print_sketches,
+)
+from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _target():
+ return Target("nvidia/geforce-rtx-3070")
+
+
+def _design_space(mod):
+ return generate_design_space(
+ kind="cuda",
+ mod=mod,
+ target=_target(),
+ types=ms.ScheduleRule,
+ )
+
+
+def get_c2d_prim_func(stage: int):
+ if stage == 0:
+ # fmt: off
+ @T.prim_func
+ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight:
T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64),
"float32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.unroll_explicit": 1024})
+ conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64),
scope="local")
+ PadInput_shared = T.alloc_buffer((1, 230, 230, 3),
scope="shared")
+ weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared")
+ for n_0_h_0_w_0_co_0_fused in T.thread_binding(112,
thread="blockIdx.x"):
+ for n_1_h_1_w_1_co_1_fused in T.thread_binding(8,
thread="vthread.x"):
+ for n_2_h_2_w_2_co_2_fused in T.thread_binding(64,
thread="threadIdx.x"):
+ for rh_0, rw_0, rc_0 in T.grid(1, 1, 3):
+ for ax0_ax1_ax2_ax3_fused in range(693):
+ with T.block("PadInput_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(230,
n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33)
+ v2 = T.axis.spatial(230,
n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33)
+ v3 = T.axis.spatial(3, rc_0)
+ T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
+ T.writes(PadInput_shared[v0, v1, v2,
v3])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ PadInput_shared[v0, v1, v2, v3] =
T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 -
3, v2 - 3, v3], T.float32(0))
+ for ax0_ax1_ax2_ax3_fused in range(3136):
+ with T.block("weight_shared"):
+ v0 = T.axis.spatial(7,
ax0_ax1_ax2_ax3_fused // 448)
+ v1 = T.axis.spatial(7,
ax0_ax1_ax2_ax3_fused % 448 // 64)
+ v2 = T.axis.spatial(3, rc_0)
+ v3 = T.axis.spatial(64,
ax0_ax1_ax2_ax3_fused % 64)
+ T.reads(weight[v0, v1, v2, v3])
+ T.writes(weight_shared[v0, v1, v2, v3])
+
T.block_attr({"meta_schedule.cooperative_fetch": 3})
+ weight_shared[v0, v1, v2, v3] =
weight[v0, v1, v2, v3]
+ for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3,
rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1,
1, 1, 1, 1):
+ with T.block("conv2d_nhwc"):
+ v_n = T.axis.spatial(1, n_3 + n_4)
+ v_h = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 +
n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4)
+ v_w = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4)
+ v_co = T.axis.spatial(64, co_3 + co_4
+ n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16)
+ v_rh = T.axis.reduce(7, rh_0 * 7 +
rh_1 + rh_2)
+ v_rw = T.axis.reduce(7, rw_0 * 7 +
rw_1 * 7 + rw_2)
+ v_rc = T.axis.reduce(3, rc_1 + rc_2 +
rc_0)
+ T.reads(PadInput_shared[v_n, v_h * 2 +
v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc,
v_co])
+ T.writes(conv2d_nhwc_local[v_n, v_h,
v_w, v_co])
+
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024,
"meta_schedule.thread_extent_low_inclusive": 32,
"meta_schedule.tiling_structure": "SSSRRSRS"})
+ with T.init():
+ conv2d_nhwc_local[v_n, v_h, v_w,
v_co] = T.float32(0)
+ conv2d_nhwc_local[v_n, v_h, v_w, v_co]
= conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh,
v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co]
+ for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1):
+ with T.block("conv2d_nhwc_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 +
n_2_h_2_w_2_co_2_fused // 16 + ax1)
+ v2 = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2)
+ v3 = T.axis.spatial(64,
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3)
+ T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
+ T.writes(conv2d_nhwc[v0, v1, v2, v3])
+ conv2d_nhwc[v0, v1, v2, v3] =
conv2d_nhwc_local[v0, v1, v2, v3]
+
+ # fmt: on
+ else:
+ # fmt: off
+ @T.prim_func
+ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight:
T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64),
"float32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.unroll_explicit": 1024})
+ conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64),
scope="local")
+ PadInput_shared = T.alloc_buffer((1, 230, 230, 3),
scope="shared")
+ weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared")
+ for n_0_h_0_w_0_co_0_fused in T.thread_binding(112,
thread="blockIdx.x"):
+ for n_1_h_1_w_1_co_1_fused in T.thread_binding(8,
thread="vthread.x"):
+ for n_2_h_2_w_2_co_2_fused in T.thread_binding(64,
thread="threadIdx.x"):
+ for rh_0_rw_0_rc_0_fused in T.serial(3,
annotations={"software_pipeline_async_stages": [0], "software_pipeline_order":
[0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}):
+ for ax0_ax1_ax2_ax3_fused in range(693):
+ with T.block("PadInput_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(230,
n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33)
+ v2 = T.axis.spatial(230,
n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33)
+ v3 = T.axis.spatial(3,
rh_0_rw_0_rc_0_fused)
+ T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
+ T.writes(PadInput_shared[v0, v1, v2,
v3])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ PadInput_shared[v0, v1, v2, v3] =
T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 -
3, v2 - 3, v3], T.float32(0))
+ for ax0_ax1_ax2_ax3_fused in range(3136):
+ with T.block("weight_shared"):
+ v0 = T.axis.spatial(7,
ax0_ax1_ax2_ax3_fused // 448)
+ v1 = T.axis.spatial(7,
ax0_ax1_ax2_ax3_fused % 448 // 64)
+ v2 = T.axis.spatial(3,
rh_0_rw_0_rc_0_fused)
+ v3 = T.axis.spatial(64,
ax0_ax1_ax2_ax3_fused % 64)
+ T.reads(weight[v0, v1, v2, v3])
+ T.writes(weight_shared[v0, v1, v2, v3])
+
T.block_attr({"meta_schedule.cooperative_fetch": 3})
+ weight_shared[v0, v1, v2, v3] =
weight[v0, v1, v2, v3]
+ for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3,
rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1,
1, 1, 1, 1):
+ with T.block("conv2d_nhwc"):
+ v_n = T.axis.spatial(1, n_4 + n_3)
+ v_h = T.axis.spatial(112, h_3 + h_4 +
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 +
n_2_h_2_w_2_co_2_fused // 16)
+ v_w = T.axis.spatial(112, w_4 +
n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3)
+ v_co = T.axis.spatial(64, co_4 +
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3)
+ v_rh = T.axis.reduce(7, rh_2 + rh_1)
+ v_rw = T.axis.reduce(7, rw_1 * 7 +
rw_2)
+ v_rc = T.axis.reduce(3, rc_2 +
rh_0_rw_0_rc_0_fused + rc_1)
+ T.reads(PadInput_shared[v_n, v_h * 2 +
v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc,
v_co])
+ T.writes(conv2d_nhwc_local[v_n, v_h,
v_w, v_co])
+
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024,
"meta_schedule.thread_extent_low_inclusive": 32,
"meta_schedule.tiling_structure": "SSSRRSRS"})
+ with T.init():
+ conv2d_nhwc_local[v_n, v_h, v_w,
v_co] = T.float32(0)
+ conv2d_nhwc_local[v_n, v_h, v_w, v_co]
= conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh,
v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co]
+ for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1):
+ with T.block("conv2d_nhwc_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 +
n_2_h_2_w_2_co_2_fused // 16 + ax1)
+ v2 = T.axis.spatial(112,
n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2)
+ v3 = T.axis.spatial(64,
n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3)
+ T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
+ T.writes(conv2d_nhwc[v0, v1, v2, v3])
+ conv2d_nhwc[v0, v1, v2, v3] =
conv2d_nhwc_local[v0, v1, v2, v3]
+ # fmt: on
+ return c2d
+
+
+def test_cuda_c2d():
+ c2d_decision = [
+ ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+ ("SamplePerfectTile", [14, 2, 4, 1, 1]),
+ ("SamplePerfectTile", [8, 1, 1, 14, 1]),
+ ("SamplePerfectTile", [1, 4, 16, 1, 1]),
+ ("SamplePerfectTile", [1, 7, 1]),
+ ("SamplePerfectTile", [1, 1, 7]),
+ ("SamplePerfectTile", [3, 1, 1]),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 2),
+ ("SampleCategorical", 4),
+ ]
+
+ mod = create_te_workload("C2D", 0)
+ actual = _design_space(mod)
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[
+ get_c2d_prim_func(stage=0),
+ get_c2d_prim_func(stage=4),
+ get_c2d_prim_func(stage=5),
+ ],
+ expected_decisions=[c2d_decision, c2d_decision, c2d_decision],
+ )
+
+
+def get_gmm_prim_func(stage: int):
+ if stage == 0:
+ # fmt: off
+ @T.prim_func
+ def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024,
1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.unroll_explicit": 16})
+ Y_local = T.alloc_buffer((1, 1024, 1024), scope="local")
+ A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+ B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+ for b_0_i_0_j_0_fused in T.thread_binding(256,
thread="blockIdx.x"):
+ for b_1_i_1_j_1_fused in T.thread_binding(32,
thread="vthread.x"):
+ for b_2_i_2_j_2_fused in T.thread_binding(64,
thread="threadIdx.x"):
+ for k_0 in range(64):
+ for ax0_ax1_ax2_fused in range(1024):
+ with T.block("A_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(1024,
b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16)
+ v2 = T.axis.spatial(1024, k_0 * 16 +
ax0_ax1_ax2_fused % 16)
+ T.reads(A[v0, v1, v2])
+ T.writes(A_shared[v0, v1, v2])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ A_shared[v0, v1, v2] = A[v0, v1, v2]
+ for ax0_ax1_ax2_fused in range(1024):
+ with T.block("B_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(1024, k_0 * 16 +
ax0_ax1_ax2_fused // 64)
+ v2 = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64)
+ T.reads(B[v0, v1, v2])
+ T.writes(B_shared[v0, v1, v2])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ B_shared[v0, v1, v2] = B[v0, v1, v2]
+ for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in
T.grid(2, 1, 1, 1, 8, 1, 1, 2):
+ with T.block("Y"):
+ v_b = T.axis.spatial(1, b_4 + b_3)
+ v_i = T.axis.spatial(1024,
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused
// 8 + i_3 + i_4)
+ v_j = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused %
8 * 2 + j_3 * 2 + j_4)
+ v_k = T.axis.reduce(1024, k_0 * 16 +
k_1 * 8 + k_2)
+ T.reads(A_shared[v_b, v_i, v_k],
B_shared[v_b, v_k, v_j])
+ T.writes(Y_local[v_b, v_i, v_j])
+
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024,
"meta_schedule.thread_extent_low_inclusive": 32,
"meta_schedule.tiling_structure": "SSSRRSRS"})
+ with T.init():
+ Y_local[v_b, v_i, v_j] =
T.float32(0)
+ Y_local[v_b, v_i, v_j] = Y_local[v_b,
v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j]
+ for ax0, ax1, ax2 in T.grid(1, 1, 2):
+ with T.block("Y_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(1024,
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused
// 8 + ax1)
+ v2 = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused %
8 * 2 + ax2)
+ T.reads(Y_local[v0, v1, v2])
+ T.writes(Y[v0, v1, v2])
+ Y[v0, v1, v2] = Y_local[v0, v1, v2]
+
+ # fmt: on
+ else:
+ # fmt: off
+ @T.prim_func
+ def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024,
1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.unroll_explicit": 16})
+ Y_local = T.alloc_buffer((1, 1024, 1024), scope="local")
+ A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+ B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared")
+ for b_0_i_0_j_0_fused in T.thread_binding(256,
thread="blockIdx.x"):
+ for b_1_i_1_j_1_fused in T.thread_binding(32,
thread="vthread.x"):
+ for b_2_i_2_j_2_fused in T.thread_binding(64,
thread="threadIdx.x"):
+ for k_0_fused in T.serial(64,
annotations={"software_pipeline_async_stages": [0], "software_pipeline_order":
[0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}):
+ for ax0_ax1_ax2_fused in range(1024):
+ with T.block("A_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(1024,
b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16)
+ v2 = T.axis.spatial(1024, k_0_fused *
16 + ax0_ax1_ax2_fused % 16)
+ T.reads(A[v0, v1, v2])
+ T.writes(A_shared[v0, v1, v2])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ A_shared[v0, v1, v2] = A[v0, v1, v2]
+ for ax0_ax1_ax2_fused in range(1024):
+ with T.block("B_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(1024, k_0_fused *
16 + ax0_ax1_ax2_fused // 64)
+ v2 = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64)
+ T.reads(B[v0, v1, v2])
+ T.writes(B_shared[v0, v1, v2])
+
T.block_attr({"meta_schedule.cooperative_fetch": 4})
+ B_shared[v0, v1, v2] = B[v0, v1, v2]
+ for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in
T.grid(2, 1, 1, 1, 8, 1, 1, 2):
+ with T.block("Y"):
+ v_b = T.axis.spatial(1, b_3 + b_4)
+ v_i = T.axis.spatial(1024, i_3 + i_4 +
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused
// 8)
+ v_j = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused %
8 * 2 + j_3 * 2 + j_4)
+ v_k = T.axis.reduce(1024, k_0_fused *
16 + k_1 * 8 + k_2)
+ T.reads(A_shared[v_b, v_i, v_k],
B_shared[v_b, v_k, v_j])
+ T.writes(Y_local[v_b, v_i, v_j])
+
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024,
"meta_schedule.thread_extent_low_inclusive": 32,
"meta_schedule.tiling_structure": "SSSRRSRS"})
+ with T.init():
+ Y_local[v_b, v_i, v_j] =
T.float32(0)
+ Y_local[v_b, v_i, v_j] = Y_local[v_b,
v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j]
+ for ax0, ax1, ax2 in T.grid(1, 1, 2):
+ with T.block("Y_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(1024,
b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused
// 8 + ax1)
+ v2 = T.axis.spatial(1024,
b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused %
8 * 2 + ax2)
+ T.reads(Y_local[v0, v1, v2])
+ T.writes(Y[v0, v1, v2])
+ Y[v0, v1, v2] = Y_local[v0, v1, v2]
+
+ # fmt: on
+ return gmm
+
+
+def test_cuda_gmm():
+ gmm_decision = [
+ ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+ ("SamplePerfectTile", [16, 8, 8, 1, 1]),
+ ("SamplePerfectTile", [16, 4, 8, 1, 2]),
+ ("SamplePerfectTile", [64, 2, 8]),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 1),
+ ]
+
+ mod = create_te_workload("GMM", 3)
+ actual = _design_space(mod)
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[
+ get_gmm_prim_func(stage=0),
+ get_gmm_prim_func(stage=4),
+ get_gmm_prim_func(stage=5),
+ ],
+ expected_decisions=[gmm_decision, gmm_decision, gmm_decision],
+ )
+
+
+if __name__ == "__main__":
+ test_cuda_c2d()
+ test_cuda_gmm()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
index 87a8fcac98..e8ed3bb8b2 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py
@@ -27,7 +27,7 @@ from tvm.target import Target
def _target():
- return Target("nvidia/geforce-rtx-3070")
+ return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75
def _design_space(mod):