This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new de93c379b1 [FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in 
Slice, Squeeze and Flatten (#17490)
de93c379b1 is described below

commit de93c379b199baf45fd80cc465141544efbe7303
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Thu Oct 31 02:14:56 2024 +0100

    [FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in Slice, Squeeze and 
Flatten (#17490)
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  59 ++++--
 tests/python/relax/test_frontend_onnx.py        | 229 +++++++++++++++++++++---
 2 files changed, 256 insertions(+), 32 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 6c9225070d..611f4348d5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1199,14 +1199,29 @@ class Squeeze(OnnxOpConverter):
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
+        data = inputs[0]
         axis = get_constant(inputs[1], params)
         if isinstance(axis, relax.Constant):
-            axis = [int(x) for x in axis.data.numpy()]
+            axis = tuple([int(x) for x in axis.data.numpy()])
+
         # If data is constant, perform computation directly.
-        if isinstance(inputs[0], relax.Constant):
-            out_data = _np.squeeze(inputs[0].data.numpy(), axis)
-            return relax.const(out_data, inputs[0].struct_info.dtype)
-        return relax.op.squeeze(inputs[0], axis)
+        if isinstance(data, relax.Constant):
+            if isinstance(axis, (tuple, type(None))):
+                out_data = _np.squeeze(data.data.numpy(), axis)
+            else:
+                raise NotImplementedError("Squeeze with symbolic axes not 
supported")
+
+            return relax.const(out_data, data.struct_info.dtype)
+
+        if isinstance(data, relax.ShapeExpr):
+            if axis == (0,):
+                return relax.PrimValue(data[0])
+            else:
+                raise NotImplementedError(
+                    "Squeeze with symbolic axes and non-zero axes is not 
supported."
+                )
+
+        return relax.op.squeeze(data, axis)
 
 
 class Constant(OnnxOpConverter):
@@ -1559,12 +1574,12 @@ class Split(OnnxOpConverter):
             splits_rank = splits.checked_type.ndim
         if splits is not None and splits_rank > 0:
             if isinstance(splits, relax.Constant):
-                splits = splits.data.asnumpy()
+                splits = splits.data.numpy()
                 indices = []
                 index = 0
                 for i in splits[:-1]:
                     index += i
-                    indices.append(index)
+                    indices.append(index.item())
             else:
                 raise ValueError("Dynamic Split not yet supported")
         # When splits isnt specified divide evenly over axis.
@@ -1611,11 +1626,16 @@ class Slice(OnnxOpConverter):
             steps = [1] * len(axes)
         # If input is a shape tensor, we can directly extract it.
         if isinstance(data, relax.ShapeExpr):
-            shape_data = [dim.value for dim in data]
+            shape_data = list(data)
             # Starts, ends, and steps must be 1-d for shape operation.
             assert all(len(i) == 1 for i in [starts, ends, steps])
             sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
-            return relax.const(sliced_values, "int64")
+
+            if all([isinstance(val, (tir.IntImm, int)) for val in 
sliced_values]):
+                return relax.const([x.value for x in sliced_values], "int64")
+            else:
+                return relax.ShapeExpr(sliced_values)
+
         # If all `starts`, `ends`, and `steps` are constant, use strict mode
         # Otherwise, we assume the slice is inbound.
         assume_inbound = not all(
@@ -2237,8 +2257,24 @@ class Flatten(OnnxOpConverter):
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
         axis = attr.get("axis", 1)
-        data_shape = [i.value for i in inputs[0].struct_info.shape]
-        new_shape = (1, -1) if axis == 0 else 
(_np.prod(data_shape[0:axis]).astype("int64"), -1)
+        data_shape = list(inputs[0].struct_info.shape)
+
+        if axis == 0:
+            new_shape = (1, -1)
+        else:
+            shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in 
data_shape[0:axis]]
+
+            if all(shape_flags):
+                data_shape = [x.value for x in data_shape[0:axis]]
+                new_shape = (_np.prod(data_shape).astype("int64"), -1)
+            else:
+                batch_size = 1
+
+                for el in data_shape[0:axis]:
+                    batch_size = batch_size * el
+
+                new_shape = (batch_size, -1)
+
         return relax.op.reshape(inputs[0], new_shape)
 
 
@@ -3220,6 +3256,7 @@ class ONNXGraphImporter:
                 "Equal",
                 "Where",
                 "Cast",
+                "Squeeze",
             ]
             return_tuple_ops = [
                 "SequenceConstruct",
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 46373510b1..9faa441138 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -53,27 +53,33 @@ def generate_random_inputs(
         for dim in i.type.tensor_type.shape.dim:
             shape.append(dim.dim_value)
 
-        # Extract datatype for the input.
-        if i.type.tensor_type.elem_type:
-            dtype = 
str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type])
-        else:
-            dtype = "float32"
-
-        # Generate random inputs for each input.
-        if dtype == "bool":
-            # random_value = np.random.choice(a=[False, True], size=shape)
-            random_value = rg.choice(a=[False, True], size=shape)
-        elif dtype.startswith("int"):
-            # Keep non-zero values
-            random_value = rg.integers(low=-63, high=63, 
size=shape).astype(dtype)
-            random_value[random_value <= 0] -= 1
-        else:
-            random_value = rg.standard_normal(size=shape).astype(dtype)
-        input_values[i.name] = random_value
+        input_values[i.name] = generate_random_value(shape, 
i.type.tensor_type.elem_type)
 
     return input_values
 
 
+def generate_random_value(shape, elem_type) -> np.ndarray:
+
+    # Extract datatype for the input.
+    if elem_type:
+        dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
+    else:
+        dtype = "float32"
+
+    # Generate random inputs for each input.
+    if dtype == "bool":
+        # random_value = np.random.choice(a=[False, True], size=shape)
+        random_value = rg.choice(a=[False, True], size=shape)
+    elif dtype.startswith("int"):
+        # Keep non-zero values
+        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
+        random_value[random_value <= 0] -= 1
+    else:
+        random_value = rg.standard_normal(size=shape).astype(dtype)
+
+    return random_value
+
+
 def check_correctness(
     model: ModelProto,
     inputs: Optional[Dict[str, np.ndarray]] = None,
@@ -156,12 +162,14 @@ def check_correctness(
         elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and 
isinstance(ort_out, np.ndarray):
             shape_out = tvm.nd.array([int(i) for i in tvm_out])
             tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+        elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, 
np.ndarray):
+            tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, 
atol=atol)
         else:
             raise ValueError(f"Unsupported types: {type(tvm_out)}, 
{type(ort_out)}")
 
     # Check that number of outputs match.
     assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
-    for (tvm_out, ort_out) in zip(tvm_output, ort_output):
+    for tvm_out, ort_out in zip(tvm_output, ort_output):
         # TODO Allow configurable tolerance.
         if ort_out is not None:
             _check_output(tvm_out, ort_out)
@@ -219,6 +227,31 @@ def verify_unary(
     check_correctness(model, opset=opset)
 
 
+def verify_unary_dynamic_shape(
+    op_name,
+    shape,
+    shape_instance,
+    attrs={},
+    domain=None,
+    input_dtype=TensorProto.FLOAT,
+    output_dtype=TensorProto.FLOAT,
+    opset=14,
+):
+    test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
+    graph = helper.make_graph(
+        [test_node],
+        "elemwise_test",
+        inputs=[
+            helper.make_tensor_value_info("x", input_dtype, shape),
+        ],
+        outputs=[helper.make_tensor_value_info("y", output_dtype, shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="elemwise_test")
+    inputs = {"x": generate_random_value(shape_instance, input_dtype)}
+    check_correctness(model, inputs, opset=opset)
+
+
 def verify_binary(
     op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, 
dtype=TensorProto.FLOAT, opset=14
 ):
@@ -1013,6 +1046,87 @@ def test_squeeze(axis):
     check_correctness(model, opset=13)
 
 
[email protected]("axis", [[0, 2], None])
+def test_squeeze_constant(axis):
+    shape = [1, 32, 1, 32]
+    constant = make_constant_node(
+        "x", onnx.TensorProto.FLOAT, shape, 
rg.standard_normal(size=shape).astype("float32")
+    )
+    if axis:
+        squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+    else:
+        squeeze_node = helper.make_node("Squeeze", ["x"], ["y"])
+
+    initializer = (
+        [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if 
axis else None
+    )
+
+    graph = helper.make_graph(
+        [constant, squeeze_node],
+        "squeeze_test",
+        inputs=[],
+        initializer=initializer,
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 
32])],
+    )
+
+    model = helper.make_model(graph, producer_name="squeeze_test")
+    check_correctness(model, opset=13)
+
+
[email protected]("axis", [[0]])
[email protected]("A", [8, 16, 32])
[email protected]("B", [8, 16, 32])
+def test_dynamic_squeeze(axis, A, B):
+
+    squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+    shape = [1, "A", "B"]
+
+    initializer = (
+        [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if 
axis else None
+    )
+
+    graph = helper.make_graph(
+        [squeeze_node],
+        "squeeze_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+        ],
+        initializer=initializer,
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["A", 
"B"])],
+    )
+
+    model = helper.make_model(graph, producer_name="squeeze_test")
+    inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")}
+    check_correctness(model, inputs, opset=13)
+
+
[email protected]("axis", [[0]])
[email protected]("A", [8, 16, 32])
+def test_dynamic_shape_squeeze(axis, A):
+
+    shape_node = helper.make_node("Shape", ["x"], ["y"])
+    squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"])
+    shape = ["A"]
+
+    initializer = (
+        [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if 
axis else None
+    )
+
+    graph = helper.make_graph(
+        [shape_node, squeeze_node],
+        "squeeze_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+        ],
+        initializer=initializer,
+        outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [])],
+    )
+
+    model = helper.make_model(graph, producer_name="squeeze_test")
+    inputs = {"x": rg.standard_normal(size=[A]).astype("float32")}
+    check_correctness(model, inputs, opset=13)
+
+
 def test_const():
     shape = [32, 32]
     const_node = helper.make_node(
@@ -1548,6 +1662,68 @@ def test_slice():
     # )
 
 
+def test_slice_dynamic_shape():
+    def verify_slice(
+        data_shape, data_instance_shape, output_shape, starts, ends, 
axes=None, steps=None
+    ):
+        if isinstance(starts, list):
+            starts = np.array(starts, "int64")
+        if isinstance(ends, list):
+            ends = np.array(ends, "int64")
+        if isinstance(axes, list):
+            axes = np.array(axes, "int64")
+        if isinstance(steps, list):
+            steps = np.array(steps, "int64")
+
+        slice_inputs = ["y", "starts", "ends"]
+        initializer = [
+            helper.make_tensor("starts", TensorProto.INT64, starts.shape, 
starts),
+            helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
+        ]
+
+        if axes is not None:
+            initializer.append(helper.make_tensor("axes", TensorProto.INT64, 
axes.shape, axes))
+            slice_inputs.append("axes")
+        if steps is not None:
+            initializer.append(helper.make_tensor("steps", TensorProto.INT64, 
steps.shape, steps))
+            slice_inputs.append("steps")
+
+        shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"])
+        slice_node = helper.make_node("Slice", inputs=slice_inputs, 
outputs=["z"])
+
+        graph = helper.make_graph(
+            [shape_node, slice_node],
+            "slice_test",
+            inputs=[
+                helper.make_tensor_value_info("x", TensorProto.FLOAT, 
data_shape),
+            ],
+            outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, 
output_shape)],
+            initializer=initializer,
+        )
+
+        model = helper.make_model(graph, producer_name="slice_test")
+        inputs = {"x": 
rg.standard_normal(size=data_instance_shape).astype("float32")}
+        check_correctness(model, inputs)
+
+    verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0])
+    verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], 
axes=[0])
+    verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[0], ends=[2], 
axes=[0])
+    verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[0], ends=[2], 
axes=[0])
+    verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[0], ends=[2], 
axes=[0])
+
+    verify_slice([20, 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0])
+    verify_slice(["A", 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], 
axes=[0])
+    verify_slice(["A", "B", 5], [20, 10, 5], [1], starts=[1], ends=[2], 
axes=[0])
+    verify_slice([20, 10, "C"], [20, 10, 5], [1], starts=[1], ends=[2], 
axes=[0])
+    verify_slice(["A", "B", "C"], [20, 10, 5], [1], starts=[1], ends=[2], 
axes=[0])
+
+    verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0])
+    verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], 
axes=[0])
+    verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[1], ends=[3], 
axes=[0])
+    verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[1], ends=[3], 
axes=[0])
+    verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[1], ends=[3], 
axes=[0])
+
+
 # TODO Enable dynamism
 @pytest.mark.parametrize("dynamic", [False])
 def test_attention(dynamic):
