wllvcxz opened a new issue, #14137:
URL: https://github.com/apache/tvm/issues/14137

   I found that when tuning the fp16 tensorcore `dense_add` kernel, the tuning 
fails on some shapes and the reported error is non-deterministic.
   
   For example, when the workload is `N=1, M=1000, K=512`, the tuning fails, 
however when `N=2`, the tuning succeeds.
   
   There are two kinds of reported errors. From my observation, and the 
following error may be reported more frequently:
   
   <details>
     <summary>Click me</summary>
     
   ```
   2023-02-27 14:11:46 [INFO] Logging directory: /tmp/tmp71o3_ldv/logs
   2023-02-27 14:11:46 [INFO] LocalBuilder: max_workers = 11
   2023-02-27 14:11:47 [INFO] LocalRunner: max_workers = 1
   2023-02-27 14:11:48 [INFO] [task_scheduler.cc:159] Initializing Task #0: 
"main"
   2023-02-27 14:11:48 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task 
#0: "main"
   Traceback (most recent call last):
     File "bug_tune_dense_add.py", line 507, in <module>
       test_tune_tir_matmul_cuda_tensor_core()
     File "bug_tune_dense_add.py", line 195, in 
test_tune_tir_matmul_cuda_tensor_core
       database = tune_tir(
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", 
line 104, in tune_tir
       return tune_tasks(
     File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", 
line 117, in tune_tasks
       task_scheduler.tune(
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py",
 line 132, in tune
       _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 
237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     6: TVMFuncCall
     5: _ZN3tvm7runtime13PackedF
     4: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_s
 chedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, 
int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
 tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule:
 :Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
 void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
 tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basi
 c_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, 
tvm::runtime::TVMRetValue) const
     3: 
tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
 void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
     2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
     1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
     0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void 
(int, int)> const&)
     File "/mnt/disk5/wll/code/metaschedule/src/support/parallel_for.cc", line 
128
   RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered)
   ```
   </details>
   
   and may report this error with a lower frequency:
   
   <details>
     <summary>Click me</summary>
     
   ```
   2023-02-27 14:20:13 [INFO] Logging directory: /tmp/tmputfxvrl5/logs
   2023-02-27 14:20:13 [INFO] LocalBuilder: max_workers = 11
   2023-02-27 14:20:14 [INFO] LocalRunner: max_workers = 1
   2023-02-27 14:20:15 [INFO] [task_scheduler.cc:159] Initializing Task #0: 
"main"
   Traceback (most recent call last):
     File "bug_tune_dense_add.py", line 507, in <module>
       test_tune_tir_matmul_cuda_tensor_core()
     File "bug_tune_dense_add.py", line 195, in 
test_tune_tir_matmul_cuda_tensor_core
       database = tune_tir(
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", 
line 104, in tune_tir
       return tune_tasks(
     File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", 
line 117, in tune_tasks
       task_scheduler.tune(
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py",
 line 132, in tune
       _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
     File 
"/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 
237, in __call__
       raise get_last_ffi_error()
   tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
     9: TVMFuncCall
     8: _ZN3tvm7runtime13PackedF
     7: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_s
 chedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, 
int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
 tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, 
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule:
 :Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
 void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
 tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, 
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basi
 c_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, 
tvm::runtime::TVMRetValue) const
     6: 
tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
 void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, 
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, 
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, 
tvm::runtime::Optional<tvm::meta_schedule::Database>, 
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
     5: 
tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule 
const&)
     4: 
tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule 
const&, tvm::tir::BlockRV const&)
     3: 
tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State,
 std::allocator<tvm::meta_schedule::State> >)
     2: 
tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State)
 const
     1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, 
tvm::tir::LoopRV const&, bool, int)
     0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, 
tvm::tir::LoopRV const&, bool, int)
   ScheduleError: An error occurred in the schedule primitive 'compute-at'.
   The IR with diagnostic is:
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       @T.prim_func
       def main(p0_handle: T.handle, p1_handle: T.handle, p2_handle: T.handle, 
T_add_handle: T.handle):
           T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], 
"tir.noalias": True})
           p0 = T.match_buffer(p0_handle, (T.int64(8), T.int64(512)), "float16")
           p1 = T.match_buffer(p1_handle, (T.int64(1000), T.int64(512)), 
"float16")
           p2 = T.match_buffer(p2_handle, (T.int64(8), T.int64(1000)), 
"float16")
           T_add = T.match_buffer(T_add_handle, (T.int64(8), T.int64(1000)), 
"float16")
           # tir.Block#0
           with T.block("root"):
           ^^^^^^^^^^^^^^^^^^^^^
               T.reads()
               ^^^^^^^^^
               T.writes()
               ^^^^^^^^^^
               T_matmul_NT = T.alloc_buffer((T.int64(8), T.int64(1000)), 
"float16")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               p0_reindex = T.alloc_buffer((T.int64(16), T.int64(512)), 
"float16")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               p1_reindex = T.alloc_buffer((T.int64(1008), T.int64(512)), 
"float16")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               T_matmul_NT_reindex_shared_dyn = T.alloc_buffer((T.int64(16), 
T.int64(1008)), "float16", scope="shared.dyn")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               T_matmul_NT_reindex_shared_dyn_wmma_accumulator = 
T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", 
scope="wmma.accumulator")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               p0_reindex_shared_dyn = T.alloc_buffer((T.int64(16), 
T.int64(512)), "float16", scope="shared.dyn")
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for ax0 in range(T.int64(16)):
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   for ax1 in range(T.int64(512)):
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       with T.block("p0_reindex_reindex"):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v0 = T.axis.spatial(T.int64(16), ax0)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v1 = T.axis.spatial(T.int64(512), ax1)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.reads(p0[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^
                           T.writes(p0_reindex[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           p0_reindex[v0, v1] = T.if_then_else(v0 < T.int64(8), 
p0[v0, v1], T.float16(0))
                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for ax0 in range(T.int64(1008)):
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   for ax1 in range(T.int64(512)):
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       with T.block("p1_reindex_reindex"):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v0 = T.axis.spatial(T.int64(1008), ax0)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v1 = T.axis.spatial(T.int64(512), ax1)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.reads(p1[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^
                           T.writes(p1_reindex[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           p1_reindex[v0, v1] = T.if_then_else(v0 < 
T.int64(1000), p1[v0, v1], T.float16(0))
                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for ax0 in range(T.int64(16)):
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   for ax1 in range(T.int64(512)):
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       with T.block("p0_reindex_shared.dyn"):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v0 = T.axis.spatial(T.int64(16), ax0)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v1 = T.axis.spatial(T.int64(512), ax1)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.reads(p0_reindex[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.writes(p0_reindex_shared_dyn[v0, v1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           p0_reindex_shared_dyn[v0, v1] = p0_reindex[v0, v1]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for ax0_0_0_ax1_0_0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.y"):
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   for ax0_0_1_ax1_0_1_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       for ax0_0_2_ax1_0_2_fused in 
T.thread_binding(T.int64(3), thread="threadIdx.y"):
                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           for ax2_0_0 in range(T.int64(1)):
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               for ax2_0_1 in range(T.int64(32)):
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   for ax0_0_3 in range(T.int64(1)):
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       for ax1_0_3 in range(T.int64(21)):
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                           for ax2_0_2 in range(T.int64(1)):
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                               for ax0_0_4 in range(T.int64(1)):
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   for ax1_0_4 in 
range(T.int64(1)):
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                       with 
T.block("T_matmul_NT_o"):
                                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           v0_o = 
T.axis.spatial(T.int64(1), ax0_0_4 + ax0_0_3)
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           v1_o = 
T.axis.spatial(T.int64(63), ax1_0_4 + ax0_0_2_ax1_0_2_fused * T.int64(21) + 
ax1_0_3)
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           v2_o = 
T.axis.reduce(T.int64(32), ax2_0_0 * T.int64(32) + ax2_0_1 + ax2_0_2)
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           
T.reads(p0_reindex_shared_dyn[T.int64(0):T.int64(16), v2_o * T.int64(16):v2_o * 
T.int64(16) + T.int64(16)], p1_reindex[v1_o * T.int64(16):v1_o * T.int64(16) + 
T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)])
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16),
 v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_sync_16x16x16_f16f16f16_trans", "meta_schedule.auto_tensorize_init": 
"wmma_fill_16x16x16_f16", "meta_schedule.thread_extent_high_inclusive": 1024, 
"meta_schedule.thread_extent_low_inclusive": 1, "warp_execution": 1})
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           with T.init():
                                                           ^^^^^^^^^^^^^^
                                                               for ax0_1 in 
range(T.int64(16)):
                                                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                   for ax1_1 in 
range(T.int64(16)):
                                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                       with 
T.block("T_matmul_NT_init"):
                                                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
v0_i_init = T.axis.spatial(T.int64(16), ax0_1)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
v1_i_init = T.axis.spatial(T.int64(16), ax1_1)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T.reads()
                                                                           
^^^^^^^^^
                                                                           
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * 
T.int64(16) + v1_i_init])
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + 
v1_i_init] = T.float16(0)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                           for ax0_1 in 
range(T.int64(16)):
                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                               for ax1_1 in 
range(T.int64(16)):
                                                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                   for ax2_1 in 
range(T.int64(16)):
                                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                       with 
T.block("T_matmul_NT"):
                                                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           v0_i 
= T.axis.spatial(T.int64(16), ax0_1)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           v1_i 
= T.axis.spatial(T.int64(16), ax1_1)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           v2_i 
= T.axis.reduce(T.int64(16), ax2_1)
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * 
T.int64(16) + v1_i], p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i], 
p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i])
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * 
T.int64(16) + v1_i])
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                           
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + 
v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * 
T.int64(16) + v1_i] + p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i] * 
p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i]
                                                                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           for ax0_0 in range(T.int64(1)):
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               for ax1_0 in range(T.int64(21)):
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   with 
T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator_o"):
                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       v0_o = T.axis.spatial(T.int64(1), ax0_0)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       v1_o = T.axis.spatial(T.int64(63), 
ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0)
                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), 
v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       
T.writes(T_matmul_NT_reindex_shared_dyn[T.int64(0):T.int64(16), v1_o * 
T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f16_shared_dyn"})
                                       
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                       for ax0_1 in range(T.int64(16)):
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                           for ax1_1 in range(T.int64(16)):
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                               with 
T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator"):
                                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   v0_i = 
T.axis.spatial(T.int64(16), ax0_1)
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   v1_i = 
T.axis.spatial(T.int64(16), ax1_1)
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * 
T.int64(16) + v1_i])
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   
T.writes(T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i])
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                   
T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i] = 
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]
                                                   
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       for ax0_ax1_fused in range(T.int64(16128)):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           with T.block("T_matmul_NT_reindex_shared.dyn"):
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               v0 = T.axis.spatial(T.int64(16), ax0_ax1_fused 
// T.int64(1008))
                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               v1 = T.axis.spatial(T.int64(1008), ax0_ax1_fused 
% T.int64(1008))
                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               T.where(ax0_ax1_fused < T.int64(8056) and 
ax0_ax1_fused % T.int64(1008) < T.int64(1000))
                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               T.reads(T_matmul_NT_reindex_shared_dyn[v0, v1])
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               T.writes(T_matmul_NT[v0, v1])
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               T.block_attr({"meta_schedule.cooperative_fetch": 
1})
                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               T_matmul_NT[v0, v1] = 
T_matmul_NT_reindex_shared_dyn[v0, v1]
                               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for ax0 in range(T.int64(8)):
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   for ax1 in range(T.int64(1000)):
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                       with T.block("T_add"):
                       ^^^^^^^^^^^^^^^^^^^^^^
                           v_ax0 = T.axis.spatial(T.int64(8), ax0)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           v_ax1 = T.axis.spatial(T.int64(1000), ax1)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T.writes(T_add[v_ax0, v_ax1])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + 
p2[v_ax0, v_ax1]
   
                           
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   Error message: The scope tir.Block#0 is not a stage pipeline.
   Definition of a scope that is a stage pipeline:
   - The region cover property holds for every of its child blocks
   - No write-after-read dependency or opaque dependency,
   - only read-after-write and write-after-write are allowed
   - All the statements in the scope are schedulable statements, i.e. Block and 
For
   
   ```
   
   </details>
   
   I tried different `N`, and found that when `N=2, 4, 8, 12, 17, 18, 24` the 
