This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c79caf0c00 [Relax][ONNX] Complete ShapeExpr reshape handling in ONNX 
frontend (#18956)
c79caf0c00 is described below

commit c79caf0c0040fcdea8850214b84fbae4e26b543a
Author: YinHanke <[email protected]>
AuthorDate: Tue Mar 31 01:13:58 2026 +0800

    [Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend (#18956)
    
    ## Summary
    
    Complete `Reshape` handling for shape values in the Relax ONNX frontend.
    
    ## Changes
    
    - keep `ShapeExpr -> Reshape([-1])` on the shape-specialized path
    - materialize `ShapeExpr` to an `int64` tensor for other reshape targets
    and apply regular tensor reshape semantics
    - add frontend coverage for `Shape -> Reshape([-1])`
    - add frontend coverage for reshaping shape outputs to non-`[-1]`
    targets such as `[1, 3]` and `[3, 1]`
    - extend symbolic shape deduction coverage to include the common `Shape
    -> Reshape([-1]) -> Gather -> Unsqueeze` shape-construction pattern
    
    ## Validation
    
    - `pytest -k 'test_symbolic_shape_deduction or test_reshape_shape_output
    or test_reshape'`
    
    This PR completes the `Reshape` limitation in the Relax ONNX frontend
    operator work tracked in #18945.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 15 ++++---
 tests/python/relax/test_frontend_onnx.py        | 57 +++++++++++++++++++++++--
 2 files changed, 63 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4af7115e5c..fbbcd68bc5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1016,11 +1016,16 @@ class Reshape(OnnxOpConverter):
         data = inputs[0]
         new_shape = get_constant(inputs[1], params)
 
-        if isinstance(data, relax.ShapeExpr) and isinstance(new_shape, 
relax.Constant):
-            new_shape = new_shape.data.numpy().tolist()
-            if new_shape != [-1]:
-                raise NotImplementedError("Need to fix this case")
-            return data
+        if isinstance(data, relax.ShapeExpr):
+            # Preserve identity flatten for shape values to keep 
shape-specialized
+            # handling in downstream shape-construction patterns.
+            if isinstance(new_shape, relax.Constant):
+                new_shape_values = new_shape.data.numpy().tolist()
+                if new_shape_values == [-1]:
+                    return data
+
+            # Other reshape targets follow regular int64 tensor reshape 
semantics.
+            data = bb.normalize(relax.op.shape_to_tensor(data))
 
         if isinstance(data, relax.Constant) and isinstance(new_shape, 
relax.Constant):
             out = _np.reshape(data.data.numpy(), 
new_shape.data.numpy().tolist())
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index c848ef91d6..b68110425a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -971,6 +971,38 @@ def test_reshape(in_shape, shape, out_shape):
     check_correctness(model, inputs=input_values)
 
 
[email protected](
+    "target_shape, output_shape",
+    [
+        ([-1], [3]),
+        ([1, 3], [1, 3]),
+        ([3, 1], [3, 1]),
+    ],
+)
+def test_reshape_shape_output(target_shape, output_shape):
+    shape_node = helper.make_node("Shape", ["data"], ["shape_out"])
+    reshape_node = helper.make_node("Reshape", ["shape_out", "target_shape"], 
["reshaped"])
+
+    data_shape = [2, 3, 4]
+
+    graph = helper.make_graph(
+        [shape_node, reshape_node],
+        "reshape_shape_output",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
data_shape),
+        ],
+        initializer=[
+            helper.make_tensor("target_shape", TensorProto.INT64, 
[len(target_shape)], target_shape)
+        ],
+        outputs=[helper.make_tensor_value_info("reshaped", TensorProto.INT64, 
output_shape)],
+    )
+    input_values = {
+        "data": np.random.randn(*data_shape).astype("float32"),
+    }
+    model = helper.make_model(graph, producer_name="reshape_shape_output")
+    check_correctness(model, inputs=input_values)
+
+
 def test_transpose():
     verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
 
@@ -3630,7 +3662,8 @@ def test_optional_get_element_empty_raises():
         from_onnx(model, opset=18, keep_params_in_input=True)
 
 
-def test_symbolic_shape_deduction():
[email protected]("with_reshape_flatten", [False, True])
+def test_symbolic_shape_deduction(with_reshape_flatten):
     index_node = helper.make_node(
         "Constant",
         inputs=[],
@@ -3638,7 +3671,17 @@ def test_symbolic_shape_deduction():
         value=helper.make_tensor("indices", TensorProto.INT64, [], [0]),
     )
     shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
-    gather_node = helper.make_node("Gather", ["shape_output", "indices"], 
["gather_output"])
+    nodes = [index_node, shape_node]
+    gather_input = "shape_output"
+
+    if with_reshape_flatten:
+        reshape_node = helper.make_node(
+            "Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
+        )
+        nodes.append(reshape_node)
+        gather_input = "reshaped_shape"
+
+    gather_node = helper.make_node("Gather", [gather_input, "indices"], 
["gather_output"])
     unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], 
["unsqueeze_output"])
     constant_of_shape_node = helper.make_node(
         "ConstantOfShape",
@@ -3646,13 +3689,19 @@ def test_symbolic_shape_deduction():
         ["output"],
         value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
     )
+    nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
+
+    initializers = [helper.make_tensor("axes", TensorProto.INT64, [1], 
vals=[0])]
+    if with_reshape_flatten:
+        initializers.append(helper.make_tensor("target_shape", 
TensorProto.INT64, [1], vals=[-1]))
+
     graph = helper.make_graph(
-        [index_node, shape_node, gather_node, unsqueeze_node, 
constant_of_shape_node],
+        nodes,
         "test_shape_deduction",
         inputs=[
             helper.make_tensor_value_info("data", TensorProto.FLOAT, ["batch", 
"seq"]),
         ],
-        initializer=[helper.make_tensor("axes", TensorProto.INT64, [1], 
vals=[0])],
+        initializer=initializers,
         outputs=[helper.make_tensor_value_info("output", TensorProto.INT64, 
[1])],
     )
     model = helper.make_model(graph, producer_name="test_shape_deduction")

Reply via email to