This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d84b59  [FEATURE] Integrate oneDNN support for add, subtract, 
multiply, divide. (#20713)
7d84b59 is described below

commit 7d84b598459985a59f7601f638cf8707389609c9
Author: AdamGrabowski <[email protected]>
AuthorDate: Tue Jan 18 13:00:41 2022 +0100

    [FEATURE] Integrate oneDNN support for add, subtract, multiply, divide. 
(#20713)
    
    * Integrate oneDNN support for binary elementwise operators.
    
    * Delete template xpu for BinaryOperatorComputeExCPU function
    
    * Fix binary operators StorageType functio.
    
    * Fix SupportDNNLBinary function.
    
    * Fix test_operator, DNNLAlgorithm structure, DNNLData and 
DNNLBinaryOpForward condition.
    
    * Fix test cases, add oneDNN runtime flag to dispatch, remove node attrs, 
rename pointers
    
    * Fix sanity
---
 src/operator/nn/dnnl/dnnl_base-inl.h               |  1 +
 src/operator/nn/dnnl/dnnl_binary-inl.h             | 86 ++++++++++++++++++++++
 src/operator/nn/dnnl/dnnl_binary.cc                | 78 ++++++++++++++++++++
 src/operator/numpy/np_elemwise_broadcast_op.h      | 47 ++++++++++++
 src/operator/numpy/np_elemwise_broadcast_op_add.cc |  4 +
 src/operator/numpy/np_elemwise_broadcast_op_mul.cc |  4 +
 src/operator/numpy/np_elemwise_broadcast_op_sub.cc |  4 +
 src/operator/numpy/np_true_divide.cc               | 14 ++++
 src/operator/tensor/elemwise_binary_broadcast_op.h | 41 +++++++++++
 .../tensor/elemwise_binary_broadcast_op_basic.cc   | 79 ++++++++++++++++++--
 tests/python/unittest/test_operator.py             | 29 +++++---
 11 files changed, 370 insertions(+), 17 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h 
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 5344989..3b0bda9 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -199,6 +199,7 @@ bool SupportDNNLLayerNorm(const LayerNormParam& param, 
const std::vector<NDArray
 bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
 bool SupportDNNLSplit(const NDArray& input);
 bool SupportDNNLStack(const std::vector<NDArray>& inputs);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h 
b/src/operator/nn/dnnl/dnnl_binary-inl.h
new file mode 100644
index 0000000..2cf63aa
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_binary-inl.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_binary-inl.h
+ * \author: Adam Grabowski, [email protected]
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+#include "./dnnl_base-inl.h"
+#include "./dnnl_ops-inl.h"
+#include <vector>
+
+#include "../../tensor/elemwise_binary_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+using binary_fwd_t    = dnnl::binary;
+using binary_fwd_pd_t = dnnl::binary::primitive_desc;
+
+class DNNLBinaryOpFwd {
+ public:
+  template <dnnl::algorithm alg>
+  static DNNLBinaryOpFwd& GetBinaryOpForward(const std::vector<NDArray>& 
inputs,
+                                             const std::vector<NDArray>& 
outputs);
+  DNNLBinaryOpFwd(const dnnl::algorithm alg,
+                  const std::vector<NDArray>& inputs,
+                  const std::vector<NDArray>& outputs);
+
+  void Execute(const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+
+ private:
+  std::shared_ptr<binary_fwd_t> fwd;
+  std::shared_ptr<binary_fwd_pd_t> fwd_pd;
+};
+
+template <dnnl::algorithm alg>
+DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const 
std::vector<NDArray>& inputs,
+                                                     const 
std::vector<NDArray>& outputs) {
+  using binary_op_fwd_map = std::unordered_map<OpSignature, DNNLBinaryOpFwd, 
OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local binary_op_fwd_map fwds;
+#else
+  static MX_THREAD_LOCAL binary_op_fwd_map fwds;
+#endif
+  OpSignature key;
+  key.AddSign(static_cast<int>(alg));
+  key.AddSign(inputs[0]);
+  key.AddSign(inputs[1]);
+  key.AddSign(outputs[0]);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    const DNNLBinaryOpFwd fwd(alg, inputs, outputs);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_binary.cc 
b/src/operator/nn/dnnl/dnnl_binary.cc
new file mode 100644
index 0000000..b4d526c
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_binary.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_binary.cc
+ * \author: Adam Grabowski, [email protected]
+ */
+
+#if MXNET_USE_ONEDNN == 1
+#include "./dnnl_binary-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<NDArray>& outputs) {
+  auto src0_desc = inputs[0].GetDNNLData()->get_desc();
+  auto src1_desc = inputs[1].GetDNNLData()->get_desc();
+  auto dst_desc  = outputs[0].GetDNNLData()->get_desc();
+
+  dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc);
+  fwd_pd = std::make_shared<binary_fwd_pd_t>(fwd_desc, 
mxnet::CpuEngine::Get()->get_engine());
+  fwd    = std::make_shared<binary_fwd_t>(*fwd_pd);
+}
+
+void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  auto engine = mxnet::CpuEngine::Get()->get_engine();
+  auto src0   = inputs[0].GetDNNLData();
+  auto src1   = inputs[1].GetDNNLData();
+  dnnl_output_t out_mem;
+  if (outputs[0].GetDNNLData()->get_data_handle() == 
inputs[1].GetDNNLData()->get_data_handle())
+    out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], 
&inputs[1]);
+  else
+    out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], 
&inputs[0]);
+
+  dnnl_args_map_t args = {
+      {DNNL_ARG_SRC_0, *src0},
+      {DNNL_ARG_SRC_1, *src1},
+      {DNNL_ARG_DST, *out_mem.second},
+  };
+
+  DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
+  CommitOutput(outputs[0], out_mem);
+  DNNLStream::Get()->Submit();
+}
+
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
+  auto dtype  = inputs[0].dtype();
+  auto ndim_0 = inputs[0].shape().ndim();
+  auto ndim_1 = inputs[1].shape().ndim();
+  return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
+         inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
+         dtype == mshadow::kFloat32 && dtype == inputs[1].dtype();
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h 
b/src/operator/numpy/np_elemwise_broadcast_op.h
index fa329bf..3d28ffc 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
   }
 }
 
