anirudhacharya commented on a change in pull request #10605: [MXNET-310]
[ONNX-MXNet] API to import ONNX models into Gluon.
URL: https://github.com/apache/incubator-mxnet/pull/10605#discussion_r192206073
##########
File path: python/mxnet/contrib/onnx/_import/op_translations.py
##########
@@ -43,32 +43,42 @@ def add(attrs, inputs, cls):
"""Adding two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
- op_value = translation_utils._fix_bias_shape('broadcast_add', inputs,
cls)
+ broadcast_axis = attrs['axis']
+ op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
+ broadcast_axis, cls)
return op_value, new_attr, inputs
- return 'elemwise_add', new_attr, inputs
+ return 'broadcast_add', new_attr, inputs
def subtract(attrs, inputs, cls):
"""Subtracting two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
- return 'broadcast_sub', new_attr, inputs
- return 'elemwise_sub', new_attr, inputs
+ broadcast_axis = attrs['axis']
+ op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
+ broadcast_axis, cls)
+ return op_value, new_attr, inputs
+ return 'broadcast_sub', new_attr, inputs
def multiply(attrs, inputs, cls):
"""Multiply two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
- op_value = translation_utils._fix_bias_shape('broadcast_mul', inputs,
cls)
+ broadcast_axis = attrs['axis']
+ op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
+ broadcast_axis, cls)
return op_value, new_attr, inputs
- return 'elemwise_mul', new_attr, inputs
+ return 'broadcast_mul', new_attr, inputs
def divide(attrs, inputs, cls):
"""Divide two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
- return 'broadcast_div', new_attr, inputs
- return 'elemwise_div', new_attr, inputs
+ broadcast_axis = attrs['axis']
+ op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
+ broadcast_axis, cls)
Review comment:
broadcast_axis comes from the ONNX's axis attribute in operators that
support broadcasting -
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103
With OP_SET version 6 broadcasting (1,1) on (4,5) would not be permissible.
If we are broadcasting (5,) on (4,5) the broadcast_axis will be equal to 1. On
the other hand if we broadcast (4,) on (4,5) broadcast axis will be equal to 0.
ONNX with their OP_SET version 7 are updating the broadcast rules to be
aligned with numpy broadcasting rules. When that gets consistently updated in
ONNX repo we will also update the translation code in mxnet.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services