Zha0q1 commented on a change in pull request #19677:
URL: https://github.com/apache/incubator-mxnet/pull/19677#discussion_r543703650
##########
File path: python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
##########
@@ -1645,30 +1645,37 @@ def convert_cast(node, **kwargs):
@mx_op.register("slice_axis")
def convert_slice_axis(node, **kwargs):
+ from onnx.helper import make_node
"""Map MXNet's slice_axis operator attributes to onnx's Slice operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
- axes = int(attrs.get("axis"))
- starts = int(attrs.get("begin"))
- ends = attrs.get("end", None)
- if not ends or ends == 'None':
+ axis = int(attrs.get("axis"))
+ begin = int(attrs.get("begin"))
+ end = attrs.get("end", None)
+
+ nodes = []
+ create_tensor([axis], name+'_axis',kwargs["initializer"])
+ create_tensor([begin], name+'_begin',kwargs["initializer"])
+ if not end or end == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
- in_shape = kwargs['in_shape'][0]
- ends = in_shape[axes]
+ create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"])
+ nodes += [
+ make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
+ make_node('Slice', [name+'_data_shape', name+'_axis',
name+'_axis_plus_1'],
Review comment:
yes. this is covered by a test case too
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]