+#if MXNET_USE_ONEDNN == 1
+inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs,
+                                            const int dev_mask,
+                                            DispatchMode* dispatch_mode,
+                                            std::vector<int>* in_attrs,
+                                            std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2);
+  CHECK_EQ(out_attrs->size(), 1);
+
+  return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, 
out_attrs);
+}
+
+void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<TBlob>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<TBlob>& outputs);
+
+template <typename OP>
+void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                     const OpContext& ctx,
+                                     const std::vector<mxnet::NDArray>& inputs,
+                                     const std::vector<OpReqType>& req,
+                                     const std::vector<mxnet::NDArray>& 
outputs) {
+  if (SupportDNNLBinary(inputs)) {
+    const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
+    DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
+    return;
+  }
+  using namespace op::mshadow_op;
+  std::vector<mxnet::TBlob> in_data  = {inputs[0].data(), inputs[1].data()};
+  std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
+  if (std::is_same<OP, plus>::value) {
+    NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_plus, mixed_plus>(
+        attrs, ctx, in_data, req, out_data);
+  } else if (std::is_same<OP, minus>::value) {
+    NumpyBinaryBroadcastCompute<cpu, OP, mixed_minus, mixed_rminus>(
+        attrs, ctx, in_data, req, out_data);
+  } else if (std::is_same<OP, mul>::value) {
+    NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_mul, mixed_mul>(
+        attrs, ctx, in_data, req, out_data);
+  } else if (std::is_same<OP, div>::value) {
+    NumpyDivideBroadcastComputeCPU(attrs, ctx, in_data, req, out_data);
+  }
+}
+#endif  // MXNET_USE_ONEDNN
+
 #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name)                        \
   NNVM_REGISTER_OP(name)                                                      \
       .set_num_inputs(1)                                                      \
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_add.cc 
b/src/operator/numpy/np_elemwise_broadcast_op_add.cc
index 50a79ab..69fc12b 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_add.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_add.cc
@@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
                                                             
op::mshadow_op::plus,
                                                             
op::mshadow_op::mixed_plus,
                                                             
op::mshadow_op::mixed_plus>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::plus>)
+    .set_attr<FInferStorageType>("FInferStorageType", 
NumpyBinaryBroadcastStorageType)
+#endif  // MXNET_USE_ONEDNN
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_broadcast_add"});
 
 NNVM_REGISTER_OP(_backward_npi_broadcast_add)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc 
b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
index 3e627c8..b450b81 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
@@ -33,6 +33,10 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
                                                             
op::mshadow_op::mul,
                                                             
op::mshadow_op::mixed_mul,
                                                             
op::mshadow_op::mixed_mul>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::mul>)
+    .set_attr<FInferStorageType>("FInferStorageType", 
NumpyBinaryBroadcastStorageType)
+#endif  // MXNET_USE_ONEDNN
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
 
 NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc 
b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
index 5f3ba76..018b7a7 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
@@ -33,6 +33,10 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
                                                     op::mshadow_op::minus,
                                                     
op::mshadow_op::mixed_minus,
                                                     
op::mshadow_op::mixed_rminus>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::minus>)
+    .set_attr<FInferStorageType>("FInferStorageType", 
NumpyBinaryBroadcastStorageType)
+#endif  // MXNET_USE_ONEDNN
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});
 
 NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
diff --git a/src/operator/numpy/np_true_divide.cc 
b/src/operator/numpy/np_true_divide.cc
index 639379d..3ef93c9 100644
--- a/src/operator/numpy/np_true_divide.cc
+++ b/src/operator/numpy/np_true_divide.cc
@@ -61,6 +61,16 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+#if MXNET_USE_ONEDNN == 1
+void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<TBlob>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<TBlob>& outputs) {
+  TrueDivideBroadcastCompute<cpu>(attrs, ctx, inputs, req, outputs);
+}
+#endif  // MXNET_USE_ONEDNN
+
 NNVM_REGISTER_OP(_npi_true_divide)
     .set_num_inputs(2)
     .set_num_outputs(1)
@@ -79,6 +89,10 @@ NNVM_REGISTER_OP(_npi_true_divide)
                                   return 
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
                                 })
     .set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::div>)
+    .set_attr<FInferStorageType>("FInferStorageType", 
NumpyBinaryBroadcastStorageType)
+#endif  // MXNET_USE_ONEDNN
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
     .add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
     .add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h 
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 20d874d..1c4d84d 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -91,8 +91,14 @@ inline bool BinaryBroadcastMulStorageType(const 
nnvm::NodeAttrs& attrs,
   int& out_stype      = out_attrs->at(0);
   bool dispatched     = false;
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+#if MXNET_USE_ONEDNN == 1
+    if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
+      dispatched = storage_type_assign(
+          &out_stype, kDefaultStorage, dispatch_mode, 
DispatchMode::kFComputeEx);
+#else
     dispatched =
         storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, 
DispatchMode::kFCompute);
+#endif  // MXNET_USE_ONEDNN == 1
   }
   if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) 
{
     dispatched =
@@ -116,8 +122,14 @@ inline bool BinaryBroadcastAddStorageType(const 
nnvm::NodeAttrs& attrs,
   int& out_stype      = out_attrs->at(0);
   bool dispatched     = false;
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+#if MXNET_USE_ONEDNN == 1
+    if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
+      dispatched = storage_type_assign(
+          &out_stype, kDefaultStorage, dispatch_mode, 
DispatchMode::kFComputeEx);
+#else
     dispatched =
         storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, 
DispatchMode::kFCompute);
+#endif  // MXNET_USE_ONEDNN == 1
   }
   if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == 
kDefaultStorage) ||
                       (lhs_stype == kDefaultStorage && rhs_stype == 
kCSRStorage))) {
@@ -788,6 +800,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& 
attrs,
   }
 }
 
+#if MXNET_USE_ONEDNN == 1
+template <dnnl::algorithm alg>
+void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs);
+
+// template struct converting op::mshadow_op to dnnl::algorithm
+template <typename OP>
+struct DNNLAlgorithm {};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::plus> {
+  static const dnnl::algorithm value = dnnl::algorithm::binary_add;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::minus> {
+  static const dnnl::algorithm value = dnnl::algorithm::binary_sub;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::mul> {
+  static const dnnl::algorithm value = dnnl::algorithm::binary_mul;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::div> {
+  static const dnnl::algorithm value = dnnl::algorithm::binary_div;
+};
+#endif  // MXNET_USE_ONEDNN == 1
+
 #define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name)                         
                   \
   NNVM_REGISTER_OP(name)                                                       
                   \
       .set_num_inputs(2)                                                       
                   \
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc 
b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 9d0f107..cc66a1e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -24,9 +24,76 @@
 #include "./elemwise_unary_op.h"
 #include "./elemwise_binary_op-inl.h"
 #include "./elemwise_binary_broadcast_op.h"
+#if MXNET_USE_ONEDNN == 1
+#include "../nn/dnnl/dnnl_binary-inl.h"
+#endif  // MXNET_USE_ONEDNN == 1
 
 namespace mxnet {
 namespace op {
+
+#if MXNET_USE_ONEDNN == 1
+template <dnnl::algorithm alg>
+void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs) {
+  mxnet::TShape new_lshape, new_rshape, new_oshape;
+  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+                                              inputs[1].shape(),
+                                              outputs[0].shape(),
+                                              &new_lshape,
+                                              &new_rshape,
+                                              &new_oshape);
+  std::vector<NDArray> new_inputs;
+  std::vector<NDArray> new_outputs;
+  if (ndim_diff) {
+    new_inputs  = {inputs[0].Reshape(new_lshape), 
inputs[1].Reshape(new_rshape)};
+    new_outputs = {outputs[0].Reshape(new_oshape)};
+  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
+    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size 
(1,1,...,1)
+    // into shape (1). It is mandatory for oneDNN primitive to have this 
reshape done.
+    mxnet::TShape one_shape = mxnet::TShape(1, 1);
+    new_inputs              = {inputs[0].Reshape(one_shape), 
inputs[1].Reshape(one_shape)};
+    new_outputs             = {outputs[0].Reshape(one_shape)};
+  } else {
+    new_inputs  = {inputs[0], inputs[1]};
+    new_outputs = {outputs[0]};
+  }
+
+  DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs, 
new_outputs);
+  fwd.Execute(new_inputs, req, new_outputs);
+}
+#endif
+
+template <typename OP>
+static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                       const OpContext& ctx,
+                                       const std::vector<NDArray>& inputs,
+                                       const std::vector<OpReqType>& req,
+                                       const std::vector<NDArray>& outputs) {
+#if MXNET_USE_ONEDNN == 1
+  if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
+    if (SupportDNNLBinary(inputs)) {
+      const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
+      DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
+    } else {
+      std::vector<mxnet::TBlob> in_data  = {inputs[0].data(), 
inputs[1].data()};
+      std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
+      BinaryBroadcastCompute<cpu, OP>(attrs, ctx, in_data, req, out_data);
+    }
+    return;
+  }
+#endif  // MXNET_USE_ONEDNN == 1
+  if (std::is_same<OP, op::mshadow_op::plus>::value ||
+      std::is_same<OP, op::mshadow_op::minus>::value) {
+    BinaryBroadcastComputeDenseEx<cpu, OP>(attrs, ctx, inputs, req, outputs);
+  } else if (std::is_same<OP, op::mshadow_op::mul>::value ||
+             std::is_same<OP, op::mshadow_op::div>::value) {
+    BinaryBroadcastComputeSparseEx<cpu, OP>(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_add)
 MXNET_ADD_SPARSE_OP_ALIAS(broadcast_add)
 MXNET_ADD_SPARSE_OP_ALIAS(broadcast_plus)
@@ -56,8 +123,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
     .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::plus>)
-    .set_attr<FComputeEx>("FComputeEx<cpu>",
-                          BinaryBroadcastComputeDenseEx<cpu, 
op::mshadow_op::plus>)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
BinaryOperatorComputeExCPU<op::mshadow_op::plus>)
     .set_attr<FInferStorageType>("FInferStorageType", 
BinaryBroadcastAddStorageType)
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_add"});
 
