wllvcxz opened a new issue, #14137:
URL: https://github.com/apache/tvm/issues/14137
I found that when tuning the fp16 tensorcore `dense_add` kernel, the tuning
fails on some shapes and the reported error is non-deterministic.
For example, when the workload is `N=1, M=1000, K=512`, the tuning fails,
however when `N=2`, the tuning succeeds.
There are two kinds of reported errors. From my observation, and the
following error may be reported more frequently:
<details>
<summary>Click me</summary>
```
2023-02-27 14:11:46 [INFO] Logging directory: /tmp/tmp71o3_ldv/logs
2023-02-27 14:11:46 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:11:47 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:159] Initializing Task #0:
"main"
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task
#0: "main"
Traceback (most recent call last):
File "bug_tune_dense_add.py", line 507, in <module>
test_tune_tir_matmul_cuda_tensor_core()
File "bug_tune_dense_add.py", line 195, in
test_tune_tir_matmul_cuda_tensor_core
database = tune_tir(
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py",
line 104, in tune_tir
return tune_tasks(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py",
line 117, in tune_tasks
task_scheduler.tune(
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py",
line 132, in tune
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line
237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
6: TVMFuncCall
5: _ZN3tvm7runtime13PackedF
4: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_s
chedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int,
int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule:
:Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basi
c_string<char, std::char_traits<char>, std::allocator<char>
>)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const,
tvm::runtime::TVMRetValue) const
3:
tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void
(int, int)> const&)
File "/mnt/disk5/wll/code/metaschedule/src/support/parallel_for.cc", line
128
RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered)
```
</details>
and may report this error with a lower frequency:
<details>
<summary>Click me</summary>
```
2023-02-27 14:20:13 [INFO] Logging directory: /tmp/tmputfxvrl5/logs
2023-02-27 14:20:13 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:20:14 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:20:15 [INFO] [task_scheduler.cc:159] Initializing Task #0:
"main"
Traceback (most recent call last):
File "bug_tune_dense_add.py", line 507, in <module>
test_tune_tir_matmul_cuda_tensor_core()
File "bug_tune_dense_add.py", line 195, in
test_tune_tir_matmul_cuda_tensor_core
database = tune_tir(
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py",
line 104, in tune_tir
return tune_tasks(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py",
line 117, in tune_tasks
task_scheduler.tune(
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py",
line 132, in tune
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
File
"/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line
237, in __call__
raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
9: TVMFuncCall
8: _ZN3tvm7runtime13PackedF
7: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_s
chedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int,
int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule:
:Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler,
tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>,
tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basi
c_string<char, std::char_traits<char>, std::allocator<char>
>)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const,
tvm::runtime::TVMRetValue) const
6:
tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext,
void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int,
tvm::meta_schedule::Builder, tvm::meta_schedule::Runner,
tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>,
tvm::runtime::Optional<tvm::meta_schedule::Database>,
tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
5:
tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule
const&)
4:
tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule
const&, tvm::tir::BlockRV const&)
3:
tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State,
std::allocator<tvm::meta_schedule::State> >)
2:
tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State)
const
1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&,
tvm::tir::LoopRV const&, bool, int)
0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&,
tvm::tir::LoopRV const&, bool, int)
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(p0_handle: T.handle, p1_handle: T.handle, p2_handle: T.handle,
T_add_handle: T.handle):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1],
"tir.noalias": True})
p0 = T.match_buffer(p0_handle, (T.int64(8), T.int64(512)), "float16")
p1 = T.match_buffer(p1_handle, (T.int64(1000), T.int64(512)),
"float16")
p2 = T.match_buffer(p2_handle, (T.int64(8), T.int64(1000)),
"float16")
T_add = T.match_buffer(T_add_handle, (T.int64(8), T.int64(1000)),
"float16")
# tir.Block#0
with T.block("root"):
^^^^^^^^^^^^^^^^^^^^^
T.reads()
^^^^^^^^^
T.writes()
^^^^^^^^^^
T_matmul_NT = T.alloc_buffer((T.int64(8), T.int64(1000)),
"float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex = T.alloc_buffer((T.int64(16), T.int64(512)),
"float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p1_reindex = T.alloc_buffer((T.int64(1008), T.int64(512)),
"float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn = T.alloc_buffer((T.int64(16),
T.int64(1008)), "float16", scope="shared.dyn")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator =
T.alloc_buffer((T.int64(16), T.int64(1008)), "float16",
scope="wmma.accumulator")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex_shared_dyn = T.alloc_buffer((T.int64(16),
T.int64(512)), "float16", scope="shared.dyn")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p0_reindex_reindex"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0[v0, v1])
^^^^^^^^^^^^^^^^^^^
T.writes(p0_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex[v0, v1] = T.if_then_else(v0 < T.int64(8),
p0[v0, v1], T.float16(0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(1008)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p1_reindex_reindex"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(1008), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p1[v0, v1])
^^^^^^^^^^^^^^^^^^^
T.writes(p1_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p1_reindex[v0, v1] = T.if_then_else(v0 <
T.int64(1000), p1[v0, v1], T.float16(0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p0_reindex_shared.dyn"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(p0_reindex_shared_dyn[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex_shared_dyn[v0, v1] = p0_reindex[v0, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_0_ax1_0_0_fused in T.thread_binding(T.int64(1),
thread="blockIdx.y"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_1_ax1_0_1_fused in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_2_ax1_0_2_fused in
T.thread_binding(T.int64(3), thread="threadIdx.y"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_0 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_1 in range(T.int64(32)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_3 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0_3 in range(T.int64(21)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_2 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_4 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0_4 in
range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with
T.block("T_matmul_NT_o"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_o =
T.axis.spatial(T.int64(1), ax0_0_4 + ax0_0_3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_o =
T.axis.spatial(T.int64(63), ax1_0_4 + ax0_0_2_ax1_0_2_fused * T.int64(21) +
ax1_0_3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2_o =
T.axis.reduce(T.int64(32), ax2_0_0 * T.int64(32) + ax2_0_1 + ax2_0_2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0_reindex_shared_dyn[T.int64(0):T.int64(16), v2_o * T.int64(16):v2_o *
T.int64(16) + T.int64(16)], p1_reindex[v1_o * T.int64(16):v1_o * T.int64(16) +
T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16),
v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_sync_16x16x16_f16f16f16_trans", "meta_schedule.auto_tensorize_init":
"wmma_fill_16x16x16_f16", "meta_schedule.thread_extent_high_inclusive": 1024,
"meta_schedule.thread_extent_low_inclusive": 1, "warp_execution": 1})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.init():
^^^^^^^^^^^^^^
for ax0_1 in
range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in
range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with
T.block("T_matmul_NT_init"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i_init = T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i_init = T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads()
^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o *
T.int64(16) + v1_i_init])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) +
v1_i_init] = T.float16(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_1 in
range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in
range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_1 in
range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with
T.block("T_matmul_NT"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i
= T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i
= T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2_i
= T.axis.reduce(T.int64(16), ax2_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o *
T.int64(16) + v1_i], p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i],
p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o *
T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) +
v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o *
T.int64(16) + v1_i] + p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i] *
p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0 in range(T.int64(21)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with
T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator_o"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_o = T.axis.spatial(T.int64(1), ax0_0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_o = T.axis.spatial(T.int64(63),
ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16),
v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn[T.int64(0):T.int64(16), v1_o *
T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f16_shared_dyn"})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with
T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i =
T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i =
T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o *
T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i] =
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_ax1_fused in range(T.int64(16128)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_reindex_shared.dyn"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0_ax1_fused
// T.int64(1008))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(1008), ax0_ax1_fused
% T.int64(1008))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.where(ax0_ax1_fused < T.int64(8056) and
ax0_ax1_fused % T.int64(1008) < T.int64(1000))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.cooperative_fetch":
1})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT[v0, v1] =
T_matmul_NT_reindex_shared_dyn[v0, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(8)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(1000)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_add"):
^^^^^^^^^^^^^^^^^^^^^^
v_ax0 = T.axis.spatial(T.int64(8), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v_ax1 = T.axis.spatial(T.int64(1000), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_add[v_ax0, v_ax1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] +
p2[v_ax0, v_ax1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and
For
```
</details>
I tried different `N`, and found that when `N=2, 4, 8, 12, 17, 18, 24` the
tuning still fails, but when `N=16, 32` it succeeds. I guess it may be because
of the alignment requirement of `m16n16k16` tensor core.
### Expected behavior
The tuning succeeds
### Environment
* OS: Ubuntu 20.04.3 LTS
* TVM version: main branch
* GPU: nvidia-a100
### Steps to reproduce
```python3
import tempfile
import os
import numpy as np
import tvm
import tvm.tir.tensor_intrin
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.meta_schedule.database import JSONDatabase
from tvm.target import Target
from tvm.tir import Schedule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.testing import te_workload
from tvm import tir
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module0:
@T.prim_func
def main(p0: T.Buffer((T.int64(1), T.int64(512)), "float16"), p1:
T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(1),
T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(1000)),
"float16")):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1],
"tir.noalias": True})
# with T.block("root"):
T_matmul_NT = T.alloc_buffer((T.int64(1), T.int64(1000)), "float16")
for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(512)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(p0[v_i, v_k], p1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float16(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k]
* p1[v_j, v_k]
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0,
v_ax1]
@I.ir_module
class Module1:
@T.prim_func
def main(p0: T.Buffer((T.int64(16), T.int64(512)), "float16"), p1:
T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(16),
T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(16), T.int64(1000)),
"float16")):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1],
"tir.noalias": True})
# with T.block("root"):
T_matmul_NT = T.alloc_buffer((T.int64(16), T.int64(1000)), "float16")
for i, j, k in T.grid(T.int64(16), T.int64(1000), T.int64(512)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(p0[v_i, v_k], p1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float16(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k]
* p1[v_j, v_k]
for ax0, ax1 in T.grid(T.int64(16), T.int64(1000)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0,
v_ax1]
def tune_dense_add_cuda_tensor_core():
target = Target("nvidia/nvidia-a100")
with tempfile.TemporaryDirectory() as work_dir:
database = ms.database.Database.create(kind="json",
work_dir=work_dir)
mod = Module0 # failed
# mod = Module1 # success
database = tune_tir(
mod=mod,
target=target,
work_dir=work_dir,
num_trials_per_iter=10,
max_trials_global=10,
strategy="replay-trace",
# strategy="evolutionary",
database=database,
)
sch = ms.tir_integration.compile_tir(database, mod, target)
if sch is None:
print("No valid schedule found!")
else:
from tvm.contrib import nvcc
import numpy as np
ctx = tvm.cuda()
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext(config={"tir.use_async_copy":
1}):
func = tvm.build(sch.mod["main"], [], "cuda")
# print(func.imported_modules[0].get_source())
# print(sch.mod.script())
# print(sch.trace)
if __name__ == "__main__":
tune_dense_add_cuda_tensor_core()
```
### Triage
* tune:meta_schedule
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]