This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a493c5707c [ONNX][FRONTEND][Fix] Update Resize to accept ShapeExpr
(#18209)
a493c5707c is described below
commit a493c5707c99cc41e02bfa09245fe998d27e9c93
Author: Balint Cristian <[email protected]>
AuthorDate: Sat Aug 16 23:24:59 2025 +0300
[ONNX][FRONTEND][Fix] Update Resize to accept ShapeExpr (#18209)
[ONNX][FRONTEND] Update Resize to accept ShapeExpr
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 19 +++++++++++++------
1 file changed, 13 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index e16b109ab5..ee80436a8a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2143,18 +2143,24 @@ class Resize(OnnxOpConverter):
# Convert scales to sizes if needed.
if scales is not None:
- assert isinstance(scales, relax.Constant), "Only constant scales
currently supported."
- scales = scales.data.numpy()
+ if isinstance(scales, relax.Constant):
+ scales = scales.data.numpy()
+ elif isinstance(scales, relax.expr.ShapeExpr):
+ scales = [int(val.value) for val in scales.values]
+ else:
+ assert f"Type {type(scales)} for scale is currently
unsupported."
sizes = []
for i, dim in enumerate(x.struct_info.shape):
sizes.append(cast(scales[i] * dim, "int64"))
sizes = sizes[2:]
else:
- assert isinstance(
- sizes, relax.Constant
- ), "Only constant output size currently supported."
- sizes = sizes.data.numpy().astype("int64").tolist()[2:]
+ if isinstance(sizes, relax.Constant):
+ sizes = sizes.data.numpy().astype("int64").tolist()[2:]
+ elif isinstance(sizes, relax.expr.ShapeExpr):
+ sizes = [int(val.value) for val in sizes.values][2:]
+ else:
+ assert f"Type {type(size)} for size is currently unsupported."
return relax.op.image.resize2d(
x,
@@ -3751,6 +3757,7 @@ class ONNXGraphImporter:
# convert it to a tensor.
shape_compatible_ops = [
"Reshape",
+ "Resize",
"ConstantOfShape",
"Gather",
"Slice",