tsu-bin opened a new pull request, #17012:
URL: https://github.com/apache/tvm/pull/17012

   Below script can be used to reproduce the issue. You may run it several 
times to reproduce, because sample_perfect_tile may sometime to hide the issue 
with some decision.
   `
   in_type="float16"
   out_type="float16"
   BS = 100
   MM = 32
   NN = 32
   KK = 32
   
   def create_batch_matmul(
       b: int = BS, m: int = MM, n: int = NN, k: int = KK, in_dtype: str = 
in_type, out_dtype: str = out_type
   ) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
       A = te.placeholder((b, m, k), name="A", dtype=in_dtype)
       B = te.placeholder((b, n, k), name="B", dtype=in_dtype)
       C = topi.nn.batch_matmul(A, B)
       return (A, B, C)
   
   space = ms.space_generator.PostOrderApply(
           sch_rules="cuda-tensorcore",
           postprocs="cuda-tensorcore",
       )
   database = ms.tune_tir(
       mod=te.create_prim_func( create_batch_matmul () ),
       target=tvm.target.Target("cuda -arch=sm_89 
-max_shared_memory_per_block=49152 -max_threads_per_block=1024"),
       max_trials_global = 3,
       space=space,
       work_dir="./xb_demo/debug_batch_matmul/meta_db/test_batch_matmul0/",
   )
   `
   
   The error log is something like below.
   `
     3: tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, 
tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt 
const&)#1}::operator()(tvm::tir::Stmt const&) const
           at /hostShare/tools/tvm_all/tvm-dev/src/tir/ir/stmt_functor.cc:210
     2: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
           at 
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:60
     1: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
           at 
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:63
     0: tvm::tir::Stmt 
tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode
 const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
           at 
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112
     File "/hostShare/tools/tvm_all/tvm-dev/src/support/parallel_for.cc", line 
139
   RuntimeError: parallel_for_dynamic error with [22:30:41] 
/hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112:
 Check failed: (ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) is 
false: ValueError: All loops that are bound to `threadIdx.y` should have the 
same extent. However, there are two loops with extent 12 and 4, which are not 
equal
   `
   I also pasted the trace, you can replay it line by line and print out the 
