This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c79caf0c00 [Relax][ONNX] Complete ShapeExpr reshape handling in ONNX
frontend (#18956)
c79caf0c00 is described below
commit c79caf0c0040fcdea8850214b84fbae4e26b543a
Author: YinHanke <[email protected]>
AuthorDate: Tue Mar 31 01:13:58 2026 +0800
[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend (#18956)
## Summary
Complete `Reshape` handling for shape values in the Relax ONNX frontend.
## Changes
- keep `ShapeExpr -> Reshape([-1])` on the shape-specialized path
- materialize `ShapeExpr` to an `int64` tensor for other reshape targets
and apply regular tensor reshape semantics
- add frontend coverage for `Shape -> Reshape([-1])`
- add frontend coverage for reshaping shape outputs to non-`[-1]`
targets such as `[1, 3]` and `[3, 1]`
- extend symbolic shape deduction coverage to include the common `Shape
-> Reshape([-1]) -> Gather -> Unsqueeze` shape-construction pattern
## Validation
- `pytest -k 'test_symbolic_shape_deduction or test_reshape_shape_output
or test_reshape'`
This PR completes the `Reshape` limitation in the Relax ONNX frontend
operator work tracked in #18945.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 15 ++++---
tests/python/relax/test_frontend_onnx.py | 57 +++++++++++++++++++++++--
2 files changed, 63 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4af7115e5c..fbbcd68bc5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1016,11 +1016,16 @@ class Reshape(OnnxOpConverter):
data = inputs[0]
new_shape = get_constant(inputs[1], params)
- if isinstance(data, relax.ShapeExpr) and isinstance(new_shape,
relax.Constant):
- new_shape = new_shape.data.numpy().tolist()
- if new_shape != [-1]:
- raise NotImplementedError("Need to fix this case")
- return data
+ if isinstance(data, relax.ShapeExpr):
+ # Preserve identity flatten for shape values to keep
shape-specialized
+ # handling in downstream shape-construction patterns.
+ if isinstance(new_shape, relax.Constant):
+ new_shape_values = new_shape.data.numpy().tolist()
+ if new_shape_values == [-1]:
+ return data
+
+ # Other reshape targets follow regular int64 tensor reshape
semantics.
+ data = bb.normalize(relax.op.shape_to_tensor(data))
if isinstance(data, relax.Constant) and isinstance(new_shape,
relax.Constant):
out = _np.reshape(data.data.numpy(),
new_shape.data.numpy().tolist())
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index c848ef91d6..b68110425a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -971,6 +971,38 @@ def test_reshape(in_shape, shape, out_shape):
check_correctness(model, inputs=input_values)
[email protected](
+ "target_shape, output_shape",
+ [
+ ([-1], [3]),
+ ([1, 3], [1, 3]),
+ ([3, 1], [3, 1]),
+ ],
+)
+def test_reshape_shape_output(target_shape, output_shape):
+ shape_node = helper.make_node("Shape", ["data"], ["shape_out"])
+ reshape_node = helper.make_node("Reshape", ["shape_out", "target_shape"],
["reshaped"])
+
+ data_shape = [2, 3, 4]
+
+ graph = helper.make_graph(
+ [shape_node, reshape_node],
+ "reshape_shape_output",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
data_shape),
+ ],
+ initializer=[
+ helper.make_tensor("target_shape", TensorProto.INT64,
[len(target_shape)], target_shape)
+ ],
+ outputs=[helper.make_tensor_value_info("reshaped", TensorProto.INT64,
output_shape)],
+ )
+ input_values = {
+ "data": np.random.randn(*data_shape).astype("float32"),
+ }
+ model = helper.make_model(graph, producer_name="reshape_shape_output")
+ check_correctness(model, inputs=input_values)
+
+
def test_transpose():
verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
@@ -3630,7 +3662,8 @@ def test_optional_get_element_empty_raises():
from_onnx(model, opset=18, keep_params_in_input=True)
-def test_symbolic_shape_deduction():
[email protected]("with_reshape_flatten", [False, True])
+def test_symbolic_shape_deduction(with_reshape_flatten):
index_node = helper.make_node(
"Constant",
inputs=[],
@@ -3638,7 +3671,17 @@ def test_symbolic_shape_deduction():
value=helper.make_tensor("indices", TensorProto.INT64, [], [0]),
)
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
- gather_node = helper.make_node("Gather", ["shape_output", "indices"],
["gather_output"])
+ nodes = [index_node, shape_node]
+ gather_input = "shape_output"
+
+ if with_reshape_flatten:
+ reshape_node = helper.make_node(
+ "Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
+ )
+ nodes.append(reshape_node)
+ gather_input = "reshaped_shape"
+
+ gather_node = helper.make_node("Gather", [gather_input, "indices"],
["gather_output"])
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"],
["unsqueeze_output"])
constant_of_shape_node = helper.make_node(
"ConstantOfShape",
@@ -3646,13 +3689,19 @@ def test_symbolic_shape_deduction():
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
)
+ nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
+
+ initializers = [helper.make_tensor("axes", TensorProto.INT64, [1],
vals=[0])]
+ if with_reshape_flatten:
+ initializers.append(helper.make_tensor("target_shape",
TensorProto.INT64, [1], vals=[-1]))
+
graph = helper.make_graph(
- [index_node, shape_node, gather_node, unsqueeze_node,
constant_of_shape_node],
+ nodes,
"test_shape_deduction",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, ["batch",
"seq"]),
],
- initializer=[helper.make_tensor("axes", TensorProto.INT64, [1],
vals=[0])],
+ initializer=initializers,
outputs=[helper.make_tensor_value_info("output", TensorProto.INT64,
[1])],
)
model = helper.make_model(graph, producer_name="test_shape_deduction")