This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bd51b5b592 [Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d 
implementation (#15734)
bd51b5b592 is described below

commit bd51b5b592c7dbba4eec21086464b847971a7c86
Author: Thrsu <[email protected]>
AuthorDate: Fri Sep 15 01:34:16 2023 +0800

    [Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d implementation 
(#15734)
    
    This PR fixes a bug in the avg_pool2d implementation. The bug causes a 
KeyError when running the provided code snippet. The error message is as below:
    ```
    Traceback (most recent call last):
        ...
        mod = from_fx(fx_model, input_info)
      File 
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", 
line 1449, in from_fx
        return TorchFXImporter().from_fx(
      File 
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", 
line 1336, in from_fx
        self.env[node] = self.convert_map[func_name](node)
      File 
"/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", 
line 763, in _avg_pool2d
        stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
    KeyError: 'stride'
    ```
    And here is the code to reproduce this bug.
    ```python
    import torch
    from torch import fx
    from torch.nn import Module
    import tvm
    from tvm import relax
    from tvm.relax.frontend.torch import from_fx
    
    input_data = torch.randn([1, 1, 3, 3], dtype=torch.float32)
    para_1 = (2, 1)
    class avg_pool2d(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool2d(input, 
para_1,divisor_override=2,)
    model = avg_pool2d().float()
    input_data = [input_data]
    input_info = list(zip([list(inp.shape) for inp in input_data], 
[str(inp.dtype) for inp in input_data]))
    fx_model : torch.fx.GraphModule = fx.symbolic_trace(model)
    with torch.no_grad():
        mod = from_fx(fx_model, input_info)
    ```
    The issue arises due to the lack of a check for the existence of the 
"stride" key in the code. To resolve this bug, I have modified the code to 
include a check for the existence of the "stride" key before accessing it.
    
    I have tested these changes by running the provided code snippet, and the 
KeyError is no longer thrown. The fix ensures that the code gracefully handles 
cases where the "stride" key is missing.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 21 ++++++++++++++++++---
 tests/python/relax/test_frontend_from_fx.py      | 24 ++++++++++++++++++++++++
 2 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index b5cee77d11..a5c2a68cd8 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -859,9 +859,24 @@ class TorchFXImporter:
         else:
             nargs = len(node.args)
             kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"]
-            stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
-            padding = node.args[3] if nargs > 3 else node.kwargs["padding"]
-            ceil_mode = node.args[4] if nargs > 4 else node.kwargs["ceil_mode"]
+            if nargs > 2:
+                stride = node.args[2]
+            elif "stride" in node.kwargs.keys():
+                stride = node.kwargs["stride"]
+            else:
+                stride = None
+            if nargs > 3:
+                padding = node.args[3]
+            elif "padding" in node.kwargs.keys():
+                padding = node.kwargs["padding"]
+            else:
+                padding = 0
+            if nargs > 4:
+                ceil_mode = node.args[4]
+            elif "ceil_mode" in node.kwargs.keys():
+                ceil_mode = node.kwargs["ceil_mode"]
+            else:
+                ceil_mode = False
 
         stride = kernel if stride is None else stride
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index ec312767b4..36ef25b025 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -823,9 +823,33 @@ def test_avgpool2d():
                 R.output(gv)
             return gv
 
+    class AvgPool2d4(Module):
+        def forward(self, input):
+            return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], 
divisor_override=2)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.avg_pool2d(
+                    input_1,
+                    pool_size=[2, 1],
+                    strides=[2, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv = lv
+                R.output(gv)
+            return gv
+
     verify_model(AvgPool2d(), input_info, {}, expected1)
     verify_model(AvgPool2d2(), input_info, {}, expected2)
     verify_model(AvgPool2d3(), input_info, {}, expected2)
+    verify_model(AvgPool2d4(), input_info, {}, expected3)
 
 
 def test_adaptive_avgpool2d():

Reply via email to