j-paulus opened a new issue #6441:
URL: https://github.com/apache/incubator-tvm/issues/6441


   If strided slice is used in a model, the stride argument is ignored and the 
result is wrong. 
   
   I encountered the problem when trying to compile an ONNX model created by 
pytorch conversion. Similar problem was present in the pytorch frontend 
(#6414), and was fixed by #6418.
   
   Possibly related issue #6316.
   
   Code to reproduce the problem:
   
   ```
   import torch
   import tvm
   from tvm import relay
   import onnx
   
   class TriggerBug(torch.nn.Module):
       def __init__(self):
           super(TriggerBug, self).__init__()
   
       def forward(self, x):
           return x[..., 0::2] + x[..., 1::2]
   
   x_in = torch.randn(1, 4)
   torch_model = TriggerBug()
   onnx_name = 'strided_slice.onnx'
   example_output = torch_model(x_in)
   # convert to ONNX
   torch.onnx.export(torch_model, (x_in,), onnx_name,
                     verbose=True,
                     example_outputs=example_output,
                     input_names=['x'],
                     output_names=['y'],
                     opset_version=10,
                     enable_onnx_checker=True)
   
   onnx_model = onnx.load(onnx_name)
   target = 'llvm'
   shape_dict = {'x': x_in.shape}
   mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
   
   with tvm.transform.PassContext(opt_level=1):
       intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), 
target)
   
   ```
   
   The traceback:
   
   >   mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
   >   File 
"/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/relay/frontend/onnx.py",
 line 2456, in from_onnx
   >     mod, params = g.from_onnx(graph, opset)
   >   File 
"/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/relay/frontend/onnx.py",
 line 2302, in from_onnx
   >     return IRModule.from_expr(func), self._params
   >   File 
"/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/ir/module.py",
 line 236, in from_expr
   >     return _ffi_api.Module_FromExpr(expr, funcs, defs)
   >   File 
"/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py",
 line 225, in __call__
   >     raise get_last_ffi_error()
   > tvm._ffi.base.TVMError: Traceback (most recent call last):
   >   [bt] (8) 9   libtvm.dylib                        0x0000000122684df8 
TVMFuncCall + 72
   >   [bt] (7) 8   libtvm.dylib                        0x0000000121b8e452 
std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::IRModule 
(tvm::RelayExpr, tvm::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>, 
tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, void, 
void>)>::AssignTypedLambda<tvm::$_9>(tvm::$_9)::'lambda'(tvm::runtime::TVMArgs 
const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void 
tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::RelayExpr, 
tvm::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>, 
tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, void, 
void>)>::AssignTypedLambda<tvm::$_9>(tvm::$_9)::'lambda'(tvm::runtime::TVMArgs 
const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, 
tvm::runtime::TVMRetValue*&&) + 610
   >   [bt] (6) 7   libtvm.dylib                        0x0000000121b7f810 
tvm::IRModule::FromExpr(tvm::RelayExpr const&, tvm::Map<tvm::GlobalVar, 
tvm::BaseFunc, void, void> const&, tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, 
void, void> const&) + 1040
   >   [bt] (5) 6   libtvm.dylib                        0x0000000121b7ca47 
tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool) + 183
   >   [bt] (4) 5   libtvm.dylib                        0x0000000121b7c4ef 
tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, 
tvm::relay::Function) + 1103
   >   [bt] (3) 4   libtvm.dylib                        0x00000001224dca20 
tvm::relay::InferType(tvm::relay::Function const&, tvm::IRModule const&, 
tvm::GlobalVar const&) + 544
   >   [bt] (2) 3   libtvm.dylib                        0x00000001224dbbc7 
tvm::relay::TypeInferencer::Infer(tvm::RelayExpr) + 119
   >   [bt] (1) 2   libtvm.dylib                        0x0000000121b6d87c 
tvm::ErrorReporter::RenderErrors(tvm::IRModule const&, bool) + 5308
   >   [bt] (0) 1   libtvm.dylib                        0x00000001219917bf 
dmlc::LogMessageFatal::~LogMessageFatal() + 111
   >   File "/Users/name/code/python/tvm/src/ir/error.cc", line 132
   > TVMError: 
   > Error(s) have occurred. The program has been annotated with them:
   > 
   > In `main`: 
   > #[version = "0.0.5"]
   > fn (%x: Tensor[(1, 4), float32]) {
   >   %0 = strided_slice(%x, begin=[0, 0], end=[2147483647, 
9223372036854775807], strides=[1]);
   >   %1 = strided_slice(%x, begin=[0, 1], end=[2147483647, 
9223372036854775807], strides=[1]);
   >   add(%0, %1) Incompatible broadcast type TensorType([1, 4], float32) and 
TensorType([1, 3], float32); 
   > }
   
   The intermediate ONNX graph is:
   > graph(%x : Float(1:4, 4:1, requires_grad=0, device=cpu)):
   >   %1 : Tensor = onnx::Constant[value={1}]()
   >   %2 : Tensor = onnx::Constant[value={0}]()
   >   %3 : Tensor = onnx::Constant[value={9223372036854775807}]()
   >   %4 : Tensor = onnx::Constant[value={2}]()
   >   %5 : Float(1:4, 2:2, requires_grad=0, device=cpu) = onnx::Slice(%x, %2, 
%3, %1, %4)
   >   %6 : Tensor = onnx::Constant[value={1}]()
   >   %7 : Tensor = onnx::Constant[value={1}]()
   >   %8 : Tensor = onnx::Constant[value={9223372036854775807}]()
   >   %9 : Tensor = onnx::Constant[value={2}]()
   >   %10 : Float(1:4, 2:2, requires_grad=0, device=cpu) = onnx::Slice(%x, %7, 
%8, %6, %9)
   >   %y : Float(1:2, 2:1, requires_grad=0, device=cpu) = onnx::Add(%5, %10)
   >   return (%y)
   
   Here the stride length is correctly present.
   
   Versions:
   - pytorch: 1.7.0.dev20200908
   - TVM: 0.7.dev1 git revision 84fa626
   - onnx: 1.7.0
   
   If you are asking why am I going this route via ONNX and not use directly 
pytorch frontend: The compilation of my real model from pytorch does not 
currently work, but I have verified that the converted ONNX version works. I 
was hoping that the ONNX frontend could then compile the full model.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to