loop extent to verify the inconsistency.
   `
   b0 = sch.get_block(name="T_batch_matmul_NT", func_name="main")
   sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", 
ann_val="SSSRRSRS")
   b1 = sch.reindex(block=b0, buffer=("write", 0))
   b2 = sch.reindex(block=b0, buffer=("read", 0))
   b3 = sch.reindex(block=b0, buffer=("read", 1))
   sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda v_b, 
v_i, v_k: (v_b, v_i, v_k,), pad_value=None, assume_injective_transform=True)
   sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda v_b, 
v_j, v_k: (v_b, v_k, v_j,), pad_value=None, assume_injective_transform=True)
   sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda v_b, 
v_i, v_j: (v_b, v_i, v_j,), pad_value=None, assume_injective_transform=True)
   sch.transform_block_layout(block=b1, index_map=lambda v_b, v_i, v_j: (v_b, 
v_i, v_j,))
   sch.transform_block_layout(block=b2, index_map=lambda v_b, v_i, v_k: (v_b, 
v_i, v_k,))
   sch.transform_block_layout(block=b3, index_map=lambda v_b, v_j, v_k: (v_b, 
v_k, v_j,))
   sch.transform_block_layout(block=b0, index_map=lambda v_b, v_i, v_j, v_k: 
(v_b, v_i, v_j, v_k,))
   l4, l5, l6, l7 = sch.get_loops(block=b0)
   l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True, 
disable_predication=False)
   l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True, 
disable_predication=False)
   l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True, 
disable_predication=False)
   l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0)
   sch.reorder(l17, l19, l13, l11, l9)
   b21 = sch.blockize(target=l13, preserve_unit_iters=True)
   sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize", 
ann_val="wmma_sync_16x16x16_f16f16f16")
   sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init", 
ann_val="wmma_fill_16x16x16_f16")
   sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1)
   l22, l23, l24, l25 = sch.get_loops(block=b21)
   v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5, 
max_innermost_factor=4, decision=[11, 35, 3, 1, 1])
   l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29, 
v30], preserve_unit_iters=True, disable_predication=False)
   v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5, 
max_innermost_factor=4, decision=[1, 2, 1, 2, 1])
   l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39, 
v40], preserve_unit_iters=True, disable_predication=False)
   v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5, 
max_innermost_factor=4, decision=[1, 1, 4, 1, 1])
   l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49, 
v50], preserve_unit_iters=True, disable_predication=False)
   v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3, 
max_innermost_factor=4, decision=[1, 2, 2])
   l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58], 
preserve_unit_iters=True, disable_predication=False)
   sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44, 
l54, l61, l35, l45, l55)
   l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True)
   sch.bind(loop=l62, thread_axis="blockIdx.y")
   l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True)
   sch.bind(loop=l63, thread_axis="blockIdx.x")
   l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
   sch.bind(loop=l64, thread_axis="threadIdx.y")
   sch.annotate(block_or_loop=b21, 
ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)
   sch.annotate(block_or_loop=b21, 
ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
   sch.transform_layout(block=b21, buffer=("write", 0), index_map=lambda i0, 
i1, i2: (i0, i1 // 16 // (v39 * v40), i2 // 16 // (v49 * v50), i1 // 16 % (v39 
* v40), i2 // 16 % (v49 * v50), i1 % 16, i2 % 16,), pad_value=None, 
assume_injective_transform=True)
   b65 = sch.cache_write(block=b21, write_buffer_index=0, 
storage_scope="shared")
   sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, 
index=-1)
   b66 = sch.cache_write(block=b21, write_buffer_index=0, 
storage_scope="wmma.accumulator")
   l67, l68, l69, l70, l71, l72, l73, l74, l75 = sch.get_loops(block=b65)
   sch.reorder(l72, l70, l71, l73)
   sch.compute_at(block=b66, loop=l72, preserve_unit_loops=True, index=-1)
   l76, l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = 
sch.get_loops(block=b66)
   l87 = sch.fuse(l81, l82, preserve_unit_iters=True)
   sch.bind(loop=l87, thread_axis="threadIdx.y")
   sch.reverse_compute_inline(block=b1)
   l88, l89, l90, l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b66)
   b98 = sch.blockize(target=l96, preserve_unit_iters=True)
   sch.annotate(block_or_loop=b98, ann_key="meta_schedule.auto_tensorize", 
ann_val="wmma_store_16x16x16_f16_shared")
   l99, l100, l101, l102, l103, l104, l105, l106, l107 = 
sch.get_loops(block=b65)
   l108 = sch.fuse(l103, l104, l105, l106, l107, preserve_unit_iters=True)
   v109 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 
0.25, 0.25], decision=2)
   sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", 
ann_val=v109)
   b110 = sch.cache_read(block=b21, read_buffer_index=0, 
storage_scope="shared", consumer_blocks=[b21])
   sch.compute_at(block=b110, loop=l59, preserve_unit_loops=True, index=-1)
   l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
   l118 = sch.fuse(l115, l116, l117, preserve_unit_iters=True)
   v119 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 
0.25, 0.25], decision=1)
   sch.annotate(block_or_loop=b110, ann_key="meta_schedule.cooperative_fetch", 
ann_val=v119)
   b120 = sch.cache_read(block=b21, read_buffer_index=1, 
storage_scope="shared", consumer_blocks=[b21])
   sch.compute_at(block=b120, loop=l59, preserve_unit_loops=True, index=-1)
   l121, l122, l123, l124, l125, l126, l127 = sch.get_loops(block=b120)
   l128 = sch.fuse(l125, l126, l127, preserve_unit_iters=True)
   v129 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 
0.25, 0.25], decision=2)
   sch.annotate(block_or_loop=b120, ann_key="meta_schedule.cooperative_fetch", 
ann_val=v129)
   b130 = sch.cache_read(block=b21, read_buffer_index=0, 
storage_scope="wmma.matrix_a")
   sch.compute_at(block=b130, loop=l60, preserve_unit_loops=True, index=-1)
   l131, l132, l133, l134, l135, l136, l137, l138 = sch.get_loops(block=b130)
   l139, l140 = sch.split(loop=l138, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
   l141, l142 = sch.split(loop=l137, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
   l143, l144, l145, l146, l147, l148, l149, l150, l151, l152 = 
sch.get_loops(block=b130)
   sch.reorder(l151, l142, l140)
   b153 = sch.blockize(target=l142, preserve_unit_iters=True)
   sch.annotate(block_or_loop=b153, ann_key="meta_schedule.auto_tensorize", 
ann_val="wmma_load_16x16x16_f16_a_shared")
   b154 = sch.cache_read(block=b21, read_buffer_index=1, 
storage_scope="wmma.matrix_b")
   sch.compute_at(block=b154, loop=l60, preserve_unit_loops=True, index=-1)
   l155, l156, l157, l158, l159, l160, l161, l162 = sch.get_loops(block=b154)
   l163, l164 = sch.split(loop=l162, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
   l165, l166 = sch.split(loop=l161, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
   l167, l168, l169, l170, l171, l172, l173, l174, l175, l176 = 
sch.get_loops(block=b154)
   sch.reorder(l175, l166, l164)
   b177 = sch.blockize(target=l166, preserve_unit_iters=True)
   sch.annotate(block_or_loop=b177, ann_key="meta_schedule.auto_tensorize", 
ann_val="wmma_load_16x16x16_f16_b_shared")
   b178, = sch.get_producers(block=b110)
   sch.compute_inline(block=b178)
   sch.storage_align(block=b110, buffer_index=0, axis=-2, factor=32, offset=8)
   b179, = sch.get_producers(block=b120)
   sch.compute_inline(block=b179)
   sch.storage_align(block=b120, buffer_index=0, axis=-2, factor=32, offset=8)
   
   `
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to