ganler opened a new issue, #13029:
URL: https://github.com/apache/tvm/issues/13029

   ### Expected behavior
   
   TVM should successfully compile a model whose operators are supoprted.
   
   ### Actual behavior
   
   The compilation could fail when the model contains the 
[recently](https://github.com/apache/tvm/pull/12124) supported `trilu` operator.
   
   In the  `Steps to reproduce` section, the minimal reproducible is derived 
from a ONNX model exported by PyTorch which uses a mix a `int64` as shape 
arguments, mixing with `int32` constants in TVM's frontend translator, causing 
the compilation to fail due to int32-int64 mismatch in `check_op`:
   
   
https://github.com/apache/tvm/blob/bdcfa01eae3ffe8c6d39aa26d0d1e5b311d47efb/python/tvm/topi/transform.py#L1057
   
   A quick fix could just be aligning integer types of `row_index` and 
`col_index - k` before doing `check_op`.
   
   ### Environment
   
   `fa17da22c73fb9e95c27e4c28130835b628caf6b` on Ubuntu 20.04.
   
   ### Steps to reproduce
   
   Minimized reproducible.
   
   ```python
   import tvm
   from tvm import relay
   
   x1 = relay.var("x1", shape=[2, 1], dtype="float32")
   x2 = relay.var("x2", shape=(1, 1, 1, 1), dtype="float32")
   x3 = relay.var("x3", shape=(), dtype="int64")
   v0 = relay.broadcast_to(x1, shape=relay.const([2, 1], dtype="int64"))
   v2 = relay.divide(x2, v0)
   v3 = relay.trilu(v0, x3)
   
   f = relay.Function([x1, x2, x3], relay.Tuple([v2, v3]))
   relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
   ```
   
   <details><summary>Log. Click to expand!</summary> 
   
   ```python
   """
   Traceback (most recent call last):
     File "test.py", line 12, in <module>
       relay.create_executor("graph", device=tvm.cpu(), 
target="llvm").evaluate(f)
    ...
     25: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode
 const*)
     24: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
 const*)
     23: _ZN3tvm5relay9
     22: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
     21: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     20: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr 
const&)>::VisitExpr(tvm::RelayExpr const&)
     19: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr 
const&)>*)>::operator()(tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
     18: _ZZN3tvm5relay11ExprFunc
     17: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::TupleNode const*)
     16: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     15: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr 
const&)>::VisitExpr(tvm::RelayExpr const&)
     14: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr 
const&)>*)>::operator()(tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
     13: _ZZN3tvm5relay11ExprFunc
     12: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode 
const*)
     11: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode
 const*)
     10: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey 
const&)
     9: 
tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey 
const&, tvm::GlobalVarSupply)
     8: tvm::relay::tec::PrimFuncFor(tvm::relay::Function const&, tvm::Target 
const&, tvm::GlobalVarSupply)
     7: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, 
tvm::GlobalVarSupply)
     6: tvm::relay::tec::LowerToTECompute::Lower(tvm::relay::Function const&)
     5: 
tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor,
 void> >::VisitExpr(tvm::RelayExpr const&)
     4: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> 
(tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     3: tvm::NodeFunctor<tvm::runtime::Array<tvm::te::Tensor, void> 
(tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> 
(tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, 
tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> 
(tvm::RelayExpr const&)>*) const
     2: _ZZN3tvm5relay11ExprFunc
     1: tvm::relay::tec::LowerToTECompute::VisitExpr_(tvm::relay::CallNode 
const*)
     0: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2>
 >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", 
line 81, in cfun
       rv = local_pyfunc(*pyargs)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py",
 line 317, in lower_call
       best_impl, outputs = select_implementation(op, call.attrs, inputs, 
ret_type, target)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py",
 line 207, in select_implementation
       outs = impl.compute(attrs, inputs, out_type)
     File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/op.py", 
line 126, in compute
       return _OpImplementationCompute(self, attrs, inputs, out_type)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", 
line 237, in __call__
       raise get_last_ffi_error()
     3: TVMFuncCall
     2: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::$_3>
 >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, 
tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
     0: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2>
 >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", 
line 81, in cfun
       rv = local_pyfunc(*pyargs)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/strategy/generic.py",
 line 1489, in _compute_trilu
       topi_compute(
     File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", 
line 1061, in trilu
       return te.compute(data.shape, _apply_trilu, name="trilu")
     File "/home/jiawei/dev/tvm-official-release/python/tvm/te/operation.py", 
line 132, in compute
       body = fcompute(*[v.var for v in dim_var])
     File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", 
line 1057, in _apply_trilu
       check_position = check_op(row_index, col_index - k)
     File "/home/jiawei/dev/tvm-official-release/python/tvm/tir/expr.py", line 
881, in __init__
       self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span)  # type: 
ignore
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/object.py", line 
145, in __init_handle_by_constructor__
       handle = __init_by_constructor__(fconstructor, args)
     File 
"/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", 
line 260, in __init_handle_by_constructor__
       raise get_last_ffi_error()
     2: TVMFuncCall
     1: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::tir::LE
 (tvm::PrimExpr, tvm::PrimExpr, 
tvm::Span)>::AssignTypedLambda<tvm::tir::$_51>(tvm::tir::$_51, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     0: tvm::tir::LE::LE(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
     File "/home/jiawei/dev/tvm-official-release/src/tir/ir/expr.cc", line 459
   TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched 
types. int32 vs. int64
   """
   ```
   
   </details>
   
   ### Triage
   
   Please refer to the list of label tags linked above to find the relevant 
tags and add them here in a bullet format (example below).
   
   * needs-triage
   
   cc: @jwfromm 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to