wllvcxz opened a new issue, #14112:
URL: https://github.com/apache/tvm/issues/14112
I tune an int8 conv2d kernel with MetaSchedule and save its
schedule(`apply_trace`) generated by `print(sch.trace)`. When I run the
`apply_trace` again, I get the following error:
```
7: TVMFuncCall
6: _ZN3tvm7runtime13PackedF
5: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule,
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap>
const&)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule,
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap>
const&)#16}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&,
int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap>
const&)#16}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const,
tvm::runtime::TVMRetValue) const [clone .isra.0]
4: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&,
int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&)
3: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&)
2: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&)
1:
tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>)
const
0: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&,
tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel)
File "/home/wll/code/metaschedule/src/tir/ir/index_map.cc", line 95
TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could
not parse mapping as sum of iterators. Error: IterMapExpr or subclasses should
only result from calls in IterMapRewriter using DirectMutate. Indirect return
occurred in v_nn
```
on this line `sch.transform_block_layout(block=b3, index_map=lambda v_nn,
v_yy, v_xx, v_ff: (v_nn * T.int64(64) + v_yy * T.int64(8) + v_xx, v_ff,))`
It seems that there is some problem with the generated `apply_trace`.
### Environment
OS: Ubuntu 20.04.3 LTS
TVM version: main branch
GPU: nvidia-a100
### Steps to reproduce
```python3
import logging
import tempfile
import os
import numpy as np
import tvm
import tvm.tir.tensor_intrin
from tvm import tir
from tvm import relay
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.meta_schedule.database import JSONDatabase
from tvm.target import Target
from tvm.tir import Schedule
from tvm.ir.transform import PassContext
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer((T.int64(1), T.int64(16), T.int64(16),
T.int64(128)), "int8"), p1: T.Buffer((T.int64(128), T.int64(3), T.int64(3),
T.int64(128)), "int8"), conv2d_nhwc: T.Buffer((T.int64(1), T.int64(8),
T.int64(8), T.int64(128)), "int32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(1), T.int64(18), T.int64(18),
T.int64(128)), "int8")
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(18), T.int64(18),
T.int64(128)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
T.reads(p0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1)
<= v_i1 and v_i1 < T.int64(17) and T.int64(1) <= v_i2 and v_i2 < T.int64(17),
p0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], T.int8(0))
for nn, yy, xx, ff, ry, rx, rc in T.grid(T.int64(1), T.int64(8),
T.int64(8), T.int64(128), T.int64(3), T.int64(3), T.int64(128)):
with T.block("conv2d_nhwc"):
v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc =
T.axis.remap("SSSSRRR", [nn, yy, xx, ff, ry, rx, rc])
T.reads(pad_temp[v_nn, v_yy * T.int64(2) + v_ry, v_xx *
T.int64(2) + v_rx, v_rc], p1[v_ff, v_ry, v_rx, v_rc])
T.writes(conv2d_nhwc[v_nn, v_yy, v_xx, v_ff])
with T.init():
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = 0
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = conv2d_nhwc[v_nn,
v_yy, v_xx, v_ff] + T.Cast("int32", pad_temp[v_nn, v_yy * T.int64(2) + v_ry,
v_xx * T.int64(2) + v_rx, v_rc]) * T.Cast("int32", p1[v_ff, v_ry, v_rx, v_rc])
def tune_conv2d_add():
target = Target("nvidia/nvidia-a100")
with tempfile.TemporaryDirectory() as work_dir:
database = ms.database.Database.create(kind="json",
work_dir=work_dir)
mod = Module
database = tune_tir(
mod=mod,
target=target,
work_dir=work_dir,
num_trials_per_iter=32,
max_trials_global=32,
strategy="replay-trace",
# strategy="evolutionary",
database=database,
)
sch = ms.tir_integration.compile_tir(database, mod, target)
if sch is None:
print("No valid schedule found!")
else:
from tvm.contrib import nvcc
ctx = tvm.cuda()
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext(config={"tir.use_async_copy":
1}):
func = tvm.build(sch.mod["main"], [], "cuda")
# print(func.imported_modules[0].get_source())
# print(sch.mod.script())
print(sch.trace)
# generated by tune_conv2d_add
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="pad_temp", func_name="main")
b1 = sch.get_block(name="conv2d_nhwc", func_name="main")
b2 = sch.get_block(name="root", func_name="main")
sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure",
ann_val="SSSRRSRS")
b3 = sch.reindex(block=b1, buffer=("write", 0))
b4 = sch.reindex(block=b1, buffer=("read", 0))
b5 = sch.reindex(block=b1, buffer=("read", 1))
sch.transform_layout(block=b1, buffer=("read", 0), index_map=lambda
v_nn, v_yy, v_xx, v_ry, v_rx, v_rc: (v_nn * T.int64(64) + v_yy * T.int64(8) +
v_xx, v_ry * T.int64(384) + v_rx * T.int64(128) + v_rc,), pad_value=None)
sch.transform_layout(block=b1, buffer=("read", 1), index_map=lambda
v_ff, v_ry, v_rx, v_rc: (v_ff, v_ry * T.int64(384) + v_rx * T.int64(128) +
v_rc,), pad_value=None)
sch.transform_layout(block=b1, buffer=("write", 0), index_map=lambda
v_nn, v_yy, v_xx, v_ff: (v_nn * T.int64(64) + v_yy * T.int64(8) + v_xx, v_ff,),
pad_value=None)
sch.transform_block_layout(block=b3, index_map=lambda v_nn, v_yy, v_xx,
v_ff: (v_nn * T.int64(64) + v_yy * T.int64(8) + v_xx, v_ff,))
sch.transform_block_layout(block=b4, index_map=lambda v_nn, v_yy, v_xx,
v_ry, v_rx, v_rc: (v_nn * T.int64(64) + v_yy * T.int64(8) + v_xx, v_ry *
T.int64(384) + v_rx * T.int64(128) + v_rc,))
sch.transform_block_layout(block=b5, index_map=lambda v_ff, v_ry, v_rx,
v_rc: (v_ff, v_ry * T.int64(384) + v_rx * T.int64(128) + v_rc,))
sch.transform_block_layout(block=b1, index_map=lambda v_nn, v_yy, v_xx,
v_ff, v_ry, v_rx, v_rc: (v_nn * T.int64(64) + v_yy * T.int64(8) + v_xx, v_ff,
v_ry * T.int64(384) + v_rx * T.int64(128) + v_rc,))
l6, l7, l8 = sch.get_loops(block=b1)
l9, l10 = sch.split(loop=l8, factors=[None, 16],
preserve_unit_iters=True)
l11, l12 = sch.split(loop=l7, factors=[None, 16],
preserve_unit_iters=True)
l13, l14 = sch.split(loop=l6, factors=[None, 16],
preserve_unit_iters=True)
l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b1)
sch.reorder(l17, l19, l14, l12, l10)
b21 = sch.blockize(loop=l14, preserve_unit_iters=True)
sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_sync_16x16x16_s8s8s32_trans")
sch.annotate(block_or_loop=b21,
ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_s32")
sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1)
l22, l23, l24 = sch.get_loops(block=b21)
v25, v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l22, n=5,
max_innermost_factor=4, decision=[2, 1, 1, 1, 2])
l30, l31, l32, l33, l34 = sch.split(loop=l22, factors=[v25, v26, v27,
v28, v29], preserve_unit_iters=True)
v35, v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l23, n=5,
max_innermost_factor=4, decision=[2, 4, 1, 1, 1])
l40, l41, l42, l43, l44 = sch.split(loop=l23, factors=[v35, v36, v37,
v38, v39], preserve_unit_iters=True)
v45, v46, v47 = sch.sample_perfect_tile(loop=l24, n=3,
max_innermost_factor=4, decision=[18, 1, 4])
l48, l49, l50 = sch.split(loop=l24, factors=[v45, v46, v47],
preserve_unit_iters=True)
sch.reorder(l30, l40, l31, l41, l32, l42, l48, l49, l33, l43, l50, l34,
l44)
l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
sch.bind(loop=l51, thread_axis="blockIdx.y")
l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
sch.bind(loop=l52, thread_axis="blockIdx.x")
l53 = sch.fuse(l32, l42, preserve_unit_iters=True)
sch.bind(loop=l53, thread_axis="threadIdx.y")
sch.annotate(block_or_loop=b21,
ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=1)
sch.annotate(block_or_loop=b21,
ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
b54 = sch.cache_write(block=b21, write_buffer_index=0,
storage_scope="shared.dyn")
sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True,
index=-1)
b55 = sch.cache_write(block=b21, write_buffer_index=0,
storage_scope="wmma.accumulator")
sch.reverse_compute_at(block=b55, loop=l53, preserve_unit_loops=True,
index=-1)
l56, l57, l58, l59 = sch.get_loops(block=b54)
l60 = sch.fuse(l58, l59, preserve_unit_iters=True)
v61 = sch.sample_categorical(candidates=[1, 2, 3, 4, 8, 16],
probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666,
0.16666666666666666, 0.16666666666666666, 0.16666666666666666], decision=0)
sch.annotate(block_or_loop=b54,
ann_key="meta_schedule.cooperative_fetch", ann_val=v61)
sch.reverse_compute_inline(block=b3)
l62, l63, l64, l65, l66 = sch.get_loops(block=b55)
l67, l68 = sch.split(loop=l66, factors=[None, 16],
preserve_unit_iters=True)
l69, l70 = sch.split(loop=l65, factors=[None, 16],
preserve_unit_iters=True)
l71, l72, l73, l74, l75, l76, l77 = sch.get_loops(block=b55)
sch.reorder(l76, l70, l68)
b78 = sch.blockize(loop=l70, preserve_unit_iters=True)
sch.annotate(block_or_loop=b78, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_store_16x16x16_s32_shared_dyn")
b79 = sch.cache_read(block=b21, read_buffer_index=0,
storage_scope="shared.dyn", consumer_blocks=[b21])
sch.compute_at(block=b79, loop=l48, preserve_unit_loops=True, index=-1)
l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
v87 = sch.sample_categorical(candidates=[1, 2, 3, 4, 8, 16],
probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666,
0.16666666666666666, 0.16666666666666666, 0.16666666666666666], decision=3)
sch.annotate(block_or_loop=b79,
ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
b88 = sch.cache_read(block=b21, read_buffer_index=1,
storage_scope="shared.dyn", consumer_blocks=[b21])
sch.compute_at(block=b88, loop=l48, preserve_unit_loops=True, index=-1)
l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b88)
l95 = sch.fuse(l93, l94, preserve_unit_iters=True)
v96 = sch.sample_categorical(candidates=[1, 2, 3, 4, 8, 16],
probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666,
0.16666666666666666, 0.16666666666666666, 0.16666666666666666], decision=3)
sch.annotate(block_or_loop=b88,
ann_key="meta_schedule.cooperative_fetch", ann_val=v96)
b97 = sch.cache_read(block=b21, read_buffer_index=0,
storage_scope="wmma.matrix_a")
sch.compute_at(block=b97, loop=l49, preserve_unit_loops=True, index=-1)
l98, l99, l100, l101, l102, l103, l104 = sch.get_loops(block=b97)
l105, l106 = sch.split(loop=l104, factors=[None, 16],
preserve_unit_iters=True)
l107, l108 = sch.split(loop=l103, factors=[None, 16],
preserve_unit_iters=True)
l109, l110, l111, l112, l113, l114, l115, l116, l117 =
sch.get_loops(block=b97)
sch.reorder(l116, l108, l106)
b118 = sch.blockize(loop=l108, preserve_unit_iters=True)
sch.annotate(block_or_loop=b118, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_load_16x16x16_s8_a_shared_dyn")
b119 = sch.cache_read(block=b21, read_buffer_index=1,
storage_scope="wmma.matrix_b")
sch.compute_at(block=b119, loop=l49, preserve_unit_loops=True, index=-1)
l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b119)
l127, l128 = sch.split(loop=l126, factors=[None, 16],
preserve_unit_iters=True)
l129, l130 = sch.split(loop=l125, factors=[None, 16],
preserve_unit_iters=True)
l131, l132, l133, l134, l135, l136, l137, l138, l139 =
sch.get_loops(block=b119)
sch.reorder(l138, l130, l128)
b140 = sch.blockize(loop=l130, preserve_unit_iters=True)
sch.annotate(block_or_loop=b140, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_load_16x16x16_s8_b_trans_shared_dyn")
sch.compute_inline(block=b4)
sch.compute_inline(block=b5)
sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32,
offset=16)
sch.storage_align(block=b88, buffer_index=0, axis=-2, factor=32,
offset=16)
sch.compute_inline(block=b0)
v141 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024],
probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001,
0.20000000000000001, 0.20000000000000001], decision=4)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit",
ann_val=v141)
sch.enter_postproc()
sch.unannotate(block_or_loop=b54,
ann_key="meta_schedule.cooperative_fetch")
l142, l143, l144 = sch.get_loops(block=b54)
l145, l146, l147 = sch.split(loop=l144, factors=[None, 1, 32],
preserve_unit_iters=True)
sch.bind(loop=l147, thread_axis="threadIdx.x")
sch.bind(loop=l146, thread_axis="threadIdx.y")
sch.unannotate(block_or_loop=b79,
ann_key="meta_schedule.cooperative_fetch")
l148, l149, l150, l151, l152 = sch.get_loops(block=b79)
l153, l154, l155, l156 = sch.split(loop=l152, factors=[None, 1, 32, 4],
preserve_unit_iters=True)
sch.vectorize(loop=l156)
sch.bind(loop=l155, thread_axis="threadIdx.x")
sch.bind(loop=l154, thread_axis="threadIdx.y")
sch.unannotate(block_or_loop=b88,
ann_key="meta_schedule.cooperative_fetch")
l157, l158, l159, l160, l161 = sch.get_loops(block=b88)
l162, l163, l164, l165 = sch.split(loop=l161, factors=[None, 1, 32, 4],
preserve_unit_iters=True)
sch.vectorize(loop=l165)
sch.bind(loop=l164, thread_axis="threadIdx.x")
sch.bind(loop=l163, thread_axis="threadIdx.y")
b166 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b166,
ann_key="meta_schedule.unroll_explicit")
b167, b168, b169, b170, b171, b172, b173 = sch.get_child_blocks(b166)
l174, l175, l176, l177, l178, l179, l180, l181 =
sch.get_loops(block=b167)
sch.annotate(block_or_loop=l174, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l174, ann_key="pragma_unroll_explicit",
ann_val=1)
l182, l183, l184, l185, l186, l187, l188, l189 =
sch.get_loops(block=b168)
sch.annotate(block_or_loop=l182, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l182, ann_key="pragma_unroll_explicit",
ann_val=1)
l190, l191, l192, l193, l194, l195, l196 = sch.get_loops(block=b169)
sch.annotate(block_or_loop=l190, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l190, ann_key="pragma_unroll_explicit",
ann_val=1)
l197, l198, l199, l200, l201, l202, l203 = sch.get_loops(block=b170)
sch.annotate(block_or_loop=l197, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l197, ann_key="pragma_unroll_explicit",
ann_val=1)
l204, l205, l206, l207, l208, l209, l210, l211, l212, l213 =
sch.get_loops(block=b171)
sch.annotate(block_or_loop=l204, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l204, ann_key="pragma_unroll_explicit",
ann_val=1)
l214, l215, l216, l217, l218 = sch.get_loops(block=b172)
sch.annotate(block_or_loop=l214, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l214, ann_key="pragma_unroll_explicit",
ann_val=1)
l219, l220, l221, l222, l223 = sch.get_loops(block=b173)
sch.annotate(block_or_loop=l219, ann_key="pragma_auto_unroll_max_step",
ann_val=1024)
sch.annotate(block_or_loop=l219, ann_key="pragma_unroll_explicit",
ann_val=1)
b224 = sch.get_block(name="conv2d_nhwc_o", func_name="main")
l225, l226, l227, l228, l229, l230, l231, l232, l233, l234 =
sch.get_loops(block=b224)
b235 = sch.decompose_reduction(block=b224, loop=l228)
sch.unannotate(block_or_loop=b235,
ann_key="meta_schedule.auto_tensorize")
sch.annotate(block_or_loop=b235, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_fill_16x16x16_s32")
sch.unannotate(block_or_loop=b224,
ann_key="meta_schedule.auto_tensorize_init")
sch.unannotate(block_or_loop=b235,
ann_key="meta_schedule.auto_tensorize_init")
b236 = sch.get_block(name="conv2d_nhwc_o_init", func_name="main")
sch.unannotate(block_or_loop=b236,
ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b236,
tensor_intrin="wmma_fill_16x16x16_s32", preserve_unit_iters=True)
b237 = sch.get_block(name="pad_temp_reindex_shared.dyn_wmma.matrix_a_o",
func_name="main")
sch.unannotate(block_or_loop=b237,
ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b237,
tensor_intrin="wmma_load_16x16x16_s8_a_shared_dyn", preserve_unit_iters=True)
b238 = sch.get_block(name="p1_reindex_shared.dyn_wmma.matrix_b_o",
func_name="main")
sch.unannotate(block_or_loop=b238,
ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b238,
tensor_intrin="wmma_load_16x16x16_s8_b_trans_shared_dyn",
preserve_unit_iters=True)
b239 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main")
sch.unannotate(block_or_loop=b239,
ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b239,
tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans", preserve_unit_iters=True)
b240 =
sch.get_block(name="conv2d_nhwc_reindex_shared.dyn_wmma.accumulator_o",
func_name="main")
sch.unannotate(block_or_loop=b240,
ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b240,
tensor_intrin="wmma_store_16x16x16_s32_shared_dyn", preserve_unit_iters=True)
def run_apply_trace():
mod = Module
sch = tvm.tir.Schedule(mod)
apply_trace(sch)
if __name__ == "__main__":
# tune_conv2d_add()
run_apply_trace()
```
### Triage
* tune:meta_schedule
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]