This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e25c88  Add oneDNN support for "where" operator (#20862)
6e25c88 is described below

commit 6e25c884bab56727ef0e9b81b444875d492d0587
Author: bgawrych <[email protected]>
AuthorDate: Fri Mar 4 09:26:20 2022 +0100

    Add oneDNN support for "where" operator (#20862)
    
    * Where operator enabled in oneDNN
    
    * Fix bug & refactor
    
    * fix sanity
    
    * apply review
    
    * Fix get_broadcastable_shape function
    
    * Apply review
    
    * Remove unused variable
    
    * Apply suggestions from code review
    
    Co-authored-by: bartekkuncer <[email protected]>
---
 src/operator/nn/dnnl/dnnl_ops-inl.h       |   7 +
 src/operator/nn/dnnl/dnnl_where-inl.h     |  73 ++++++++++
 src/operator/nn/dnnl/dnnl_where.cc        | 224 ++++++++++++++++++++++++++++++
 src/operator/numpy/np_where_forward_op.cc |  48 ++++++-
 4 files changed, 349 insertions(+), 3 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h 
b/src/operator/nn/dnnl/dnnl_ops-inl.h
index 40e9449..06ed1e0 100644
--- a/src/operator/nn/dnnl/dnnl_ops-inl.h
+++ b/src/operator/nn/dnnl/dnnl_ops-inl.h
@@ -210,6 +210,13 @@ void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
                         const NDArray& input,
                         const OpReqType& req,
                         const NDArray& output);
+
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h 
b/src/operator/nn/dnnl/dnnl_where-inl.h
new file mode 100644
index 0000000..bfda684
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_where-inl.h
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_where-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+#include <memory>
+#include <unordered_map>
+#include <vector>
+#include "dnnl_base-inl.h"
+#include "dnnl_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class DNNLWhereFwd {
+ public:
+  struct Tensors {
+    Tensors(const std::vector<NDArray>& inputs, const std::vector<NDArray>& 
outputs);
+    const NDArray& condition;
+    const NDArray& left;
+    const NDArray& right;
+    const NDArray& output;
+  };
+
+  static DNNLWhereFwd GetCached(const Tensors& tensors);
+
+  explicit DNNLWhereFwd(const Tensors& tensors);
+
+  void Execute(const Tensors& tensors,
+               const std::vector<OpReqType>& req,
+               const OpContext& ctx) const;
+
+ private:
+  dnnl::binary::primitive_desc binary_eq_zero_pd;
+  dnnl::binary::primitive_desc binary_ne_zero_pd;
+  dnnl::binary::primitive_desc binary_mul_l_pd;
+  dnnl::binary::primitive_desc binary_mul_r_pd;
+  dnnl::binary::primitive_desc binary_sum_pd;
+  dnnl::binary binary_eq_zero;
+  dnnl::binary binary_ne_zero;
+  dnnl::binary binary_mul_l;
+  dnnl::binary binary_mul_r;
+  dnnl::binary binary_sum;
+};
+
+bool SupportDNNLWhere(const std::vector<NDArray>& inputs);
+
+}  // namespace op
+}  // namespace mxnet
+#endif
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_where.cc 
b/src/operator/nn/dnnl/dnnl_where.cc
new file mode 100644
index 0000000..c2335b9
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_where.cc
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <algorithm>
+#include <set>
+#include <unordered_set>
+#include "dnnl_where-inl.h"
+#include "operator/operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLWhere(const std::vector<NDArray>& inputs) {
+  static const std::set<int> supported_dtypes = {
+      mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8};
+  for (int i = 0; i < inputs.size(); ++i) {
+    if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size() 
<= 0 ||
+        inputs[i].shape().ndim() <= 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  TmpMemMgr::Get()->Init(ctx.requested[0]);
+  const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs);
+  const auto fwd     = DNNLWhereFwd::GetCached(tensors);
+  fwd.Execute(tensors, req, ctx);
+}
+
+DNNLWhereFwd::Tensors::Tensors(const std::vector<NDArray>& inputs,
+                               const std::vector<NDArray>& outputs)
+    : condition(inputs[0]), left(inputs[1]), right(inputs[2]), 
output(outputs[0]) {}
+
+DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) {
+  using where_op_fwd_map = std::unordered_map<OpSignature, DNNLWhereFwd, 
OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local where_op_fwd_map fwds;
+#else
+  static MX_THREAD_LOCAL where_op_fwd_map fwds;
+#endif
+
+  OpSignature key;
+  key.AddSign(tensors.condition);
+  key.AddSign(tensors.left);
+  key.AddSign(tensors.right);
+  key.AddSign(tensors.output);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    DNNLWhereFwd fwd(tensors);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+/*!
+ * \brief Align number of input dimensions to output. It is done by prepending 
shape with ones.
+ *        oneDNN requires shapes to have same number of dimensions even if 
they are broadcastable.
+ * \param in_shape input shape which should be broadcastable with output
+ * \param out_shape output shape to which number of dimensions of input should 
be aligned
+ * \return input shape extended with ones to match number of dimensions of 
output
+ */
+static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape,
+                                           const mxnet::TShape& out_shape) {
+  if (in_shape == out_shape) {
+    return in_shape;
+  }
+
+  mxnet::TShape broadcastable_in_shape(out_shape.ndim(), 1);
+  const int lack_dims = out_shape.ndim() - in_shape.ndim();
+  for (int i = lack_dims; i < out_shape.ndim(); ++i) {
+    broadcastable_in_shape[i] = in_shape[i - lack_dims];
+  }
+  return broadcastable_in_shape;
+}
+
+DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) {
+  const auto cpu_engine = CpuEngine::Get()->get_engine();
+
+  const auto cnd = tensors.condition;
+  const auto lhs = tensors.left;
+  const auto rhs = tensors.right;
+  const auto out = tensors.output;
+
+  const auto cnd_shape = GetBroadcastableShape(cnd.shape(), out.shape());
+  const auto lhs_shape = GetBroadcastableShape(lhs.shape(), out.shape());
+  const auto rhs_shape = GetBroadcastableShape(rhs.shape(), out.shape());
+
+  const auto& cnd_dtype = get_dnnl_type(cnd.dtype());
+  const auto& inp_dtype = get_dnnl_type(lhs.dtype());
+  const auto& def_ft    = 
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(lhs_shape.ndim()));
+
+  const auto& cnd_dims    = dnnl::memory::dims(cnd_shape.begin(), 
cnd_shape.end());
+  const auto& lhs_dims    = dnnl::memory::dims(lhs_shape.begin(), 
lhs_shape.end());
+  const auto& rhs_dims    = dnnl::memory::dims(rhs_shape.begin(), 
rhs_shape.end());
+  const auto& out_dims    = dnnl::memory::dims(out.shape().begin(), 
out.shape().end());
+  const auto& scalar_dims = dnnl::memory::dims(cnd_shape.ndim(), 1);  // 
broadcastable scalar
+
+  auto cnd_md    = dnnl::memory::desc(cnd_dims, cnd_dtype, def_ft);
+  auto lhs_md    = dnnl::memory::desc(lhs_dims, inp_dtype, def_ft);
+  auto rhs_md    = dnnl::memory::desc(rhs_dims, inp_dtype, def_ft);
+  auto out_md    = dnnl::memory::desc(out_dims, inp_dtype, def_ft);
+  auto scalar_md = dnnl::memory::desc(scalar_dims, cnd_dtype, def_ft);
+
+  binary_ne_zero_pd = dnnl::binary::primitive_desc(
+      dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md, 
cnd_md), cpu_engine);
+  binary_eq_zero_pd = dnnl::binary::primitive_desc(
+      dnnl::binary::desc(dnnl::algorithm::binary_eq, cnd_md, scalar_md, 
cnd_md), cpu_engine);
+
+  // if broadcast is needed output must be larger in size
+  auto lmask_dim  = lhs_shape.Size() > cnd_shape.Size() ? lhs_dims : cnd_dims;
+  auto lmask_md   = dnnl::memory::desc(lmask_dim, inp_dtype, def_ft);
+  binary_mul_l_pd = dnnl::binary::primitive_desc(
+      dnnl::binary::desc(dnnl::algorithm::binary_mul, lhs_md, cnd_md, 
lmask_md), cpu_engine);
+
+  auto rmask_dim  = rhs_shape.Size() > cnd_shape.Size() ? rhs_dims : cnd_dims;
+  auto rmask_md   = dnnl::memory::desc(rmask_dim, inp_dtype, def_ft);
+  binary_mul_r_pd = dnnl::binary::primitive_desc(
+      dnnl::binary::desc(dnnl::algorithm::binary_mul, rhs_md, cnd_md, 
rmask_md), cpu_engine);
+
+  binary_sum_pd = dnnl::binary::primitive_desc(
+      dnnl::binary::desc(dnnl::algorithm::binary_add, lmask_md, rmask_md, 
out_md), cpu_engine);
+
+  binary_ne_zero = dnnl::binary(binary_ne_zero_pd);
+  binary_eq_zero = dnnl::binary(binary_eq_zero_pd);
+  binary_mul_l   = dnnl::binary(binary_mul_l_pd);
+  binary_mul_r   = dnnl::binary(binary_mul_r_pd);
+  binary_sum     = dnnl::binary(binary_sum_pd);
+}
+
+/*!
+ * \brief
+ * Execute where operator by oneDNN primitives.
+ * 1. Create tensor cnd_lhs = condition == 0 ==> convert 0 to 1 and all other 
values to 0
+ * 2. Create tensor cnd_rhs = condition != 0 ==> convert all non-zero values 
to 1
+ * 3. Mask lhs tensor by cnd_lhs => mask_lhs = lhs * cnd_lhs
+ * 4. Mask rhs tensor by cnd_hs => mask_rhs = rhs * cnd_rhs
+ * 5. output = mask_lhs + mask_rhs
+ */
+void DNNLWhereFwd::Execute(const Tensors& tensors,
+                           const std::vector<OpReqType>& req,
+                           const OpContext& ctx) const {
+  const auto& cpu_engine = CpuEngine::Get()->get_engine();
+  const auto& cpu_stream = ctx.get_stream<cpu>();
+
+  const auto& cnd_tensor = 
tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc());
+  const auto& lhs_tensor = 
tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc());
+  const auto& rhs_tensor = 
tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc());
+
+  mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, 
binary_sum_pd.dst_desc(), req[0]);
+
+  const int dtype_size =
+      std::max(GetTypeSize(tensors.condition.dtype()), 
GetTypeSize(tensors.left.dtype()));
+
+  // allocate temporary memory for 4 additional tensors
+  mshadow::Tensor<cpu, 1> tmp_workspace = ctx.requested[0].get_space<cpu>(
+      mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size), 
cpu_stream);
+  char* workspace_ptr   = reinterpret_cast<char*>(tmp_workspace.dptr_);
+  const int offset_size = tensors.output.shape().Size() * dtype_size;
+
+  dnnl::memory cnd_lhs(binary_ne_zero_pd.dst_desc(), cpu_engine, 
workspace_ptr);
+  dnnl::memory cnd_rhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr 
+ offset_size);
+  dnnl::memory masked_lhs(binary_mul_l_pd.dst_desc(), cpu_engine, 
workspace_ptr + 2 * offset_size);
+  dnnl::memory masked_rhs(binary_mul_r_pd.dst_desc(), cpu_engine, 
workspace_ptr + 3 * offset_size);
+
+  double zero{0};
+  dnnl::memory zero_scalar(binary_eq_zero_pd.src1_desc(), cpu_engine, &zero);
+
+  DNNLStream::Get()->RegisterPrimArgs(
+      binary_ne_zero,
+      {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, 
{DNNL_ARG_DST, cnd_lhs}});
+
+  DNNLStream::Get()->RegisterPrimArgs(
+      binary_eq_zero,
+      {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, 
{DNNL_ARG_DST, cnd_rhs}});
+
+  DNNLStream::Get()->RegisterPrimArgs(
+      binary_mul_l,
+      {{DNNL_ARG_SRC_0, *lhs_tensor}, {DNNL_ARG_SRC_1, cnd_lhs}, 
{DNNL_ARG_DST, masked_lhs}});
+
+  DNNLStream::Get()->RegisterPrimArgs(
+      binary_mul_r,
+      {{DNNL_ARG_SRC_0, *rhs_tensor}, {DNNL_ARG_SRC_1, cnd_rhs}, 
{DNNL_ARG_DST, masked_rhs}});
+
+  DNNLStream::Get()->RegisterPrimArgs(binary_sum,
+                                      {{DNNL_ARG_SRC_0, masked_lhs},
+                                       {DNNL_ARG_SRC_1, masked_rhs},
+                                       {DNNL_ARG_DST, *out_mem.second}});
+
+  CommitOutput(tensors.output, out_mem);
+  DNNLStream::Get()->Submit();
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif
diff --git a/src/operator/numpy/np_where_forward_op.cc 
b/src/operator/numpy/np_where_forward_op.cc
index bef9b19..6caa58d 100644
--- a/src/operator/numpy/np_where_forward_op.cc
+++ b/src/operator/numpy/np_where_forward_op.cc
@@ -23,6 +23,7 @@
  */
 
 #include "np_where_op-inl.h"