@@ -106,8 +172,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
     .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::minus>)
-    .set_attr<FComputeEx>("FComputeEx<cpu>",
-                          BinaryBroadcastComputeDenseEx<cpu, 
op::mshadow_op::minus>)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
BinaryOperatorComputeExCPU<op::mshadow_op::minus>)
     .set_attr<FInferStorageType>("FInferStorageType", 
BinaryBroadcastAddStorageType)
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
@@ -148,8 +213,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
     .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::mul>)
-    .set_attr<FComputeEx>("FComputeEx<cpu>",
-                          BinaryBroadcastComputeSparseEx<cpu, 
op::mshadow_op::mul>)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
BinaryOperatorComputeExCPU<op::mshadow_op::mul>)
     .set_attr<FInferStorageType>("FInferStorageType", 
BinaryBroadcastMulStorageType)
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
@@ -189,8 +253,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
     .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::div>)
-    .set_attr<FComputeEx>("FComputeEx<cpu>",
-                          BinaryBroadcastComputeSparseEx<cpu, 
op::mshadow_op::div>)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", 
BinaryOperatorComputeExCPU<op::mshadow_op::div>)
     .set_attr<FInferStorageType>("FInferStorageType", 
BinaryBroadcastMulStorageType)
     .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_div"});
 
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 7203212..5f29031 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -927,9 +927,9 @@ def test_sign():
     assert_almost_equal(out, npout)
 
     out_grad = mx.nd.empty(shape)
-    out_grad[:] = 2;
+    out_grad[:] = 2
     npout_grad = out_grad.asnumpy()
-    npout_grad = 0;
+    npout_grad = 0
     exe_test.backward(out_grad)
     assert_almost_equal(arr_grad, npout_grad)
 
@@ -1076,7 +1076,7 @@ def test_abs():
     assert_almost_equal(out, npout)
 
     out_grad = mx.nd.empty(shape)
-    out_grad[:] = 2;
+    out_grad[:] = 2
     npout_grad = out_grad.asnumpy()
     npout_grad = npout_grad * np.sign(data_tmp)
     exe_test.backward(out_grad)
@@ -1915,7 +1915,12 @@ def gen_broadcast_data(idx):
         [[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]],
         [[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]],
         [[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]],
-        [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]]])
+        [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]],
+        [[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]],
+        [[256, 256, 3], [3]], [[5, 4], [4]],
+        [[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]],
+        [[15, 3, 5], [3, 1]], [[1,1,1,1], [1,1]],
+        [[15,3], [4, 1, 3]], [[7, 1, 5], [8, 1, 6, 1]]])
     if idx < binary_op_data_shape.shape[0]:
         l_shape = binary_op_data_shape[idx][0]
         r_shape = binary_op_data_shape[idx][1]
@@ -1939,7 +1944,7 @@ def gen_broadcast_data(idx):
 
 
 def gen_broadcast_data_int(idx):
-    d = gen_broadcast_data(idx);
+    d = gen_broadcast_data(idx)
     return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
 
 
@@ -1951,7 +1956,7 @@ def gen_binary_data(dummy):
 
 
 def gen_binary_data_int(dummy):
-    d = gen_binary_data(dummy);
+    d = gen_binary_data(dummy)
     return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
 
 
@@ -2012,10 +2017,16 @@ def check_binary_op_backward(symbol, baseline, 
gen_data, rtol=1e-3, atol=1e-5):
             if shape == x.shape:
                 return x
             keepdims_shape = list(x.shape)
+            # calculate difference between output and input ndims
+            # to include cases where inputs' ndims are not equal
+            ndim_diff = len(x.shape) - len(shape)
+            for i in range(ndim_diff):
+                keepdims_shape[i] = 1
+                x = np.sum(x, axis=i).reshape(keepdims_shape)
             for i in range(len(shape)):
-                if x.shape[i] != shape[i]:
-                    keepdims_shape[i] = 1
-                    x = np.sum(x, axis=i).reshape(keepdims_shape)
+                if x.shape[ndim_diff + i] != shape[i]:
+                    keepdims_shape[ndim_diff + i] = 1
+                    x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape)
             return x
 
         baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1])

Reply via email to