This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 506a0bbc3f [Relax][PyTorch] Add decomposed operator support for 
AdaptiveAvgPool (#18437)
506a0bbc3f is described below

commit 506a0bbc3f37bbee4bca5ce45972eefb6dc0288c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 14 01:25:59 2025 +0800

    [Relax][PyTorch] Add decomposed operator support for AdaptiveAvgPool 
(#18437)
    
    * Add decomposed operator support for AdaptiveAvgPool
    
    * Refactor avg_pool1d tests
---
 .../frontend/torch/exported_program_translator.py  |   3 +
 .../relax/test_frontend_from_exported_program.py   | 158 +++++++++++----------
 2 files changed, 88 insertions(+), 73 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0d4abb0336..a6da21ada8 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -950,6 +950,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # linear algebra
             "linalg_vector_norm.default": self._norm,
             # neural network
+            "_adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
+            "_adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
+            "_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
             "_native_batch_norm_legit_functional.default": 
self._batch_norm_legit_functional,
             "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
             "batch_norm.default": self._batch_norm_legit_no_training,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 8f308e59b7..774a50db0e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1632,16 +1632,18 @@ def test_adaptive_avgpool1d():
             input_1: R.Tensor((1, 3, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 5), dtype="float32") = 
R.nn.adaptive_avg_pool1d(
-                    input_1, output_size=[5], layout="NCW"
+                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.expand_dims(input_1, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
+                    lv, output_size=[1, 5], layout="NCHW"
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,)
+                lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
-    verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_adaptive_avgpool2d():
@@ -1673,8 +1675,8 @@ def test_adaptive_avgpool2d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
-    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_adaptive_avgpool3d():
@@ -1705,8 +1707,8 @@ def test_adaptive_avgpool3d():
             return gv
 
     example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
-    verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_addmm():
@@ -1781,21 +1783,23 @@ def test_avg_pool1d():
     class expected1:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10), dtype="float32")
+            input: R.Tensor((1, 3, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d(
-                    input_1,
-                    pool_size=[1],
-                    strides=[1],
-                    dilation=[1],
-                    padding=[0, 0],
+                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.expand_dims(input, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.nn.avg_pool2d(
+                    lv,
+                    pool_size=[1, 1],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
                     ceil_mode=False,
-                    count_include_pad=True,
-                    layout="NCW",
-                    out_layout="NCW",
+                    count_include_pad=False,
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,)
+                lv2: R.Tensor((1, 3, 10), dtype="float32") = R.squeeze(lv1, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
@@ -1816,20 +1820,24 @@ def test_avg_pool1d():
     @tvm.script.ir_module
     class expected2:
         @R.function
-        def main(input_1: R.Tensor((1, 3, 10), dtype="float32")):
+        def main(
+            input: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 6), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.avg_pool1d(
-                    input_1,
-                    pool_size=[3],
-                    strides=[2],
-                    dilation=[1],
-                    padding=[1, 1],
+                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.expand_dims(input, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 6), dtype="float32") = R.nn.avg_pool2d(
+                    lv,
+                    pool_size=[1, 3],
+                    strides=[1, 2],
+                    dilation=[1, 1],
+                    padding=[0, 1, 0, 1],
                     ceil_mode=True,
-                    count_include_pad=True,
-                    layout="NCW",
-                    out_layout="NCW",
+                    count_include_pad=False,
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv = (lv,)
+                lv2: R.Tensor((1, 3, 6), dtype="float32") = R.squeeze(lv1, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 6), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
@@ -1840,28 +1848,32 @@ def test_avg_pool1d():
     @tvm.script.ir_module
     class expected3:
         @R.function
-        def main(input_1: R.Tensor((1, 3, 10), dtype="float32")):
+        def main(
+            input: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.avg_pool1d(
-                    input_1,
-                    pool_size=[2],
-                    strides=[2],
-                    dilation=[1],
-                    padding=[0, 0],
+                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.expand_dims(input, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.avg_pool2d(
+                    lv,
+                    pool_size=[1, 2],
+                    strides=[1, 2],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
                     ceil_mode=False,
-                    count_include_pad=True,
-                    layout="NCW",
-                    out_layout="NCW",
+                    count_include_pad=False,
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv = (lv,)
+                lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
-    verify_model(AvgPool1d1(), example_args, {}, expected1)
-    verify_model(AvgPool1d2(), example_args, {}, expected2)
-    verify_model(AvgPool1d3(), example_args, {}, expected2)
-    verify_model(AvgPool1d4(), example_args, {}, expected3)
+    verify_model(AvgPool1d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AvgPool1d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool1d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool1d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_avg_pool2d():
@@ -1951,10 +1963,10 @@ def test_avg_pool2d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(AvgPool2d1(), example_args, {}, expected1)
-    verify_model(AvgPool2d2(), example_args, {}, expected2)
-    verify_model(AvgPool2d3(), example_args, {}, expected2)
-    verify_model(AvgPool2d4(), example_args, {}, expected3)
+    verify_model(AvgPool2d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AvgPool2d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool2d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool2d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_avg_pool3d():
@@ -2047,10 +2059,10 @@ def test_avg_pool3d():
             return gv
 
     example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
-    verify_model(AvgPool3d1(), example_args, {}, expected1)
-    verify_model(AvgPool3d2(), example_args, {}, expected2)
-    verify_model(AvgPool3d3(), example_args, {}, expected2)
-    verify_model(AvgPool3d4(), example_args, {}, expected3)
+    verify_model(AvgPool3d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AvgPool3d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool3d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(AvgPool3d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_baddbmm():
@@ -2284,15 +2296,15 @@ def test_conv_transpose1d():
 
     model = ConvTranspose1d1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = ConvTranspose1d1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = ConvTranspose1d2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_conv_transpose2d():
@@ -2378,15 +2390,15 @@ def test_conv_transpose2d():
 
     model = ConvTranspose2d1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = ConvTranspose2d1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = ConvTranspose2d2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_conv1d():
@@ -2470,15 +2482,15 @@ def test_conv1d():
 
     model = Conv1D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv1D1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv1D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_conv2d():
@@ -2562,15 +2574,15 @@ def test_conv2d():
 
     model = Conv2D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv2D1Func()
     binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv2D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_conv3d():
@@ -2654,15 +2666,15 @@ def test_conv3d():
 
     model = Conv3D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv3D1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Conv3D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_pad():
@@ -6523,7 +6535,7 @@ def test_lstm():
     with torch.no_grad():
         pytorch_output = model(x)
     exported_program = export(model, args=(x,))
-    mod = from_exported_program(exported_program)
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -6559,7 +6571,7 @@ def test_lstm():
     with torch.no_grad():
         pytorch_output2 = model2(x2)
     exported_program2 = export(model2, args=(x2,))
-    mod2 = from_exported_program(exported_program2)
+    mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
     ex2 = relax.build(mod2, target)
     vm2 = relax.VirtualMachine(ex2, tvm.cpu())
     x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -6616,7 +6628,7 @@ def test_gru():
     with torch.no_grad():
         pytorch_output = model(x)
     exported_program = export(model, args=(x,))
-    mod = from_exported_program(exported_program)
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -6652,7 +6664,7 @@ def test_gru():
     with torch.no_grad():
         pytorch_output2 = model2(x2)
     exported_program2 = export(model2, args=(x2,))
-    mod2 = from_exported_program(exported_program2)
+    mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
     ex2 = relax.build(mod2, target)
     vm2 = relax.VirtualMachine(ex2, tvm.cpu())
     x2_tvm = tvm.runtime.tensor(x2.numpy())

Reply via email to