anirudhacharya commented on a change in pull request #10605: [MXNET-310] 
[ONNX-MXNet] API to import ONNX models into Gluon.
URL: https://github.com/apache/incubator-mxnet/pull/10605#discussion_r192206073
 
 

 ##########
 File path: python/mxnet/contrib/onnx/_import/op_translations.py
 ##########
 @@ -43,32 +43,42 @@ def add(attrs, inputs, cls):
     """Adding two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        op_value = translation_utils._fix_bias_shape('broadcast_add', inputs, 
cls)
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
+                                                    broadcast_axis, cls)
         return op_value, new_attr, inputs
-    return 'elemwise_add', new_attr, inputs
+    return 'broadcast_add', new_attr, inputs
 
 def subtract(attrs, inputs, cls):
     """Subtracting two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        return 'broadcast_sub', new_attr, inputs
-    return 'elemwise_sub', new_attr, inputs
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
+                                                    broadcast_axis, cls)
+        return op_value, new_attr, inputs
+    return 'broadcast_sub', new_attr, inputs
 
 
 def multiply(attrs, inputs, cls):
     """Multiply two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        op_value = translation_utils._fix_bias_shape('broadcast_mul', inputs, 
cls)
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
+                                                    broadcast_axis, cls)
         return op_value, new_attr, inputs
-    return 'elemwise_mul', new_attr, inputs
+    return 'broadcast_mul', new_attr, inputs
 
 def divide(attrs, inputs, cls):
     """Divide two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        return 'broadcast_div', new_attr, inputs
-    return 'elemwise_div', new_attr, inputs
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
+                                                    broadcast_axis, cls)
 
 Review comment:
   broadcast_axis comes from the ONNX's axis attribute in operators that 
support broadcasting - 
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103
   
   With OP_SET version 6 broadcasting (1,1) on (4,5) would not be permissible. 
If we are broadcasting (5,) on (4,5) the broadcast_axis will be equal to 1. On 
the other hand if we broadcast (4,) on (4,5) broadcast axis will be equal to 0.
   
   ONNX with their OP_SET version 7 are updating the broadcast rules to be 
aligned with numpy broadcasting rules. When that gets consistently updated in 
ONNX repo we will also update the translation code in mxnet.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to