tsu-bin opened a new pull request, #17012:
URL: https://github.com/apache/tvm/pull/17012
Below script can be used to reproduce the issue. You may run it several
times to reproduce, because sample_perfect_tile may sometime to hide the issue
with some decision.
`
in_type="float16"
out_type="float16"
BS = 100
MM = 32
NN = 32
KK = 32
def create_batch_matmul(
b: int = BS, m: int = MM, n: int = NN, k: int = KK, in_dtype: str =
in_type, out_dtype: str = out_type
) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
A = te.placeholder((b, m, k), name="A", dtype=in_dtype)
B = te.placeholder((b, n, k), name="B", dtype=in_dtype)
C = topi.nn.batch_matmul(A, B)
return (A, B, C)
space = ms.space_generator.PostOrderApply(
sch_rules="cuda-tensorcore",
postprocs="cuda-tensorcore",
)
database = ms.tune_tir(
mod=te.create_prim_func( create_batch_matmul () ),
target=tvm.target.Target("cuda -arch=sm_89
-max_shared_memory_per_block=49152 -max_threads_per_block=1024"),
max_trials_global = 3,
space=space,
work_dir="./xb_demo/debug_batch_matmul/meta_db/test_batch_matmul0/",
)
`
The error log is something like below.
`
3: tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*,
tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt
const&)#1}::operator()(tvm::tir::Stmt const&) const
at /hostShare/tools/tvm_all/tvm-dev/src/tir/ir/stmt_functor.cc:210
2: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
at
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:60
1: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
at
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:63
0: tvm::tir::Stmt
tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode
const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
at
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112
File "/hostShare/tools/tvm_all/tvm-dev/src/support/parallel_for.cc", line
139
RuntimeError: parallel_for_dynamic error with [22:30:41]
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112:
Check failed: (ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) is
false: ValueError: All loops that are bound to `threadIdx.y` should have the
same extent. However, there are two loops with extent 12 and 4, which are not
equal
`
I also pasted the trace, you can replay it line by line and print out the
loop extent to verify the inconsistency.
`
b0 = sch.get_block(name="T_batch_matmul_NT", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure",
ann_val="SSSRRSRS")
b1 = sch.reindex(block=b0, buffer=("write", 0))
b2 = sch.reindex(block=b0, buffer=("read", 0))
b3 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda v_b,
v_i, v_k: (v_b, v_i, v_k,), pad_value=None, assume_injective_transform=True)
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda v_b,
v_j, v_k: (v_b, v_k, v_j,), pad_value=None, assume_injective_transform=True)
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda v_b,
v_i, v_j: (v_b, v_i, v_j,), pad_value=None, assume_injective_transform=True)
sch.transform_block_layout(block=b1, index_map=lambda v_b, v_i, v_j: (v_b,
v_i, v_j,))
sch.transform_block_layout(block=b2, index_map=lambda v_b, v_i, v_k: (v_b,
v_i, v_k,))
sch.transform_block_layout(block=b3, index_map=lambda v_b, v_j, v_k: (v_b,
v_k, v_j,))
sch.transform_block_layout(block=b0, index_map=lambda v_b, v_i, v_j, v_k:
(v_b, v_i, v_j, v_k,))
l4, l5, l6, l7 = sch.get_loops(block=b0)
l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True,
disable_predication=False)
l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True,
disable_predication=False)
l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True,
disable_predication=False)
l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0)
sch.reorder(l17, l19, l13, l11, l9)
b21 = sch.blockize(target=l13, preserve_unit_iters=True)
sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_sync_16x16x16_f16f16f16")
sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init",
ann_val="wmma_fill_16x16x16_f16")
sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1)
l22, l23, l24, l25 = sch.get_loops(block=b21)
v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5,
max_innermost_factor=4, decision=[11, 35, 3, 1, 1])
l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29,
v30], preserve_unit_iters=True, disable_predication=False)
v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5,
max_innermost_factor=4, decision=[1, 2, 1, 2, 1])
l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39,
v40], preserve_unit_iters=True, disable_predication=False)
v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5,
max_innermost_factor=4, decision=[1, 1, 4, 1, 1])
l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49,
v50], preserve_unit_iters=True, disable_predication=False)
v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3,
max_innermost_factor=4, decision=[1, 2, 2])
l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58],
preserve_unit_iters=True, disable_predication=False)
sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44,
l54, l61, l35, l45, l55)
l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True)
sch.bind(loop=l62, thread_axis="blockIdx.y")
l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True)
sch.bind(loop=l63, thread_axis="blockIdx.x")
l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
sch.bind(loop=l64, thread_axis="threadIdx.y")
sch.annotate(block_or_loop=b21,
ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)
sch.annotate(block_or_loop=b21,
ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
sch.transform_layout(block=b21, buffer=("write", 0), index_map=lambda i0,
i1, i2: (i0, i1 // 16 // (v39 * v40), i2 // 16 // (v49 * v50), i1 // 16 % (v39
* v40), i2 // 16 % (v49 * v50), i1 % 16, i2 % 16,), pad_value=None,
assume_injective_transform=True)
b65 = sch.cache_write(block=b21, write_buffer_index=0,
storage_scope="shared")
sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True,
index=-1)
b66 = sch.cache_write(block=b21, write_buffer_index=0,
storage_scope="wmma.accumulator")
l67, l68, l69, l70, l71, l72, l73, l74, l75 = sch.get_loops(block=b65)
sch.reorder(l72, l70, l71, l73)
sch.compute_at(block=b66, loop=l72, preserve_unit_loops=True, index=-1)
l76, l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 =
sch.get_loops(block=b66)
l87 = sch.fuse(l81, l82, preserve_unit_iters=True)
sch.bind(loop=l87, thread_axis="threadIdx.y")
sch.reverse_compute_inline(block=b1)
l88, l89, l90, l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b66)
b98 = sch.blockize(target=l96, preserve_unit_iters=True)
sch.annotate(block_or_loop=b98, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_store_16x16x16_f16_shared")
l99, l100, l101, l102, l103, l104, l105, l106, l107 =
sch.get_loops(block=b65)
l108 = sch.fuse(l103, l104, l105, l106, l107, preserve_unit_iters=True)
v109 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25,
0.25, 0.25], decision=2)
sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch",
ann_val=v109)
b110 = sch.cache_read(block=b21, read_buffer_index=0,
storage_scope="shared", consumer_blocks=[b21])
sch.compute_at(block=b110, loop=l59, preserve_unit_loops=True, index=-1)
l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
l118 = sch.fuse(l115, l116, l117, preserve_unit_iters=True)
v119 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25,
0.25, 0.25], decision=1)
sch.annotate(block_or_loop=b110, ann_key="meta_schedule.cooperative_fetch",
ann_val=v119)
b120 = sch.cache_read(block=b21, read_buffer_index=1,
storage_scope="shared", consumer_blocks=[b21])
sch.compute_at(block=b120, loop=l59, preserve_unit_loops=True, index=-1)
l121, l122, l123, l124, l125, l126, l127 = sch.get_loops(block=b120)
l128 = sch.fuse(l125, l126, l127, preserve_unit_iters=True)
v129 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25,
0.25, 0.25], decision=2)
sch.annotate(block_or_loop=b120, ann_key="meta_schedule.cooperative_fetch",
ann_val=v129)
b130 = sch.cache_read(block=b21, read_buffer_index=0,
storage_scope="wmma.matrix_a")
sch.compute_at(block=b130, loop=l60, preserve_unit_loops=True, index=-1)
l131, l132, l133, l134, l135, l136, l137, l138 = sch.get_loops(block=b130)
l139, l140 = sch.split(loop=l138, factors=[None, 16],
preserve_unit_iters=True, disable_predication=False)
l141, l142 = sch.split(loop=l137, factors=[None, 16],
preserve_unit_iters=True, disable_predication=False)
l143, l144, l145, l146, l147, l148, l149, l150, l151, l152 =
sch.get_loops(block=b130)
sch.reorder(l151, l142, l140)
b153 = sch.blockize(target=l142, preserve_unit_iters=True)
sch.annotate(block_or_loop=b153, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_load_16x16x16_f16_a_shared")
b154 = sch.cache_read(block=b21, read_buffer_index=1,
storage_scope="wmma.matrix_b")
sch.compute_at(block=b154, loop=l60, preserve_unit_loops=True, index=-1)
l155, l156, l157, l158, l159, l160, l161, l162 = sch.get_loops(block=b154)
l163, l164 = sch.split(loop=l162, factors=[None, 16],
preserve_unit_iters=True, disable_predication=False)
l165, l166 = sch.split(loop=l161, factors=[None, 16],
preserve_unit_iters=True, disable_predication=False)
l167, l168, l169, l170, l171, l172, l173, l174, l175, l176 =
sch.get_loops(block=b154)
sch.reorder(l175, l166, l164)
b177 = sch.blockize(target=l166, preserve_unit_iters=True)
sch.annotate(block_or_loop=b177, ann_key="meta_schedule.auto_tensorize",
ann_val="wmma_load_16x16x16_f16_b_shared")
b178, = sch.get_producers(block=b110)
sch.compute_inline(block=b178)
sch.storage_align(block=b110, buffer_index=0, axis=-2, factor=32, offset=8)
b179, = sch.get_producers(block=b120)
sch.compute_inline(block=b179)
sch.storage_align(block=b120, buffer_index=0, axis=-2, factor=32, offset=8)
`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]