shaoyuyoung opened a new issue, #16891: URL: https://github.com/apache/tvm/issues/16891
TVM seems to have strict restrictions on `MatMul` operator which means that **it cannot use tensors with different shapes**. Look at this simple graph. In Pytorch, the model is **correctly defined** and the input and output shapes are exactly as shown below. The evidence is here: [https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul)  When I try to covert `ONNX` to `TVM`, I get an error indicating that the tensor shape is inconsistent. However, When converting `Pytorch` to `TVM`, everything is OK! I guess one possible reason is that `TorchScript` plays a role in this but `ONNX` does not. Moreover, look at the last line of the error message. I wonder why **T.int64(1)** is used here. It seems that TVM has a pretty fragile system of `int64`.  ### Expected behavior Pass compilation as it can produce results in ONNX and PyTorch. ### Actual behavior **Compilation failure** ``` Traceback (most recent call last): 18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue) 17: tvm::transform::Pass::operator()(tvm::IRModule) const 16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 14: _ZN3tvm7runtime13PackedFun 13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule) 11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&) 10: tvm::transform::Pass::operator()(tvm::IRModule) const 9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 7: tvm::transform::Pass::operator()(tvm::IRModule) const 6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) 2: tvm::relay::TypeSolver::Solve() 1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 0: bool tvm::relay::BatchMatmulRel<tvm::relay::BatchMatmulAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/nn.h", line 212 InternalError: Check failed: (reporter->AssertEQ(xk, yk)) is false: BatchDot: shapes of x and y is inconsistent, x shape=[T.int64(1), 5, 5], y shape=[5, 5, 4] ``` ### Environment Operating System: Ubuntu 18 TVM:0.15 Torch: 2.1.1 ONNX: 1.15.0 ### Steps to reproduce Here is the script: ```python import torch import torch.nn as nn import tvm from tvm import relay import onnx class DirectMatMulModel(nn.Module): def __init__(self): super(DirectMatMulModel, self).__init__() def forward(self, x1, x2, y1, y2): result1 = torch.matmul(x1, x2) result2 = torch.matmul(y1, y2) final_result = torch.matmul(result1, result2) return final_result torch_model = DirectMatMulModel().eval() x1 = torch.randn(5, 1) x2 = torch.randn(1) y1 = torch.randn(5, 4, 5) y2 = torch.randn(5) scripted_model = torch.jit.trace(torch_model, (x1, x2, y1, y2)) torch.onnx.export(torch_model, (x1, x2, y1, y2), "direct_matmul_model.onnx", export_params=True, opset_version=12, do_constant_folding=True, input_names=['x1', 'x2', 'y1', 'y2'], output_names=['output']) onnx_model = onnx.load("direct_matmul_model.onnx") onnx.checker.check_model(onnx_model) def compile_onnx(): mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model, shape={'x1': [5, 1], 'x2': [1], 'y1': [5, 4, 5], 'y2': [5]}) with tvm.transform.PassContext(opt_level=4): executor = relay.build_module.create_executor( 'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx ).evaluate() def compile_torch(): mod_from_torch, params_torch = relay.frontend.from_pytorch(scripted_model, input_infos=[('x1', [5, 1]), ('x2', [1]), ('y1', [5, 4, 5]), ('y2', [5])]) with tvm.transform.PassContext(opt_level=4): executor = relay.build_module.create_executor( 'graph', mod_from_torch, tvm.cpu(), 'llvm', params_torch ).evaluate() try: compile_torch() except Exception as e: print(f"torch fail\n {e}") try: compile_onnx() except Exception as e: print(f"onnx fail\n {e}") ``` ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
