This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2616275  Add mkldnn OP for slice (#13730)
2616275 is described below

commit 26162752b98c840aaabbafa13a75822705ac78b3
Author: zhiyuan-huang <[email protected]>
AuthorDate: Wed Jan 16 23:08:54 2019 +0800

    Add mkldnn OP for slice (#13730)
    
    * add mkldnn slice
    
    * fix lint
    
    * fix lint
    
    * mv SliceEx to matrix_op.cc
    
    * fix lint
    
    * optimize dispatch_mode
    
    * retrigger ci
    
    * fix indent
---
 src/operator/nn/mkldnn/mkldnn_slice-inl.h |  66 +++++++++++++++++++
 src/operator/nn/mkldnn/mkldnn_slice.cc    | 104 ++++++++++++++++++++++++++++++
 src/operator/tensor/matrix_op-inl.h       |  37 ++++++-----
 src/operator/tensor/matrix_op.cc          |  30 ++++++++-
 src/operator/tensor/slice-inl.h           |  71 ++++++++++++++++++++
 5 files changed, 292 insertions(+), 16 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h 
b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
new file mode 100644
index 0000000..f41db01
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_slice-inl.h
+ * \brief
+ * \author Zhiyuan Huang
+*/
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <utility>
+#include "../../operator_common.h"
+#include "../../tensor/slice-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNSliceFwd {
+ public:
+  MKLDNNSliceFwd(const SliceParam &param,
+                 const NDArray &in,
+                 const NDArray &out);
+  void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
+  const mkldnn::reorder &GetPd() const;
+
+ private:
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::reorder> fwd_;
+};
+
+typedef ParamOpSign<SliceParam> MKLDNNSliceSignature;
+MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
+                 const NDArray &in_data, const NDArray &out_data);
+
+void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
+                 const NDArray &in, OpReqType req, const NDArray &out);
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc 
b/src/operator/nn/mkldnn/mkldnn_slice.cc
new file mode 100644
index 0000000..f3c8a14
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_slice.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_slice.cc
+ * \brief
+ * \author Zhiyuan Huang
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_slice-inl.h"
+
+namespace mxnet {
+namespace op {
+
+MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam &param,
+                               const NDArray &in,
+                               const NDArray &out) {
+  const TShape ishape = in.shape();
+  const TShape oshape = out.shape();
+  uint32_t N = ishape.ndim();
+  mkldnn::memory::dims dims(N);
+  mkldnn::memory::dims offsets(N);
+  for (uint32_t i = 0; i < N; ++i) {
+    int s = 0;
+    if (param.begin[i]) {
+      s = *param.begin[i];
+      if (s < 0) s += ishape[i];
+    }
+    dims[i] = oshape[i];
+    offsets[i] = s;
+  }
+  auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc();
+  auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
+  auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
+  auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), 
out_mem_pd);
+  this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), 
nullptr);
+  this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), 
nullptr);
+  this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, 
*this->out_);
+}
+
+void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const 
mkldnn::memory &output) {
+  this->data_->set_data_handle(input.get_data_handle());
+  this->out_->set_data_handle(output.get_data_handle());
+}
+
+const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
+  return *fwd_;
+}
+
+MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
+                                const NDArray &in_data, const NDArray 
&out_data) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, 
OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNSliceSignature, 
MKLDNNSliceFwd, OpHash> fwds;
+#endif
+  MKLDNNSliceSignature key(param);
+  key.AddSign(is_train);
+  key.AddSign(in_data);
+  key.AddSign(out_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNSliceFwd fwd(param, in_data, out_data);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
+                 const NDArray &in, OpReqType req, const NDArray &out) {
+  MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out);
+  auto in_mem = in.GetMKLDNNData();
+  auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
+  auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
+  fwd.SetNewMem(*in_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
+  CommitOutput(out, out_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 3b229cf..8b575ca 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -37,6 +37,7 @@
 #include "broadcast_reduce_op.h"
 #include "./init_op.h"
 #include "../../common/static_array.h"
+#include "./slice-inl.h"
 
 #if MXNET_USE_CUDA
 #include <thrust/device_vector.h>
@@ -398,19 +399,15 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-struct SliceParam : public dmlc::Parameter<SliceParam> {
-  nnvm::Tuple<dmlc::optional<int>> begin, end;
-  nnvm::Tuple<dmlc::optional<int>> step;
-  DMLC_DECLARE_PARAMETER(SliceParam) {
-    DMLC_DECLARE_FIELD(begin)
-    .describe("starting indices for the slice operation, supports negative 
indices.");
-    DMLC_DECLARE_FIELD(end)
-    .describe("ending indices for the slice operation, supports negative 
indices.");
-    DMLC_DECLARE_FIELD(step)
-    .set_default(nnvm::Tuple<dmlc::optional<int>>())
-    .describe("step for the slice operation, supports negative values.");
+// Currently MKLDNN only supports step = 1 or step has no value
+inline bool SupportMKLDNNSlice(const SliceParam& param) {
+  if (param.step.ndim() == 0U) return true;
+  for (uint32_t i = 0; i < param.step.ndim(); ++i) {
+    if (param.step[i].has_value() && param.step[i].value() != 1)
+      return false;
   }
-};
+  return true;
+}
 
 inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
@@ -432,9 +429,19 @@ inline bool SliceForwardInferStorageType(const 
nnvm::NodeAttrs& attrs,
       && (!param.step[0].has_value() || param.step[0].value() == 1)) {
     trivial_step = true;
   }
-  if (!dispatched && in_stype == kDefaultStorage) {
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFCompute);
+
+  if (in_stype == kDefaultStorage) {
+#if MXNET_USE_MKLDNN == 1
+    if (dev_mask == Context::kCPU && MKLDNNEnvSet()
+        && SupportMKLDNNSlice(param)) {
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                       dispatch_mode, dispatch_ex);
+    }
+#endif
+    if (!dispatched) {
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                       dispatch_mode, DispatchMode::kFCompute);
+    }
   }
 
   if (!dispatched && in_stype == kCSRStorage && trivial_step) {
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index db8efa4..ed8912f 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -27,6 +27,7 @@
 #include "./elemwise_unary_op.h"
 #include "../nn/mkldnn/mkldnn_ops-inl.h"
 #include "../nn/mkldnn/mkldnn_base-inl.h"
+#include "../nn/mkldnn/mkldnn_slice-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -420,6 +421,30 @@ will return a new array with shape ``(2,1,3,4)``.
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .add_arguments(ExpandDimParam::__FIELDS__());
 
+void SliceExCPU(const nnvm::NodeAttrs& attrs,
+                const OpContext& ctx,
+                const std::vector<NDArray>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1);
+  CHECK_EQ(outputs.size(), 1);
+  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  auto in_stype = inputs[0].storage_type();
+  if (in_stype == kCSRStorage) {
+    SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
+#if MXNET_USE_MKLDNN == 1
+  } else if (in_stype == kDefaultStorage) {
+    if (SupportMKLDNN(inputs[0])) {
+      MKLDNNSlice(param, ctx, inputs[0], req[0], outputs[0]);
+    } else {
+      FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
+    }
+#endif
+  } else {
+    LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
+  }
+}
+
 NNVM_REGISTER_OP(slice)
 MXNET_ADD_SPARSE_OP_ALIAS(slice)
 .add_alias("crop")
@@ -478,7 +503,10 @@ Example::
 .set_attr<FInferStorageType>("FInferStorageType", SliceForwardInferStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
 .set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", SliceEx<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", SliceExCPU)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+#endif
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .add_arguments(SliceParam::__FIELDS__());
 
diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h
new file mode 100644
index 0000000..4e94cbe
--- /dev/null
+++ b/src/operator/tensor/slice-inl.h
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file slice-inl.h
+ * \brief
+ * \author Zhiyuan Huang
+*/
+
+#ifndef MXNET_OPERATOR_TENSOR_SLICE_INL_H_
+#define MXNET_OPERATOR_TENSOR_SLICE_INL_H_
+
+#include <utility>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+struct SliceParam : public dmlc::Parameter<SliceParam> {
+  nnvm::Tuple<dmlc::optional<int>> begin, end;
+  nnvm::Tuple<dmlc::optional<int>> step;
+  DMLC_DECLARE_PARAMETER(SliceParam) {
+    DMLC_DECLARE_FIELD(begin)
+    .describe("starting indices for the slice operation, supports negative 
indices.");
+    DMLC_DECLARE_FIELD(end)
+    .describe("ending indices for the slice operation, supports negative 
indices.");
+    DMLC_DECLARE_FIELD(step)
+    .set_default(nnvm::Tuple<dmlc::optional<int>>())
+    .describe("step for the slice operation, supports negative values.");
+  }
+  bool operator==(const SliceParam& other) const {
+    return this->begin == other.begin &&
+           this->end == other.end &&
+           this->step == other.step;
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+namespace std {
+template<>
+struct hash<mxnet::op::SliceParam> {
+  size_t operator()(const mxnet::op::SliceParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.begin);
+    ret = dmlc::HashCombine(ret, val.end);
+    ret = dmlc::HashCombine(ret, val.step);
+    return ret;
+  }
+};
+}  // namespace std
+
+#endif  // MXNET_OPERATOR_TENSOR_SLICE_INL_H_

Reply via email to