This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7d84b59 [FEATURE] Integrate oneDNN support for add, subtract,
multiply, divide. (#20713)
7d84b59 is described below
commit 7d84b598459985a59f7601f638cf8707389609c9
Author: AdamGrabowski <[email protected]>
AuthorDate: Tue Jan 18 13:00:41 2022 +0100
[FEATURE] Integrate oneDNN support for add, subtract, multiply, divide.
(#20713)
* Integrate oneDNN support for binary elementwise operators.
* Delete template xpu for BinaryOperatorComputeExCPU function
* Fix binary operators StorageType functio.
* Fix SupportDNNLBinary function.
* Fix test_operator, DNNLAlgorithm structure, DNNLData and
DNNLBinaryOpForward condition.
* Fix test cases, add oneDNN runtime flag to dispatch, remove node attrs,
rename pointers
* Fix sanity
---
src/operator/nn/dnnl/dnnl_base-inl.h | 1 +
src/operator/nn/dnnl/dnnl_binary-inl.h | 86 ++++++++++++++++++++++
src/operator/nn/dnnl/dnnl_binary.cc | 78 ++++++++++++++++++++
src/operator/numpy/np_elemwise_broadcast_op.h | 47 ++++++++++++
src/operator/numpy/np_elemwise_broadcast_op_add.cc | 4 +
src/operator/numpy/np_elemwise_broadcast_op_mul.cc | 4 +
src/operator/numpy/np_elemwise_broadcast_op_sub.cc | 4 +
src/operator/numpy/np_true_divide.cc | 14 ++++
src/operator/tensor/elemwise_binary_broadcast_op.h | 41 +++++++++++
.../tensor/elemwise_binary_broadcast_op_basic.cc | 79 ++++++++++++++++++--
tests/python/unittest/test_operator.py | 29 +++++---
11 files changed, 370 insertions(+), 17 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 5344989..3b0bda9 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -199,6 +199,7 @@ bool SupportDNNLLayerNorm(const LayerNormParam& param,
const std::vector<NDArray
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
bool SupportDNNLSplit(const NDArray& input);
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h
b/src/operator/nn/dnnl/dnnl_binary-inl.h
new file mode 100644
index 0000000..2cf63aa
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_binary-inl.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_binary-inl.h
+ * \author: Adam Grabowski, [email protected]
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+#include "./dnnl_base-inl.h"
+#include "./dnnl_ops-inl.h"
+#include <vector>
+
+#include "../../tensor/elemwise_binary_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+using binary_fwd_t = dnnl::binary;
+using binary_fwd_pd_t = dnnl::binary::primitive_desc;
+
+class DNNLBinaryOpFwd {
+ public:
+ template <dnnl::algorithm alg>
+ static DNNLBinaryOpFwd& GetBinaryOpForward(const std::vector<NDArray>&
inputs,
+ const std::vector<NDArray>&
outputs);
+ DNNLBinaryOpFwd(const dnnl::algorithm alg,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs);
+
+ void Execute(const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+ private:
+ std::shared_ptr<binary_fwd_t> fwd;
+ std::shared_ptr<binary_fwd_pd_t> fwd_pd;
+};
+
+template <dnnl::algorithm alg>
+DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const
std::vector<NDArray>& inputs,
+ const
std::vector<NDArray>& outputs) {
+ using binary_op_fwd_map = std::unordered_map<OpSignature, DNNLBinaryOpFwd,
OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local binary_op_fwd_map fwds;
+#else
+ static MX_THREAD_LOCAL binary_op_fwd_map fwds;
+#endif
+ OpSignature key;
+ key.AddSign(static_cast<int>(alg));
+ key.AddSign(inputs[0]);
+ key.AddSign(inputs[1]);
+ key.AddSign(outputs[0]);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ const DNNLBinaryOpFwd fwd(alg, inputs, outputs);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_binary.cc
b/src/operator/nn/dnnl/dnnl_binary.cc
new file mode 100644
index 0000000..b4d526c
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_binary.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_binary.cc
+ * \author: Adam Grabowski, [email protected]
+ */
+
+#if MXNET_USE_ONEDNN == 1
+#include "./dnnl_binary-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs) {
+ auto src0_desc = inputs[0].GetDNNLData()->get_desc();
+ auto src1_desc = inputs[1].GetDNNLData()->get_desc();
+ auto dst_desc = outputs[0].GetDNNLData()->get_desc();
+
+ dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc);
+ fwd_pd = std::make_shared<binary_fwd_pd_t>(fwd_desc,
mxnet::CpuEngine::Get()->get_engine());
+ fwd = std::make_shared<binary_fwd_t>(*fwd_pd);
+}
+
+void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ auto engine = mxnet::CpuEngine::Get()->get_engine();
+ auto src0 = inputs[0].GetDNNLData();
+ auto src1 = inputs[1].GetDNNLData();
+ dnnl_output_t out_mem;
+ if (outputs[0].GetDNNLData()->get_data_handle() ==
inputs[1].GetDNNLData()->get_data_handle())
+ out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0],
&inputs[1]);
+ else
+ out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0],
&inputs[0]);
+
+ dnnl_args_map_t args = {
+ {DNNL_ARG_SRC_0, *src0},
+ {DNNL_ARG_SRC_1, *src1},
+ {DNNL_ARG_DST, *out_mem.second},
+ };
+
+ DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
+ CommitOutput(outputs[0], out_mem);
+ DNNLStream::Get()->Submit();
+}
+
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
+ auto dtype = inputs[0].dtype();
+ auto ndim_0 = inputs[0].shape().ndim();
+ auto ndim_1 = inputs[1].shape().ndim();
+ return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
+ inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
+ dtype == mshadow::kFloat32 && dtype == inputs[1].dtype();
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h
b/src/operator/numpy/np_elemwise_broadcast_op.h
index fa329bf..3d28ffc 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
}
}
+#if MXNET_USE_ONEDNN == 1
+inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2);
+ CHECK_EQ(out_attrs->size(), 1);
+
+ return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
+}
+
+void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+
+template <typename OP>
+void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<mxnet::NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<mxnet::NDArray>&
outputs) {
+ if (SupportDNNLBinary(inputs)) {
+ const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
+ DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ using namespace op::mshadow_op;
+ std::vector<mxnet::TBlob> in_data = {inputs[0].data(), inputs[1].data()};
+ std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
+ if (std::is_same<OP, plus>::value) {
+ NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_plus, mixed_plus>(
+ attrs, ctx, in_data, req, out_data);
+ } else if (std::is_same<OP, minus>::value) {
+ NumpyBinaryBroadcastCompute<cpu, OP, mixed_minus, mixed_rminus>(
+ attrs, ctx, in_data, req, out_data);
+ } else if (std::is_same<OP, mul>::value) {
+ NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_mul, mixed_mul>(
+ attrs, ctx, in_data, req, out_data);
+ } else if (std::is_same<OP, div>::value) {
+ NumpyDivideBroadcastComputeCPU(attrs, ctx, in_data, req, out_data);
+ }
+}
+#endif // MXNET_USE_ONEDNN
+
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_add.cc
b/src/operator/numpy/np_elemwise_broadcast_op_add.cc
index 50a79ab..69fc12b 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_add.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_add.cc
@@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
op::mshadow_op::plus,
op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::plus>)
+ .set_attr<FInferStorageType>("FInferStorageType",
NumpyBinaryBroadcastStorageType)
+#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_npi_broadcast_add"});
NNVM_REGISTER_OP(_backward_npi_broadcast_add)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
index 3e627c8..b450b81 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc
@@ -33,6 +33,10 @@
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
op::mshadow_op::mul,
op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::mul>)
+ .set_attr<FInferStorageType>("FInferStorageType",
NumpyBinaryBroadcastStorageType)
+#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
index 5f3ba76..018b7a7 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc
@@ -33,6 +33,10 @@
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
op::mshadow_op::minus,
op::mshadow_op::mixed_minus,
op::mshadow_op::mixed_rminus>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::minus>)
+ .set_attr<FInferStorageType>("FInferStorageType",
NumpyBinaryBroadcastStorageType)
+#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});
NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
diff --git a/src/operator/numpy/np_true_divide.cc
b/src/operator/numpy/np_true_divide.cc
index 639379d..3ef93c9 100644
--- a/src/operator/numpy/np_true_divide.cc
+++ b/src/operator/numpy/np_true_divide.cc
@@ -61,6 +61,16 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
return true;
}
+#if MXNET_USE_ONEDNN == 1
+void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ TrueDivideBroadcastCompute<cpu>(attrs, ctx, inputs, req, outputs);
+}
+#endif // MXNET_USE_ONEDNN
+
NNVM_REGISTER_OP(_npi_true_divide)
.set_num_inputs(2)
.set_num_outputs(1)
@@ -79,6 +89,10 @@ NNVM_REGISTER_OP(_npi_true_divide)
return
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
NumpyBinaryOperatorComputeExCPU<op::mshadow_op::div>)
+ .set_attr<FInferStorageType>("FInferStorageType",
NumpyBinaryBroadcastStorageType)
+#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 20d874d..1c4d84d 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -91,8 +91,14 @@ inline bool BinaryBroadcastMulStorageType(const
nnvm::NodeAttrs& attrs,
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+#if MXNET_USE_ONEDNN == 1
+ if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
+ dispatched = storage_type_assign(
+ &out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
+#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
+#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage)
{
dispatched =
@@ -116,8 +122,14 @@ inline bool BinaryBroadcastAddStorageType(const
nnvm::NodeAttrs& attrs,
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+#if MXNET_USE_ONEDNN == 1
+ if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
+ dispatched = storage_type_assign(
+ &out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
+#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
+#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype ==
kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype ==
kCSRStorage))) {
@@ -788,6 +800,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs&
attrs,
}
}
+#if MXNET_USE_ONEDNN == 1
+template <dnnl::algorithm alg>
+void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+// template struct converting op::mshadow_op to dnnl::algorithm
+template <typename OP>
+struct DNNLAlgorithm {};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::plus> {
+ static const dnnl::algorithm value = dnnl::algorithm::binary_add;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::minus> {
+ static const dnnl::algorithm value = dnnl::algorithm::binary_sub;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::mul> {
+ static const dnnl::algorithm value = dnnl::algorithm::binary_mul;
+};
+template <>
+struct DNNLAlgorithm<op::mshadow_op::div> {
+ static const dnnl::algorithm value = dnnl::algorithm::binary_div;
+};
+#endif // MXNET_USE_ONEDNN == 1
+
#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name)
\
NNVM_REGISTER_OP(name)
\
.set_num_inputs(2)
\
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 9d0f107..cc66a1e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -24,9 +24,76 @@
#include "./elemwise_unary_op.h"
#include "./elemwise_binary_op-inl.h"
#include "./elemwise_binary_broadcast_op.h"
+#if MXNET_USE_ONEDNN == 1
+#include "../nn/dnnl/dnnl_binary-inl.h"
+#endif // MXNET_USE_ONEDNN == 1
namespace mxnet {
namespace op {
+
+#if MXNET_USE_ONEDNN == 1
+template <dnnl::algorithm alg>
+void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ mxnet::TShape new_lshape, new_rshape, new_oshape;
+ int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+ inputs[1].shape(),
+ outputs[0].shape(),
+ &new_lshape,
+ &new_rshape,
+ &new_oshape);
+ std::vector<NDArray> new_inputs;
+ std::vector<NDArray> new_outputs;
+ if (ndim_diff) {
+ new_inputs = {inputs[0].Reshape(new_lshape),
inputs[1].Reshape(new_rshape)};
+ new_outputs = {outputs[0].Reshape(new_oshape)};
+ } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
+ // BinaryBroadcastShapeCompact function doesn't reshape tensors of size
(1,1,...,1)
+ // into shape (1). It is mandatory for oneDNN primitive to have this
reshape done.
+ mxnet::TShape one_shape = mxnet::TShape(1, 1);
+ new_inputs = {inputs[0].Reshape(one_shape),
inputs[1].Reshape(one_shape)};
+ new_outputs = {outputs[0].Reshape(one_shape)};
+ } else {
+ new_inputs = {inputs[0], inputs[1]};
+ new_outputs = {outputs[0]};
+ }
+
+ DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs,
new_outputs);
+ fwd.Execute(new_inputs, req, new_outputs);
+}
+#endif
+
+template <typename OP>
+static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+#if MXNET_USE_ONEDNN == 1
+ if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
+ if (SupportDNNLBinary(inputs)) {
+ const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
+ DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
+ } else {
+ std::vector<mxnet::TBlob> in_data = {inputs[0].data(),
inputs[1].data()};
+ std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
+ BinaryBroadcastCompute<cpu, OP>(attrs, ctx, in_data, req, out_data);
+ }
+ return;
+ }
+#endif // MXNET_USE_ONEDNN == 1
+ if (std::is_same<OP, op::mshadow_op::plus>::value ||
+ std::is_same<OP, op::mshadow_op::minus>::value) {
+ BinaryBroadcastComputeDenseEx<cpu, OP>(attrs, ctx, inputs, req, outputs);
+ } else if (std::is_same<OP, op::mshadow_op::mul>::value ||
+ std::is_same<OP, op::mshadow_op::div>::value) {
+ BinaryBroadcastComputeSparseEx<cpu, OP>(attrs, ctx, inputs, req, outputs);
+ }
+}
+
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_add)
MXNET_ADD_SPARSE_OP_ALIAS(broadcast_add)
MXNET_ADD_SPARSE_OP_ALIAS(broadcast_plus)
@@ -56,8 +123,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu,
op::mshadow_op::plus>)
- .set_attr<FComputeEx>("FComputeEx<cpu>",
- BinaryBroadcastComputeDenseEx<cpu,
op::mshadow_op::plus>)
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
BinaryOperatorComputeExCPU<op::mshadow_op::plus>)
.set_attr<FInferStorageType>("FInferStorageType",
BinaryBroadcastAddStorageType)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_broadcast_add"});
@@ -106,8 +172,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu,
op::mshadow_op::minus>)
- .set_attr<FComputeEx>("FComputeEx<cpu>",
- BinaryBroadcastComputeDenseEx<cpu,
op::mshadow_op::minus>)
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
BinaryOperatorComputeExCPU<op::mshadow_op::minus>)
.set_attr<FInferStorageType>("FInferStorageType",
BinaryBroadcastAddStorageType)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_broadcast_sub"});
@@ -148,8 +213,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu,
op::mshadow_op::mul>)
- .set_attr<FComputeEx>("FComputeEx<cpu>",
- BinaryBroadcastComputeSparseEx<cpu,
op::mshadow_op::mul>)
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
BinaryOperatorComputeExCPU<op::mshadow_op::mul>)
.set_attr<FInferStorageType>("FInferStorageType",
BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_broadcast_mul"});
@@ -189,8 +253,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu,
op::mshadow_op::div>)
- .set_attr<FComputeEx>("FComputeEx<cpu>",
- BinaryBroadcastComputeSparseEx<cpu,
op::mshadow_op::div>)
+ .set_attr<FComputeEx>("FComputeEx<cpu>",
BinaryOperatorComputeExCPU<op::mshadow_op::div>)
.set_attr<FInferStorageType>("FInferStorageType",
BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_broadcast_div"});
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 7203212..5f29031 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -927,9 +927,9 @@ def test_sign():
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
- out_grad[:] = 2;
+ out_grad[:] = 2
npout_grad = out_grad.asnumpy()
- npout_grad = 0;
+ npout_grad = 0
exe_test.backward(out_grad)
assert_almost_equal(arr_grad, npout_grad)
@@ -1076,7 +1076,7 @@ def test_abs():
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
- out_grad[:] = 2;
+ out_grad[:] = 2
npout_grad = out_grad.asnumpy()
npout_grad = npout_grad * np.sign(data_tmp)
exe_test.backward(out_grad)
@@ -1915,7 +1915,12 @@ def gen_broadcast_data(idx):
[[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]],
[[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]],
[[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]],
- [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]]])
+ [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]],
+ [[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]],
+ [[256, 256, 3], [3]], [[5, 4], [4]],
+ [[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]],
+ [[15, 3, 5], [3, 1]], [[1,1,1,1], [1,1]],
+ [[15,3], [4, 1, 3]], [[7, 1, 5], [8, 1, 6, 1]]])
if idx < binary_op_data_shape.shape[0]:
l_shape = binary_op_data_shape[idx][0]
r_shape = binary_op_data_shape[idx][1]
@@ -1939,7 +1944,7 @@ def gen_broadcast_data(idx):
def gen_broadcast_data_int(idx):
- d = gen_broadcast_data(idx);
+ d = gen_broadcast_data(idx)
return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
@@ -1951,7 +1956,7 @@ def gen_binary_data(dummy):
def gen_binary_data_int(dummy):
- d = gen_binary_data(dummy);
+ d = gen_binary_data(dummy)
return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
@@ -2012,10 +2017,16 @@ def check_binary_op_backward(symbol, baseline,
gen_data, rtol=1e-3, atol=1e-5):
if shape == x.shape:
return x
keepdims_shape = list(x.shape)
+ # calculate difference between output and input ndims
+ # to include cases where inputs' ndims are not equal
+ ndim_diff = len(x.shape) - len(shape)
+ for i in range(ndim_diff):
+ keepdims_shape[i] = 1
+ x = np.sum(x, axis=i).reshape(keepdims_shape)
for i in range(len(shape)):
- if x.shape[i] != shape[i]:
- keepdims_shape[i] = 1
- x = np.sum(x, axis=i).reshape(keepdims_shape)
+ if x.shape[ndim_diff + i] != shape[i]:
+ keepdims_shape[ndim_diff + i] = 1
+ x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape)
return x
baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1])