mshr-h opened a new issue, #16126:
URL: https://github.com/apache/tvm/issues/16126

   Seems like the aten::new_zeros support is broken. #14747 
   
   ### Steps to reproduce
   
   ```python
   from tvm import relay
   import torch
   
   torch.set_grad_enabled(False)
   
   
   class NewZeros1(torch.nn.Module):
     def forward(self, x):
       return x.new_zeros((2, 3))
   
   
   input_data = torch.tensor((), dtype=torch.float)
   module = NewZeros1().float().eval()
   scripted_module = torch.jit.trace(module, input_data)
   input_infos = [('x', (input_data.shape, 'float32'))]
   mod, params = relay.frontend.from_pytorch(scripted_module, input_infos)
   ```
   
   ### Actual behavior
   
   ```
   Traceback (most recent call last):
     File "/home/ubuntu/workspace/sandbox/tvm_/frontend/new_zeros.py", line 17, 
in <module>
       mod, params = relay.frontend.from_pytorch(scripted_module, input_infos)
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", 
line 5183, in from_pytorch
       outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", 
line 4402, in convert_operators
       relay_out = relay_op(
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", 
line 851, in new_zeros
       return self.full_impl(data, 0, dtype)
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", 
line 755, in full_impl
       out = _op.full(fill_value, size, dtype=dtype)
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/op/transform.py", 
line 535, in full
       return _make.full(fill_value, shape, dtype)
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/_ctypes/packed_func.py",
 line 239, in __call__
       raise_last_ffi_error()
     File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/base.py", 
line 481, in raise_last_ffi_error
       raise py_err
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     2: _ZN3tvm7runtime13PackedFuncObj
     1: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, 
tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, 
tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType), 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*) const
     0: tvm::runtime::TVMMovableArgValueWithContext_::operator 
tvm::runtime::DataType<tvm::runtime::DataType>() const
     3: _ZN3tvm7runtime13PackedFuncObj
     2: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, 
tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, 
tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, 
tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType), 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*) const
     1: tvm::runtime::TVMMovableArgValueWithContext_::operator 
tvm::runtime::DataType<tvm::runtime::DataType>() const
     0: tvm::runtime::TVMArgValue::operator DLDataType() const
     File 
"/home/ubuntu/workspace/sandbox/.dep/tvm/include/tvm/runtime/packed_func.h", 
line 777
   TVMError: In function relay.op._make.full(0: RelayExpr, 1: Array<IntImm>, 2: 
DataType) -> RelayExpr: error while converting argument 2: [19:47:16] 
/home/ubuntu/workspace/sandbox/.dep/tvm/include/tvm/runtime/packed_func.h:2210: 
InternalError: Check failed: type_code_ == kTVMDataType (8 vs. 5) : expected 
DLDataType but got Object
   ```
   
   ### Environment
   
   TVM: c8ef902a752e71dc92fb7fdfd7c3fcc4178a3455
   LLVM=ON
   CUDA=OFF
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * frontend:pytorch
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to