This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e6262bf985 [ONNX] Support SequenceLength op (#13863)
e6262bf985 is described below

commit e6262bf9855a0c6f546f097910b48f955f2749cf
Author: Valery Chernov <[email protected]>
AuthorDate: Mon Jan 30 13:29:04 2023 +0400

    [ONNX] Support SequenceLength op (#13863)
    
    * add SequenceLength op
    
    * add SequenceLength test
    
    * graph fix
    
    ---------
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          | 10 ++++++++++
 tests/python/frontend/onnx/test_forward.py | 21 +++++++++++++++++++--
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 7b35d4a481..6e0c7cc2dd 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -6148,6 +6148,15 @@ class SequenceConstruct(OnnxOpConverter):
         return _expr.Tuple(inputs)
 
 
+class SequenceLength(OnnxOpConverter):
+    """Operator converter for sequence length op."""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        # Get length of input sequence
+        return _expr.const(len(inputs[0]), dtype="int64")
+
+
 class SequenceInsert(OnnxOpConverter):
     """Operator converter for sequence insert op."""
 
@@ -6483,6 +6492,7 @@ def _get_convert_map(opset):
         "LinearRegressor": LinearRegressor.get_converter(opset),
         # Sequence operators
         "SequenceConstruct": SequenceConstruct.get_converter(opset),
+        "SequenceLength": SequenceLength.get_converter(opset),
         "SequenceInsert": SequenceInsert.get_converter(opset),
         "ConcatFromSequence": ConcatFromSequence.get_converter(opset),
         "SplitToSequence": SplitToSequence.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 4b17cfbbb3..6a780a632f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7760,10 +7760,16 @@ def test_sequence(target, dev):
             "SplitToSequence", inputs=["concat_sequence"], 
outputs=["split_sequence"], axis=axis
         )
 
+        # Test tensor extraction from sequence
         at_node = helper.make_node(
             "SequenceAt", inputs=["split_sequence", "position"], 
outputs=["output"]
         )
 
+        # Test sequence length
+        length_node = helper.make_node(
+            "SequenceLength", inputs=["split_sequence"], outputs=["output_2"]
+        )
+
         if new_axis is not None:
             new_axis_attr = helper.make_attribute("new_axis", new_axis)
             concat_node.attribute.append(new_axis_attr)
@@ -7781,9 +7787,20 @@ def test_sequence(target, dev):
             output_shape[axis] = num_tensors + 1
         else:
             output_shape[axis] = (num_tensors + 1) * output_shape[axis]
-        graph_outputs = [helper.make_tensor_value_info("output", 
TensorProto.FLOAT, output_shape)]
+        graph_outputs = [
+            helper.make_tensor_value_info("output", TensorProto.FLOAT, 
output_shape),
+            helper.make_tensor_value_info("output_2", TensorProto.INT64, []),
+        ]
 
-        graph_nodes = [position_node, construct_node, insert_node, 
concat_node, split_node, at_node]
+        graph_nodes = [
+            position_node,
+            construct_node,
+            insert_node,
+            concat_node,
+            split_node,
+            at_node,
+            length_node,
+        ]
 
         graph = helper.make_graph(
             graph_nodes,

Reply via email to