j-paulus opened a new issue #6414:
URL: https://github.com/apache/incubator-tvm/issues/6414


   When a torch model contains strided slicing of an N-D array, the stride 
argument is ignored and non-strided version is returned. Because the striding 
is ignored, the size and the contents of the result are wrong.
   
   Pseudocode example:
   ```
   x = Tensor[1,4]
   y = x[:, 0::2]  # => shape:(1, 4), should be (1, 2)
   
   ```
   
   A small script triggering the bug:
   ```
   import torch
   from tvm import relay
   
   class TriggerBug(torch.nn.Module):
       def __init__(self):
           super(TriggerBug, self).__init__()
       def forward(self, x):
           return x[..., 0::2] + x[..., 1::2]
   
   #x_in = torch.randn(4)  # this would work
   x_in = torch.randn(1, 4)  # this doesn't work
   
   torch_model = TriggerBug()
   traced_model = torch.jit.trace(torch_model, (x_in,))
   
   mod, params = relay.frontend.from_pytorch(traced_model, [('x_in', 
x_in.shape)])
   
   ```
   The output is:
   
   >  mod, params = relay.frontend.from_pytorch(traced_model, [('x_in', 
x_in.shape)])
   >   File 
"/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/relay/frontend/pytorch.py",
 line 2788, in from_pytorch
   >     mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
   >   File 
"/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/ir/module.py",
 line 74, in __setitem__
   >     return self._add(var, val)
   >   File 
"/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/ir/module.py",
 line 83, in _add
   >     _ffi_api.Module_Add(self, var, val, update)
   >   File 
"/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py",
 line 225, in __call__
   >     raise get_last_ffi_error()
   > tvm._ffi.base.TVMError: Traceback (most recent call last):
   >   [bt] (8) 9   libffi.7.dylib                      0x000000010ebd7ead 
ffi_call_unix64 + 85
   >   [bt] (7) 8   libtvm.dylib                        0x000000012801c3c8 
TVMFuncCall + 72
   >   [bt] (6) 7   libtvm.dylib                        0x000000012754260c 
std::__1::__function::__func<tvm::$_3, std::__1::allocator<tvm::$_3>, void 
(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, 
tvm::runtime::TVMRetValue*&&) + 492
   >   [bt] (5) 6   libtvm.dylib                        0x0000000127536437 
tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool) + 183
   >   [bt] (4) 5   libtvm.dylib                        0x0000000127535edf 
tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, 
tvm::relay::Function) + 1103
   >   [bt] (3) 4   libtvm.dylib                        0x0000000127e773e0 
tvm::relay::InferType(tvm::relay::Function const&, tvm::IRModule const&, 
tvm::GlobalVar const&) + 544
   >   [bt] (2) 3   libtvm.dylib                        0x0000000127e76587 
tvm::relay::TypeInferencer::Infer(tvm::RelayExpr) + 119
   >   [bt] (1) 2   libtvm.dylib                        0x000000012752714c 
tvm::ErrorReporter::RenderErrors(tvm::IRModule const&, bool) + 5308
   >   [bt] (0) 1   libtvm.dylib                        0x000000012735447f 
dmlc::LogMessageFatal::~LogMessageFatal() + 111
   >   File "/Users/puu/code/python/tvm_fix/tvm/src/ir/error.cc", line 132
   > TVMError: 
   > Error(s) have occurred. The program has been annotated with them:
   > 
   > In `main`: 
   > #[version = "0.0.5"]
   > fn (%x_in: Tensor[(1, 4), float32]) {
   >   %0 = strided_slice(%x_in, meta[relay.Constant][0], 
meta[relay.Constant][1], meta[relay.Constant][2], begin=[0, 0], end=[1, 4], 
strides=[2]);
   >   %1 = strided_slice(%x_in, meta[relay.Constant][3], 
meta[relay.Constant][4], meta[relay.Constant][5], begin=[0, 1], end=[1, 4], 
strides=[2]);
   >   add(%0, %1) Incompatible broadcast type TensorType([1, 4], float32) and 
TensorType([1, 3], float32); 
   > }
   
   This suggests that the slicing ignores the stride argument.
   
   If the tensor to be sliced is 1D, the result is correct, but even a 2D fails.
   
   #6316 seems somehow related, but unfortunately it doesn't fix this issue.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to