This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0594994c7d [ONNX] Fix interpreting auto_pad parameters in
ConvTranspose operator (#16001)
0594994c7d is described below
commit 0594994c7d064156612b353454c22118003c6650
Author: padreofthegame <[email protected]>
AuthorDate: Mon Apr 8 23:10:18 2024 +0200
[ONNX] Fix interpreting auto_pad parameters in ConvTranspose operator
(#16001)
[ONNX] Fix in interpreting auto_pad parameters SAME_UPPER and SAME_LOWER in
ConvTranspose operator
---
python/tvm/relay/frontend/onnx.py | 18 ++++++++++--
tests/python/frontend/onnx/test_forward.py | 46 +++++++++++++++++++++++++++++-
2 files changed, 60 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 17329cfb15..a5e98b38b3 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -826,6 +826,15 @@ class Conv(OnnxOpConverter):
return out
+def is_ort_version_greater_than(ver):
+ import onnxruntime as ort
+
+ v11, v12, v13 = tuple(int(v) for v in ort.__version__.split("."))
+ v21, v22, v23 = tuple(int(v) for v in ver.split("."))
+
+ return (v11 > v21) or (v11 == v21 and v12 > v22) or ((v11, v12) == (v21,
v22) and v13 > v23)
+
+
class ConvTranspose(OnnxOpConverter):
"""Operator converter for ConvTranspose."""
@@ -963,12 +972,15 @@ class ConvTranspose(OnnxOpConverter):
)
left = [p // 2 for p in total_pad]
right = [total_pad[i] - left[i] for i in range(kndim)]
+
if "output_shape" in attr and "auto_pad" not in attr:
pad = right + left
- elif "LOWER" in attr["auto_pad"]:
- pad = left + right
- else:
+ elif ("LOWER" in attr["auto_pad"] and
is_ort_version_greater_than("1.12.1")) or (
+ ("UPPER" in attr["auto_pad"] and not
is_ort_version_greater_than("1.12.1"))
+ ):
pad = right + left
+ else:
+ pad = left + right
attr["pads"] = pad
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 4bfa497034..7774c66233 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3404,6 +3404,36 @@ def test_convtranspose(target, dev):
auto_pad="SAME_LOWER",
)
+ verify_convtranspose_with_output_shape(
+ (1, 1) + repeat(32, dims),
+ (1, 2) + repeat(4, dims),
+ repeat(num, dims),
+ repeat(4, dims),
+ repeat(2, dims),
+ repeat(1, dims),
+ auto_pad="SAME_UPPER",
+ )
+
+ verify_convtranspose_with_output_shape(
+ (1, 1, 3, 3),
+ (1, 2, 3, 3),
+ (6, 6),
+ (3, 3),
+ (2, 2),
+ (1, 1),
+ auto_pad="SAME_UPPER",
+ )
+
+ verify_convtranspose_with_output_shape(
+ (1, 1, 3, 3),
+ (1, 2, 3, 3),
+ (6, 6),
+ (3, 3),
+ (2, 2),
+ (1, 1),
+ auto_pad="SAME_LOWER",
+ )
+
@tvm.testing.parametrize_targets
def test_unsqueeze_constant(target, dev):
@@ -5634,7 +5664,6 @@ unsupported_onnx_tests = [
"test_cast_DOUBLE_to_FLOAT16",
"test_castlike_DOUBLE_to_FLOAT16",
"test_castlike_DOUBLE_to_FLOAT16_expanded",
- "test_convtranspose_autopad_same",
"test_convtranspose_dilations",
"test_cumsum_1d",
"test_cumsum_1d_exclusive",
@@ -5766,6 +5795,15 @@ def _load_proto(proto_filename, target_list,
model_type_proto):
)
+def is_ort_version_lower_than(ver):
+ import onnxruntime as ort
+
+ v11, v12, v13 = tuple(int(v) for v in ort.__version__.split("."))
+ v21, v22, v23 = tuple(int(v) for v in ver.split("."))
+
+ return (v11 < v21) or (v11 == v21 and v12 < v22) or ((v11, v12) == (v21,
v22) and v13 < v23)
+
+
@pytest.mark.parametrize("onnx_test", onnx_test_folders)
@tvm.testing.parametrize_targets
def test_onnx_nodes(target, dev, onnx_test):
@@ -5782,6 +5820,12 @@ def test_onnx_nodes(target, dev, onnx_test):
if onnx_test in target_specific_skips:
pytest.skip(f"Onnx test '{onnx_test}' not yet supported by TVM on
{target_kind} targets")
+ if is_ort_version_lower_than("1.13.1") and onnx_test ==
"test_convtranspose_autopad_same":
+ pytest.skip(
+ f"Onnx test '{onnx_test}' expected to fail for onnxruntime version
lower than 1.13.1 "
+ "due to different interpretation of auto_pad parameters SAME_UPPER
and SAME_LOWER."
+ )
+
test_dir = os.path.join(onnx_test_node_dir, onnx_test)
atol = 1e-5