This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 093720230b [Unity] Introduce Default GPU Schedule Pass (#14182)
093720230b is described below
commit 093720230bc2403d278e820bd164b93158ab8823
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Mar 6 10:19:19 2023 -0800
[Unity] Introduce Default GPU Schedule Pass (#14182)
* Implement default schedule.
* Add test.
* Add tests.
* Fix linting.
* Skip scheduled blocks.
* Address issues.
* Use target current.
* Minor fixes.
* Remove Mutator.
* Move pass to tir given it's all primfunc scheduling.
* Move unit tests.
* Remove redundant headers and adjust comments.
* Add more explanation to the pass documents.
* Modify pass name.
* Fix comment.
---
include/tvm/tir/transform.h | 12 +
python/tvm/tir/transform/transform.py | 18 +
.../space_generator/post_order_apply.cc | 62 ---
src/meta_schedule/utils.h | 62 +++
src/tir/transforms/default_gpu_schedule.cc | 116 ++++++
.../test_transform_default_gpu_schedule.py | 417 +++++++++++++++++++++
6 files changed, 625 insertions(+), 62 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d212578b8d..4653fa3640 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -717,6 +717,18 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
*/
TVM_DLL Pass InstrumentProfileIntrinsics();
+/*!
+ * \brief The pass sets default thread bindings for PrimFuncs, including
symbolic shape functions,
+ * allowing their build and execution on GPU devices. It examines all the
blocks within the
+ * PrimFunc and conducts loop fusion, splitting, and reordering operations
based on the loop extent
+ * and target information, such as the maximum thread block number and
maximum thread per block.
+ * \note The primary objective of this pass is not to optimize performance,
but rather to
+ * generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs.
The pass is
+ * currently only working for CUDA targets.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultGPUSchedule();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index a6e0cf06cb..69d3119531 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1055,3 +1055,21 @@ def InstallDebugSpans():
The result pass
"""
return _ffi_api.InstallDebugSpans() # type: ignore
+
+
+def DefaultGPUSchedule():
+ """The pass sets default thread bindings for PrimFuncs, including symbolic
shape functions,
+ allowing their build and execution on GPU devices. It examines all the
blocks within the
+ PrimFunc and conducts loop fusion, splitting, and reordering operation
based on the loop
+ extent and target information, such as the maximum thread block number and
maximum thread
+ per block.
+
+ The primary objective of this pass is not to optimize performance, but
rather to generate
+ a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass
is currently only
+ working for CUDA targets.
+
+ Returns
+ -------
+ ret: tvm.transform.Pass
+ """
+ return _ffi_api.DefaultGPUSchedule() # type: ignore
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc
b/src/meta_schedule/space_generator/post_order_apply.cc
index 491af6e28f..1bde99869e 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -21,68 +21,6 @@
namespace tvm {
namespace meta_schedule {
-/*! \brief Collecting all the blocks */
-class BlockCollector : public tir::StmtVisitor {
- public:
- static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
- const runtime::PackedFunc f_block_filter
= nullptr) { //
- return BlockCollector(sch, f_block_filter).Run();
- }
-
- private:
- /*! \brief Entry point */
- Array<tir::BlockRV> Run() {
- std::vector<tir::BlockRV> results;
- for (const auto& kv : sch_->mod()->functions) {
- const GlobalVar& gv = kv.first; // `gv->name_hint` is the name
of the function
- const BaseFunc& base_func = kv.second; // this can be PrimFunc or
relay::Function
- if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
- func_name_ = gv->name_hint;
- block_names_.clear();
- blocks_to_collect_.clear();
- VisitStmt(func->body);
- for (const String& name : blocks_to_collect_) {
- results.push_back(sch_->GetBlock(name, func_name_));
- }
- }
- }
- return results;
- }
- /*! \brief Constructor */
- explicit BlockCollector(const tir::Schedule& sch,
- const runtime::PackedFunc f_block_filter = nullptr)
- : sch_(sch), f_block_filter_(f_block_filter) {}
- /*! \brief Override the Stmt visiting behaviour */
- void VisitStmt_(const tir::BlockNode* block) override {
- tir::StmtVisitor::VisitStmt_(block);
- CHECK(block_names_.count(block->name_hint) == 0)
- << "Duplicated block name " << block->name_hint << " in function " <<
func_name_
- << " not supported!";
- block_names_.insert(block->name_hint);
-
- // If filter function is provided, use it to selectively collect blocks.
- // Otherwise collect all blocks.
- Bool collect_block = Bool(true);
- if (f_block_filter_ != nullptr) {
- collect_block = f_block_filter_(GetRef<tir::Block>(block));
- }
- if (collect_block) {
- blocks_to_collect_.push_back(block->name_hint);
- }
- }
-
- /*! \brief The schedule to be collected */
- const tir::Schedule& sch_;
- /*! \brief An optional packed func that allows only certain blocks to be
collected. */
- const runtime::PackedFunc f_block_filter_;
- /*! \brief The set of func name and block name pair */
- std::unordered_set<String> block_names_;
- /* \brief The list of blocks to collect in order */
- Array<String> blocks_to_collect_;
- /*! \brief Name of the current PrimFunc */
- String func_name_;
-};
-
/*!
* \brief Design Space Generator that generates design spaces by applying
schedule rules to blocks
* in post-DFS order.
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 9a372dde8f..955381b740 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -554,6 +554,68 @@ inline double Sum(const Array<FloatImm>& arr) {
return sum;
}
+/*! \brief Collecting all the blocks */
+class BlockCollector : public tir::StmtVisitor {
+ public:
+ static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
+ const runtime::PackedFunc f_block_filter
= nullptr) { //
+ return BlockCollector(sch, f_block_filter).Run();
+ }
+
+ private:
+ /*! \brief Entry point */
+ Array<tir::BlockRV> Run() {
+ std::vector<tir::BlockRV> results;
+ for (const auto& [gv, base_func] : sch_->mod()->functions) {
+ // `gv->name_hint` is the name of the function
+ // `base_func` can be PrimFunc or relay::Function
+ if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+ func_name_ = gv->name_hint;
+ block_names_.clear();
+ blocks_to_collect_.clear();
+ VisitStmt(func->body);
+ for (const String& name : blocks_to_collect_) {
+ results.push_back(sch_->GetBlock(name, func_name_));
+ }
+ }
+ }
+ return results;
+ }
+ /*! \brief Constructor */
+ explicit BlockCollector(const tir::Schedule& sch,
+ const runtime::PackedFunc f_block_filter = nullptr)
+ : sch_(sch), f_block_filter_(f_block_filter) {}
+ /*! \brief Override the Stmt visiting behaviour */
+ void VisitStmt_(const tir::BlockNode* block) override {
+ tir::StmtVisitor::VisitStmt_(block);
+ CHECK(block_names_.count(block->name_hint) == 0)
+ << "Duplicated block name " << block->name_hint << " in function " <<
func_name_
+ << " not supported!";
+ block_names_.insert(block->name_hint);
+
+ // If filter function is provided, use it to selectively collect blocks.
+ // Otherwise collect all blocks.
+ Bool collect_block = Bool(true);
+ if (f_block_filter_ != nullptr) {
+ collect_block = f_block_filter_(GetRef<tir::Block>(block));
+ }
+ if (collect_block) {
+ blocks_to_collect_.push_back(block->name_hint);
+ }
+ }
+
+ /*! \brief The schedule to be collected */
+ const tir::Schedule& sch_;
+ /*! \brief An optional packed func that allows only certain blocks to be
collected. */
+ const runtime::PackedFunc f_block_filter_;
+ /*! \brief The set of func name and block name pair */
+ std::unordered_set<String> block_names_;
+ /* \brief The list of blocks to collect in order */
+ Array<String> blocks_to_collect_;
+ /*! \brief Name of the current PrimFunc */
+ String func_name_;
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/tir/transforms/default_gpu_schedule.cc
b/src/tir/transforms/default_gpu_schedule.cc
new file mode 100644
index 0000000000..7877c86d0a
--- /dev/null
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../../meta_schedule/utils.h"
+
+namespace tvm {
+namespace tir {
+namespace transform {
+/*!
+ * \brief A helper function to do default thread binding for a block.
+ * \param sch The schedule to work on.
+ * \param block The block to be scheduled.
+ * \param max_thread_per_block The maximum number of threads per block.
+ * \param max_threadblocks The maximum number of threadblocks.
+ */
+void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t
max_thread_per_block,
+ int64_t max_threadblocks = 256) {
+ // fetch the loops
+ Array<tir::LoopRV> loops = sch->GetLoops(block);
+ for (const tir::LoopRV& loop : loops) {
+ // skip block if already scheduled
+ if (sch->Get(loop)->thread_binding.defined()) {
+ return;
+ }
+ }
+ Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+ ICHECK_EQ(loops.size(), iters.size());
+ Array<tir::LoopRV> data_parallel_loops;
+ // only fuse data parallel loops
+ for (size_t i = 0; i < loops.size(); ++i) {
+ if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+ data_parallel_loops.push_back(loops[i]);
+ }
+ }
+ // skip if no data parallel loops
+ if (data_parallel_loops.size() == 0) {
+ return;
+ }
+ // fuse all data parallel loops
+ tir::LoopRV fused = sch->Fuse(data_parallel_loops,
/*preserve_unit_iters=*/false);
+ int64_t product = std::numeric_limits<int64_t>::max();
+ if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+ product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+ }
+ // schedule the fused loop
+ if (product > max_thread_per_block * max_threadblocks) {
+ Array<tir::LoopRV> splits =
+ sch->Split(fused,
+ /*factors=*/{NullOpt, Integer(max_threadblocks),
Integer(max_thread_per_block)});
+ sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+ sch->Bind(splits[1], "blockIdx.x");
+ sch->Bind(splits[2], "threadIdx.x");
+ } else {
+ Array<tir::LoopRV> splits =
+ sch->Split(fused, /*factors=*/{NullOpt, Integer(std::min(product,
max_thread_per_block))});
+ sch->Bind(splits[0], "blockIdx.x");
+ sch->Bind(splits[1], "threadIdx.x");
+ }
+}
+
+Pass DefaultGPUSchedule() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) {
+ // get the target from context.
+ tvm::Target target = tvm::Target::Current();
+ ICHECK(target.defined()) << "Target is not set in current context";
+ // skip non-cuda targets.
+ if (target->kind->name != "cuda") {
+ return m;
+ }
+ // get the max thread per block from target.
+ Optional<Integer> opt_max_thread_per_block =
target->GetAttr<Integer>("max_num_threads");
+ ICHECK(opt_max_thread_per_block.defined())
+ << "max_num_threads is not set for target " << target;
+ int64_t max_thread_per_block =
opt_max_thread_per_block.value().IntValue();
+ tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1,
/*debug_mask=*/0,
+
tir::ScheduleErrorRenderLevel::kDetail);
+ for (const auto& [gv, func] : m->functions) {
+ if (func->IsInstance<tir::PrimFuncNode>()) {
+ sch->WorkOn(gv->name_hint);
+ Array<tir::BlockRV> blocks =
meta_schedule::BlockCollector::Collect(sch);
+ for (const tir::BlockRV& block : blocks) {
+ ThreadBind(sch, block, max_thread_per_block);
+ }
+ }
+ }
+ return sch->mod();
+ };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"DefaultGPUSchedule", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py
b/tests/python/unittest/test_transform_default_gpu_schedule.py
new file mode 100644
index 0000000000..644a9aede0
--- /dev/null
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -0,0 +1,417 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,,missing-function-docstring
+import tvm
+from tvm.tir.transform import DefaultGPUSchedule
+from tvm.script import tir as T
+import tvm.testing
+
+
+def test_broadcast_to_symbolic():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def broadcast_to(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
+ var_T_broadcast_to: T.handle,
+ ):
+ T.func_attr({"tir.noalias": True})
+ x_0 = T.int64()
+ x_1 = T.int64()
+ T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(x_0, x_1):
+ with T.block("T_broadcast_to"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, T.int64(0)])
+ T.writes(T_broadcast_to[v_ax0, v_ax1])
+ T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0,
T.int64(0)]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def broadcast_to(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
+ var_T_broadcast_to: T.handle,
+ ):
+ T.func_attr({"tir.noalias": True})
+ x_0 = T.int64()
+ x_1 = T.int64()
+ T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
+ # with T.block("root"):
+ for ax0_ax1_fused_1 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for ax0_ax1_fused_2 in T.thread_binding(
+ T.int64(1024), thread="threadIdx.x"
+ ):
+ for ax0_ax1_fused_0 in range(
+ (x_0 * x_1 + T.int64(262143)) // T.int64(262144)
+ ):
+ with T.block("T_broadcast_to"):
+ v_ax0 = T.axis.spatial(
+ x_0,
+ (
+ (ax0_ax1_fused_0 * T.int64(256) +
ax0_ax1_fused_1)
+ * T.int64(1024)
+ + ax0_ax1_fused_2
+ )
+ // x_1,
+ )
+ v_ax1 = T.axis.spatial(
+ x_1,
+ (
+ (ax0_ax1_fused_0 * T.int64(256) +
ax0_ax1_fused_1)
+ * T.int64(1024)
+ + ax0_ax1_fused_2
+ )
+ % x_1,
+ )
+ T.where(
+ (ax0_ax1_fused_0 * T.int64(256) +
ax0_ax1_fused_1)
+ * T.int64(1024)
+ + ax0_ax1_fused_2
+ < x_0 * x_1
+ )
+ T.reads(rxplaceholder[v_ax0, T.int64(0)])
+ T.writes(T_broadcast_to[v_ax0, v_ax1])
+ T_broadcast_to[v_ax0, v_ax1] =
rxplaceholder[v_ax0, T.int64(0)]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ After = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_matmul():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def matmul(
+ A: T.Buffer((32, 32), "float16"),
+ B: T.Buffer((32, 32), "float16"),
+ C: T.Buffer((32, 32), "float16"),
+ ):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # with T.block("root"):
+ for i, j, k in T.grid(32, 32, 32):
+ with T.block("C"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def matmul(
+ A: T.Buffer((32, 32), "float16"),
+ B: T.Buffer((32, 32), "float16"),
+ C: T.Buffer((32, 32), "float16"),
+ ):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # with T.block("root"):
+ for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for i_j_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ for k in range(32):
+ with T.block("C"):
+ v_i = T.axis.spatial(
+ 32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32
+ )
+ v_j = T.axis.spatial(
+ 32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32
+ )
+ v_k = T.axis.reduce(32, k)
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k,
v_j]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ After = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_add():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2),
T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2),
T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2),
T.int64(3)):
+ with T.block("T_add"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2,
ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)),
"float32"),
+ rxplaceholder_1: T.Buffer(
+ (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
+ ),
+ T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)),
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_i2_i3_fused_1 in T.thread_binding(
+ T.int64(72), thread="threadIdx.x"
+ ):
+ with T.block("T_add"):
+ ax0 = T.axis.spatial(
+ T.int64(4),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ // T.int64(18),
+ )
+ ax1 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(18)
+ // T.int64(6),
+ )
+ ax2 = T.axis.spatial(
+ T.int64(2),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(6)
+ // T.int64(3),
+ )
+ ax3 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(3),
+ )
+ T.reads(
+ rxplaceholder[T.int64(0), ax2, ax3],
+ rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
+ )
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = (
+ rxplaceholder[T.int64(0), ax2, ax3]
+ + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+ )
+
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ After = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_full():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def full(rxplaceholder: T.Buffer((), "int32"), T_full:
T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_full"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[()])
+ T.writes(T_full[ax0, ax1])
+ T_full[ax0, ax1] = rxplaceholder[()]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def full(
+ rxplaceholder: T.Buffer((), "int32"),
+ T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):
+ with T.block("T_full"):
+ ax0 = T.axis.spatial(
+ T.int64(2),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) //
T.int64(3),
+ )
+ ax1 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) %
T.int64(3),
+ )
+ T.reads(rxplaceholder[()])
+ T.writes(T_full[ax0, ax1])
+ T_full[ax0, ax1] = rxplaceholder[()]
+
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ After = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_scheduled():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+
+ @tvm.script.ir_module
+ class Scheduled:
+ @T.prim_func
+ def full(
+ rxplaceholder: T.Buffer((), "int32"),
+ T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):
+ with T.block("T_full"):
+ ax0 = T.axis.spatial(
+ T.int64(2),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) //
T.int64(3),
+ )
+ ax1 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) %
T.int64(3),
+ )
+ T.reads(rxplaceholder[()])
+ T.writes(T_full[ax0, ax1])
+ T_full[ax0, ax1] = rxplaceholder[()]
+
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ # should do nothing
+ After = DefaultGPUSchedule()(Scheduled)
+ tvm.ir.assert_structural_equal(After, Scheduled)
+
+
+def test_multiple():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2),
T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2),
T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2),
T.int64(3)):
+ with T.block("T_add"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2,
ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+ @T.prim_func
+ def full(rxplaceholder: T.Buffer((), "int32"), T_full:
T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_full"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[()])
+ T.writes(T_full[ax0, ax1])
+ T_full[ax0, ax1] = rxplaceholder[()]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)),
"float32"),
+ rxplaceholder_1: T.Buffer(
+ (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
+ ),
+ T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)),
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_i2_i3_fused_1 in T.thread_binding(
+ T.int64(72), thread="threadIdx.x"
+ ):
+ with T.block("T_add"):
+ ax0 = T.axis.spatial(
+ T.int64(4),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ // T.int64(18),
+ )
+ ax1 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(18)
+ // T.int64(6),
+ )
+ ax2 = T.axis.spatial(
+ T.int64(2),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(6)
+ // T.int64(3),
+ )
+ ax3 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_i2_i3_fused_0 * T.int64(72) +
i0_i1_i2_i3_fused_1)
+ % T.int64(3),
+ )
+ T.reads(
+ rxplaceholder[T.int64(0), ax2, ax3],
+ rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
+ )
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = (
+ rxplaceholder[T.int64(0), ax2, ax3]
+ + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+ )
+
+ @T.prim_func
+ def full(
+ rxplaceholder: T.Buffer((), "int32"),
+ T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):
+ with T.block("T_full"):
+ ax0 = T.axis.spatial(
+ T.int64(2),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) //
T.int64(3),
+ )
+ ax1 = T.axis.spatial(
+ T.int64(3),
+ (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) %
T.int64(3),
+ )
+ T.reads(rxplaceholder[()])
+ T.writes(T_full[ax0, ax1])
+ T_full[ax0, ax1] = rxplaceholder[()]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=3):
+ After = DefaultGPUSchedule()(Before)
+ assert tvm.ir.structural_equal(After, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()