jwfromm commented on code in PR #13602:
URL: https://github.com/apache/tvm/pull/13602#discussion_r1046562472


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -5565,6 +5581,66 @@ def _impl_v11(cls, inputs, attr, params):
         return _op.concatenate(inputs[0], axis=axis)
 
 
+class SplitToSequence(OnnxOpConverter):
+    """Operator converter for split to sequence op."""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        keepdims = attr.get("keepdims", 1)
+
+        input_tensor = inputs[0]
+        input_shape = infer_shape(input_tensor)
+        split = inputs[1]
+
+        # If split is not provided, we split all values along axis.
+        if split is None:
+            output = _op.split(input_tensor, input_shape[axis], axis=axis)
+            # If keepdims is 0, then we need to squeeze off the axis.
+            if keepdims == 0:
+                output = [_op.squeeze(tensor_slice, axis=[axis]) for 
tensor_slice in output]
+            return _expr.Tuple(list(output))
+
+        # Otherwise, split based on provided split value.
+        else:
+            # For now we only support constant valued split.
+            assert isinstance(
+                split, _expr.Constant
+            ), "Only constant split supported for SplitToSequence"
+            split = split.data.numpy()
+            if len(split.shape) == 1 and split.shape[0] > 1:
+                # If split is a 1D tensor, it must be converted to indices for 
relay compatibility.
+                split = np.cumsum(split)
+                # Remove final invalid index.
+                split = split[:-1]
+            else:
+                # Otherwise get split as an integer.
+                split = int(split)
+
+            output = _op.split(input_tensor, split, axis=axis)
+
+            # If keepdims is set to 0 remove split axis. Note that this is
+            # an inconsistency with the onnx spec but is needed for pytorch 
compatibility.
+            if keepdims == 0:

Review Comment:
   i actually prefer comparing to 0 in this case since it has meaning distinct 
from None.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to