jikechao opened a new issue, #15282:
URL: https://github.com/apache/tvm/issues/15282

   torch.nn.functional.instance_norm(args[0], use_input_stats=True) and 
input_dtype=float64  wil lead to a crash: 
   **Error: tensor type `Tensor[(1), float64]` has 1 dimensions, while 
`float64` has 0 dimension**
   
   ### Actual behavior
   ```
   Traceback (most recent call last):
     File "test.py", line 20, in <module>
       mod, params = relay.frontend.from_pytorch(trace, input_shapes)
     File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", 
line 5002, in from_pytorch
       outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
     File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", 
line 4263, in convert_operators
       self.record_output_type(relay_out)
     File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", 
line 238, in record_output_type
       self.infer_type_with_prelude(output)
     File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", 
line 174, in infer_type_with_prelude
       body = self.infer_type(val, self.prelude.mod)
     File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", 
line 167, in infer_type
       new_mod = transform.InferType()(new_mod)
     File "/workplace/software/tvm/tvm_/python/tvm/ir/transform.py", line 160, 
in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File 
"/workplace/software/tvm/tvm_/python/tvm/_ffi/_ctypes/packed_func.py", line 
237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     10: TVMFuncCall
     9: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::$_6>(tvm::transform::$_6, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     8: tvm::transform::Pass::operator()(tvm::IRModule) const
     7: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     5: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs
 const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj 
const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     4: tvm::DiagnosticContext::Render()
     3: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&)
     2: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void
 
(tvm::DiagnosticContext)>::AssignTypedLambda<tvm::TerminalRenderer(std::ostream&)::$_10>(tvm::TerminalRenderer(std::ostream&)::$_10)::{lambda(tvm::runtime::TVMArgs
 const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj 
const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     1: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span 
const&, tvm::Diagnostic const&)
     0: _ZN3tvm7runtime6detail
     File "/workplace/software/tvm/tvm_/src/ir/diagnostic.cc", line 264
   TVMError: The source maps are not populated for this module. Please use 
`tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting.
   Error: tensor type `Tensor[(1), float64]` has 1 dimensions, while `float64` 
has 0 dimensions
   ```
   
   
   ### Steps to reproduce
   
   ```
   import torch
   from tvm import relay
   import tvm
   import numpy as np
   from torch.nn import Module
   
   input_data = torch.randn([1, 1, 1, 2], dtype=torch.float64)
   
   class instance_norm(Module):
       def forward(self, *args):
           return torch.nn.functional.instance_norm(args[0], 
use_input_stats=True)
   
   m = instance_norm().float().eval()
   torch_outputs = m(input_data)
   
   trace = torch.jit.trace(m, input_data)
   input_shapes = [('input0', torch.Size([1, 1, 1, 2]))]
   
   mod, params = relay.frontend.from_pytorch(trace, input_shapes)
   ```
   
   ### Triage
   
   
   * needs-triage
   * frontend:pytorch
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to