+#include "../nn/dnnl/dnnl_where-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -89,6 +90,39 @@ inline bool NumpyWhereScalarOpType(const nnvm::NodeAttrs& 
attrs,
 DMLC_REGISTER_PARAMETER(NumpyWhereScalarParam);
 DMLC_REGISTER_PARAMETER(NumpyWhereScalar2Param);
 
+#if MXNET_USE_ONEDNN == 1
+static void WhereForwardEx(const nnvm::NodeAttrs& attrs,
+                           const OpContext& op_ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
+  CHECK(!inputs.empty());
+  if (req[0] == kNullOp) {
+    return;
+  }
+  if (SupportDNNLWhere(inputs)) {
+    DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
+    DNNLRun(DNNLWhereForward, attrs, op_ctx, inputs, req, outputs);
+    DNNL_OPCHECK_RUN(NumpyWhereOpForward<cpu>, attrs, op_ctx, inputs, req, 
outputs);
+  } else {
+    FallBackCompute(NumpyWhereOpForward<cpu>, attrs, op_ctx, inputs, req, 
outputs);
+  }
+}
+
+inline static bool WhereInferStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_attrs,
+                                         std::vector<int>* out_attrs) {
+  return DNNLStorageType(attrs,
+                         dev_mask,
+                         /*support onednn*/ true,
+                         dispatch_mode,
+                         in_attrs,
+                         out_attrs);
+}
+#endif  // MXNET_USE_ONEDNN == 1
+
 NNVM_REGISTER_OP(_npi_where)
     .set_num_inputs(3)
     .set_num_outputs(1)
@@ -103,11 +137,19 @@ NNVM_REGISTER_OP(_npi_where)
                                       return std::vector<std::pair<int, int> 
>{{1, 0}, {2, 0}};
                                     })
     .set_attr<FCompute>("FCompute<cpu>", NumpyWhereOpForward<cpu>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& n) {
+                                  return 
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
+    .set_attr<FComputeEx>("FComputeEx<cpu>", WhereForwardEx)
+    .set_attr<bool>("TIsDNNL", true)
+    .set_attr<FInferStorageType>("FInferStorageType", WhereInferStorageType)
+#endif
     .set_attr<nnvm::FGradient>(
         "FGradient",
-        // Use the following lambda function instead of ElemwiseGradUseIn
-        // for best efficiency. grad[condition] = 0; to calculate grad[x] and 
grad[y]
-        // we need only condition from input.
+        // Use the following lambda function instead of ElemwiseGradUseIn for 
best efficiency.
+        // grad[condition] = 0; to calculate grad[x] and grad[y] we need only 
condition from input.
         [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& 
ograds) {
           std::vector<nnvm::NodeEntry> ret;
           // make zero grad node for grad[condition]

Reply via email to