This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new de93c379b1 [FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in
Slice, Squeeze and Flatten (#17490)
de93c379b1 is described below
commit de93c379b199baf45fd80cc465141544efbe7303
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Thu Oct 31 02:14:56 2024 +0100
[FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in Slice, Squeeze and
Flatten (#17490)
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 59 ++++--
tests/python/relax/test_frontend_onnx.py | 229 +++++++++++++++++++++---
2 files changed, 256 insertions(+), 32 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 6c9225070d..611f4348d5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1199,14 +1199,29 @@ class Squeeze(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
axis = get_constant(inputs[1], params)
if isinstance(axis, relax.Constant):
- axis = [int(x) for x in axis.data.numpy()]
+ axis = tuple([int(x) for x in axis.data.numpy()])
+
# If data is constant, perform computation directly.
- if isinstance(inputs[0], relax.Constant):
- out_data = _np.squeeze(inputs[0].data.numpy(), axis)
- return relax.const(out_data, inputs[0].struct_info.dtype)
- return relax.op.squeeze(inputs[0], axis)
+ if isinstance(data, relax.Constant):
+ if isinstance(axis, (tuple, type(None))):
+ out_data = _np.squeeze(data.data.numpy(), axis)
+ else:
+ raise NotImplementedError("Squeeze with symbolic axes not
supported")
+
+ return relax.const(out_data, data.struct_info.dtype)
+
+ if isinstance(data, relax.ShapeExpr):
+ if axis == (0,):
+ return relax.PrimValue(data[0])
+ else:
+ raise NotImplementedError(
+ "Squeeze with symbolic axes and non-zero axes is not
supported."
+ )
+
+ return relax.op.squeeze(data, axis)
class Constant(OnnxOpConverter):
@@ -1559,12 +1574,12 @@ class Split(OnnxOpConverter):
splits_rank = splits.checked_type.ndim
if splits is not None and splits_rank > 0:
if isinstance(splits, relax.Constant):
- splits = splits.data.asnumpy()
+ splits = splits.data.numpy()
indices = []
index = 0
for i in splits[:-1]:
index += i
- indices.append(index)
+ indices.append(index.item())
else:
raise ValueError("Dynamic Split not yet supported")
# When splits isnt specified divide evenly over axis.
@@ -1611,11 +1626,16 @@ class Slice(OnnxOpConverter):
steps = [1] * len(axes)
# If input is a shape tensor, we can directly extract it.
if isinstance(data, relax.ShapeExpr):
- shape_data = [dim.value for dim in data]
+ shape_data = list(data)
# Starts, ends, and steps must be 1-d for shape operation.
assert all(len(i) == 1 for i in [starts, ends, steps])
sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
- return relax.const(sliced_values, "int64")
+
+ if all([isinstance(val, (tir.IntImm, int)) for val in
sliced_values]):
+ return relax.const([x.value for x in sliced_values], "int64")
+ else:
+ return relax.ShapeExpr(sliced_values)
+
# If all `starts`, `ends`, and `steps` are constant, use strict mode
# Otherwise, we assume the slice is inbound.
assume_inbound = not all(
@@ -2237,8 +2257,24 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 1)
- data_shape = [i.value for i in inputs[0].struct_info.shape]
- new_shape = (1, -1) if axis == 0 else
(_np.prod(data_shape[0:axis]).astype("int64"), -1)
+ data_shape = list(inputs[0].struct_info.shape)
+
+ if axis == 0:
+ new_shape = (1, -1)
+ else:
+ shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in
data_shape[0:axis]]
+
+ if all(shape_flags):
+ data_shape = [x.value for x in data_shape[0:axis]]
+ new_shape = (_np.prod(data_shape).astype("int64"), -1)
+ else:
+ batch_size = 1
+
+ for el in data_shape[0:axis]:
+ batch_size = batch_size * el
+
+ new_shape = (batch_size, -1)
+
return relax.op.reshape(inputs[0], new_shape)
@@ -3220,6 +3256,7 @@ class ONNXGraphImporter:
"Equal",
"Where",
"Cast",
+ "Squeeze",
]
return_tuple_ops = [
"SequenceConstruct",
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 46373510b1..9faa441138 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -53,27 +53,33 @@ def generate_random_inputs(
for dim in i.type.tensor_type.shape.dim:
shape.append(dim.dim_value)
- # Extract datatype for the input.
- if i.type.tensor_type.elem_type:
- dtype =
str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type])
- else:
- dtype = "float32"
-
- # Generate random inputs for each input.
- if dtype == "bool":
- # random_value = np.random.choice(a=[False, True], size=shape)
- random_value = rg.choice(a=[False, True], size=shape)
- elif dtype.startswith("int"):
- # Keep non-zero values
- random_value = rg.integers(low=-63, high=63,
size=shape).astype(dtype)
- random_value[random_value <= 0] -= 1
- else:
- random_value = rg.standard_normal(size=shape).astype(dtype)
- input_values[i.name] = random_value
+ input_values[i.name] = generate_random_value(shape,
i.type.tensor_type.elem_type)
return input_values
+def generate_random_value(shape, elem_type) -> np.ndarray:
+
+ # Extract datatype for the input.
+ if elem_type:
+ dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
+ else:
+ dtype = "float32"
+
+ # Generate random inputs for each input.
+ if dtype == "bool":
+ # random_value = np.random.choice(a=[False, True], size=shape)
+ random_value = rg.choice(a=[False, True], size=shape)
+ elif dtype.startswith("int"):
+ # Keep non-zero values
+ random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
+ random_value[random_value <= 0] -= 1
+ else:
+ random_value = rg.standard_normal(size=shape).astype(dtype)
+
+ return random_value
+
+
def check_correctness(
model: ModelProto,
inputs: Optional[Dict[str, np.ndarray]] = None,
@@ -156,12 +162,14 @@ def check_correctness(
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and
isinstance(ort_out, np.ndarray):
shape_out = tvm.nd.array([int(i) for i in tvm_out])
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol,
atol=atol)
+ elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out,
np.ndarray):
+ tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol,
atol=atol)
else:
raise ValueError(f"Unsupported types: {type(tvm_out)},
{type(ort_out)}")
# Check that number of outputs match.
assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
- for (tvm_out, ort_out) in zip(tvm_output, ort_output):
+ for tvm_out, ort_out in zip(tvm_output, ort_output):
# TODO Allow configurable tolerance.
if ort_out is not None:
_check_output(tvm_out, ort_out)
@@ -219,6 +227,31 @@ def verify_unary(
check_correctness(model, opset=opset)
+def verify_unary_dynamic_shape(
+ op_name,
+ shape,
+ shape_instance,
+ attrs={},
+ domain=None,
+ input_dtype=TensorProto.FLOAT,
+ output_dtype=TensorProto.FLOAT,
+ opset=14,
+):
+ test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
+ graph = helper.make_graph(
+ [test_node],
+ "elemwise_test",
+ inputs=[
+ helper.make_tensor_value_info("x", input_dtype, shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", output_dtype, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="elemwise_test")
+ inputs = {"x": generate_random_value(shape_instance, input_dtype)}
+ check_correctness(model, inputs, opset=opset)
+
+
def verify_binary(
op_name, shape_a, shape_b, shape_c, attrs={}, domain=None,
dtype=TensorProto.FLOAT, opset=14
):
@@ -1013,6 +1046,87 @@ def test_squeeze(axis):
check_correctness(model, opset=13)
[email protected]("axis", [[0, 2], None])
+def test_squeeze_constant(axis):
+ shape = [1, 32, 1, 32]
+ constant = make_constant_node(
+ "x", onnx.TensorProto.FLOAT, shape,
rg.standard_normal(size=shape).astype("float32")
+ )
+ if axis:
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ else:
+ squeeze_node = helper.make_node("Squeeze", ["x"], ["y"])
+
+ initializer = (
+ [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if
axis else None
+ )
+
+ graph = helper.make_graph(
+ [constant, squeeze_node],
+ "squeeze_test",
+ inputs=[],
+ initializer=initializer,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph, producer_name="squeeze_test")
+ check_correctness(model, opset=13)
+
+
[email protected]("axis", [[0]])
[email protected]("A", [8, 16, 32])
[email protected]("B", [8, 16, 32])
+def test_dynamic_squeeze(axis, A, B):
+
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ shape = [1, "A", "B"]
+
+ initializer = (
+ [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if
axis else None
+ )
+
+ graph = helper.make_graph(
+ [squeeze_node],
+ "squeeze_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ ],
+ initializer=initializer,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["A",
"B"])],
+ )
+
+ model = helper.make_model(graph, producer_name="squeeze_test")
+ inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")}
+ check_correctness(model, inputs, opset=13)
+
+
[email protected]("axis", [[0]])
[email protected]("A", [8, 16, 32])
+def test_dynamic_shape_squeeze(axis, A):
+
+ shape_node = helper.make_node("Shape", ["x"], ["y"])
+ squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"])
+ shape = ["A"]
+
+ initializer = (
+ [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if
axis else None
+ )
+
+ graph = helper.make_graph(
+ [shape_node, squeeze_node],
+ "squeeze_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ ],
+ initializer=initializer,
+ outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [])],
+ )
+
+ model = helper.make_model(graph, producer_name="squeeze_test")
+ inputs = {"x": rg.standard_normal(size=[A]).astype("float32")}
+ check_correctness(model, inputs, opset=13)
+
+
def test_const():
shape = [32, 32]
const_node = helper.make_node(
@@ -1548,6 +1662,68 @@ def test_slice():
# )
+def test_slice_dynamic_shape():
+ def verify_slice(
+ data_shape, data_instance_shape, output_shape, starts, ends,
axes=None, steps=None
+ ):
+ if isinstance(starts, list):
+ starts = np.array(starts, "int64")
+ if isinstance(ends, list):
+ ends = np.array(ends, "int64")
+ if isinstance(axes, list):
+ axes = np.array(axes, "int64")
+ if isinstance(steps, list):
+ steps = np.array(steps, "int64")
+
+ slice_inputs = ["y", "starts", "ends"]
+ initializer = [
+ helper.make_tensor("starts", TensorProto.INT64, starts.shape,
starts),
+ helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
+ ]
+
+ if axes is not None:
+ initializer.append(helper.make_tensor("axes", TensorProto.INT64,
axes.shape, axes))
+ slice_inputs.append("axes")
+ if steps is not None:
+ initializer.append(helper.make_tensor("steps", TensorProto.INT64,
steps.shape, steps))
+ slice_inputs.append("steps")
+
+ shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"])
+ slice_node = helper.make_node("Slice", inputs=slice_inputs,
outputs=["z"])
+
+ graph = helper.make_graph(
+ [shape_node, slice_node],
+ "slice_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT,
data_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("z", TensorProto.INT64,
output_shape)],
+ initializer=initializer,
+ )
+
+ model = helper.make_model(graph, producer_name="slice_test")
+ inputs = {"x":
rg.standard_normal(size=data_instance_shape).astype("float32")}
+ check_correctness(model, inputs)
+
+ verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0])
+ verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[0], ends=[2],
axes=[0])
+ verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[0], ends=[2],
axes=[0])
+ verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[0], ends=[2],
axes=[0])
+ verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[0], ends=[2],
axes=[0])
+
+ verify_slice([20, 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0])
+ verify_slice(["A", 10, 5], [20, 10, 5], [1], starts=[1], ends=[2],
axes=[0])
+ verify_slice(["A", "B", 5], [20, 10, 5], [1], starts=[1], ends=[2],
axes=[0])
+ verify_slice([20, 10, "C"], [20, 10, 5], [1], starts=[1], ends=[2],
axes=[0])
+ verify_slice(["A", "B", "C"], [20, 10, 5], [1], starts=[1], ends=[2],
axes=[0])
+
+ verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0])
+ verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[1], ends=[3],
axes=[0])
+ verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[1], ends=[3],
axes=[0])
+ verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[1], ends=[3],
axes=[0])
+ verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[1], ends=[3],
axes=[0])
+
+
# TODO Enable dynamism
@pytest.mark.parametrize("dynamic", [False])
def test_attention(dynamic):
@@ -1795,12 +1971,15 @@ def test_split(fp_arith, dynamic):
)
]
+ split_constant = None
if pass_split:
if opset >= 13:
np_split = np.array(split).astype(np.int64)
- initializer.append(
- helper.make_tensor("split", TensorProto.INT64,
list(np_split.shape), np_split)
+ split_constant = make_constant_node(
+ "split", onnx.TensorProto.INT64, list(np_split.shape),
np_split
)
+ input_names.append("split")
+
node = helper.make_node(
"Split",
inputs=input_names,
@@ -1812,8 +1991,10 @@ def test_split(fp_arith, dynamic):
split_attr = helper.make_attribute("split", split)
node.attribute.append(split_attr)
+ nodes = [split_constant, node] if split_constant else [node]
+
graph = helper.make_graph(
- [node],
+ nodes,
"split_test",
inputs=inputs,
initializer=initializer,
@@ -2226,6 +2407,12 @@ def test_flatten():
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})
+def test_flatten_dynamic():
+ verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32],
attrs={"axis": 0})
+ verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32],
attrs={"axis": -1})
+ verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32],
attrs={"axis": 2})
+
+
def test_onehot():
one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"],
["y"], axis=1)
graph = helper.make_graph(