mkroening opened a new issue #9885:
URL: https://github.com/apache/tvm/issues/9885
### Expected behavior
No issue when running
```console
tvmc compile --target llvm --output out.tar model.onnx
```
with the ONNX model containing a `ConvTranspose` with `group != 1` (ONNX
Runtime is working fine).
Model:
```onnx
ir_version: 8
graph {
node {
input: "A"
input: "B"
output: "C"
op_type: "ConvTranspose"
attribute {
name: "group"
i: 2
type: INT
}
}
name: "test-model"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "C"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 15
}
```
### Actual behavior
```
# [..]
The Relay type checker is unable to show the following types match.
In particular dimension 1 conflicts: 0 does not match 1.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(2, 1, 1, 1), float32]` does not match `Tensor[(2, 0,
1, 1), float32]`
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/mkroening/Development/tvm/python/tvm/driver/tvmc/__main__.py",
line 24, in <module>
tvmc.main.main()
File "/home/mkroening/Development/tvm/python/tvm/driver/tvmc/main.py",
line 94, in main
sys.exit(_main(sys.argv[1:]))
File "/home/mkroening/Development/tvm/python/tvm/driver/tvmc/main.py",
line 87, in _main
return args.func(args)
File "/home/mkroening/Development/tvm/python/tvm/driver/tvmc/compiler.py",
line 141, in drive_compile
compile_model(
File "/home/mkroening/Development/tvm/python/tvm/driver/tvmc/compiler.py",
line 271, in compile_model
graph_module = relay.build(mod, target=tvm_target, params=params)
File "/home/mkroening/Development/tvm/python/tvm/relay/build_module.py",
line 369, in build
executor_config, runtime_mod, params = bld_mod.build(
File "/home/mkroening/Development/tvm/python/tvm/relay/build_module.py",
line 177, in build
self._build(mod, target, target_host, executor, mod_name)
File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in
tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 267, in
tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm.error.DiagnosticError: Traceback (most recent call last):
12: TVMFuncCall
11: std::_Function_handler<void (tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*),
tvm::relay::backend::RelayBuildModule::GetFunction(std::string const&,
tvm::runtime::ObjectPtr<tvm::runtime::Object>
const&)::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
10: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule,
std::unordered_map<std::string, tvm::runtime::NDArray, std::hash<std::string>,
std::equal_to<std::string>, std::allocator<std::pair<std::string const,
tvm::runtime::NDArray> > > const&, tvm::runtime::String)
9: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule,
std::unordered_map<std::string, tvm::runtime::NDArray, std::hash<std::string>,
std::equal_to<std::string>, std::allocator<std::pair<std::string const,
tvm::runtime::NDArray> > > const&)
8: tvm::transform::Pass::operator()(tvm::IRModule) const
7: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
6: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
5: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
4: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
3: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
2: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
1: std::_Function_handler<void (tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
0: tvm::DiagnosticContext::Render()
File "/home/mkroening/Development/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check
diagnostic render for output.
```
### Environment
Operating System: Ubuntu 20.04.3 LTS
TVM version: 4babd36481b7108bf50df5c3b256c95c0d9c3291
### Steps to reproduce
Create `model.onnx` with:
```python
import onnx
from onnx import helper
from onnx import TensorProto
n = 1
c_in = 2
h_in = 1
w_in = 1
c_out = 2
h_out = 1
w_out = 1
groups = 2
kernel_size = [1, 1]
A = helper.make_tensor_value_info(
'A', TensorProto.FLOAT, [n, c_in, h_in, w_in])
B = helper.make_tensor_value_info(
'B', TensorProto.FLOAT, [c_out, int(c_in / groups), kernel_size[0],
kernel_size[1]])
C = helper.make_tensor_value_info(
'C', TensorProto.FLOAT, [n, c_out, h_out, w_out])
node_def = helper.make_node(
'ConvTranspose', # name
['A', 'B'], # inputs
['C'], # outputs
group=groups,
)
graph_def = helper.make_graph(
[node_def], # nodes
'test-model', # name
[A, B], # inputs
[C], # outputs
)
model_def = helper.make_model(graph_def)
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
onnx.save(model_def, 'model.onnx')
```
Thanks a lot for your help! :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]