This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new accd582d3a [BugFix][TIR][Schedule] TileWithTensorIntrin skip 
ComputeInline if bu… (#17440)
accd582d3a is described below

commit accd582d3a006b6c3473187e1c155fa535343d8a
Author: Yongqi <[email protected]>
AuthorDate: Sat Oct 5 15:32:31 2024 +0800

    [BugFix][TIR][Schedule] TileWithTensorIntrin skip ComputeInline if bu… 
(#17440)
    
    [BugFix][TIR][Schedule] TileWithTensorIntrin skip ComputeInline if buffer 
not padded by PadEinsum
---
 src/tir/schedule/transform.cc                      |  63 ++++-
 .../test_meta_schedule_schedule_rule_mlt_tc.py     | 295 +++++++++++++++++++++
 2 files changed, 346 insertions(+), 12 deletions(-)

diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index fec214fa1f..c644fbecdf 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -326,23 +326,62 @@ Optional<LoopRV> TileWithTensorIntrin(const 
tir::Schedule& sch, const tir::Block
   if (!opt_tensorize_info) return NullOpt;
   const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
   if (info->block_iter_paddings.defined()) {
+    // We have to track whether each producer or consumer is padded.
+    // To do so, we first record all the Block's.
+    std::unordered_set<const StmtSRefNode*> original_producers, 
original_consumers;
+    {
+      for (const auto& p : GetProducers(sch->state(), sch->GetSRef(block_rv)))
+        original_producers.insert(p.get());
+      for (const auto& c : GetConsumers(sch->state(), sch->GetSRef(block_rv)))
+        original_consumers.insert(c.get());
+    }
+
+    // Pad. Maybe we can make PadEinsum return the changes it made, to avoid 
bookkeeping?
     sch->PadEinsum(block_rv, info->block_iter_paddings.value());
+
+    // Now we need to find out all the padded Block's.
+    Array<BlockRV> inlined_producers, inlined_consumers;
+    for (const auto& producer : sch->GetProducers(block_rv)) {
+      // PadEinsum will not modify the producer if it does not need padding.
+      if (original_producers.count(sch->GetSRef(producer).get())) {
+        // Producer not padded. No inlining.
+        continue;
+      }
+      auto the_original_producers = sch->GetProducers(producer);
+      if (the_original_producers.empty()) {
+        // The original producer is input.
+        continue;
+      }
+      ICHECK_EQ(the_original_producers.size(), 1u);
+      auto the_original_producer = the_original_producers[0];
+      
ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get()));
+      inlined_producers.push_back(the_original_producer);
+    }
+    for (const auto& consumer : sch->GetConsumers(block_rv)) {
+      // PadEinsum will not modify the consumer if it does not need padding.
+      if (original_consumers.count(sch->GetSRef(consumer).get())) {
+        // Consumer not padded. No inlining.
+        continue;
+      }
+      auto the_original_consumers = sch->GetConsumers(consumer);
+      if (the_original_consumers.empty()) {
+        // The original consumer is output.
+        continue;
+      }
+      ICHECK_EQ(the_original_consumers.size(), 1u);
+      auto the_original_consumer = the_original_consumers[0];
+      
ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get()));
+      inlined_consumers.push_back(consumer);
+    }
+
     // Inline the producer and consumer padding blocks
-    auto producers = sch->GetProducers(block_rv);
-    for (const auto& producer : producers) {
-      auto original_producers = sch->GetProducers(producer);
-      // NOTICE: there may not all producers padded.
+    for (const auto& the_original_producer : inlined_producers) {
       // Inline the original producer into the padding block. This ensures 
that the new producer
       // has the padded shape.
-      if (original_producers.size() == 1u) {
-        sch->ComputeInline(original_producers[0]);
-      }
+      sch->ComputeInline(the_original_producer);
     }
-    auto consumers = sch->GetConsumers(block_rv);
-    for (const auto& consumer : consumers) {
-      auto sref = sch->GetSRef(consumer);
-      if (!tir::IsOutputBlock(sch->state(), sref, 
tir::GetScopeRoot(sch->state(), sref, true)))
-        sch->ComputeInline(consumer);
+    for (const auto& consumer : inlined_consumers) {
+      sch->ComputeInline(consumer);
     }
   }
   // Construct a mapping from tir loops back to LoopRVs
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index 1fd2ab8474..be936e6e84 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -1207,5 +1207,300 @@ def test_padded_conv():
     )
 
 
