This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2b4b58d217 [Relax][Onnx][Resize] Handle non-4D input tensors (#18666)
2b4b58d217 is described below
commit 2b4b58d217c9615908d4e206f6c9e0152103f1e0
Author: YinHanke <[email protected]>
AuthorDate: Fri Jan 16 23:51:05 2026 +0800
[Relax][Onnx][Resize] Handle non-4D input tensors (#18666)
### Motivation
The ONNX Resize operator supports resizing N-D tensors per the ONNX
specification.
However, the current Relax ONNX frontend only supports 4D inputs and
raises an
assertion error for valid non-4D models.
This PR extends the Relax ONNX Resize converter beyond the 4D-only
restriction,
aligning behavior with the ONNX specification and ONNX Runtime for
supported ranks.
### Changes
- Remove the 4D-only assertion in the Relax ONNX Resize converter
- Preserve the existing 4D resize path without behavior changes
- Support non-4D Resize by lowering to existing resize implementations
for supported ranks
- Ensure Resize attributes are handled correctly for non-4D cases
### Testing
- Verified outputs against ONNX Runtime
- Added ONNX Resize tests covering non-4D input cases
(`test_resize_nd_sizes`)
Fixes: [[Bug] Resize N-D import failure: TVM only supports 4D resize2d,
but ONNX Resize supports N-D tensors
https://github.com/apache/tvm/issues/18608](https://github.com/apache/tvm/issues/18608)
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 66 +++++++++++++++++++------
tests/python/relax/test_frontend_onnx.py | 35 +++++++++++++
2 files changed, 85 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 1479d6f239..6e8c43f671 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2235,7 +2235,10 @@ class Resize(OnnxOpConverter):
# Adapt attributes to fit TVM definition.
if mode == "nearest":
- mode = "nearest_neighbor"
+ relax_mode = "nearest_neighbor"
+ else:
+ relax_mode = mode
+ topi_mode = relax_mode
# Unpack inputs.
x = inputs[0]
@@ -2243,7 +2246,7 @@ class Resize(OnnxOpConverter):
scales = get_constant(inputs[2], params)
sizes = get_constant(inputs[3], params)
ndims = len(x.struct_info.shape)
- assert ndims == 4, "Only resize2d is currently supported."
+ assert ndims in (3, 4, 5), "Only resize1d/resize2d/resize3d are
supported."
assert (
scales is None or sizes is None
@@ -2253,6 +2256,8 @@ class Resize(OnnxOpConverter):
if roi is not None:
if isinstance(roi, relax.Constant):
roi = roi.data.numpy().tolist()
+ if len(roi) == 2 * ndims:
+ roi = roi[2:ndims] + roi[ndims + 2 : 2 * ndims]
else:
roi = relax.op.concat(
[
@@ -2262,9 +2267,9 @@ class Resize(OnnxOpConverter):
axis=0,
)
# TODO The backend C++ func resize2d does not support dynamic
ROI for now.
- raise NotImplementedError("Dynamic ROI is not supported in
resize2d for now.")
+ raise NotImplementedError("Dynamic ROI is not supported in
resize for now.")
else:
- roi = [0.0] * 4
+ roi = [0.0] * (2 * (ndims - 2))
# Convert scales to sizes if needed.
if scales is not None:
@@ -2287,18 +2292,47 @@ class Resize(OnnxOpConverter):
else:
assert f"Type {type(size)} for size is currently unsupported."
- return relax.op.image.resize2d(
- x,
- size=relax.ShapeExpr(sizes),
- roi=roi,
- layout="NCHW",
- method=mode,
- coordinate_transformation_mode=coord_mode,
- rounding_method=rounding_method,
- cubic_alpha=cubic_coeff_a,
- cubic_exclude=exclude_outside,
- extrapolation_value=extrapolation_value,
- )
+ if ndims == 3:
+ return bb.emit_te(
+ topi.image.resize1d,
+ x,
+ roi,
+ sizes,
+ "NCW",
+ topi_mode,
+ coord_mode,
+ rounding_method,
+ cubic_coeff_a,
+ exclude_outside,
+ extrapolation_value,
+ )
+ elif ndims == 4:
+ return relax.op.image.resize2d(
+ x,
+ size=relax.ShapeExpr(sizes),
+ roi=roi,
+ layout="NCHW",
+ method=relax_mode,
+ coordinate_transformation_mode=coord_mode,
+ rounding_method=rounding_method,
+ cubic_alpha=cubic_coeff_a,
+ cubic_exclude=exclude_outside,
+ extrapolation_value=extrapolation_value,
+ )
+ else: # ndims == 5
+ return bb.emit_te(
+ topi.image.resize3d,
+ x,
+ roi,
+ sizes,
+ "NCDHW",
+ topi_mode,
+ coord_mode,
+ rounding_method,
+ cubic_coeff_a,
+ exclude_outside,
+ extrapolation_value,
+ )
class Einsum(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index f967b3c4c6..6f5c7da5ef 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2681,6 +2681,41 @@ def test_resize(with_roi, roi_list):
check_correctness(model)
+def test_resize_nd_sizes():
+ cases = [
+ ("resize1d", [1, 1, 4], [1, 1, 7]),
+ ("resize2d", [1, 1, 4, 5], [1, 1, 6, 7]),
+ ("resize3d", [1, 1, 3, 4, 5], [1, 1, 4, 6, 7]),
+ ]
+
+ for name, input_shape, sizes in cases:
+ resize_node = helper.make_node(
+ "Resize",
+ ["X", "", "", "sizes"],
+ ["Y"],
+ mode="nearest",
+ coordinate_transformation_mode="asymmetric",
+ nearest_mode="floor",
+ )
+
+ graph = helper.make_graph(
+ [resize_node],
+ name,
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT,
input_shape),
+ ],
+ initializer=[
+ helper.make_tensor("sizes", TensorProto.INT64, [len(sizes)],
sizes),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, sizes),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name=name)
+ check_correctness(model, opset=18)
+
+
def test_einsum():
eqn = "ij->i"
einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn)