This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bd51b5b592 [Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d
implementation (#15734)
bd51b5b592 is described below
commit bd51b5b592c7dbba4eec21086464b847971a7c86
Author: Thrsu <[email protected]>
AuthorDate: Fri Sep 15 01:34:16 2023 +0800
[Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d implementation
(#15734)
This PR fixes a bug in the avg_pool2d implementation. The bug causes a
KeyError when running the provided code snippet. The error message is as below:
```
Traceback (most recent call last):
...
mod = from_fx(fx_model, input_info)
File
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py",
line 1449, in from_fx
return TorchFXImporter().from_fx(
File
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py",
line 1336, in from_fx
self.env[node] = self.convert_map[func_name](node)
File
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py",
line 763, in _avg_pool2d
stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
KeyError: 'stride'
```
And here is the code to reproduce this bug.
```python
import torch
from torch import fx
from torch.nn import Module
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
input_data = torch.randn([1, 1, 3, 3], dtype=torch.float32)
para_1 = (2, 1)
class avg_pool2d(Module):
def forward(self, input):
return torch.nn.functional.avg_pool2d(input,
para_1,divisor_override=2,)
model = avg_pool2d().float()
input_data = [input_data]
input_info = list(zip([list(inp.shape) for inp in input_data],
[str(inp.dtype) for inp in input_data]))
fx_model : torch.fx.GraphModule = fx.symbolic_trace(model)
with torch.no_grad():
mod = from_fx(fx_model, input_info)
```
The issue arises due to the lack of a check for the existence of the
"stride" key in the code. To resolve this bug, I have modified the code to
include a check for the existence of the "stride" key before accessing it.
I have tested these changes by running the provided code snippet, and the
KeyError is no longer thrown. The fix ensures that the code gracefully handles
cases where the "stride" key is missing.
---
python/tvm/relax/frontend/torch/fx_translator.py | 21 ++++++++++++++++++---
tests/python/relax/test_frontend_from_fx.py | 24 ++++++++++++++++++++++++
2 files changed, 42 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index b5cee77d11..a5c2a68cd8 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -859,9 +859,24 @@ class TorchFXImporter:
else:
nargs = len(node.args)
kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"]
- stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
- padding = node.args[3] if nargs > 3 else node.kwargs["padding"]
- ceil_mode = node.args[4] if nargs > 4 else node.kwargs["ceil_mode"]
+ if nargs > 2:
+ stride = node.args[2]
+ elif "stride" in node.kwargs.keys():
+ stride = node.kwargs["stride"]
+ else:
+ stride = None
+ if nargs > 3:
+ padding = node.args[3]
+ elif "padding" in node.kwargs.keys():
+ padding = node.kwargs["padding"]
+ else:
+ padding = 0
+ if nargs > 4:
+ ceil_mode = node.args[4]
+ elif "ceil_mode" in node.kwargs.keys():
+ ceil_mode = node.kwargs["ceil_mode"]
+ else:
+ ceil_mode = False
stride = kernel if stride is None else stride
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index ec312767b4..36ef25b025 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -823,9 +823,33 @@ def test_avgpool2d():
R.output(gv)
return gv
+ class AvgPool2d4(Module):
+ def forward(self, input):
+ return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1],
divisor_override=2)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.avg_pool2d(
+ input_1,
+ pool_size=[2, 1],
+ strides=[2, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ ceil_mode=False,
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv = lv
+ R.output(gv)
+ return gv
+
verify_model(AvgPool2d(), input_info, {}, expected1)
verify_model(AvgPool2d2(), input_info, {}, expected2)
verify_model(AvgPool2d3(), input_info, {}, expected2)
+ verify_model(AvgPool2d4(), input_info, {}, expected3)
def test_adaptive_avgpool2d():