This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 18ff9ff89b [MetaSchedule]Add a testcase for padded conv2d in
meta_schedule (#17171)
18ff9ff89b is described below
commit 18ff9ff89b4617d8925ef6afde233e8d1742a5bd
Author: YXY-0922 <[email protected]>
AuthorDate: Tue Jul 23 02:48:57 2024 +0800
[MetaSchedule]Add a testcase for padded conv2d in meta_schedule (#17171)
### Bug Fix
In the `TileWithTensorIntrin` function, when the `allow_padding` parameter
is enabled, the original implementation inlines all consumer blocks. This
behavior can lead to incorrect inlining of output blocks, causing issues with
block shapes and dependencies. To ensure correct inlining operations, only
non-output consumer blocks should be inlined.
---------
Co-authored-by: yuxiyue <[email protected]>
---
src/tir/schedule/transform.cc | 4 +-
.../test_meta_schedule_schedule_rule_mlt_tc.py | 152 +++++++++++++++++++++
2 files changed, 155 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 8f912c59ea..fec214fa1f 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule&
sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
- sch->ComputeInline(consumer);
+ auto sref = sch->GetSRef(consumer);
+ if (!tir::IsOutputBlock(sch->state(), sref,
tir::GetScopeRoot(sch->state(), sref, true)))
+ sch->ComputeInline(consumer);
}
}
// Construct a mapping from tir loops back to LoopRVs
diff --git
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index df8607e551..1fd2ab8474 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -1055,5 +1055,157 @@ def test_conv_1x1():
)
+def test_padded_conv():
+ # fmt: off
+ @T.prim_func
+ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight:
T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16),
scope="shared")
+ conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2,
14, 2, 16, 16), scope="wmma.accumulator")
+ PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16",
scope="shared")
+ weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16",
scope="shared")
+ PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544,
160), "float16", scope="wmma.matrix_a")
+ weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64),
"float16", scope="wmma.matrix_b")
+ for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"):
+ for ax0_0_1_ax1_0_1_fused in T.thread_binding(1,
thread="blockIdx.x"):
+ for ax0_0_2_ax1_0_2_fused in T.thread_binding(8,
thread="threadIdx.y"):
+ for ax2_0_0 in range(10):
+ for ax0_ax1_fused in range(28672):
+ with T.block("PadInput_reindex_pad_shared"):
+ v0 = T.axis.spatial(12544,
ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16)
+ v1 = T.axis.spatial(160, ax2_0_0 * 16 +
ax0_ax1_fused % 16)
+ T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 -
3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3])
+ T.writes(PadInput_reindex_pad_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 4})
+ PadInput_reindex_pad_shared[v0, v1] =
T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 //
112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2
+ v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1
% 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0))
+ for ax0_ax1_fused in range(512):
+ with T.block("weight_reindex_pad_shared"):
+ v0 = T.axis.spatial(160, ax2_0_0 * 16 +
ax0_ax1_fused // 32)
+ v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused
% 2 * 32 + ax0_ax1_fused % 32)
+ T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3,
v1])
+ T.writes(weight_reindex_pad_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 2})
+ weight_reindex_pad_shared[v0, v1] =
T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1],
T.float16(0))
+ for ax2_0_1 in range(1):
+ for ax0_0, ax1_0 in T.grid(14, 1):
+ with
T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(784,
ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0)
+ v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0)
+ T.reads(PadInput_reindex_pad_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16,
v1_o * 16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("PadInput_reindex_pad_shared_wmma.matrix_a"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+
T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16
+ v1_i])
+
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0, ax1_0 in T.grid(1, 2):
+ with
T.block("weight_reindex_pad_shared_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0)
+ v1_o = T.axis.spatial(4,
ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0)
+ T.reads(weight_reindex_pad_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o
* 16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with
T.block("weight_reindex_pad_shared_wmma.matrix_b"):
+ v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
+
T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 +
v1_i])
+
weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+ for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in
T.grid(7, 2, 1, 2, 1):
+ with T.block("conv2d_nhwc_o"):
+ v0_o = T.axis.spatial(784,
ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 +
ax0_0_4)
+ v1_o = T.axis.spatial(4,
ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4)
+ v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1
+ ax2_0_2)
+
T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o *
16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2,
v0_o % 14, v1_o % 2, 0:16, 0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
+ with T.init():
+ for ax0_1, ax1_1 in T.grid(16, 16):
+ with T.block("conv2d_nhwc_init"):
+ v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
+ T.reads()
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2,
v0_o % 14, v1_o % 2, v0_i_init, v1_i_init])
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14,
v1_o % 2, v0_i_init, v1_i_init] = T.float32(0)
+ for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
+ with T.block("conv2d_nhwc"):
+ v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o
% 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o *
16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16
+ v2_i, v1_o * 16 + v1_i])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2,
v0_o % 14, v1_o % 2, v0_i, v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14,
v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14,
v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32",
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i])
* T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i,
v1_o * 16 + v1_i])
+ for ax2 in range(14):
+ for ax0_ax1_fused in T.thread_binding(8,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 2):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(56,
ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused)
+ v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused
% 2)
+ v2_o = T.axis.spatial(14, ax2 + ax2_1)
+ v3_o = T.axis.spatial(2, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o,
0:16, 0:16])
+ T.writes(conv2d_nhwc_reindex_shared[v0_o,
v1_o, v2_o, v3_o, 0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o,
v4_i, v5_i])
+
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
+ conv2d_nhwc_reindex_shared[v0_o, v1_o,
v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o,
v1_o, v2_o, v3_o, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+ with T.block("conv2d_nhwc_reindex_shared"):
+ v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2
* 8 + ax0_ax1_ax3_ax4_ax5_fused // 512)
+ v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
+ v2 = T.axis.spatial(14, ax2)
+ v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused %
512 // 256)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3,
v4, v5])
+ T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224)
// 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32])
+ T.block_attr({"meta_schedule.cooperative_fetch":
3})
+ conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112,
(v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] =
conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
+ # fmt: on
+
+ decision_0 = [
+ ("SamplePerfectTile", [7, 1, 8, 7, 2]),
+ ("SamplePerfectTile", [2, 1, 1, 2, 1]),
+ ("SamplePerfectTile", [10, 1, 1]),
+ ("SampleCategorical", 2),
+ ("SampleCategorical", 2),
+ ("SampleCategorical", 1),
+ ]
+ mod = te.create_prim_func(
+ te_workload.conv2d_nhwc(
+ 1,
+ 224,
+ 224,
+ 3,
+ 64,
+ 7,
+ 2,
+ 3,
+ in_dtype="float16",
+ out_dtype="float32",
+ )
+ )
+ actual = generate_design_space(
+ kind="cuda",
+ mod=mod,
+ target=tvm.target.Target("cuda --arch=sm_70"),
+ types=None,
+ sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
+ + get_rules("cuda", ms.schedule_rule.AutoInline),
+ )
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[padded_conv2d_0],
+ expected_decisions=[decision_0],
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()