tkonolige opened a new issue, #12907:
URL: https://github.com/apache/tvm/issues/12907

   Running this script:
   
   ```
   import tvm
   import tvm.meta_schedule
   from tvm.script import tir as T
   
   @T.prim_func
   def matmul(
       A: T.Buffer[(512, 512), "float32"],
       B: T.Buffer[(512, 512), "float32"],
       C: T.Buffer[(512, 512), "float32"],
       )->None:
       for i in range(512):
           for j in range(512):
               for k in range(512):
                   with T.block("update"):
                       C[i, j] = C[i, j] + A[i, k] * B[j, k]
   
   s=tvm.meta_schedule.tune_tir(matmul, "llvm --num-cores 1", 
tvm.meta_schedule.TuneConfig(1, 1), "tmp")
   ```
   
   repeatedly outputs the following warning:
   
   ```
   tvm/src/meta_schedule/search_strategy/../utils.h:321: Warning: 
ThreadedTraceApply::Apply failed with error ScheduleError: (not rendered)
   ```
   
   Changing `utils.h` to print the full message, I get the following message:
   
   ```
   [09:27:38] 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/../utils.h:321: 
Warning: ThreadedTraceApply::Apply failed with error ScheduleError: An error 
occurred in the schedule primitive 'parallel'.
   
   The IR with diagnostic is:
   # from tvm.script import tir as T
   @tvm.script.ir_module
   class Module:
       @T.prim_func
       def main(A: T.Buffer[(512, 512), "float32"], B: T.Buffer[(512, 512), 
"float32"], C: T.Buffer[(512, 512), "float32"]) -> None:
           # function attr dict
           T.func_attr({"tir.noalias": True, "global_symbol": "main"})
           # body
           # with T.block("root")
           # tir.For#0
           for i_fused in T.serial(512):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
               for j, k in T.grid(512, 512):
                   # tir.Block#1
                   with T.block("update"):
                   ^^^^^^^^^^^^^^^^^^^^^^^
                       T.reads(C[i_fused, j], A[i_fused, k], B[j, k])
                       T.writes(C[i_fused, j])
                       C[i_fused, j] = C[i_fused, j] + A[i_fused, k] * B[j, k]
   
   Error message: The queried subtree root tir.For#0 in SRef tree does not have 
compact dataflow, because its child block tir.Block#1 on SRef tree is neither a 
local complete block nor a local reduction block.
   It violates condition #3 as a local complete block.
   Definition of a local complete block:
   1) All block vars are data parallel
   2) Local Dominant: the block is the only writer of its output, dominating 
the reader of its output buffers under a given subtree
   3) No overlap between the buffers the block reads and writes
   It violates condition #1 as a local reduction block.
   Definition of a reduction block:
   1) The block has the `init` statement
   2) All the block bindings are quasi-affine expressions
   3) All block vars are either data parallel block vars or reduction block vars
   4) Local Dominant: the block is the only writer of its output, dominating 
the reader of its output buffers under a given subtree
   5) The reduction block vars are not used to index the output buffers
   
   Stack trace:
     0: tvm::tir::ConcreteScheduleNode::Parallel(tvm::tir::LoopRV const&)
           at /home/tristan/octoml/tvm/src/tir/schedule/concrete_schedule.cc:505
     1: tvm::tir::TracedScheduleNode::Parallel(tvm::tir::LoopRV const&)
           at /home/tristan/octoml/tvm/src/tir/schedule/traced_schedule.cc:244
     2: tvm::tir::RewriteParallel(tvm::tir::Schedule const&, unsigned long, 
tvm::runtime::Array<tvm::tir::LoopRV, void>*)
           at 
/home/tristan/octoml/tvm/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc:323
     3: 
tvm::meta_schedule::RewriteParallelVectorizeUnrollNode::Apply(tvm::tir::Schedule
 const&)
           at 
/home/tristan/octoml/tvm/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc:369
     4: tvm::meta_schedule::ThreadedTraceApply::Apply(tvm::IRModule const&, 
tvm::tir::Trace const&, long*)
           at 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/../utils.h:315
     5: 
tvm::meta_schedule::EvolutionarySearchNode::State::SampleInitPopulation(int)::{lambda(int,
 int)#1}::operator()(int, int) const
           at 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/evolutionary_search.cc:498
     6: std::function<void (int, int)>::operator()(int, int) const
           at 
/usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:590
     7: tvm::support::parallel_for_dynamic(int, int, int, std::function<void 
(int, int)> const&)::$_1::operator()(int) const
           at /home/tristan/octoml/tvm/src/support/parallel_for.cc:113
     8: tvm::support::parallel_for_dynamic(int, int, int, std::function<void 
(int, int)> const&)
           at /home/tristan/octoml/tvm/src/support/parallel_for.cc:123
     9: 
tvm::meta_schedule::EvolutionarySearchNode::State::SampleInitPopulation(int)
           at 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/evolutionary_search.cc:502
     10: 
tvm::meta_schedule::EvolutionarySearchNode::State::GenerateMeasureCandidates()
           at 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/evolutionary_search.cc:693
     11: tvm::meta_schedule::EvolutionarySearchNode::GenerateMeasureCandidates()
           at 
/home/tristan/octoml/tvm/src/meta_schedule/search_strategy/evolutionary_search.cc:426
     12: tvm::meta_schedule::TaskSchedulerNode::Tune()
           at 
/home/tristan/octoml/tvm/src/meta_schedule/task_scheduler/task_scheduler.cc:66
     13: 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}::operator()(tvm::meta_schedule::TaskScheduler)
 const
           at /home/tristan/octoml/tvm/include/tvm/runtime/registry.h:245
     14: void tvm::runtime::detail::unpack_call_dispatcher<void, 0, 1, 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > 
(*)(), 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}
 const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, 
tvm::runtime::TVMMovableArgValueWithContext_&&)
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1659
     15: void tvm::runtime::detail::unpack_call_dispatcher<void, 1, 0, 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<>(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > 
(*)(), 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}
 const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1631
     16: void tvm::runtime::detail::unpack_call<void, 1, 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const*, 
tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, 
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}
 const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1671
     17: tvm::runtime::TypedPackedFunc<void 
(tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1},
 std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*) const
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1731
     18: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void
 
(tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
 tvm::meta_schedule::TaskSchedulerNode, void, , void>(void 
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1},
 std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1213
     19: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*) const
           at /home/tristan/octoml/tvm/include/tvm/runtime/packed_func.h:1217
     20: TVMFuncCall
           at /home/tristan/octoml/tvm/src/runtime/c_runtime_api.cc:477
   ```
   
   I've bisected this problem to 779dc51e1332f417fa4c304b595ce76891dfc33a, 
which is the commit that introduced `tune_tir`. So it looks like this error has 
been around since the create of metaschedule.
   
   @junrushao @zxybazh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to