This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 83c56bd7d7 [Relax] Add support for ONNX LPPool (#17540)
83c56bd7d7 is described below

commit 83c56bd7d764f415882c2149e3234827fc015bf4
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Nov 21 19:38:27 2024 +0800

    [Relax] Add support for ONNX LPPool (#17540)
    
    adding support for ONNX LPPool and refactoring frontend tests
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  22 ++-
 tests/python/relax/test_frontend_onnx.py        | 221 +++++-------------------
 2 files changed, 62 insertions(+), 181 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9e0f5a060c..f5083caf82 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2284,6 +2284,7 @@ class Pool(OnnxOpConverter):
         kernel_shape = attr.get("kernel_shape")
         pads = attr.get("pads", 0)
         strides = attr.get("strides", [1] * (ndim - 2))
+        count_include_pad = attr.get("count_include_pad", False)
 
         assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/ 
pooling is supported."
 
@@ -2328,7 +2329,7 @@ class Pool(OnnxOpConverter):
             pads = tuple([val for pair in zip(*pads) for val in pair])
 
         op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d")
-        return op(data, kernel_shape, strides, pads, dilations, ceil_mode)
+        return op(data, kernel_shape, strides, pads, dilations, ceil_mode, 
count_include_pad)
 
     @classmethod
     def _get_input_spatial_shape(cls, tensor):
@@ -2348,6 +2349,23 @@ class AveragePool(Pool):
     name = "avg_pool"
 
 
+class LpPool(OnnxOpConverter):
+    """Converts an onnx LpPool node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        dtype = inputs[0].struct_info.dtype
+        p = attr.get("p", 2.0)
+        reci_p = relax.const(1.0 / p, dtype=dtype)
+        # emit for get struct_info
+        data = bb.emit(relax.op.power(inputs[0], relax.const(p, dtype=dtype)))
+        attr.update({"count_include_pad": True})
+        avg_pool = AveragePool._impl_v1(bb, [data], attr, params)
+        kernels = attr["kernel_shape"]
+        out = avg_pool * relax.const(_np.prod(kernels).astype(dtype))
+        return relax.op.power(out, reci_p)
+
+
 class GlobalAveragePool(OnnxOpConverter):
     """Converts an onnx GlobalAveragePool node into an equivalent Relax 
expression."""
 
@@ -3202,7 +3220,7 @@ def _get_convert_map():
         "Tile": Tile,
         "AveragePool": AveragePool,
         "MaxPool": MaxPool,
-        # "LpPool": LpPool,
+        "LpPool": LpPool,
         "GlobalAveragePool": GlobalAveragePool,
         "GlobalMaxPool": GlobalMaxPool,
         "GlobalLpPool": GlobalLpPool,
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 4cd4704ac0..16445a7914 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2306,194 +2306,57 @@ def test_batch_norm():
     check_correctness(model, opset=15)
 
 
-def test_maxpool_and_averagepool():
-    for pool_name in ["MaxPool", "AveragePool"]:
[email protected]("pool_name", ["MaxPool", "AveragePool", "LpPool"])
[email protected](
+    "shape, auto_pad, kernel_shape, strides, pads",
+    [
         # Pool1D
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3],
-                pads=[1, 1],
-                strides=[1],
-            ),
-        )
+        ([1, 1, 32], "NOTSET", [3], [1], [1, 1]),
         # Pool1D with stride
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3],
-                pads=[1, 2],
-                strides=[2],
-            ),
-        )
+        ([1, 1, 32], "NOTSET", [3], [2], [1, 1]),
         # Pool1D with stride and autopadding
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[7],
-                pads=None,
-                strides=[2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="SAME_LOWER",
-                kernel_shape=[4],
-                pads=None,
-                strides=[4],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="VALID",
-                kernel_shape=[5],
-                pads=None,
-                strides=[5],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[3],
-                pads=None,
-            ),
-        )
+        ([1, 1, 32], "SAME_UPPER", [7], [2], None),
+        ([1, 1, 32], "SAME_LOWER", [4], [4], None),
+        ([1, 1, 32], "VALID", [5], [5], None),
+        ([1, 1, 32], "SAME_UPPER", [3], [1], None),
         # Pool2D
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3, 3],
-                pads=[1, 1, 1, 1],
-                strides=[1, 1],
-            ),
-        )
+        ([1, 1, 32, 32], "NOTSET", [3, 3], [1, 1], [1, 1, 1, 1]),
         # Pool2D with stride
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3, 3],
-                pads=[1, 1, 1, 1],
-                strides=[2, 2],
-            ),
-        )
+        ([1, 1, 32, 32], "NOTSET", [3, 3], [2, 2], [1, 1, 1, 1]),
         # Pool2D with stride and autopadding
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[3, 7],
-                pads=None,
-                strides=[3, 2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="SAME_LOWER",
-                kernel_shape=[3, 3],
-                pads=None,
-                strides=[2, 2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="VALID",
-                kernel_shape=[3, 3],
-                pads=None,
-                strides=[2, 2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[3, 3],
-                pads=None,
-            ),
-        )
+        ([1, 1, 32, 32], "SAME_UPPER", [3, 7], [3, 2], None),
+        ([1, 1, 32, 32], "SAME_LOWER", [3, 3], [2, 2], None),
+        ([1, 1, 32, 32], "VALID", [3, 3], [2, 2], None),
+        ([1, 1, 32, 32], "SAME_UPPER", [3, 3], [1, 1], None),
         # Pool3D
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3, 3, 4],
-                pads=[1, 2, 1, 1, 2, 2],
-                strides=[1, 1, 1],
-            ),
-        )
+        ([1, 1, 32, 32, 32], "NOTSET", [3, 3, 4], [1, 1, 1], [1, 2, 1, 1, 2, 
2]),
         # Pool3D with stride
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="NOTSET",
-                kernel_shape=[3, 4, 3],
-                pads=[1, 1, 1, 1, 1, 2],
-                strides=[2, 2, 3],
-            ),
-        )
+        ([1, 1, 32, 32, 32], "NOTSET", [3, 4, 3], [2, 2, 3], [1, 1, 1, 1, 1, 
2]),
         # Pool3D with stride and autopadding
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[4, 3, 3],
-                pads=None,
-                strides=[3, 2, 2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="SAME_LOWER",
-                kernel_shape=[3, 3, 4],
-                pads=None,
-                strides=[2, 2, 2],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="VALID",
-                kernel_shape=[3, 3, 5],
-                pads=None,
-                strides=[2, 2, 3],
-            ),
-        )
-        verify_unary(
-            pool_name,
-            [1, 1, 32, 32, 32],
-            dict(
-                auto_pad="SAME_UPPER",
-                kernel_shape=[3, 3, 5],
-                pads=None,
-            ),
-        )
+        ([1, 1, 32, 32, 32], "SAME_UPPER", [4, 3, 3], [3, 2, 2], None),
+        ([1, 1, 32, 32, 32], "SAME_LOWER", [3, 3, 4], [2, 2, 2], None),
+        ([1, 1, 32, 32, 32], "VALID", [3, 3, 5], [2, 2, 3], None),
+        ([1, 1, 32, 32, 32], "SAME_UPPER", [3, 3, 5], [1, 1, 1], None),
+    ],
+)
+def test_pool(
+    pool_name: str,
+    shape: List[int],
+    auto_pad: str,
+    kernel_shape: List[int],
+    strides: List[int],
+    pads: List[int],
+):
+    verify_unary(
+        pool_name,
+        shape,
+        attrs={
+            "kernel_shape": kernel_shape,
+            "strides": strides,
+            "pads": pads,
+            "auto_pad": auto_pad,
+        },
+    )
 
 
 def test_global_average_pool():

Reply via email to