fengzi0 opened a new issue, #15155:
URL: https://github.com/apache/tvm/issues/15155
When i load Pytorch fx quantized model to TVM like below code:
```
import torch
from torch.ao.quantization import get_default_qconfig_mapping,
get_default_qat_qconfig_mapping, quantize_fx
import tvm
from tvm import relay
qconfig_mapping = get_default_qconfig_mapping()
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(256, 256)
def forward(self, input):
x = self.linear(input)
return x
mm = MyModule()
input = torch.randn((1,900,256))
mm_prepared = quantize_fx.prepare_fx(mm, qconfig_mapping, (input))
r = mm_prepared(input)
mm_quantize = quantize_fx.convert_fx(mm_prepared)
script_mm = torch.jit.trace(mm_quantize, (input))
input_shapes_mm = [('input', tuple(input.shape))]
mod, params = relay.frontend.from_pytorch(script_mm, input_shapes_mm)
```
this give:
```
The Relay type checker is unable to show the following types match:
Tensor[(900), int32]
Tensor[(256), int32]
In particular:
dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), int32]` does not match `Tensor[(900), int32]`
The Relay type checker is unable to show the following types match:
Tensor[(900), float32]
Tensor[(256), float32]
In particular:
dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), float32]` does not match `Tensor[(900),
float32]`
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py",
line 4649, in from_pytorch
outputs =
converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs,
ret_name)
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py",
line 4025, in convert_operators
self.record_output_type(relay_out)
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py",
line 220, in record_output_type
self.infer_type_with_prelude(output)
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py",
line 168, in infer_type_with_prelude
body = self.infer_type(val, self.prelude.mod)
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py",
line 161, in infer_type
new_mod = transform.InferType()(new_mod)
File
"/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/ir/transform.py",
line 161, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in
tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in
tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in
tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 181, in
tvm._ffi._cy3.core.CHECK_CALL
tvm.error.DiagnosticError: Traceback (most recent call last):
6: TVMFuncCall
5:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::transform::Pass,
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}, std::string)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
4: tvm::transform::Pass::operator()(tvm::IRModule) const
3: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
2: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
1:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::DiagnosticContext::Render()
File "/workspace/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check
diagnostic render for output.
```
version info:
```
>>> torch.__version__
'2.0.0+cu117'
>>> tvm.__version__
'0.11.1'
```
Wondering is there something obvious that I should fix? Thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]