+def test_padded_matmul_single_padded_input():
+    # fmt: off
+    @T.prim_func
+    def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), 
"float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), 
"float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16), 
scope="shared")
+        C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2, 
16, 16), scope="wmma.accumulator")
+        A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer((4096, 1024), "float16", 
scope="shared")
+        A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), 
"float16", scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), 
"float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(32, 
thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, 
thread="threadIdx.y"):
+                    for ax2_0_0 in range(32):
+                        for ax0_ax1_fused in range(65536):
+                            with T.block("A_reindex_pad_shared"):
+                                v0 = T.axis.spatial(1024, 
ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(4096, ax2_0_0 * 128 + 
ax0_ax1_fused % 128)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_pad_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
+                                A_reindex_pad_shared[v0, v1] = 
T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0))
+                        for ax0_ax1_fused in range(8192):
+                            with T.block("B_reindex_shared"):
+                                v0 = T.axis.spatial(4096, ax2_0_0 * 128 + 
ax0_ax1_fused // 64)
+                                v1 = T.axis.spatial(1024, 
ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
+                                B_reindex_shared[v0, v1] = B[v0, v1]
+                        for ax2_0_1 in range(8):
+                            for ax0_0, ax1_0 in T.grid(8, 1):
+                                with 
T.block("A_reindex_pad_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(64, 
ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0)
+                                    v1_o = T.axis.spatial(256, ax2_0_0 * 8 + 
ax2_0_1 + ax1_0)
+                                    T.reads(A_reindex_pad_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("A_reindex_pad_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            T.reads(A_reindex_pad_shared[v0_o 
* 16 + v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(256, ax2_0_0 * 8 + 
ax2_0_1 + ax0_0)
+                                    v1_o = T.axis.spatial(64, 
ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0)
+                                    T.reads(B_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("B_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            T.reads(B_reindex_shared[v0_o * 16 
+ v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in 
T.grid(2, 1, 1, 4, 2):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(64, 
ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4 
+ ax0_0_4)
+                                    v1_o = T.axis.spatial(64, 
ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 + 
ax1_0_4)
+                                    v2_o = T.axis.reduce(256, ax2_0_0 * 8 + 
ax2_0_1 + ax2_0_2)
+                                    
T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, 
v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, 
v1_o % 2, 0:16, 0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, 
v1_o % 2, v0_i_init, v1_i_init])
+                                                
C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 
v0_i_init, v1_i_init] = T.float32(0.0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, 
v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, 
v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + 
v1_i])
+                                            
T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, 
v1_o % 2, v0_i, v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 
v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o 
% 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32", 
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * 
T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + 
v1_i])
+                for ax2 in range(8):
+                    for ax0_ax1_fused in T.thread_binding(8, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 2):
+                            with 
T.block("C_reindex_pad_shared_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused 
// 16 * 4 + ax0_ax1_fused // 2)
+                                v1_o = T.axis.spatial(32, 
ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2)
+                                v2_o = T.axis.spatial(8, ax2 + ax2_1)
+                                v3_o = T.axis.spatial(2, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 
0:16])
+                                T.writes(C_reindex_pad_shared[v0_o, v1_o, 
v2_o, v3_o, 0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_pad_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, 
v5_i])
+                                        T.writes(C_reindex_pad_shared[v0_o, 
v1_o, v2_o, v3_o, v4_i, v5_i])
+                                        C_reindex_pad_shared[v0_o, v1_o, v2_o, 
v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, 
v3_o, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+                        with T.block("C_reindex_pad_shared"):
+                            v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 
* 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024)
+                            v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 
* 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512)
+                            v2 = T.axis.spatial(8, ax2)
+                            v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 
512 // 256)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 + 
ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16 < 1023)
+                            T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4, 
v5])
+                            T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + 
v1 * 32])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
4})
+                            C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32] 
= C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]
+    # fmt: on
+
+    decision_0 = [
+        ("SamplePerfectTile", [1, 2, 4, 2, 4]),
+        ("SamplePerfectTile", [1, 16, 2, 1, 2]),
+        ("SamplePerfectTile", [32, 8, 1]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 0),
+    ]
+    mod = te.create_prim_func(
+        te_workload.matmul(
+            n=1023,
+            m=1024,
+            k=4096,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = generate_design_space(
+        kind="cuda",
+        mod=mod,
+        target=tvm.target.Target("cuda --arch=sm_70"),
+        types=None,
+        sch_rules=[multi_level_tiling_tensor_core()]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    )
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[padded_matmul_single_padded_input_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_padded_matmul_no_padded_output():
+    # fmt: off
+    @T.prim_func
+    def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), 
B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16), 
scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16, 
16), scope="wmma.accumulator")
+        A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", 
scope="shared")
+        B_reindex_pad_shared = T.alloc_buffer((4096, 1024), "float16", 
scope="shared")
+        A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), 
"float16", scope="wmma.matrix_a")
+        B_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), 
"float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(64, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, 
thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, 
thread="threadIdx.y"):
+                    for ax2_0_0 in range(128):
+                        for ax0_ax1_fused in range(4096):
+                            with T.block("A_reindex_pad_shared"):
+                                v0 = T.axis.spatial(1024, 
ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused 
// 32)
+                                v1 = T.axis.spatial(4096, ax2_0_0 * 32 + 
ax0_ax1_fused % 32)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_pad_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
+                                A_reindex_pad_shared[v0, v1] = 
T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0))
+                        for ax0_ax1_fused in range(2048):
+                            with T.block("B_reindex_pad_shared"):
+                                v0 = T.axis.spatial(4096, ax2_0_0 * 32 + 
ax0_ax1_fused // 64)
+                                v1 = T.axis.spatial(1024, 
ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_pad_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
+                                B_reindex_pad_shared[v0, v1] = 
T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0))
+                        for ax2_0_1 in range(2):
+                            for ax0_0, ax1_0 in T.grid(2, 1):
+                                with 
T.block("A_reindex_pad_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(64, 
ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + 
ax0_0_2_ax1_0_2_fused * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(256, ax2_0_0 * 2 + 
ax2_0_1 + ax1_0)
+                                    T.reads(A_reindex_pad_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("A_reindex_pad_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            T.reads(A_reindex_pad_shared[v0_o 
* 16 + v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(1, 4):
+                                with 
T.block("B_reindex_pad_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(256, ax2_0_0 * 2 + 
ax2_0_1 + ax0_0)
+                                    v1_o = T.axis.spatial(64, 
ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0)
+                                    T.reads(B_reindex_pad_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with 
T.block("B_reindex_pad_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
+                                            T.reads(B_reindex_pad_shared[v0_o 
* 16 + v0_i, v1_o * 16 + v1_i])
+                                            
T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            
B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in 
T.grid(2, 1, 1, 1, 4):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(64, 
ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + 
ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4)
+                                    v1_o = T.axis.spatial(64, 
ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4)
+                                    v2_o = T.axis.reduce(256, ax2_0_0 * 2 + 
ax2_0_1 + ax2_0_2)
+                                    
T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 
16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, 0:16, 0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i_init, v1_i_init])
+                                                
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 
v0_i_init, v1_i_init] = T.float32(0.0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 
16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + 
v1_i])
+                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i, v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, 
v1_o % 4, v0_i, v1_i] + T.Cast("float32", 
A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * 
T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 
16 + v1_i])
+                for ax2 in range(2):
+                    for ax0_ax1_fused in T.thread_binding(4, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 4):
+                            with 
T.block("C_reindex_shared_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(32, 
ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused)
+                                v1_o = T.axis.spatial(16, 
ax0_0_0_ax1_0_0_fused % 16)
+                                v2_o = T.axis.spatial(2, ax2 + ax2_1)
+                                v3_o = T.axis.spatial(4, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                
T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
+                                T.writes(C_reindex_shared[v0_o, v1_o, v2_o, 
v3_o, 0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
+                                        T.writes(C_reindex_shared[v0_o, v1_o, 
v2_o, v3_o, v4_i, v5_i])
+                                        C_reindex_shared[v0_o, v1_o, v2_o, 
v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 
v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
+                        with T.block("C_reindex_shared"):
+                            v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 
16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024)
+                            v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16)
+                            v2 = T.axis.spatial(2, ax2)
+                            v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 
1024 // 256)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+                            T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + 
v1 * 64])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
3})
+                            C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] 
= C_reindex_shared[v0, v1, v2, v3, v4, v5]
+    # fmt: on
+
+    decision_0 = [
+        ("SamplePerfectTile", [4, 2, 4, 2, 1]),
+        ("SamplePerfectTile", [16, 1, 1, 1, 4]),
+        ("SamplePerfectTile", [128, 2, 1]),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 0),
+    ]
+    mod = te.create_prim_func(
+        te_workload.matmul(
+            n=1024,
+            m=1024,
+            k=4095,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = generate_design_space(
+        kind="cuda",
+        mod=mod,
+        target=tvm.target.Target("cuda --arch=sm_70"),
+        types=None,
+        sch_rules=[multi_level_tiling_tensor_core()]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    )
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[padded_matmul_no_padded_output_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to