tuning still fails, but when `N=16, 32` it succeeds. I guess it may be because 
of the alignment requirement of `m16n16k16` tensor core.
   
   
   ### Expected behavior
   
   The tuning succeeds
   
   ### Environment
   
   * OS: Ubuntu 20.04.3 LTS
   * TVM version: main branch
   * GPU: nvidia-a100
   
   ### Steps to reproduce
   
   ```python3
   import tempfile
   import os
   import numpy as np
   
   import tvm
   import tvm.tir.tensor_intrin
   from tvm import meta_schedule as ms
   from tvm.meta_schedule import tune_tir
   from tvm.meta_schedule.database import JSONDatabase
   from tvm.target import Target
   from tvm.tir import Schedule
   from tvm.ir.transform import PassContext
   from tvm.meta_schedule.testing import te_workload
   from tvm import tir
   from tvm.script import ir as I
   from tvm.script import tir as T
   
   
   @I.ir_module
   class Module0:
       @T.prim_func
       def main(p0: T.Buffer((T.int64(1), T.int64(512)), "float16"), p1: 
T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(1), 
T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(1000)), 
"float16")):
           T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], 
"tir.noalias": True})
           # with T.block("root"):
           T_matmul_NT = T.alloc_buffer((T.int64(1), T.int64(1000)), "float16")
           for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(512)):
               with T.block("T_matmul_NT"):
                   v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                   T.reads(p0[v_i, v_k], p1[v_j, v_k])
                   T.writes(T_matmul_NT[v_i, v_j])
                   with T.init():
                       T_matmul_NT[v_i, v_j] = T.float16(0)
                   T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] 
* p1[v_j, v_k]
           for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
               with T.block("T_add"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                   T.writes(T_add[v_ax0, v_ax1])
                   T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, 
v_ax1]
   
   
   @I.ir_module
   class Module1:
       @T.prim_func
       def main(p0: T.Buffer((T.int64(16), T.int64(512)), "float16"), p1: 
T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(16), 
T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(16), T.int64(1000)), 
"float16")):
           T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], 
"tir.noalias": True})
           # with T.block("root"):
           T_matmul_NT = T.alloc_buffer((T.int64(16), T.int64(1000)), "float16")
           for i, j, k in T.grid(T.int64(16), T.int64(1000), T.int64(512)):
               with T.block("T_matmul_NT"):
                   v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                   T.reads(p0[v_i, v_k], p1[v_j, v_k])
                   T.writes(T_matmul_NT[v_i, v_j])
                   with T.init():
                       T_matmul_NT[v_i, v_j] = T.float16(0)
                   T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] 
* p1[v_j, v_k]
           for ax0, ax1 in T.grid(T.int64(16), T.int64(1000)):
               with T.block("T_add"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                   T.writes(T_add[v_ax0, v_ax1])
                   T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, 
v_ax1]
   
   
   def tune_dense_add_cuda_tensor_core():
       target = Target("nvidia/nvidia-a100")
       with tempfile.TemporaryDirectory() as work_dir:
           database = ms.database.Database.create(kind="json", 
work_dir=work_dir)
           mod = Module0   # failed
           # mod = Module1   # success
           database = tune_tir(
               mod=mod,
               target=target,
               work_dir=work_dir,
               num_trials_per_iter=10,
               max_trials_global=10,
               strategy="replay-trace",
               # strategy="evolutionary",
               database=database,
           )
           sch = ms.tir_integration.compile_tir(database, mod, target)
           if sch is None:
               print("No valid schedule found!")
           else:
               from tvm.contrib import nvcc
               import numpy as np
               ctx = tvm.cuda()
               if nvcc.have_tensorcore(ctx.compute_version):
                   with tvm.transform.PassContext(config={"tir.use_async_copy": 
1}):
                       func = tvm.build(sch.mod["main"], [], "cuda")
                       # print(func.imported_modules[0].get_source())
                       # print(sch.mod.script())
                       # print(sch.trace)
   
   
   if __name__ == "__main__":
       tune_dense_add_cuda_tensor_core()
   
   ```
   
   
   ### Triage
   * tune:meta_schedule
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to