@@ -1795,12 +1971,15 @@ def test_split(fp_arith, dynamic):
             )
         ]
 
+        split_constant = None
         if pass_split:
             if opset >= 13:
                 np_split = np.array(split).astype(np.int64)
-                initializer.append(
-                    helper.make_tensor("split", TensorProto.INT64, 
list(np_split.shape), np_split)
+                split_constant = make_constant_node(
+                    "split", onnx.TensorProto.INT64, list(np_split.shape), 
np_split
                 )
+                input_names.append("split")
+
         node = helper.make_node(
             "Split",
             inputs=input_names,
@@ -1812,8 +1991,10 @@ def test_split(fp_arith, dynamic):
             split_attr = helper.make_attribute("split", split)
             node.attribute.append(split_attr)
 
+        nodes = [split_constant, node] if split_constant else [node]
+
         graph = helper.make_graph(
-            [node],
+            nodes,
             "split_test",
             inputs=inputs,
             initializer=initializer,
@@ -2226,6 +2407,12 @@ def test_flatten():
     verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})
 
 
+def test_flatten_dynamic():
+    verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], 
attrs={"axis": 0})
+    verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], 
attrs={"axis": -1})
+    verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], 
attrs={"axis": 2})
+
+
 def test_onehot():
     one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], 
["y"], axis=1)
     graph = helper.make_graph(

Reply via email to