This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5099068ffe [Relax][PyTorch] Add `count_include_pad` support to 
`avg_pool2d` in PyTorch frontend (#18487)
5099068ffe is described below

commit 5099068ffe20bc07cc20c839caea707963e5a491
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 23 14:14:19 2025 +0900

    [Relax][PyTorch] Add `count_include_pad` support to `avg_pool2d` in PyTorch 
frontend (#18487)
    
    As per title.
    Note that `count_include_pad` is True by default on PyTorch. But on
    Relax, it's False by default.
    
    cc @tlopex
---
 .../frontend/torch/base_fx_graph_translator.py     |  5 ++-
 src/contrib/msc/framework/tvm/relax_opcode.cc      |  1 +
 .../relax/test_frontend_from_exported_program.py   | 40 +++++++++++++++++++---
 tests/python/relax/test_frontend_from_fx.py        |  3 ++
 4 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 5ca79344ba..4165086808 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -653,6 +653,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         stride: Optional[Union[int, Tuple[int, int]]] = None,
         padding: Optional[int] = 0,
         ceil_mode: Optional[bool] = False,
+        count_include_pad: Optional[bool] = True,
     ) -> relax.Var:
         # Expand to 4D by adding batch dim if input is 3D
         x_ndim = x.struct_info.ndim
@@ -667,6 +668,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                 strides=stride,
                 padding=padding,
                 ceil_mode=ceil_mode,
+                count_include_pad=count_include_pad,
                 layout="NCHW",
             )
         )
@@ -682,7 +684,8 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
         padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
         ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", 
False)
-        return self._avg_pool2d_impl(x, kernel_size, stride, padding, 
ceil_mode)
+        count_include_pad = args[5] if len(args) > 5 else 
kwargs.get("count_include_pad", True)
+        return self._avg_pool2d_impl(x, kernel_size, stride, padding, 
ceil_mode, count_include_pad)
 
     def _avg_pool3d_impl(
         self,
diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc 
b/src/contrib/msc/framework/tvm/relax_opcode.cc
index 54d55721ac..da2cdfba59 100644
--- a/src/contrib/msc/framework/tvm/relax_opcode.cc
+++ b/src/contrib/msc/framework/tvm/relax_opcode.cc
@@ -507,6 +507,7 @@ class RelaxPool2dCodeGen : public RelaxOpCode {
         .op_list_arg<int>("strides")
         .op_list_arg<int>("padding")
         .op_list_arg<int>("dilation")
+        .op_arg<bool>("count_include_pad")
         .op_arg<bool>("ceil_mode")
         .op_str_arg("layout")
         .op_str_arg("out_layout");
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index c4851973ea..a61da359d3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1910,7 +1910,7 @@ def test_avg_pool1d():
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
                     ceil_mode=False,
-                    count_include_pad=False,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -1948,7 +1948,7 @@ def test_avg_pool1d():
                     dilation=[1, 1],
                     padding=[0, 1, 0, 1],
                     ceil_mode=True,
-                    count_include_pad=False,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -1976,7 +1976,7 @@ def test_avg_pool1d():
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
                     ceil_mode=False,
-                    count_include_pad=False,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -2015,6 +2015,7 @@ def test_avg_pool2d():
                     strides=[1, 1],
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -2048,6 +2049,7 @@ def test_avg_pool2d():
                     dilation=[1, 1],
                     padding=[2, 2, 2, 2],
                     ceil_mode=True,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -2060,7 +2062,7 @@ def test_avg_pool2d():
             return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], 
divisor_override=2)
 
     @tvm.script.ir_module
-    class expected3:
+    class expected4:
         @R.function
         def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
             with R.dataflow():
@@ -2071,6 +2073,33 @@ def test_avg_pool2d():
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
                     ceil_mode=False,
+                    count_include_pad=True,
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    class AvgPool2d5(Module):
+        def forward(self, input):
+            return torch.nn.functional.avg_pool2d(
+                input, kernel_size=[2, 1], divisor_override=2, 
count_include_pad=False
+            )
+
+    @tvm.script.ir_module
+    class expected5:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.avg_pool2d(
+                    input_1,
+                    pool_size=[2, 1],
+                    strides=[2, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    count_include_pad=False,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -2082,7 +2111,8 @@ def test_avg_pool2d():
     verify_model(AvgPool2d1(), example_args, {}, expected1)
     verify_model(AvgPool2d2(), example_args, {}, expected2)
     verify_model(AvgPool2d3(), example_args, {}, expected2)
-    verify_model(AvgPool2d4(), example_args, {}, expected3)
+    verify_model(AvgPool2d4(), example_args, {}, expected4)
+    verify_model(AvgPool2d5(), example_args, {}, expected5)
 
 
 def test_avg_pool3d():
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index d377bb7574..031a855fb9 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1434,6 +1434,7 @@ def test_avgpool2d():
                     strides=[1, 1],
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -1467,6 +1468,7 @@ def test_avgpool2d():
                     dilation=[1, 1],
                     padding=[2, 2, 2, 2],
                     ceil_mode=True,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )
@@ -1490,6 +1492,7 @@ def test_avgpool2d():
                     dilation=[1, 1],
                     padding=[0, 0, 0, 0],
                     ceil_mode=False,
+                    count_include_pad=True,
                     layout="NCHW",
                     out_layout="NCHW",
                 )

Reply via email to