This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bb40e39969 [Relax][Frontend][Onnx] Add auto_pad support for conv 
(#17536)
bb40e39969 is described below

commit bb40e39969ec49dbb5fa9b4b18908aceda695ca2
Author: Honglin Zhu <[email protected]>
AuthorDate: Thu Nov 21 14:50:00 2024 +0800

    [Relax][Frontend][Onnx] Add auto_pad support for conv (#17536)
    
    * add auto_pad support for conv
    
    * Update test_frontend_onnx.py
    
    * Update onnx_frontend.py
    
    * add common.py
    
    * reformat common.py
    
    * reformat common.py
    
    * combine test into test_conv
---
 python/tvm/relax/frontend/common.py             | 74 +++++++++++++++++++++++++
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 32 ++++++++++-
 tests/python/relax/test_frontend_onnx.py        | 64 ++++++++++++++++-----
 3 files changed, 154 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relax/frontend/common.py 
b/python/tvm/relax/frontend/common.py
index bbd0c55aac..ba2960c159 100644
--- a/python/tvm/relax/frontend/common.py
+++ b/python/tvm/relax/frontend/common.py
@@ -17,8 +17,10 @@
 # pylint: disable=invalid-name
 """Commons for Relax frontend."""
 from typing import Dict, List, Tuple
+import numpy as _np
 
 import tvm
+from tvm import topi
 
 
 def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, 
List[tvm.nd.NDArray]]]:
@@ -53,3 +55,75 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, 
Dict[str, List[tvm.n
         else:
             detached_mod[gv] = func
     return detached_mod, params_dict
+
+
+def autopad(
+    bb,
+    data,
+    strides,
+    kernel_shape,
+    dilations=(1, 1),
+    pad_type="constant",
+    deconv=False,
+    mode="SAME_UPPER",
+    pad_value=0.0,
+):
+    """
+    Perform autopadding with dynamic input shapes
+    """
+    # get attributes as constants
+    strides = _np.array(strides)
+    dilated_kernel_shape = _np.array(
+        [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, 
dilations)]
+    )
+    # get input shape
+    ndim = data.struct_info.ndim
+    data_shape = list(data.struct_info.shape)
+    shape = data_shape[2:ndim]
+
+    # set up integer constants
+    zero = 0
+    one = 1
+    two = 2
+
+    # Calculate total padding
+    mod = shape % strides
+
+    left = _np.maximum(dilated_kernel_shape - strides, zero)
+    right = _np.maximum(dilated_kernel_shape - mod, zero)
+
+    total_pad = _np.where(_np.equal(mod, zero), left, right)
+    if deconv:
+        total_pad = _np.array(kernel_shape) - one - total_pad
+
+    # split total padding into before and after
+    pad_before = _np.floor_divide(total_pad, two)
+    pad_after = total_pad - pad_before
+
+    # combine
+    if "LOWER" in mode:
+        pad = _np.concatenate(
+            [_np.reshape(pad_after, [-1, 1]), _np.reshape(pad_before, [-1, 
1])], axis=1
+        )
+    else:
+        pad = _np.concatenate(
+            [_np.reshape(pad_before, [-1, 1]), _np.reshape(pad_after, [-1, 
1])], axis=1
+        )
+
+    # pad N and C with zeros
+    pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0)
+
+    if pad_type not in ["constant", "edge", "reflect"]:
+        raise tvm.error.OpAttributeInvalid(
+            "Value " + pad_type + ' in attribute "mode" is invalid for 
operator Pad.'
+        )
+
+    if pad_type == "constant":
+        return bb.emit_te(topi.nn.pad, data, pad[:, 0].tolist(), pad[:, 
1].tolist(), pad_value)
+    elif pad_type == "reflect":
+        return bb.emit_te(
+            topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), 
"REFLECT"
+        )
+    else:
+        # TODO(gigiblender) Support edge mode.
+        raise NotImplementedError("Pad mode {} not 
implemented".format(pad_type))
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index b64e87822a..9e0f5a060c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -49,6 +49,8 @@ from tvm.ir import IRModule
 from tvm.ir.supply import NameSupply
 from tvm.tir.generic import cast
 
+from ..common import autopad
+
 
 def get_type(elem_type: Union[str, int]) -> str:
     """Converts onnx integer datatype to numpy datatype"""
@@ -1208,11 +1210,15 @@ class Conv(OnnxOpConverter):
 
     @classmethod
     def _impl_v11(cls, bb, inputs, attr, params):
+        data = inputs[0]
         if hasattr(inputs[0].struct_info, "ndim"):
             ndim = inputs[0].struct_info.ndim
         else:
             ndim = len(inputs[0].struct_info.shape)
 
+        if "kernel_shape" not in attr:
+            attr["kernel_shape"] = inputs[1].struct_info.shape.values[2:]
+
         if ndim == 3:
             op = relax.op.nn.conv1d
             data_layout = "NCW"
@@ -1228,9 +1234,33 @@ class Conv(OnnxOpConverter):
         else:
             raise NotImplementedError("Ndim > 5 not supported for 
convolution.")
 
+        if "auto_pad" in attr:
+            attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
+            if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
+                data = autopad(
+                    bb,
+                    inputs[0],
+                    attr.get("strides", [1] * (ndim - 2)),
+                    attr["kernel_shape"],
+                    attr.get("dilations", [1] * (ndim - 2)),
+                    mode=attr["auto_pad"],
+                    deconv=False,
+                )
+            elif attr["auto_pad"] == "VALID":
+                attr["pads"] = [0 for _ in range(ndim - 2)]
+            elif attr["auto_pad"] == "NOTSET":
+                pass
+            else:
+                msg = (
+                    f'Value {attr["auto_pad"]} in attribute "auto_pad" of 
operator Conv '
+                    f"is invalid."
+                )
+                raise tvm.error.OpAttributeInvalid(msg)
+            attr.pop("auto_pad")
+
         conv_out = bb.normalize(
             op(
-                data=inputs[0],
+                data=data,
                 weight=inputs[1],
                 strides=attr.get("strides", 1),
                 padding=attr.get("pads", 0),
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 89f08e5af9..4cd4704ac0 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -980,23 +980,57 @@ def test_shrink():
 @pytest.mark.parametrize("dilation", [1, 2])
 @pytest.mark.parametrize("bias", [True, False])
 @pytest.mark.parametrize("pad", [0, 2])
-def test_conv(stride: int, dilation: int, pad: int, bias: bool):
[email protected]("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"])
+def test_conv(stride: int, dilation: int, pad: int, bias: bool, auto_pad: str):
     def _verify_conv(input_shape, weight_shape):
         nd = len(weight_shape) - 2
-        output_shape = [input_shape[0], weight_shape[0]] + [
-            (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) 
// stride + 1
-            for i in range(2, len(input_shape))
-        ]
-        bias_shape = [output_shape[1]]
-        conv_node = helper.make_node(
-            "Conv",
-            inputs=["x", "w"] + (["b"] if bias else []),
-            outputs=["y"],
-            strides=[stride] * nd,
-            dilations=[dilation] * nd,
-            pads=[pad] * nd * 2,
-            group=input_shape[1] // weight_shape[1],
-        )
+        if auto_pad == "VALID":
+            output_shape = [input_shape[0], weight_shape[0]] + [
+                (input_shape[i] - dilation * (weight_shape[i] - 1) - 1) // 
stride + 1
+                for i in range(2, len(input_shape))
+            ]
+            bias_shape = [output_shape[1]]
+            conv_node = helper.make_node(
+                "Conv",
+                inputs=["x", "w"] + (["b"] if bias else []),
+                outputs=["y"],
+                strides=[stride] * nd,
+                dilations=[dilation] * nd,
+                auto_pad=auto_pad,
+                group=input_shape[1] // weight_shape[1],
+            )
+        elif auto_pad in ("SAME_UPPER", "SAME_LOWER"):
+            if dilation == 2:
+                # auto_pad = "SAME" and dilation = 2 is not supported in ONNX
+                return
+            output_shape = [input_shape[0], weight_shape[0]] + [
+                (input_shape[i] + stride - 1) // stride for i in range(2, 
len(input_shape))
+            ]
+            bias_shape = [output_shape[1]]
+            conv_node = helper.make_node(
+                "Conv",
+                inputs=["x", "w"] + (["b"] if bias else []),
+                outputs=["y"],
+                strides=[stride] * nd,
+                dilations=[dilation] * nd,
+                auto_pad=auto_pad,
+                group=input_shape[1] // weight_shape[1],
+            )
+        else:
+            output_shape = [input_shape[0], weight_shape[0]] + [
+                (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 
1) // stride + 1
+                for i in range(2, len(input_shape))
+            ]
+            bias_shape = [output_shape[1]]
+            conv_node = helper.make_node(
+                "Conv",
+                inputs=["x", "w"] + (["b"] if bias else []),
+                outputs=["y"],
+                strides=[stride] * nd,
+                dilations=[dilation] * nd,
+                pads=[pad] * nd * 2,
+                group=input_shape[1] // weight_shape[1],
+            )
         graph = helper.make_graph(
             [conv_node],
             "conv_test",

Reply via email to