bartekkuncer commented on code in PR #21132:
URL: https://github.com/apache/incubator-mxnet/pull/21132#discussion_r954778934


##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());
+
+  if (same_shape && alg == dnnl::algorithm::binary_add) {
+    DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+    fwd.Execute(ctx, inputs, req, outputs);
   } else {
-    new_inputs  = {inputs[0], inputs[1]};
-    new_outputs = {outputs[0]};
-  }
+    mxnet::TShape new_lshape, new_rshape, new_oshape;
+    int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+                                                inputs[1].shape(),
+                                                outputs[0].shape(),
+                                                &new_lshape,
+                                                &new_rshape,
+                                                &new_oshape);
+    std::vector<NDArray> new_inputs;
+    std::vector<NDArray> new_outputs;
+    if (ndim_diff) {
+      new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
+      new_outputs = {outputs[0].Reshape(new_oshape)};
+    } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) 
{
+      // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)

Review Comment:
   ```suggestion
         // BinaryBroadcastShapeCompact function does not reshaping tensors of 
size (1,1,...,1)
   ```



##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same

Review Comment:
   ```suggestion
     // We can use more efficient sum kernel, when there is no broadcast - when 
shapes are the same.
   ```



##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());

Review Comment:
   ```suggestion
     const bool same_shape = inputs[0].shape() == inputs[1].shape();
   ```



##########
tests/python/unittest/test_operator.py:
##########
@@ -9414,6 +9414,24 @@ def test_elementwise_ops_on_misaligned_input():
     mx.nd.waitall()
     assert a[3].asscalar() == 4.0
 
+
[email protected]('dtype', ['float16', 'float32', 'float64'])
[email protected]('ndim', [1, 2, 3, 4, 5])
[email protected]('max_dim_size', [1, 2, 3, 4, 5])
+def test_broadcast_ops_on_input_with_the_same_shape(dtype, ndim, max_dim_size):
+    shape = list(rand_shape_nd(ndim, dim=max_dim_size))
+    shape_size = np.product(shape)
+    a = np.random.uniform(low=-100, high=100, size=5000)

Review Comment:
   Please change it to 3125 and add simple comment that it is 5^5/5**5.



##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());

Review Comment:
   Maybe put these shapes in variables as they are referred to multiple times?



##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());
+
+  if (same_shape && alg == dnnl::algorithm::binary_add) {
+    DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+    fwd.Execute(ctx, inputs, req, outputs);
   } else {
-    new_inputs  = {inputs[0], inputs[1]};
-    new_outputs = {outputs[0]};
-  }
+    mxnet::TShape new_lshape, new_rshape, new_oshape;
+    int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+                                                inputs[1].shape(),
+                                                outputs[0].shape(),
+                                                &new_lshape,
+                                                &new_rshape,
+                                                &new_oshape);
+    std::vector<NDArray> new_inputs;
+    std::vector<NDArray> new_outputs;
+    if (ndim_diff) {
+      new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
+      new_outputs = {outputs[0].Reshape(new_oshape)};
+    } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) 
{
+      // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
+      // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.

Review Comment:
   So BinaryBroadcastShapeCompact does not support this reshape or oneDNN?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to