This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0594994c7d [ONNX] Fix interpreting auto_pad parameters in 
ConvTranspose operator (#16001)
0594994c7d is described below

commit 0594994c7d064156612b353454c22118003c6650
Author: padreofthegame <[email protected]>
AuthorDate: Mon Apr 8 23:10:18 2024 +0200

    [ONNX] Fix interpreting auto_pad parameters in ConvTranspose operator 
(#16001)
    
    [ONNX] Fix in interpreting auto_pad parameters SAME_UPPER and SAME_LOWER in 
ConvTranspose operator
---
 python/tvm/relay/frontend/onnx.py          | 18 ++++++++++--
 tests/python/frontend/onnx/test_forward.py | 46 +++++++++++++++++++++++++++++-
 2 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 17329cfb15..a5e98b38b3 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -826,6 +826,15 @@ class Conv(OnnxOpConverter):
         return out
 
 
+def is_ort_version_greater_than(ver):
+    import onnxruntime as ort
+
+    v11, v12, v13 = tuple(int(v) for v in ort.__version__.split("."))
+    v21, v22, v23 = tuple(int(v) for v in ver.split("."))
+
+    return (v11 > v21) or (v11 == v21 and v12 > v22) or ((v11, v12) == (v21, 
v22) and v13 > v23)
+
+
 class ConvTranspose(OnnxOpConverter):
     """Operator converter for ConvTranspose."""
 
@@ -963,12 +972,15 @@ class ConvTranspose(OnnxOpConverter):
                         )
                 left = [p // 2 for p in total_pad]
                 right = [total_pad[i] - left[i] for i in range(kndim)]
+
                 if "output_shape" in attr and "auto_pad" not in attr:
                     pad = right + left
-                elif "LOWER" in attr["auto_pad"]:
-                    pad = left + right
-                else:
+                elif ("LOWER" in attr["auto_pad"] and 
is_ort_version_greater_than("1.12.1")) or (
+                    ("UPPER" in attr["auto_pad"] and not 
is_ort_version_greater_than("1.12.1"))
+                ):
                     pad = right + left
+                else:
+                    pad = left + right
                 attr["pads"] = pad
             elif attr["auto_pad"] == "VALID":
                 attr["pads"] = tuple([0 for i in range(ndim - 2)])
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 4bfa497034..7774c66233 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3404,6 +3404,36 @@ def test_convtranspose(target, dev):
                 auto_pad="SAME_LOWER",
             )
 
+            verify_convtranspose_with_output_shape(
+                (1, 1) + repeat(32, dims),
+                (1, 2) + repeat(4, dims),
+                repeat(num, dims),
+                repeat(4, dims),
+                repeat(2, dims),
+                repeat(1, dims),
+                auto_pad="SAME_UPPER",
+            )
+
+    verify_convtranspose_with_output_shape(
+        (1, 1, 3, 3),
+        (1, 2, 3, 3),
+        (6, 6),
+        (3, 3),
+        (2, 2),
+        (1, 1),
+        auto_pad="SAME_UPPER",
+    )
+
+    verify_convtranspose_with_output_shape(
+        (1, 1, 3, 3),
+        (1, 2, 3, 3),
+        (6, 6),
+        (3, 3),
+        (2, 2),
+        (1, 1),
+        auto_pad="SAME_LOWER",
+    )
+
 
 @tvm.testing.parametrize_targets
 def test_unsqueeze_constant(target, dev):
@@ -5634,7 +5664,6 @@ unsupported_onnx_tests = [
     "test_cast_DOUBLE_to_FLOAT16",
     "test_castlike_DOUBLE_to_FLOAT16",
     "test_castlike_DOUBLE_to_FLOAT16_expanded",
-    "test_convtranspose_autopad_same",
     "test_convtranspose_dilations",
     "test_cumsum_1d",
     "test_cumsum_1d_exclusive",
@@ -5766,6 +5795,15 @@ def _load_proto(proto_filename, target_list, 
model_type_proto):
             )
 
 
+def is_ort_version_lower_than(ver):
+    import onnxruntime as ort
+
+    v11, v12, v13 = tuple(int(v) for v in ort.__version__.split("."))
+    v21, v22, v23 = tuple(int(v) for v in ver.split("."))
+
+    return (v11 < v21) or (v11 == v21 and v12 < v22) or ((v11, v12) == (v21, 
v22) and v13 < v23)
+
+
 @pytest.mark.parametrize("onnx_test", onnx_test_folders)
 @tvm.testing.parametrize_targets
 def test_onnx_nodes(target, dev, onnx_test):
@@ -5782,6 +5820,12 @@ def test_onnx_nodes(target, dev, onnx_test):
     if onnx_test in target_specific_skips:
         pytest.skip(f"Onnx test '{onnx_test}' not yet supported by TVM on 
{target_kind} targets")
 
+    if is_ort_version_lower_than("1.13.1") and onnx_test == 
"test_convtranspose_autopad_same":
+        pytest.skip(
+            f"Onnx test '{onnx_test}' expected to fail for onnxruntime version 
lower than 1.13.1 "
+            "due to different interpretation of auto_pad parameters SAME_UPPER 
and SAME_LOWER."
+        )
+
     test_dir = os.path.join(onnx_test_node_dir, onnx_test)
 
     atol = 1e-5

Reply via email to