This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cc8afdb0e3 Add support for `torch.nn.functional.max_pool2d` (#17189)
cc8afdb0e3 is described below
commit cc8afdb0e3be52a3aa162ff14a81b11a793dca6b
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Jul 24 22:36:19 2024 +0900
Add support for `torch.nn.functional.max_pool2d` (#17189)
* add a testcase for call_function
* add maxpool2d to call_function
---
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
tests/python/relax/test_frontend_from_fx.py | 8 ++++++++
2 files changed, 9 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index e6b39c3eee..093f3ae4cf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1476,6 +1476,7 @@ class TorchFXImporter:
"getitem": self._getitem,
"contiguous": lambda node: self.env[node.args[0]],
"to": self._to,
+ "max_pool2d": self._max_pool2d,
"avg_pool2d": self._avg_pool2d,
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
"layer_norm": self._layer_norm,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index b4ac3fa60c..1a2cc5da62 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -796,6 +796,13 @@ def test_maxpool2d():
def forward(self, input):
return self.pool(input)
+ class MaxPool2d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -876,6 +883,7 @@ def test_maxpool2d():
return gv
verify_model(MaxPool2d(), input_info, {}, expected1)
+ verify_model(MaxPool2d_functional(), input_info, {}, expected1)
verify_model(MaxPool2d2(), input_info, {}, expected2)
verify_model(MaxPool2d3(), input_info, {}, expected3)