This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e6262bf985 [ONNX] Support SequenceLength op (#13863)
e6262bf985 is described below
commit e6262bf9855a0c6f546f097910b48f955f2749cf
Author: Valery Chernov <[email protected]>
AuthorDate: Mon Jan 30 13:29:04 2023 +0400
[ONNX] Support SequenceLength op (#13863)
* add SequenceLength op
* add SequenceLength test
* graph fix
---------
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relay/frontend/onnx.py | 10 ++++++++++
tests/python/frontend/onnx/test_forward.py | 21 +++++++++++++++++++--
2 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 7b35d4a481..6e0c7cc2dd 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -6148,6 +6148,15 @@ class SequenceConstruct(OnnxOpConverter):
return _expr.Tuple(inputs)
+class SequenceLength(OnnxOpConverter):
+ """Operator converter for sequence length op."""
+
+ @classmethod
+ def _impl_v11(cls, inputs, attr, params):
+ # Get length of input sequence
+ return _expr.const(len(inputs[0]), dtype="int64")
+
+
class SequenceInsert(OnnxOpConverter):
"""Operator converter for sequence insert op."""
@@ -6483,6 +6492,7 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
+ "SequenceLength": SequenceLength.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 4b17cfbbb3..6a780a632f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7760,10 +7760,16 @@ def test_sequence(target, dev):
"SplitToSequence", inputs=["concat_sequence"],
outputs=["split_sequence"], axis=axis
)
+ # Test tensor extraction from sequence
at_node = helper.make_node(
"SequenceAt", inputs=["split_sequence", "position"],
outputs=["output"]
)
+ # Test sequence length
+ length_node = helper.make_node(
+ "SequenceLength", inputs=["split_sequence"], outputs=["output_2"]
+ )
+
if new_axis is not None:
new_axis_attr = helper.make_attribute("new_axis", new_axis)
concat_node.attribute.append(new_axis_attr)
@@ -7781,9 +7787,20 @@ def test_sequence(target, dev):
output_shape[axis] = num_tensors + 1
else:
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
- graph_outputs = [helper.make_tensor_value_info("output",
TensorProto.FLOAT, output_shape)]
+ graph_outputs = [
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
output_shape),
+ helper.make_tensor_value_info("output_2", TensorProto.INT64, []),
+ ]
- graph_nodes = [position_node, construct_node, insert_node,
concat_node, split_node, at_node]
+ graph_nodes = [
+ position_node,
+ construct_node,
+ insert_node,
+ concat_node,
+ split_node,
+ at_node,
+ length_node,
+ ]
graph = helper.make_graph(
graph_nodes,