This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new accd582d3a [BugFix][TIR][Schedule] TileWithTensorIntrin skip
ComputeInline if bu… (#17440)
accd582d3a is described below
commit accd582d3a006b6c3473187e1c155fa535343d8a
Author: Yongqi <[email protected]>
AuthorDate: Sat Oct 5 15:32:31 2024 +0800
[BugFix][TIR][Schedule] TileWithTensorIntrin skip ComputeInline if bu…
(#17440)
[BugFix][TIR][Schedule] TileWithTensorIntrin skip ComputeInline if buffer
not padded by PadEinsum
---
src/tir/schedule/transform.cc | 63 ++++-
.../test_meta_schedule_schedule_rule_mlt_tc.py | 295 +++++++++++++++++++++
2 files changed, 346 insertions(+), 12 deletions(-)
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index fec214fa1f..c644fbecdf 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -326,23 +326,62 @@ Optional<LoopRV> TileWithTensorIntrin(const
tir::Schedule& sch, const tir::Block
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
if (info->block_iter_paddings.defined()) {
+ // We have to track whether each producer or consumer is padded.
+ // To do so, we first record all the Block's.
+ std::unordered_set<const StmtSRefNode*> original_producers,
original_consumers;
+ {
+ for (const auto& p : GetProducers(sch->state(), sch->GetSRef(block_rv)))
+ original_producers.insert(p.get());
+ for (const auto& c : GetConsumers(sch->state(), sch->GetSRef(block_rv)))
+ original_consumers.insert(c.get());
+ }
+
+ // Pad. Maybe we can make PadEinsum return the changes it made, to avoid
bookkeeping?
sch->PadEinsum(block_rv, info->block_iter_paddings.value());
+
+ // Now we need to find out all the padded Block's.
+ Array<BlockRV> inlined_producers, inlined_consumers;
+ for (const auto& producer : sch->GetProducers(block_rv)) {
+ // PadEinsum will not modify the producer if it does not need padding.
+ if (original_producers.count(sch->GetSRef(producer).get())) {
+ // Producer not padded. No inlining.
+ continue;
+ }
+ auto the_original_producers = sch->GetProducers(producer);
+ if (the_original_producers.empty()) {
+ // The original producer is input.
+ continue;
+ }
+ ICHECK_EQ(the_original_producers.size(), 1u);
+ auto the_original_producer = the_original_producers[0];
+
ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get()));
+ inlined_producers.push_back(the_original_producer);
+ }
+ for (const auto& consumer : sch->GetConsumers(block_rv)) {
+ // PadEinsum will not modify the consumer if it does not need padding.
+ if (original_consumers.count(sch->GetSRef(consumer).get())) {
+ // Consumer not padded. No inlining.
+ continue;
+ }
+ auto the_original_consumers = sch->GetConsumers(consumer);
+ if (the_original_consumers.empty()) {
+ // The original consumer is output.
+ continue;
+ }
+ ICHECK_EQ(the_original_consumers.size(), 1u);
+ auto the_original_consumer = the_original_consumers[0];
+
ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get()));
+ inlined_consumers.push_back(consumer);
+ }
+
// Inline the producer and consumer padding blocks
- auto producers = sch->GetProducers(block_rv);
- for (const auto& producer : producers) {
- auto original_producers = sch->GetProducers(producer);
- // NOTICE: there may not all producers padded.
+ for (const auto& the_original_producer : inlined_producers) {
// Inline the original producer into the padding block. This ensures
that the new producer
// has the padded shape.
- if (original_producers.size() == 1u) {
- sch->ComputeInline(original_producers[0]);
- }
+ sch->ComputeInline(the_original_producer);
}
- auto consumers = sch->GetConsumers(block_rv);
- for (const auto& consumer : consumers) {
- auto sref = sch->GetSRef(consumer);
- if (!tir::IsOutputBlock(sch->state(), sref,
tir::GetScopeRoot(sch->state(), sref, true)))
- sch->ComputeInline(consumer);
+ for (const auto& consumer : inlined_consumers) {
+ sch->ComputeInline(consumer);
}
}
// Construct a mapping from tir loops back to LoopRVs
diff --git
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index 1fd2ab8474..be936e6e84 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -1207,5 +1207,300 @@ def test_padded_conv():
)
+def test_padded_matmul_single_padded_input():
+ # fmt: off
+ @T.prim_func
+ def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096),
"float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16),
scope="shared")
+ C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2,
16, 16), scope="wmma.accumulator")
+ A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16",
scope="shared")
+ B_reindex_shared = T.alloc_buffer((4096, 1024), "float16",
scope="shared")
+ A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096),
"float16", scope="wmma.matrix_a")
+ B_reindex_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024),
"float16", scope="wmma.matrix_b")
+ for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+ for ax0_0_1_ax1_0_1_fused in T.thread_binding(32,
thread="blockIdx.x"):
+ for ax0_0_2_ax1_0_2_fused in T.thread_binding(8,
thread="threadIdx.y"):
+ for ax2_0_0 in range(32):
+ for ax0_ax1_fused in range(65536):
+ with T.block("A_reindex_pad_shared"):
+ v0 = T.axis.spatial(1024,
ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128)
+ v1 = T.axis.spatial(4096, ax2_0_0 * 128 +
ax0_ax1_fused % 128)
+ T.reads(A[v0, v1])
+ T.writes(A_reindex_pad_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 2})
+ A_reindex_pad_shared[v0, v1] =
T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0))
+ for ax0_ax1_fused in range(8192):
+ with T.block("B_reindex_shared"):
+ v0 = T.axis.spatial(4096, ax2_0_0 * 128 +
ax0_ax1_fused // 64)
+ v1 = T.axis.spatial(1024,
ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64)
+ T.reads(B[v0, v1])
+ T.writes(B_reindex_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
+ B_reindex_shared[v0, v1] = B[v0, v1]
+ for ax2_0_1 in range(8):
+ for ax0_0, ax1_0 in T.grid(8, 1):
+ with
T.block("A_reindex_pad_shared_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(64,
ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0)
+ v1_o = T.axis.spatial(256, ax2_0_0 * 8 +
ax2_0_1 + ax1_0)
+ T.reads(A_reindex_pad_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("A_reindex_pad_shared_wmma.matrix_a"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+ T.reads(A_reindex_pad_shared[v0_o
* 16 + v0_i, v1_o * 16 + v1_i])
+
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0, ax1_0 in T.grid(1, 2):
+ with
T.block("B_reindex_shared_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(256, ax2_0_0 * 8 +
ax2_0_1 + ax0_0)
+ v1_o = T.axis.spatial(64,
ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0)
+ T.reads(B_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("B_reindex_shared_wmma.matrix_b"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+ T.reads(B_reindex_shared[v0_o * 16
+ v0_i, v1_o * 16 + v1_i])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in
T.grid(2, 1, 1, 4, 2):
+ with T.block("C_o"):
+ v0_o = T.axis.spatial(64,
ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4
+ ax0_0_4)
+ v1_o = T.axis.spatial(64,
ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 +
ax1_0_4)
+ v2_o = T.axis.reduce(256, ax2_0_0 * 8 +
ax2_0_1 + ax2_0_2)
+
T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16,
v1_o * 16:v1_o * 16 + 16])
+
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8,
v1_o % 2, 0:16, 0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
+ with T.init():
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with T.block("C_init"):
+ v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
+ T.reads()
+
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8,
v1_o % 2, v0_i_init, v1_i_init])
+
C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2,
v0_i_init, v1_i_init] = T.float32(0.0)
+ for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
+ with T.block("C"):
+ v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8,
v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i,
v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 +
v1_i])
+
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8,
v1_o % 2, v0_i, v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2,
v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o
% 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32",
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) *
T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 +
v1_i])
+ for ax2 in range(8):
+ for ax0_ax1_fused in T.thread_binding(8,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 2):
+ with
T.block("C_reindex_pad_shared_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused
// 16 * 4 + ax0_ax1_fused // 2)
+ v1_o = T.axis.spatial(32,
ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2)
+ v2_o = T.axis.spatial(8, ax2 + ax2_1)
+ v3_o = T.axis.spatial(2, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16,
0:16])
+ T.writes(C_reindex_pad_shared[v0_o, v1_o,
v2_o, v3_o, 0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_pad_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i,
v5_i])
+ T.writes(C_reindex_pad_shared[v0_o,
v1_o, v2_o, v3_o, v4_i, v5_i])
+ C_reindex_pad_shared[v0_o, v1_o, v2_o,
v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o,
v3_o, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+ with T.block("C_reindex_pad_shared"):
+ v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16
* 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024)
+ v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16
* 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512)
+ v2 = T.axis.spatial(8, ax2)
+ v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused %
512 // 256)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 +
ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16 < 1023)
+ T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4,
v5])
+ T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 +
v1 * 32])
+ T.block_attr({"meta_schedule.cooperative_fetch":
4})
+ C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32]
= C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]
+ # fmt: on
+
+ decision_0 = [
+ ("SamplePerfectTile", [1, 2, 4, 2, 4]),
+ ("SamplePerfectTile", [1, 16, 2, 1, 2]),
+ ("SamplePerfectTile", [32, 8, 1]),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 1),
+ ("SampleCategorical", 0),
+ ]
+ mod = te.create_prim_func(
+ te_workload.matmul(
+ n=1023,
+ m=1024,
+ k=4096,
+ in_dtype="float16",
+ out_dtype="float32",
+ )
+ )
+ actual = generate_design_space(
+ kind="cuda",
+ mod=mod,
+ target=tvm.target.Target("cuda --arch=sm_70"),
+ types=None,
+ sch_rules=[multi_level_tiling_tensor_core()]
+ + get_rules("cuda", ms.schedule_rule.AutoInline),
+ )
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[padded_matmul_single_padded_input_0],
+ expected_decisions=[decision_0],
+ )
+
+
+def test_padded_matmul_no_padded_output():
+ # fmt: off
+ @T.prim_func
+ def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"),
B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16),
scope="shared")
+ C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16,
16), scope="wmma.accumulator")
+ A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16",
scope="shared")
+ B_reindex_pad_shared = T.alloc_buffer((4096, 1024), "float16",
scope="shared")
+ A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096),
"float16", scope="wmma.matrix_a")
+ B_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024),
"float16", scope="wmma.matrix_b")
+ for ax0_0_0_ax1_0_0_fused in T.thread_binding(64, thread="blockIdx.y"):
+ for ax0_0_1_ax1_0_1_fused in T.thread_binding(2,
thread="blockIdx.x"):
+ for ax0_0_2_ax1_0_2_fused in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax2_0_0 in range(128):
+ for ax0_ax1_fused in range(4096):
+ with T.block("A_reindex_pad_shared"):
+ v0 = T.axis.spatial(1024,
ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused
// 32)
+ v1 = T.axis.spatial(4096, ax2_0_0 * 32 +
ax0_ax1_fused % 32)
+ T.reads(A[v0, v1])
+ T.writes(A_reindex_pad_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 8})
+ A_reindex_pad_shared[v0, v1] =
T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0))
+ for ax0_ax1_fused in range(2048):
+ with T.block("B_reindex_pad_shared"):
+ v0 = T.axis.spatial(4096, ax2_0_0 * 32 +
ax0_ax1_fused // 64)
+ v1 = T.axis.spatial(1024,
ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64)
+ T.reads(B[v0, v1])
+ T.writes(B_reindex_pad_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
+ B_reindex_pad_shared[v0, v1] =
T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0))
+ for ax2_0_1 in range(2):
+ for ax0_0, ax1_0 in T.grid(2, 1):
+ with
T.block("A_reindex_pad_shared_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(64,
ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 +
ax0_0_2_ax1_0_2_fused * 2 + ax0_0)
+ v1_o = T.axis.spatial(256, ax2_0_0 * 2 +
ax2_0_1 + ax1_0)
+ T.reads(A_reindex_pad_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("A_reindex_pad_shared_wmma.matrix_a"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+ T.reads(A_reindex_pad_shared[v0_o
* 16 + v0_i, v1_o * 16 + v1_i])
+
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0, ax1_0 in T.grid(1, 4):
+ with
T.block("B_reindex_pad_shared_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(256, ax2_0_0 * 2 +
ax2_0_1 + ax0_0)
+ v1_o = T.axis.spatial(64,
ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0)
+ T.reads(B_reindex_pad_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("B_reindex_pad_shared_wmma.matrix_b"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+ T.reads(B_reindex_pad_shared[v0_o
* 16 + v0_i, v1_o * 16 + v1_i])
+
T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in
T.grid(2, 1, 1, 1, 4):
+ with T.block("C_o"):
+ v0_o = T.axis.spatial(64,
ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 +
ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4)
+ v1_o = T.axis.spatial(64,
ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4)
+ v2_o = T.axis.reduce(256, ax2_0_0 * 2 +
ax2_0_1 + ax2_0_2)
+
T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 +
16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, 0:16, 0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
+ with T.init():
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with T.block("C_init"):
+ v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
+ T.reads()
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i_init, v1_i_init])
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4,
v0_i_init, v1_i_init] = T.float32(0.0)
+ for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
+ with T.block("C"):
+ v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o *
16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 +
v1_i])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i, v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4,
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2,
v1_o % 4, v0_i, v1_i] + T.Cast("float32",
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) *
T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o *
16 + v1_i])
+ for ax2 in range(2):
+ for ax0_ax1_fused in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 4):
+ with
T.block("C_reindex_shared_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(32,
ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused)
+ v1_o = T.axis.spatial(16,
ax0_0_0_ax1_0_0_fused % 16)
+ v2_o = T.axis.spatial(2, ax2 + ax2_1)
+ v3_o = T.axis.spatial(4, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
+ T.writes(C_reindex_shared[v0_o, v1_o, v2_o,
v3_o, 0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
+ T.writes(C_reindex_shared[v0_o, v1_o,
v2_o, v3_o, v4_i, v5_i])
+ C_reindex_shared[v0_o, v1_o, v2_o,
v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o,
v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+ with T.block("C_reindex_shared"):
+ v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused //
16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024)
+ v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16)
+ v2 = T.axis.spatial(2, ax2)
+ v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused %
1024 // 256)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+ T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 +
v1 * 64])
+ T.block_attr({"meta_schedule.cooperative_fetch":
3})
+ C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]
= C_reindex_shared[v0, v1, v2, v3, v4, v5]
+ # fmt: on
+
+ decision_0 = [
+ ("SamplePerfectTile", [4, 2, 4, 2, 1]),
+ ("SamplePerfectTile", [16, 1, 1, 1, 4]),
+ ("SamplePerfectTile", [128, 2, 1]),
+ ("SampleCategorical", 2),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 0),
+ ]
+ mod = te.create_prim_func(
+ te_workload.matmul(
+ n=1024,
+ m=1024,
+ k=4095,
+ in_dtype="float16",
+ out_dtype="float32",
+ )
+ )
+ actual = generate_design_space(
+ kind="cuda",
+ mod=mod,
+ target=tvm.target.Target("cuda --arch=sm_70"),
+ types=None,
+ sch_rules=[multi_level_tiling_tensor_core()]
+ + get_rules("cuda", ms.schedule_rule.AutoInline),
+ )
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[padded_matmul_no_padded_output_0],
+ expected_decisions=[decision_0],
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()