This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 18ff9ff89b [MetaSchedule]Add a testcase for padded conv2d in 
meta_schedule (#17171)
18ff9ff89b is described below

commit 18ff9ff89b4617d8925ef6afde233e8d1742a5bd
Author: YXY-0922 <[email protected]>
AuthorDate: Tue Jul 23 02:48:57 2024 +0800

    [MetaSchedule]Add a testcase for padded conv2d in meta_schedule (#17171)
    
    ### Bug Fix
    
    In the `TileWithTensorIntrin` function, when the `allow_padding` parameter 
is enabled, the original implementation inlines all consumer blocks. This 
behavior can lead to incorrect inlining of output blocks, causing issues with 
block shapes and dependencies. To ensure correct inlining operations, only 
non-output consumer blocks should be inlined.
    ---------
    Co-authored-by: yuxiyue <[email protected]>
---
 src/tir/schedule/transform.cc                      |   4 +-
 .../test_meta_schedule_schedule_rule_mlt_tc.py     | 152 +++++++++++++++++++++
 2 files changed, 155 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 8f912c59ea..fec214fa1f 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& 
sch, const tir::Block
     }
     auto consumers = sch->GetConsumers(block_rv);
     for (const auto& consumer : consumers) {
-      sch->ComputeInline(consumer);
+      auto sref = sch->GetSRef(consumer);
+      if (!tir::IsOutputBlock(sch->state(), sref, 
tir::GetScopeRoot(sch->state(), sref, true)))
+        sch->ComputeInline(consumer);
     }
   }
   // Construct a mapping from tir loops back to LoopRVs
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index df8607e551..1fd2ab8474 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -1055,5 +1055,157 @@ def test_conv_1x1():
     )
 
 
+def test_padded_conv():
+    # fmt: off
+    @T.prim_func
+    def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: 
T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), 
"float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), 
scope="shared")
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 
14, 2, 16, 16), scope="wmma.accumulator")
+        PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", 
scope="shared")
+        weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", 
scope="shared")
+        PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 
160), "float16", scope="wmma.matrix_a")
+        weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), 
"float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, 
thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, 
thread="threadIdx.y"):
+                    for ax2_0_0 in range(10):
+                        for ax0_ax1_fused in range(28672):
+                            with T.block("PadInput_reindex_pad_shared"):
+                                v0 = T.axis.spatial(12544, 
ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16)
+                                v1 = T.axis.spatial(160, ax2_0_0 * 16 + 
ax0_ax1_fused % 16)
+                                T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 
3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3])
+                                T.writes(PadInput_reindex_pad_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 4})
+                                PadInput_reindex_pad_shared[v0, v1] = 
T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 
112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 
+ v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 
% 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0))
+                        for ax0_ax1_fused in range(512):
+                            with T.block("weight_reindex_pad_shared"):
+                                v0 = T.axis.spatial(160, ax2_0_0 * 16 + 
ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused 
% 2 * 32 + ax0_ax1_fused % 32)
+                                T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, 
v1])
+                                T.writes(weight_reindex_pad_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
+                                weight_reindex_pad_shared[v0, v1] = 
T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], 
T.float16(0))
+                        for ax2_0_1 in range(1):
+                            for ax0_0, ax1_0 in T.grid(14, 1):
+                                with 
T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(784, 
ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0)
+                                    v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0)
+                                    T.reads(PadInput_reindex_pad_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, 
v1_o * 16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("PadInput_reindex_pad_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            
T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 
+ v1_i])
+                                            
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                with 
T.block("weight_reindex_pad_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0)
+                                    v1_o = T.axis.spatial(4, 
ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0)
+                                    T.reads(weight_reindex_pad_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o 
* 16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("weight_reindex_pad_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            
T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i])
+                                            
weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in 
T.grid(7, 2, 1, 2, 1):
+                                with T.block("conv2d_nhwc_o"):
+                                    v0_o = T.axis.spatial(784, 
ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + 
ax0_0_4)
+                                    v1_o = T.axis.spatial(4, 
ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4)
+                                    v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 
+ ax2_0_2)
+                                    
T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 
16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, 
v0_o % 14, v1_o % 2, 0:16, 0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("conv2d_nhwc_init"):
+                                                v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, 
v0_o % 14, v1_o % 2, v0_i_init, v1_i_init])
+                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, 
v1_o % 2, v0_i_init, v1_i_init] = T.float32(0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
+                                        with T.block("conv2d_nhwc"):
+                                            v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o 
% 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 
16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 
+ v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, 
v0_o % 14, v1_o % 2, v0_i, v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, 
v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, 
v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", 
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) 
* T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, 
v1_o * 16 + v1_i])
+                for ax2 in range(14):
+                    for ax0_ax1_fused in T.thread_binding(8, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 2):
+                            with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(56, 
ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused)
+                                v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused 
% 2)
+                                v2_o = T.axis.spatial(14, ax2 + ax2_1)
+                                v3_o = T.axis.spatial(2, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 
0:16, 0:16])
+                                T.writes(conv2d_nhwc_reindex_shared[v0_o, 
v1_o, v2_o, v3_o, 0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 
v4_i, v5_i])
+                                        
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
+                                        conv2d_nhwc_reindex_shared[v0_o, v1_o, 
v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, 
v1_o, v2_o, v3_o, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+                        with T.block("conv2d_nhwc_reindex_shared"):
+                            v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 
* 8 + ax0_ax1_ax3_ax4_ax5_fused // 512)
+                            v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
+                            v2 = T.axis.spatial(14, ax2)
+                            v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 
512 // 256)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 
v4, v5])
+                            T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) 
// 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
3})
+                            conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, 
(v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = 
conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
+    # fmt: on
+
+    decision_0 = [
+        ("SamplePerfectTile", [7, 1, 8, 7, 2]),
+        ("SamplePerfectTile", [2, 1, 1, 2, 1]),
+        ("SamplePerfectTile", [10, 1, 1]),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+    ]
+    mod = te.create_prim_func(
+        te_workload.conv2d_nhwc(
+            1,
+            224,
+            224,
+            3,
+            64,
+            7,
+            2,
+            3,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = generate_design_space(
+        kind="cuda",
+        mod=mod,
+        target=tvm.target.Target("cuda --arch=sm_70"),
+        types=None,
+        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    )
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[padded_conv2d_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to