fPecc opened a new issue, #11558:
URL: https://github.com/apache/tvm/issues/11558

   See[ this 
post](https://discuss.tvm.apache.org/t/check-failed-it-rmap-end-when-using-relay-build-in-model-with-tflite-detection-postprocess-with-minimal-example/12883)
 for the explanation of the problem.
   
   ### Expected behavior
   
   Build succeds.
   
   ### Actual behavior
   
   Error message:
   
   ```
   Traceback (most recent call last):
     File "test.py", line 32, in <module>
       mod = relay.build(mod, executor=EXECUTOR, target=TARGET,runtime=RUNTIME)
     File "/local_disk/local_sw/tvm_main/tvm/python/tvm/relay/build_module.py", 
line 416, in build
       graph_json, runtime_mod, params = bld_mod.build(
     File "/local_disk/local_sw/tvm_main/tvm/python/tvm/relay/build_module.py", 
line 154, in build
       self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, 
mod_name)
     File 
"/local_disk/local_sw/tvm_main/tvm/python/tvm/_ffi/_ctypes/packed_func.py", 
line 237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     40: TVMFuncCall
     39: 
tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> 
const&)::{lambda(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*) const
     38: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, 
tvm::runtime::String const&)
     37: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> 
const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     36: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, 
tvm::relay::Function, tvm::runtime::String)
     35: tvm::transform::Pass::operator()(tvm::IRModule) const
     34: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     33: tvm::transform::SequentialNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     32: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     31: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     30: 
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec11LowerTEPassENS0_6StringESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
     29: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String 
const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
     28: tvm::transform::Pass::operator()(tvm::IRModule) const
     27: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     26: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     25: 
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
     24: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     23: 
_ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     22: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode
 const*)
     21: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
 const*)
     20: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
     19: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
     18: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     17: 
_ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     16: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode 
const*)
     15: 
tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var 
const&, tvm::RelayExpr const&)
     14: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     13: 
_ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     12: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode 
const*)
     11: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     10: 
_ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
     9: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode 
const*)
     8: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode
 const*)
     7: 
tvm::relay::tec::LowerTensorExprMutator::MakeLoweredCall(tvm::relay::Function, 
tvm::runtime::Array<tvm::RelayExpr, void>, tvm::Span, tvm::Target)
     6: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey 
const&, tvm::runtime::String)
     5: 
tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey 
const&, std::function<tvm::runtime::String (tvm::runtime::String)>)
     4: tvm::LowerSchedule(tvm::te::Schedule, 
tvm::runtime::Array<tvm::te::Tensor, void> const&, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > 
const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, 
std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, 
std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, 
bool)
     3: tvm::LowerSchedule(tvm::te::Schedule, 
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > 
const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, 
std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, 
std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, 
bool)
     2: tvm::ScheduleToModule(tvm::te::Schedule, 
tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > 
const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, 
std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, 
std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&)
     1: tvm::te::InferBound(tvm::te::Schedule const&)
     0: tvm::te::InferRootBound(tvm::te::Stage const&, tvm::te::GraphContext 
const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, 
std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, 
std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*)
     File "/local_disk/local_sw/tvm_main/tvm/src/te/schedule/bound.cc", line 144
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (it != rmap->end()) is false:
   ```
   
   ### Environment
   
   Commit 017d410bd18fd3e272ea49ea9e11955c3128bb72 (2 of Juni).
   
   ### Steps to reproduce
   
   This script should replicate the issue:
   
   ```
   from tvm.relay.op.vision import non_max_suppression
   from tvm import te
   import numpy as np
   from tvm import topi
   import tvm
   from tvm import relay
   
   x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
   x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
   x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32"))
   x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
   z = relay.vision.non_max_suppression(
       x0,
       x1,
       x2,
       x3,
       iou_threshold=0.5,
       force_suppress=True,
       top_k=2,
       return_indices=True,
       invalid_to_bottom=False,
   )
   z = z.astuple()
   func = relay.Function([x0, x1, x2, x3], z)
   mod = tvm.IRModule()
   mod["main"] = func
   
   print(mod["main"])
   RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True})
   TARGET = tvm.target.target.Target({"kind": "c"})
   EXECUTOR = tvm.relay.backend.Executor("aot",options={'interface-api': 
'packed','unpacked-api': 0, 'link-params': True})
   mod = relay.build(mod, executor=EXECUTOR, target=TARGET,runtime=RUNTIME)
   
   print(mod)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to