This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5099068ffe [Relax][PyTorch] Add `count_include_pad` support to
`avg_pool2d` in PyTorch frontend (#18487)
5099068ffe is described below
commit 5099068ffe20bc07cc20c839caea707963e5a491
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 23 14:14:19 2025 +0900
[Relax][PyTorch] Add `count_include_pad` support to `avg_pool2d` in PyTorch
frontend (#18487)
As per title.
Note that `count_include_pad` is True by default on PyTorch. But on
Relax, it's False by default.
cc @tlopex
---
.../frontend/torch/base_fx_graph_translator.py | 5 ++-
src/contrib/msc/framework/tvm/relax_opcode.cc | 1 +
.../relax/test_frontend_from_exported_program.py | 40 +++++++++++++++++++---
tests/python/relax/test_frontend_from_fx.py | 3 ++
4 files changed, 43 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 5ca79344ba..4165086808 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -653,6 +653,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[int] = 0,
ceil_mode: Optional[bool] = False,
+ count_include_pad: Optional[bool] = True,
) -> relax.Var:
# Expand to 4D by adding batch dim if input is 3D
x_ndim = x.struct_info.ndim
@@ -667,6 +668,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
strides=stride,
padding=padding,
ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
layout="NCHW",
)
)
@@ -682,7 +684,8 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode",
False)
- return self._avg_pool2d_impl(x, kernel_size, stride, padding,
ceil_mode)
+ count_include_pad = args[5] if len(args) > 5 else
kwargs.get("count_include_pad", True)
+ return self._avg_pool2d_impl(x, kernel_size, stride, padding,
ceil_mode, count_include_pad)
def _avg_pool3d_impl(
self,
diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc
b/src/contrib/msc/framework/tvm/relax_opcode.cc
index 54d55721ac..da2cdfba59 100644
--- a/src/contrib/msc/framework/tvm/relax_opcode.cc
+++ b/src/contrib/msc/framework/tvm/relax_opcode.cc
@@ -507,6 +507,7 @@ class RelaxPool2dCodeGen : public RelaxOpCode {
.op_list_arg<int>("strides")
.op_list_arg<int>("padding")
.op_list_arg<int>("dilation")
+ .op_arg<bool>("count_include_pad")
.op_arg<bool>("ceil_mode")
.op_str_arg("layout")
.op_str_arg("out_layout");
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index c4851973ea..a61da359d3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1910,7 +1910,7 @@ def test_avg_pool1d():
dilation=[1, 1],
padding=[0, 0, 0, 0],
ceil_mode=False,
- count_include_pad=False,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -1948,7 +1948,7 @@ def test_avg_pool1d():
dilation=[1, 1],
padding=[0, 1, 0, 1],
ceil_mode=True,
- count_include_pad=False,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -1976,7 +1976,7 @@ def test_avg_pool1d():
dilation=[1, 1],
padding=[0, 0, 0, 0],
ceil_mode=False,
- count_include_pad=False,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -2015,6 +2015,7 @@ def test_avg_pool2d():
strides=[1, 1],
dilation=[1, 1],
padding=[0, 0, 0, 0],
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -2048,6 +2049,7 @@ def test_avg_pool2d():
dilation=[1, 1],
padding=[2, 2, 2, 2],
ceil_mode=True,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -2060,7 +2062,7 @@ def test_avg_pool2d():
return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1],
divisor_override=2)
@tvm.script.ir_module
- class expected3:
+ class expected4:
@R.function
def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
@@ -2071,6 +2073,33 @@ def test_avg_pool2d():
dilation=[1, 1],
padding=[0, 0, 0, 0],
ceil_mode=False,
+ count_include_pad=True,
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ class AvgPool2d5(Module):
+ def forward(self, input):
+ return torch.nn.functional.avg_pool2d(
+ input, kernel_size=[2, 1], divisor_override=2,
count_include_pad=False
+ )
+
+ @tvm.script.ir_module
+ class expected5:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.avg_pool2d(
+ input_1,
+ pool_size=[2, 1],
+ strides=[2, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ ceil_mode=False,
+ count_include_pad=False,
layout="NCHW",
out_layout="NCHW",
)
@@ -2082,7 +2111,8 @@ def test_avg_pool2d():
verify_model(AvgPool2d1(), example_args, {}, expected1)
verify_model(AvgPool2d2(), example_args, {}, expected2)
verify_model(AvgPool2d3(), example_args, {}, expected2)
- verify_model(AvgPool2d4(), example_args, {}, expected3)
+ verify_model(AvgPool2d4(), example_args, {}, expected4)
+ verify_model(AvgPool2d5(), example_args, {}, expected5)
def test_avg_pool3d():
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index d377bb7574..031a855fb9 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1434,6 +1434,7 @@ def test_avgpool2d():
strides=[1, 1],
dilation=[1, 1],
padding=[0, 0, 0, 0],
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -1467,6 +1468,7 @@ def test_avgpool2d():
dilation=[1, 1],
padding=[2, 2, 2, 2],
ceil_mode=True,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)
@@ -1490,6 +1492,7 @@ def test_avgpool2d():
dilation=[1, 1],
padding=[0, 0, 0, 0],
ceil_mode=False,
+ count_include_pad=True,
layout="NCHW",
out_layout="NCHW",
)