This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cc8afdb0e3 Add support for `torch.nn.functional.max_pool2d` (#17189)
cc8afdb0e3 is described below

commit cc8afdb0e3be52a3aa162ff14a81b11a793dca6b
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Jul 24 22:36:19 2024 +0900

    Add support for `torch.nn.functional.max_pool2d` (#17189)
    
    * add a testcase for call_function
    
    * add maxpool2d to call_function
---
 python/tvm/relax/frontend/torch/fx_translator.py | 1 +
 tests/python/relax/test_frontend_from_fx.py      | 8 ++++++++
 2 files changed, 9 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index e6b39c3eee..093f3ae4cf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1476,6 +1476,7 @@ class TorchFXImporter:
             "getitem": self._getitem,
             "contiguous": lambda node: self.env[node.args[0]],
             "to": self._to,
+            "max_pool2d": self._max_pool2d,
             "avg_pool2d": self._avg_pool2d,
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
             "layer_norm": self._layer_norm,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index b4ac3fa60c..1a2cc5da62 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -796,6 +796,13 @@ def test_maxpool2d():
         def forward(self, input):
             return self.pool(input)
 
+    class MaxPool2d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -876,6 +883,7 @@ def test_maxpool2d():
             return gv
 
     verify_model(MaxPool2d(), input_info, {}, expected1)
+    verify_model(MaxPool2d_functional(), input_info, {}, expected1)
     verify_model(MaxPool2d2(), input_info, {}, expected2)
     verify_model(MaxPool2d3(), input_info, {}, expected3)
 

Reply via email to