shaoyuyoung opened a new issue, #16891:
URL: https://github.com/apache/tvm/issues/16891

   TVM seems to have strict restrictions on `MatMul` operator which means that 
**it cannot use tensors with different shapes**.
   
   Look at this simple graph. In Pytorch, the model is **correctly defined** 
and the input and output shapes are exactly as shown below.
   The evidence is here: 
[https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul)
   
![image](https://github.com/apache/tvm/assets/100203773/b4c812e8-5fb5-4b33-90cf-380a03e26b6c)
   
   
   When I try to covert `ONNX` to `TVM`, I get an error indicating that the 
tensor shape is inconsistent. However, When converting `Pytorch` to `TVM`, 
everything is OK!
   
   I guess one possible reason is that `TorchScript` plays a role in this but 
`ONNX` does not.
    
   Moreover, look at the last line of the error message. I wonder why 
**T.int64(1)** is used here. It seems that TVM has a pretty fragile system of 
`int64`.
   
   
![image](https://github.com/apache/tvm/assets/100203773/e62e8462-212b-4d0b-ac5f-826be20e4557)
   
   ### Expected behavior
   Pass compilation as it can produce results in ONNX and PyTorch.
   
   ### Actual behavior
   **Compilation failure**
   ```
   Traceback (most recent call last):
     18: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass,
 tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, 
tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>, tvm::runtime::TVMRetValue)
     17: tvm::transform::Pass::operator()(tvm::IRModule) const
     16: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     14: _ZN3tvm7runtime13PackedFun
     13: tvm::runtime::TypedPackedFunc<tvm::relay::Function 
(tvm::relay::Function, tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function,
 tvm::IRModule, 
tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function,
 tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs 
const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs 
const&, tvm::runtime::TVMRetValue*) const
     12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
     11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
     10: tvm::transform::Pass::operator()(tvm::IRModule) const
     9: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     7: tvm::transform::Pass::operator()(tvm::IRModule) const
     6: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     4: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
 tvm::transform::PassContext 
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, 
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
     2: tvm::relay::TypeSolver::Solve()
     1: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool
 (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, 
tvm::TypeReporter const&)>::AssignTypedLambda<bool 
(*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, 
tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> 
const&, int, tvm::Attrs const&, tvm::TypeReporter 
const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     0: bool 
tvm::relay::BatchMatmulRel<tvm::relay::BatchMatmulAttrs>(tvm::runtime::Array<tvm::Type,
 void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
     File 
"/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/nn.h",
 line 212
   InternalError: Check failed: (reporter->AssertEQ(xk, yk)) is false: 
BatchDot: shapes of x and y is inconsistent,  x shape=[T.int64(1), 5, 5], y 
shape=[5, 5, 4]
   
   ```
   
   
   ### Environment
   
   Operating System: Ubuntu 18
   TVM:0.15
   Torch: 2.1.1
   ONNX: 1.15.0
   
   ### Steps to reproduce
   Here is the script:
   ```python
   import torch
   import torch.nn as nn
   import tvm
   from tvm import relay
   import onnx
   
   class DirectMatMulModel(nn.Module):
       def __init__(self):
           super(DirectMatMulModel, self).__init__()
   
       def forward(self, x1, x2, y1, y2):
           result1 = torch.matmul(x1, x2)
           result2 = torch.matmul(y1, y2)
           final_result = torch.matmul(result1, result2)
           return final_result
   
   
   torch_model = DirectMatMulModel().eval()
   
   x1 = torch.randn(5, 1)
   x2 = torch.randn(1)
   y1 = torch.randn(5, 4, 5)
   y2 = torch.randn(5)
   
   scripted_model = torch.jit.trace(torch_model, (x1, x2, y1, y2))
   
   torch.onnx.export(torch_model,
                         (x1, x2, y1, y2),
                         "direct_matmul_model.onnx",
                         export_params=True,
                         opset_version=12,
                         do_constant_folding=True,
                         input_names=['x1', 'x2', 'y1', 'y2'],
                         output_names=['output'])
   
   onnx_model = onnx.load("direct_matmul_model.onnx")
   onnx.checker.check_model(onnx_model)
   
   def compile_onnx():
       mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model, 
shape={'x1': [5, 1], 'x2': [1], 'y1': [5, 4, 5], 'y2': [5]})
       with tvm.transform.PassContext(opt_level=4):
           executor = relay.build_module.create_executor(
               'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
           ).evaluate()
   
   def compile_torch():
       mod_from_torch, params_torch = 
relay.frontend.from_pytorch(scripted_model, input_infos=[('x1', [5, 1]), ('x2', 
[1]), ('y1', [5, 4, 5]), ('y2', [5])])
       with tvm.transform.PassContext(opt_level=4):
           executor = relay.build_module.create_executor(
               'graph', mod_from_torch, tvm.cpu(), 'llvm', params_torch
           ).evaluate()
   
   try:
       compile_torch()
   except Exception as e:
       print(f"torch fail\n {e}")
   
   try:
       compile_onnx()
   except Exception as e:
       print(f"onnx fail\n {e}")
   ```
   ### Triage
   
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to