This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6e25c88 Add oneDNN support for "where" operator (#20862)
6e25c88 is described below
commit 6e25c884bab56727ef0e9b81b444875d492d0587
Author: bgawrych <[email protected]>
AuthorDate: Fri Mar 4 09:26:20 2022 +0100
Add oneDNN support for "where" operator (#20862)
* Where operator enabled in oneDNN
* Fix bug & refactor
* fix sanity
* apply review
* Fix get_broadcastable_shape function
* Apply review
* Remove unused variable
* Apply suggestions from code review
Co-authored-by: bartekkuncer <[email protected]>
---
src/operator/nn/dnnl/dnnl_ops-inl.h | 7 +
src/operator/nn/dnnl/dnnl_where-inl.h | 73 ++++++++++
src/operator/nn/dnnl/dnnl_where.cc | 224 ++++++++++++++++++++++++++++++
src/operator/numpy/np_where_forward_op.cc | 48 ++++++-
4 files changed, 349 insertions(+), 3 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h
b/src/operator/nn/dnnl/dnnl_ops-inl.h
index 40e9449..06ed1e0 100644
--- a/src/operator/nn/dnnl/dnnl_ops-inl.h
+++ b/src/operator/nn/dnnl/dnnl_ops-inl.h
@@ -210,6 +210,13 @@ void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
const NDArray& input,
const OpReqType& req,
const NDArray& output);
+
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h
b/src/operator/nn/dnnl/dnnl_where-inl.h
new file mode 100644
index 0000000..bfda684
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_where-inl.h
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_where-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+#include <memory>
+#include <unordered_map>
+#include <vector>
+#include "dnnl_base-inl.h"
+#include "dnnl_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class DNNLWhereFwd {
+ public:
+ struct Tensors {
+ Tensors(const std::vector<NDArray>& inputs, const std::vector<NDArray>&
outputs);
+ const NDArray& condition;
+ const NDArray& left;
+ const NDArray& right;
+ const NDArray& output;
+ };
+
+ static DNNLWhereFwd GetCached(const Tensors& tensors);
+
+ explicit DNNLWhereFwd(const Tensors& tensors);
+
+ void Execute(const Tensors& tensors,
+ const std::vector<OpReqType>& req,
+ const OpContext& ctx) const;
+
+ private:
+ dnnl::binary::primitive_desc binary_eq_zero_pd;
+ dnnl::binary::primitive_desc binary_ne_zero_pd;
+ dnnl::binary::primitive_desc binary_mul_l_pd;
+ dnnl::binary::primitive_desc binary_mul_r_pd;
+ dnnl::binary::primitive_desc binary_sum_pd;
+ dnnl::binary binary_eq_zero;
+ dnnl::binary binary_ne_zero;
+ dnnl::binary binary_mul_l;
+ dnnl::binary binary_mul_r;
+ dnnl::binary binary_sum;
+};
+
+bool SupportDNNLWhere(const std::vector<NDArray>& inputs);
+
+} // namespace op
+} // namespace mxnet
+#endif
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_where.cc
b/src/operator/nn/dnnl/dnnl_where.cc
new file mode 100644
index 0000000..c2335b9
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_where.cc
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <algorithm>
+#include <set>
+#include <unordered_set>
+#include "dnnl_where-inl.h"
+#include "operator/operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLWhere(const std::vector<NDArray>& inputs) {
+ static const std::set<int> supported_dtypes = {
+ mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8};
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size()
<= 0 ||
+ inputs[i].shape().ndim() <= 0) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ TmpMemMgr::Get()->Init(ctx.requested[0]);
+ const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs);
+ const auto fwd = DNNLWhereFwd::GetCached(tensors);
+ fwd.Execute(tensors, req, ctx);
+}
+
+DNNLWhereFwd::Tensors::Tensors(const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs)
+ : condition(inputs[0]), left(inputs[1]), right(inputs[2]),
output(outputs[0]) {}
+
+DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) {
+ using where_op_fwd_map = std::unordered_map<OpSignature, DNNLWhereFwd,
OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local where_op_fwd_map fwds;
+#else
+ static MX_THREAD_LOCAL where_op_fwd_map fwds;
+#endif
+
+ OpSignature key;
+ key.AddSign(tensors.condition);
+ key.AddSign(tensors.left);
+ key.AddSign(tensors.right);
+ key.AddSign(tensors.output);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ DNNLWhereFwd fwd(tensors);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
+/*!
+ * \brief Align number of input dimensions to output. It is done by prepending
shape with ones.
+ * oneDNN requires shapes to have same number of dimensions even if
they are broadcastable.
+ * \param in_shape input shape which should be broadcastable with output
+ * \param out_shape output shape to which number of dimensions of input should
be aligned
+ * \return input shape extended with ones to match number of dimensions of
output
+ */
+static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape,
+ const mxnet::TShape& out_shape) {
+ if (in_shape == out_shape) {
+ return in_shape;
+ }
+
+ mxnet::TShape broadcastable_in_shape(out_shape.ndim(), 1);
+ const int lack_dims = out_shape.ndim() - in_shape.ndim();
+ for (int i = lack_dims; i < out_shape.ndim(); ++i) {
+ broadcastable_in_shape[i] = in_shape[i - lack_dims];
+ }
+ return broadcastable_in_shape;
+}
+
+DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) {
+ const auto cpu_engine = CpuEngine::Get()->get_engine();
+
+ const auto cnd = tensors.condition;
+ const auto lhs = tensors.left;
+ const auto rhs = tensors.right;
+ const auto out = tensors.output;
+
+ const auto cnd_shape = GetBroadcastableShape(cnd.shape(), out.shape());
+ const auto lhs_shape = GetBroadcastableShape(lhs.shape(), out.shape());
+ const auto rhs_shape = GetBroadcastableShape(rhs.shape(), out.shape());
+
+ const auto& cnd_dtype = get_dnnl_type(cnd.dtype());
+ const auto& inp_dtype = get_dnnl_type(lhs.dtype());
+ const auto& def_ft =
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(lhs_shape.ndim()));
+
+ const auto& cnd_dims = dnnl::memory::dims(cnd_shape.begin(),
cnd_shape.end());
+ const auto& lhs_dims = dnnl::memory::dims(lhs_shape.begin(),
lhs_shape.end());
+ const auto& rhs_dims = dnnl::memory::dims(rhs_shape.begin(),
rhs_shape.end());
+ const auto& out_dims = dnnl::memory::dims(out.shape().begin(),
out.shape().end());
+ const auto& scalar_dims = dnnl::memory::dims(cnd_shape.ndim(), 1); //
broadcastable scalar
+
+ auto cnd_md = dnnl::memory::desc(cnd_dims, cnd_dtype, def_ft);
+ auto lhs_md = dnnl::memory::desc(lhs_dims, inp_dtype, def_ft);
+ auto rhs_md = dnnl::memory::desc(rhs_dims, inp_dtype, def_ft);
+ auto out_md = dnnl::memory::desc(out_dims, inp_dtype, def_ft);
+ auto scalar_md = dnnl::memory::desc(scalar_dims, cnd_dtype, def_ft);
+
+ binary_ne_zero_pd = dnnl::binary::primitive_desc(
+ dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md,
cnd_md), cpu_engine);
+ binary_eq_zero_pd = dnnl::binary::primitive_desc(
+ dnnl::binary::desc(dnnl::algorithm::binary_eq, cnd_md, scalar_md,
cnd_md), cpu_engine);
+
+ // if broadcast is needed output must be larger in size
+ auto lmask_dim = lhs_shape.Size() > cnd_shape.Size() ? lhs_dims : cnd_dims;
+ auto lmask_md = dnnl::memory::desc(lmask_dim, inp_dtype, def_ft);
+ binary_mul_l_pd = dnnl::binary::primitive_desc(
+ dnnl::binary::desc(dnnl::algorithm::binary_mul, lhs_md, cnd_md,
lmask_md), cpu_engine);
+
+ auto rmask_dim = rhs_shape.Size() > cnd_shape.Size() ? rhs_dims : cnd_dims;
+ auto rmask_md = dnnl::memory::desc(rmask_dim, inp_dtype, def_ft);
+ binary_mul_r_pd = dnnl::binary::primitive_desc(
+ dnnl::binary::desc(dnnl::algorithm::binary_mul, rhs_md, cnd_md,
rmask_md), cpu_engine);
+
+ binary_sum_pd = dnnl::binary::primitive_desc(
+ dnnl::binary::desc(dnnl::algorithm::binary_add, lmask_md, rmask_md,
out_md), cpu_engine);
+
+ binary_ne_zero = dnnl::binary(binary_ne_zero_pd);
+ binary_eq_zero = dnnl::binary(binary_eq_zero_pd);
+ binary_mul_l = dnnl::binary(binary_mul_l_pd);
+ binary_mul_r = dnnl::binary(binary_mul_r_pd);
+ binary_sum = dnnl::binary(binary_sum_pd);
+}
+
+/*!
+ * \brief
+ * Execute where operator by oneDNN primitives.
+ * 1. Create tensor cnd_lhs = condition == 0 ==> convert 0 to 1 and all other
values to 0
+ * 2. Create tensor cnd_rhs = condition != 0 ==> convert all non-zero values
to 1
+ * 3. Mask lhs tensor by cnd_lhs => mask_lhs = lhs * cnd_lhs
+ * 4. Mask rhs tensor by cnd_hs => mask_rhs = rhs * cnd_rhs
+ * 5. output = mask_lhs + mask_rhs
+ */
+void DNNLWhereFwd::Execute(const Tensors& tensors,
+ const std::vector<OpReqType>& req,
+ const OpContext& ctx) const {
+ const auto& cpu_engine = CpuEngine::Get()->get_engine();
+ const auto& cpu_stream = ctx.get_stream<cpu>();
+
+ const auto& cnd_tensor =
tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc());
+ const auto& lhs_tensor =
tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc());
+ const auto& rhs_tensor =
tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc());
+
+ mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output,
binary_sum_pd.dst_desc(), req[0]);
+
+ const int dtype_size =
+ std::max(GetTypeSize(tensors.condition.dtype()),
GetTypeSize(tensors.left.dtype()));
+
+ // allocate temporary memory for 4 additional tensors
+ mshadow::Tensor<cpu, 1> tmp_workspace = ctx.requested[0].get_space<cpu>(
+ mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size),
cpu_stream);
+ char* workspace_ptr = reinterpret_cast<char*>(tmp_workspace.dptr_);
+ const int offset_size = tensors.output.shape().Size() * dtype_size;
+
+ dnnl::memory cnd_lhs(binary_ne_zero_pd.dst_desc(), cpu_engine,
workspace_ptr);
+ dnnl::memory cnd_rhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr
+ offset_size);
+ dnnl::memory masked_lhs(binary_mul_l_pd.dst_desc(), cpu_engine,
workspace_ptr + 2 * offset_size);
+ dnnl::memory masked_rhs(binary_mul_r_pd.dst_desc(), cpu_engine,
workspace_ptr + 3 * offset_size);
+
+ double zero{0};
+ dnnl::memory zero_scalar(binary_eq_zero_pd.src1_desc(), cpu_engine, &zero);
+
+ DNNLStream::Get()->RegisterPrimArgs(
+ binary_ne_zero,
+ {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar},
{DNNL_ARG_DST, cnd_lhs}});
+
+ DNNLStream::Get()->RegisterPrimArgs(
+ binary_eq_zero,
+ {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar},
{DNNL_ARG_DST, cnd_rhs}});
+
+ DNNLStream::Get()->RegisterPrimArgs(
+ binary_mul_l,
+ {{DNNL_ARG_SRC_0, *lhs_tensor}, {DNNL_ARG_SRC_1, cnd_lhs},
{DNNL_ARG_DST, masked_lhs}});
+
+ DNNLStream::Get()->RegisterPrimArgs(
+ binary_mul_r,
+ {{DNNL_ARG_SRC_0, *rhs_tensor}, {DNNL_ARG_SRC_1, cnd_rhs},
{DNNL_ARG_DST, masked_rhs}});
+
+ DNNLStream::Get()->RegisterPrimArgs(binary_sum,
+ {{DNNL_ARG_SRC_0, masked_lhs},
+ {DNNL_ARG_SRC_1, masked_rhs},
+ {DNNL_ARG_DST, *out_mem.second}});
+
+ CommitOutput(tensors.output, out_mem);
+ DNNLStream::Get()->Submit();
+}
+
+} // namespace op
+} // namespace mxnet
+#endif
diff --git a/src/operator/numpy/np_where_forward_op.cc
b/src/operator/numpy/np_where_forward_op.cc
index bef9b19..6caa58d 100644
--- a/src/operator/numpy/np_where_forward_op.cc
+++ b/src/operator/numpy/np_where_forward_op.cc
@@ -23,6 +23,7 @@
*/
#include "np_where_op-inl.h"
+#include "../nn/dnnl/dnnl_where-inl.h"
namespace mxnet {
namespace op {
@@ -89,6 +90,39 @@ inline bool NumpyWhereScalarOpType(const nnvm::NodeAttrs&
attrs,
DMLC_REGISTER_PARAMETER(NumpyWhereScalarParam);
DMLC_REGISTER_PARAMETER(NumpyWhereScalar2Param);
+#if MXNET_USE_ONEDNN == 1
+static void WhereForwardEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& op_ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK(!inputs.empty());
+ if (req[0] == kNullOp) {
+ return;
+ }
+ if (SupportDNNLWhere(inputs)) {
+ DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLWhereForward, attrs, op_ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(NumpyWhereOpForward<cpu>, attrs, op_ctx, inputs, req,
outputs);
+ } else {
+ FallBackCompute(NumpyWhereOpForward<cpu>, attrs, op_ctx, inputs, req,
outputs);
+ }
+}
+
+inline static bool WhereInferStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ return DNNLStorageType(attrs,
+ dev_mask,
+ /*support onednn*/ true,
+ dispatch_mode,
+ in_attrs,
+ out_attrs);
+}
+#endif // MXNET_USE_ONEDNN == 1
+
NNVM_REGISTER_OP(_npi_where)
.set_num_inputs(3)
.set_num_outputs(1)
@@ -103,11 +137,19 @@ NNVM_REGISTER_OP(_npi_where)
return std::vector<std::pair<int, int>
>{{1, 0}, {2, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyWhereOpForward<cpu>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& n) {
+ return
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+ .set_attr<FComputeEx>("FComputeEx<cpu>", WhereForwardEx)
+ .set_attr<bool>("TIsDNNL", true)
+ .set_attr<FInferStorageType>("FInferStorageType", WhereInferStorageType)
+#endif
.set_attr<nnvm::FGradient>(
"FGradient",
- // Use the following lambda function instead of ElemwiseGradUseIn
- // for best efficiency. grad[condition] = 0; to calculate grad[x] and
grad[y]
- // we need only condition from input.
+ // Use the following lambda function instead of ElemwiseGradUseIn for
best efficiency.
+ // grad[condition] = 0; to calculate grad[x] and grad[y] we need only
condition from input.
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>&
ograds) {
std::vector<nnvm::NodeEntry> ret;
// make zero grad node for grad[condition]