Thrsu opened a new pull request, #15933:
URL: https://github.com/apache/tvm/pull/15933
This PR fixes a bug in the interpolate operator of the PyTorch frontend in
TVM. The bug was caused by incorrectly using the `method` keyword instead of
the `mode` keyword when retrieving the default value for the mode parameter.
This resulted in incorrect computation of `interpolate` results.
This bug can be reproduced by the below script:
```python
import torch
from torch import fx
from torch.nn import Module
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.torch import from_fx
input_data = torch.randn([1, 2, 4, 4], dtype=torch.float32)
class interpolate(Module):
def forward(self, input):
return torch.nn.functional.interpolate(input, size=None,
scale_factor=2.0, mode='bilinear', align_corners=False,)
model = interpolate().float()
input_data = [input_data]
input_names = [f"input{idx}" for idx, _ in enumerate(input_data)]
input_info = list(zip([list(inp.shape) for inp in input_data],
[str(inp.dtype) for inp in input_data]))
fx_model : torch.fx.GraphModule = fx.symbolic_trace(model)
with torch.no_grad():
mod = from_fx(fx_model, input_info)
if torch.cuda.is_available():
model = model.cuda()
torch_input = [inp.cuda() for inp in input_data]
with torch.no_grad():
torch_outputs = model(*[input.clone() for input in input_data])
torch_outputs = (torch_outputs.cpu().numpy(),)
compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in
input_data]))
tvm_input = {}
for name, inp in compiled_input.items():
tvm_input[name] = tvm.nd.array(inp)
target = tvm.target.Target("llvm", host="llvm")
mod = relax.transform.LegalizeOps()(mod)
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
tvm_outputs = vm["main"](*[tvm_input[name] for name in input_names])
tvm_outputs = [tvm_outputs]
for i, torch_output in enumerate(torch_outputs):
output = tvm_outputs[i].numpy()
tvm.testing.assert_allclose(torch_output, output, rtol=1e-5, atol=1e-5)
```
And here is the traceback information:
```
Traceback (most recent call last):
...
tvm.testing.assert_allclose(torch_output, output, rtol=1e-5, atol=1e-5)
File "/workplace/software/tvm/tvm/python/tvm/testing/utils.py", line 120,
in assert_allclose
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol,
verbose=True)
File
"/workplace/software/miniconda3/envs/tflite/lib/python3.8/site-packages/numpy/testing/_private/utils.py",
line 1527, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File
"/workplace/software/miniconda3/envs/tflite/lib/python3.8/site-packages/numpy/testing/_private/utils.py",
line 844, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-05
Mismatched elements: 120 / 128 (93.8%)
Max absolute difference: 1.2100129
Max relative difference: 39.08069
x: array([[[[-0.407674, -0.087851, 0.551796, 0.303579, -0.832501,
-0.630166, 0.910585, 1.680961],
[-0.152558, -0.054653, 0.141159, -0.0371 , -0.589427,...
y: array([[[[-0.407674, -0.407674, 0.871619, 0.871619, -1.400542,
-1.400542, 1.680961, 1.680961],
[-0.407674, -0.407674, 0.871619, 0.871619, -1.400542,...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]