This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ef2be51265 Refactor SupportDNNL functions (#21032)
ef2be51265 is described below
commit ef2be5126592bb73bf3e42c29722b93b89b250dc
Author: AdamGrabowski <[email protected]>
AuthorDate: Thu Jun 23 09:41:21 2022 +0200
Refactor SupportDNNL functions (#21032)
* Refactor SupportDNNL functions
* FixSupportDNNL for quantized reshape
* Update src/operator/nn/dnnl/dnnl_softmax_output.cc
Co-authored-by: bgawrych <[email protected]>
* Update src/operator/nn/dnnl/dnnl_reduce.cc
Co-authored-by: bgawrych <[email protected]>
---
src/ndarray/ndarray.cc | 4 +-
src/operator/contrib/adaptive_avg_pooling.cc | 22 ++--
src/operator/contrib/batch_norm_relu.cc | 19 +--
src/operator/nn/batch_norm.cc | 19 +--
src/operator/nn/concat.cc | 11 +-
src/operator/nn/dnnl/dnnl_act.cc | 17 +--
src/operator/nn/dnnl/dnnl_base-inl.h | 142 ++++++++++++++-------
src/operator/nn/dnnl/dnnl_batch_dot.cc | 9 +-
src/operator/nn/dnnl/dnnl_binary.cc | 11 +-
src/operator/nn/dnnl/dnnl_convolution.cc | 3 +-
src/operator/nn/dnnl/dnnl_deconvolution.cc | 4 +-
src/operator/nn/dnnl/dnnl_dot.cc | 11 +-
src/operator/nn/dnnl/dnnl_eltwise.cc | 10 +-
src/operator/nn/dnnl/dnnl_layer_norm.cc | 12 +-
src/operator/nn/dnnl/dnnl_log_softmax.cc | 16 +--
src/operator/nn/dnnl/dnnl_masked_softmax.cc | 7 +-
src/operator/nn/dnnl/dnnl_pooling-inl.h | 14 +-
src/operator/nn/dnnl/dnnl_reduce.cc | 11 +-
src/operator/nn/dnnl/dnnl_reshape.cc | 8 +-
src/operator/nn/dnnl/dnnl_rnn-inl.h | 1 +
src/operator/nn/dnnl/dnnl_softmax.cc | 22 +---
src/operator/nn/dnnl/dnnl_softmax_output.cc | 6 +-
src/operator/nn/dnnl/dnnl_split.cc | 6 -
src/operator/nn/dnnl/dnnl_stack.cc | 22 +---
src/operator/nn/dnnl/dnnl_transpose-inl.h | 2 -
src/operator/nn/dnnl/dnnl_transpose.cc | 10 --
src/operator/nn/dnnl/dnnl_where.cc | 11 +-
src/operator/nn/fully_connected.cc | 7 +-
src/operator/nn/log_softmax.cc | 4 +-
src/operator/nn/lrn.cc | 4 +-
src/operator/nn/masked_softmax.cc | 3 +-
src/operator/nn/softmax.cc | 4 +-
src/operator/numpy/np_broadcast_reduce_op_value.h | 1 -
src/operator/numpy/np_dot_forward.cc | 2 +-
src/operator/numpy/np_matrix_op.cc | 2 +-
.../quantization/dnnl/dnnl_quantized_flatten.cc | 2 +-
.../quantization/dnnl/dnnl_quantized_reshape.cc | 2 +-
.../quantization/dnnl/dnnl_quantized_transpose.cc | 8 +-
src/operator/softmax_output.cc | 2 +-
src/operator/subgraph/dnnl/dnnl_conv_property.h | 2 +-
src/operator/subgraph/dnnl/dnnl_fc_property.h | 2 +-
src/operator/tensor/dot.cc | 4 +-
src/operator/tensor/elemwise_binary_op_basic.cc | 11 +-
src/operator/tensor/elemwise_unary_op.h | 2 +-
src/operator/tensor/matrix_op-inl.h | 1 +
src/operator/tensor/matrix_op.cc | 18 ++-
46 files changed, 230 insertions(+), 281 deletions(-)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 605e705ae7..902880fb1d 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1326,8 +1326,8 @@ inline void CopyFromToDnsImpl(const NDArray& from, const
NDArray& to, RunContext
TBlob tmp = to.data();
ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp, from.ctx(), to.ctx(),
ctx);
#if MXNET_USE_ONEDNN == 1
- } else if (SupportDNNL(from.dtype(), from.shape()) &&
SupportDNNL(to.dtype(), to.shape()) &&
- from.ctx().dev_mask() == cpu::kDevMask && to.ctx().dev_mask() ==
cpu::kDevMask) {
+ } else if (SupportDNNL(from) && SupportDNNL(to) && from.ctx().dev_mask() ==
cpu::kDevMask &&
+ to.ctx().dev_mask() == cpu::kDevMask) {
// If we copy data directly, we need to make sure both NDArrays are
supported
// by DNNL.
auto from_mem = from.GetDNNLData();
diff --git a/src/operator/contrib/adaptive_avg_pooling.cc
b/src/operator/contrib/adaptive_avg_pooling.cc
index c2614d9456..69d43e0060 100644
--- a/src/operator/contrib/adaptive_avg_pooling.cc
+++ b/src/operator/contrib/adaptive_avg_pooling.cc
@@ -172,10 +172,11 @@ void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<cpu>* s,
}
#if MXNET_USE_ONEDNN == 1
-bool SupportDNNLAveragePooling(const NDArray& in_data, const NDArray&
out_data) {
- for (int64_t idx = 2; idx < in_data.shape().ndim(); ++idx) {
- const int s1 = in_data.shape()[idx];
- const int s2 = out_data.shape()[idx];
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_pooling.html
+bool SupportDNNLAveragePooling(const NDArray& input, const NDArray& output) {
+ for (int64_t idx = 2; idx < input.shape().ndim(); ++idx) {
+ const int s1 = input.shape()[idx];
+ const int s2 = output.shape()[idx];
if (s2 == 0) {
return false;
}
@@ -183,17 +184,18 @@ bool SupportDNNLAveragePooling(const NDArray& in_data,
const NDArray& out_data)
return false;
}
}
- const int IH = in_data.shape()[2];
- const int IW = in_data.shape()[3];
- const int OH = out_data.shape()[2];
- const int OW = out_data.shape()[3];
+ const int IH = input.shape()[2];
+ const int IW = input.shape()[3];
+ const int OH = output.shape()[2];
+ const int OW = output.shape()[3];
const int strides_H = ((IH << 1) / OH) - (IH / OH);
const int strides_W = ((IW << 1) / OW) - (IW / OW);
const int kernel_H = DIV_ROUND_UP((IH << 1) / OH, 1) - (IH / OH);
const int kernel_W = DIV_ROUND_UP((IW << 1) / OW, 1) - (IW / OW);
const int pad_l_top = (strides_H * (OH - 1) + kernel_H - IH) / 2;
const int pad_l_left = (strides_W * (OW - 1) + kernel_W - IW) / 2;
- return pad_l_top == 0 && pad_l_left == 0;
+
+ return SupportDNNL<3, 5, DNNLTypeMode::AllTypes>(input) && pad_l_top == 0 &&
pad_l_left == 0;
}
void AdaptiveAvgPoolOpBackwardExCPU(const nnvm::NodeAttrs& attrs,
@@ -236,7 +238,7 @@ void AdaptiveAvgPoolComputeExCPU(const nnvm::NodeAttrs&
attrs,
oneDNN doesn't support adaptive pooling.
Fallback is needed when padding is not equal 0;
*/
- if (SupportDNNL(inputs[0]) && SupportDNNLAveragePooling(inputs[0],
outputs[0])) {
+ if (SupportDNNLAveragePooling(inputs[0], outputs[0])) {
DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/contrib/batch_norm_relu.cc
b/src/operator/contrib/batch_norm_relu.cc
index 3d0cbb1493..a0f158f42b 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -130,16 +130,9 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs&
attrs,
}
#if MXNET_USE_ONEDNN == 1
-static inline bool SupportDNNLBNReLU(const NDArray& input, const
BatchNormParam& param) {
- if (mxnet::op::batchnorm::disable_mkl)
- return false;
- const mxnet::TShape shape = input.shape();
- const int ndim = shape.ndim();
- if (ndim == 0 || shape.Size() == 0)
- return false;
- const int dtype = input.dtype();
- return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
- SupportStorageDNNL(input.storage_type());
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_batch_normalization.html
+static inline bool SupportDNNLBNReLU(const NDArray& input) {
+ return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(input) &&
!mxnet::op::batchnorm::disable_mkl;
}
void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -148,8 +141,7 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 5U);
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- if (SupportDNNLBNReLU(inputs[0], param)) {
+ if (SupportDNNLBNReLU(inputs[0])) {
CHECK_GT(outputs.size(), 3U);
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
@@ -165,8 +157,7 @@ void BatchNormWithReLUGradComputeExCPU(const
nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- if (SupportDNNLBNReLU(inputs[0], param)) {
+ if (SupportDNNLBNReLU(inputs[0])) {
CHECK_EQ(inputs.size(), 9U);
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx,
inputs, req, outputs);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index fa422f156c..dc09ebeb22 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -462,16 +462,9 @@ static bool BNChangeLayout(nnvm::NodeAttrs* attrs,
}
#if MXNET_USE_ONEDNN == 1
-static inline bool SupportDNNLBN(const NDArray& input, const BatchNormParam&
param) {
- if (mxnet::op::batchnorm::disable_mkl)
- return false;
- const mxnet::TShape shape = input.shape();
- const int ndim = shape.ndim();
- if (ndim == 0 || shape.Size() == 0)
- return false;
- const int dtype = input.dtype();
- return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
- SupportStorageDNNL(input.storage_type());
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_batch_normalization.html
+static inline bool SupportDNNLBN(const NDArray& input) {
+ return SupportDNNL<DNNLTypeMode::FloatTypes>(input) &&
!mxnet::op::batchnorm::disable_mkl;
}
void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -480,8 +473,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 5U);
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- if (SupportDNNLBN(inputs[0], param)) {
+ if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx,
inputs, req, outputs);
@@ -497,8 +489,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- if (SupportDNNLBN(inputs[0], param)) {
+ if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx,
inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req,
outputs);
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 70b8aeb9f8..6fd1dc18e4 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -248,24 +248,21 @@ inline static bool BackwardConcatStorageType(const
nnvm::NodeAttrs& attrs,
return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode,
wanted_mode);
}
#if MXNET_USE_ONEDNN == 1
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_concat.html
bool SupportDNNLConcat(const std::vector<NDArray>& arrs) {
for (auto& arr : arrs) {
- if (arr.IsView())
- return false;
- if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() ==
mshadow::kBfloat16))
- return false;
- // Do not support zero-size tensors.
- if (arr.shape().Size() == 0 || arr.shape().ndim() == 0)
+ if (arr.IsView() || !SupportDNNL<2, 12, AllTypes>(arr))
return false;
int ndim = arr.shape().ndim();
const int dnnl_ndims = arr.GetDNNLData()->get_desc().data.ndims;
- if ((ndim != 2 && ndim != 4) || ndim != dnnl_ndims) {
+ if (ndim != dnnl_ndims) {
return false;
}
}
return true;
}
#endif // MXNET_USE_ONEDNN == 1
+
static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
diff --git a/src/operator/nn/dnnl/dnnl_act.cc b/src/operator/nn/dnnl/dnnl_act.cc
index f3ee79a3f6..e1fb642209 100644
--- a/src/operator/nn/dnnl/dnnl_act.cc
+++ b/src/operator/nn/dnnl/dnnl_act.cc
@@ -48,12 +48,9 @@ bool SupportDNNLAct(const ActivationParam& param) {
param.act_type == activation::kSoftReLU || param.act_type ==
activation::kTanh;
}
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
bool SupportDNNLAct(const ActivationParam& param, const NDArray& input) {
- // DNNL Activation supports 1d, 2d, 3d, 4d and 5d data layout
- if ((input.shape().ndim() < 1) || (input.shape().ndim() > 5) ||
- !(input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16))
- return false;
- return SupportDNNLAct(param);
+ return SupportDNNL<DNNLTypeMode::FloatTypes>(input) && SupportDNNLAct(param);
}
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param) {
@@ -61,15 +58,13 @@ bool SupportDNNLLeakyRelu(const LeakyReLUParam& param) {
param.act_type == leakyrelu::kGELU;
}
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input) {
- // DNNL Activation supports 1d, 2d, 3d, 4d and 5d data layout
- if ((input.shape().ndim() < 1) || (input.shape().ndim() > 5) ||
- !(input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16))
- return false;
- return SupportDNNLLeakyRelu(param);
+ return SupportDNNL<DNNLTypeMode::FloatTypes>(input) &&
SupportDNNLLeakyRelu(param);
}
-bool SupportQuantizedDNNLAct(const ActivationParam& param) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLQuantizedAct(const ActivationParam& param) {
// Although it is the same as SupportDNNLAct i left it here, so when new
activations
// will be introduced it will be easier to handle.
return SupportDNNLAct(param);
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 8161d62c24..e8cf9c44b3 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -35,6 +35,8 @@
#include <unordered_map>
#include <utility>
#include <vector>
+#include <map>
+#include <set>
#include "dnnl.hpp"
#include "mxnet/graph_attr_types.h"
@@ -114,36 +116,84 @@ struct data_type_enum<uint8_t> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::u8) };
};
-static inline bool SupportDNNLArray(int dtype, const mxnet::TShape& shape) {
- int ndim = shape.ndim();
- bool support = ndim == 1 || ndim == 2 || ndim == 4;
- support = support &&
- (dtype == mshadow::kFloat32 || dtype == mshadow::kInt32 || dtype
== mshadow::kInt8 ||
- dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16);
- return support;
+// SupportDNNL variant that is used to check tensor's dimensions.
+template <int MinNdim, int MaxNdim>
+static inline bool SupportDNNLShape(const mxnet::TShape& shape) {
+ const int ndim = shape.ndim();
+ return ndim >= MinNdim && ndim <= MaxNdim && shape.Size() != 0;
}
-static inline bool SupportStorageDNNL(int stype) {
- return stype == kDefaultStorage;
+// SupportDNNL variant that is used to check defined type combinations.
+// If any other combination is necessary new variant should be implemented.
+enum DNNLTypeMode { AllTypes, NoInt32, IntTypes, FloatTypes, ByteTypes };
+
+// Mapping of DNNLTypeMode variant into set of desired type combinations.
+// clang-format off
+static std::map<DNNLTypeMode, std::set<int>> DNNLTypeModeMap = {
+ {DNNLTypeMode::AllTypes,
+ {mshadow::kBfloat16, mshadow::kFloat32, mshadow::kInt32, mshadow::kInt8,
mshadow::kUint8}},
+ {DNNLTypeMode::NoInt32,
+ {mshadow::kBfloat16, mshadow::kFloat32, mshadow::kInt8,
mshadow::kUint8}},
+ {DNNLTypeMode::IntTypes,
+ {mshadow::kInt32, mshadow::kInt8, mshadow::kUint8}},
+ {DNNLTypeMode::FloatTypes,
+ {mshadow::kBfloat16, mshadow::kFloat32}},
+ {DNNLTypeMode::ByteTypes,
+ {mshadow::kInt8, mshadow::kUint8}}};
+// clang-format on
+
+template <DNNLTypeMode TypeMode>
+inline bool SupportDNNLType(int dtype) {
+ return DNNLTypeModeMap[TypeMode].count(dtype);
}
-static inline bool SupportDNNL(int dtype, const mxnet::TShape& shape) {
- int ndim = shape.ndim();
- if (ndim == 0 || shape.Size() == 0) {
- // DNNL currently does not support 0-dim Tensor and 0-size Tensor
- return false;
- }
- return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
- (ndim == 1 || ndim == 2 || ndim == 4);
+// SupportDNNL variants:
+// Default variant used to check widest dimension possible
+// and all possible data types for given tensor.
+static inline bool SupportDNNL(const NDArray& tensor) {
+ return SupportDNNLType<AllTypes>(tensor.dtype()) && SupportDNNLShape<1,
12>(tensor.shape());
}
-static inline bool IsDNNLType(int dtype) {
- return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype ==
mshadow::kUint8 ||
- dtype == mshadow::kBfloat16;
+// Variant checking default shape (1,12) but gives possiblity to check
+// any implemented type combination.
+template <DNNLTypeMode TypeMode>
+static inline bool SupportDNNL(const NDArray& tensor) {
+ return SupportDNNLType<TypeMode>(tensor.dtype()) && SupportDNNLShape<1,
12>(tensor.shape());
}
-static inline bool SupportDNNL(const NDArray& input) {
- return SupportDNNL(input.dtype(), input.shape()) &&
SupportStorageDNNL(input.storage_type());
+// Variant with possiblity to check arbitrary shapes and type combination.
+template <int MinNdim, int MaxNdim, DNNLTypeMode TypeMode>
+static inline bool SupportDNNL(const NDArray& tensor) {
+ return SupportDNNLType<TypeMode>(tensor.dtype()) &&
+ SupportDNNLShape<MinNdim, MaxNdim>(tensor.shape());
+}
+
+// Variant checking multiple inputs at the same time
+// with possibility to check their types independantly
+// or check if those types are the same.
+enum DNNLTensorsDtypes { AllSame = 0, Mixed = 1 };
+
+template <int MinNdim, int MaxNdim, DNNLTypeMode TypeMode, DNNLTensorsDtypes
MixedTensors>
+static inline bool SupportDNNL(const std::vector<NDArray>& inputs) {
+ int dtype = MixedTensors ? -1 : inputs[0].dtype();
+ if (!SupportDNNLType<TypeMode>(dtype)) {
+ return false;
+ }
+ for (NDArray input : inputs) {
+ if (dtype == -1) {
+ if (!SupportDNNL<MinNdim, MaxNdim, TypeMode>(input))
+ return false;
+ } else {
+ if (input.dtype() != dtype && !SupportDNNLShape<MinNdim,
MaxNdim>(input.shape()))
+ return false;
+ }
+ }
+ return true;
+}
+
+template <DNNLTypeMode TypeMode, DNNLTensorsDtypes MixedTensors>
+static inline bool SupportDNNL(const std::vector<NDArray>& inputs) {
+ return SupportDNNL<1, 12, TypeMode, MixedTensors>(inputs);
}
static inline bool SupportDNNLQuantize(const int out_type) {
@@ -180,37 +230,39 @@ void* AlignMem(void* mem, size_t size, size_t alignment,
size_t* space);
namespace op {
struct ActivationParam;
-struct LeakyReLUParam;
struct ConvolutionParam;
struct DeconvolutionParam;
-struct SoftmaxParam;
+struct LayerNormParam;
+struct LeakyReLUParam;
struct MaskedSoftmaxParam;
-struct SoftmaxOutputParam;
struct ReshapeParam;
-struct LayerNormParam;
+struct SliceParam;
+struct SoftmaxParam;
+struct SoftmaxOutputParam;
bool SupportDNNLAct(const ActivationParam& param);
bool SupportDNNLAct(const ActivationParam& param, const NDArray& input);
-bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
-bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
-bool SupportQuantizedDNNLAct(const ActivationParam& param);
+bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
+bool SupportDNNLConcat(const std::vector<NDArray>& arrs);
bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input);
bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input);
-bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input, const
NDArray& output);
-bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input,
const NDArray& output);
-bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param,
- const std::vector<NDArray>& input,
- const NDArray& output);
-bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
-bool SupportDNNLTranspose(const NDArray& data);
-bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output);
-bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray&
output);
+bool SupportDNNLDot(const std::vector<NDArray>& inputs);
+bool SupportDNNLEltwise(const NDArray& input);
+bool SupportDNNLFC(const NDArray& input);
bool SupportDNNLLayerNorm(const LayerNormParam& param, const
std::vector<NDArray>& inputs);
-bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
-bool SupportDNNLSplit(const NDArray& input);
-bool SupportDNNLStack(const std::vector<NDArray>& inputs);
-bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
-bool SupportDNNLEltwise(const NDArray& input, const NDArray& output);
+bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
+bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
+bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input);
+bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const
std::vector<NDArray>& input);
bool SupportDNNLPower(const NDArray& input);
+bool SupportDNNLQuantizedAct(const ActivationParam& param);
+bool SupportDNNLReshape(const NDArray& input);
+bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const
NDArray& output);
+bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input);
+bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param, const NDArray&
input);
+bool SupportDNNLStack(const std::vector<NDArray>& inputs);
+bool SupportDNNLSum(const std::vector<NDArray>& inputs);
+
void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const
dnnl::memory& out);
} // namespace op
@@ -445,9 +497,9 @@ class TmpMemMgr {
typedef std::unordered_map<int, dnnl::memory> dnnl_args_map_t;
class DNNLStream {
- std::vector<std::pair<dnnl::primitive, dnnl_args_map_t> > net_prim_args;
+ std::vector<std::pair<dnnl::primitive, dnnl_args_map_t>> net_prim_args;
// Here we hold all memory related to the operators in the stream.
- std::vector<std::shared_ptr<const dnnl::memory> > mem_holder;
+ std::vector<std::shared_ptr<const dnnl::memory>> mem_holder;
dnnl::stream s;
public:
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot.cc
b/src/operator/nn/dnnl/dnnl_batch_dot.cc
index 4534feafd1..a40d55621c 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_batch_dot.cc
@@ -32,11 +32,10 @@ namespace op {
DMLC_REGISTER_PARAMETER(DNNLDotParam);
-bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray&
output) {
- return inputs[DotIn::lhs].shape().Size() != 0 &&
inputs[DotIn::rhs].shape().Size() != 0 &&
- output.shape().Size() != 0 &&
- (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
- inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_matmul.html
+bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs) {
+ return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::lhs]) &&
+ SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::rhs]);
}
DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_binary.cc
b/src/operator/nn/dnnl/dnnl_binary.cc
index 2ec74f41a9..75c4805fb7 100644
--- a/src/operator/nn/dnnl/dnnl_binary.cc
+++ b/src/operator/nn/dnnl/dnnl_binary.cc
@@ -63,15 +63,10 @@ void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>&
inputs,
DNNLStream::Get()->Submit();
}
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
- auto dtype_0 = inputs[0].dtype();
- auto dtype_1 = inputs[1].dtype();
- auto ndim_0 = inputs[0].shape().ndim();
- auto ndim_1 = inputs[1].shape().ndim();
- return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
- inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
- (dtype_0 == mshadow::kFloat32 || dtype_0 == mshadow::kBfloat16) &&
- (dtype_1 == mshadow::kFloat32 || dtype_1 == mshadow::kBfloat16);
+ return SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[1]) &&
+ SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0]);
}
} // namespace op
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc
b/src/operator/nn/dnnl/dnnl_convolution.cc
index ca6effb791..a38a9377fe 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -36,10 +36,11 @@ namespace op {
DMLC_REGISTER_PARAMETER(DNNLConvParam);
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_convolution.html
bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input) {
if (params.kernel.ndim() > 3 || params.kernel.ndim() == 0)
return false;
- return IsDNNLType(input.dtype()) && input.shape().ndim() >= 3 &&
input.shape().ndim() <= 5;
+ return SupportDNNL<3, 5, DNNLTypeMode::AllTypes>(input);
}
std::shared_ptr<dnnl::convolution_forward::primitive_desc> GetConvFwdImpl(
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc
b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 79e3229963..f0252a7384 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -29,10 +29,10 @@
namespace mxnet {
namespace op {
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_convolution.html
bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input)
{
return params.kernel.ndim() >= 1 && params.kernel.ndim() <= 3 &&
- input.shape().ndim() == (params.kernel.ndim() + 2) &&
- (input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16);
+ SupportDNNL<3, 5, DNNLTypeMode::FloatTypes>(input);
}
DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
index 39786462e7..30468ce22a 100644
--- a/src/operator/nn/dnnl/dnnl_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -31,15 +31,14 @@
namespace mxnet {
namespace op {
-bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output)
{
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_matmul.html
+bool SupportDNNLDot(const std::vector<NDArray>& inputs) {
#if MXNET_USE_BLAS_MKL == 1
return false;
#endif
- return inputs[DotIn::lhs].shape().Size() > 1 &&
inputs[DotIn::rhs].shape().Size() > 1 &&
- inputs[DotIn::lhs].shape().ndim() > 0 &&
inputs[DotIn::rhs].shape().ndim() > 0 &&
- output.shape().Size() != 0 && output.shape().ndim() > 0 &&
output.shape().ndim() <= 12 &&
- (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
- inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+ // Remove cases where ndim of inputs is equal to 1, because output will be
scalar in this case
+ return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::lhs]) &&
+ SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::rhs]);
}
DNNLDotFwd& DNNLDotFwd::GetCached(const DotParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_eltwise.cc
b/src/operator/nn/dnnl/dnnl_eltwise.cc
index eb51beb526..2fc2edae10 100644
--- a/src/operator/nn/dnnl/dnnl_eltwise.cc
+++ b/src/operator/nn/dnnl/dnnl_eltwise.cc
@@ -28,13 +28,9 @@
namespace mxnet {
namespace op {
-bool SupportDNNLEltwise(const NDArray& input, const NDArray& output) {
- auto checkTensor = [](const NDArray& tensor) {
- return (tensor.dtype() == mshadow::kFloat32 || tensor.dtype() ==
mshadow::kBfloat16) &&
- tensor.shape().ndim() > 0 && tensor.shape().ndim() <= 12 &&
tensor.shape().Size() > 0 &&
- SupportStorageDNNL(tensor.storage_type());
- };
- return checkTensor(input) && checkTensor(output);
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLEltwise(const NDArray& input) {
+ return SupportDNNL<DNNLTypeMode::FloatTypes>(input);
}
DNNLEltwiseFwd& DNNLEltwiseFwd::GetCached(const NDArray& input,
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm.cc
b/src/operator/nn/dnnl/dnnl_layer_norm.cc
index 65eb0a3e59..66c3c44895 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm.cc
+++ b/src/operator/nn/dnnl/dnnl_layer_norm.cc
@@ -29,6 +29,7 @@
namespace mxnet {
namespace op {
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_layer_normalization.html
bool SupportDNNLLayerNorm(const LayerNormParam& param, const
std::vector<NDArray>& inputs) {
const mxnet::TShape& shape = inputs[layernorm::kData].shape();
@@ -40,13 +41,10 @@ bool SupportDNNLLayerNorm(const LayerNormParam& param,
const std::vector<NDArray
return shape.Size() / shape[0] >= shapeLimit && shape[0] >= shapeLimit;
};
- return (ShapeBetterForDNNL(shape) &&
- (GetRealAxis(param.axis, shape.ndim()) == shape.ndim() - 1) &&
(shape.ndim() >= 2) &&
- (shape.ndim() <= 5) &&
- (inputs[layernorm::kData].dtype() == mshadow::kFloat32 ||
- inputs[layernorm::kData].dtype() == mshadow::kBfloat16) &&
- inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
- inputs[layernorm::kBeta].dtype() == mshadow::kFloat32);
+ return (ShapeBetterForDNNL(shape) && GetRealAxis(param.axis, shape.ndim())
== shape.ndim() - 1) &&
+ SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[layernorm::kData])
&&
+ inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
+ inputs[layernorm::kBeta].dtype() == mshadow::kFloat32;
}
void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc
b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index 1559ee347b..e694363836 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -51,21 +51,15 @@ static dnnl::logsoftmax_backward::primitive_desc
GetLogSoftmaxBwdPd(
return dnnl::logsoftmax_backward::primitive_desc(desc, cpu_engine,
hint_fwd_pd);
}
-bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& data,
const NDArray& output) {
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_logsoftmax.html
+bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& data) {
const int ndim = data.shape().ndim();
- const int in_dtype = data.dtype();
- const int out_dtype = output.dtype();
- const int axis = CheckAxis(param.axis, ndim);
+ const int out_dtype = param.dtype.has_value() ? param.dtype.value() :
data.dtype();
// DNNL does not support temperature argument in their log_softmax function
// now. Need update this once they start to support it.
// Currently, DNNL shows bad performance when log_softmax is not performed
on the last dimension
- if (data.shape().Size() == 0 || data.shape().ndim() == 0 ||
param.temperature.has_value() ||
- in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim
- 1)) {
- return false;
- }
-
- // only supports ndim = 1, 2, 3, 4 for now
- return (ndim >= 1 && ndim <= 4);
+ return !param.temperature.has_value() && CheckAxis(param.axis, ndim) ==
(ndim - 1) &&
+ SupportDNNL<1, 4, DNNLTypeMode::FloatTypes>(data) && out_dtype ==
data.dtype();
}
class DNNLLogSoftmaxFwd {
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
index a2a9c5835d..d47b818a59 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
@@ -28,16 +28,15 @@
namespace mxnet {
namespace op {
-bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param,
- const std::vector<NDArray>& inputs,
- const NDArray& output) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const
std::vector<NDArray>& inputs) {
CHECK_EQ(inputs.size(), 2);
const auto mask = inputs[1];
SoftmaxParam softmax_param;
softmax_param.axis = param.axis;
softmax_param.dtype = inputs[0].dtype();
softmax_param.temperature = param.temperature;
- return mask.dtype() == mshadow::kBool && SupportDNNLSoftmax(softmax_param,
inputs[0], output);
+ return mask.dtype() == mshadow::kBool && SupportDNNLSoftmax(softmax_param,
inputs[0]);
}
inline static dnnl::memory::dims GetOneDNNDims(const NDArray& arr) {
diff --git a/src/operator/nn/dnnl/dnnl_pooling-inl.h
b/src/operator/nn/dnnl/dnnl_pooling-inl.h
index 77ad9baf74..6efc8210da 100644
--- a/src/operator/nn/dnnl/dnnl_pooling-inl.h
+++ b/src/operator/nn/dnnl/dnnl_pooling-inl.h
@@ -127,25 +127,19 @@ inline bool SupportDNNLPooling(const PoolingParam& param)
{
param.layout.value() == mshadow::kNCDHW));
}
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_pooling.html
inline bool SupportDNNLPooling(const PoolingParam& param, const NDArray&
input) {
- const auto dshape = input.shape();
- const auto ndim = dshape.ndim();
- const auto dtype = input.dtype();
-
- if (!(SupportStorageDNNL(input.storage_type()) && (ndim == 3 || ndim == 4 ||
ndim == 5) &&
- (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
- return false;
-
- if (!SupportDNNLPooling(param))
+ if (!SupportDNNL<3, 5, DNNLTypeMode::FloatTypes>(input) ||
!SupportDNNLPooling(param))
return false;
if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
+ const auto dshape = input.shape();
if (param.pool_type == pool_enum::kAvgPooling) {
// dnnl works differently when padding is asymmetric, so let's skip this
case.
bool is_symmetric = true;
- switch (ndim) {
+ switch (dshape.ndim()) {
case 5:
is_symmetric =
is_symmetric &&
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc
b/src/operator/nn/dnnl/dnnl_reduce.cc
index a9be0af295..36aeec15b3 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -83,12 +83,10 @@ mxnet::Tuple<int> CanonicalizeAndSortAxes(const NDArray&
input,
return axes;
}
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reduction.html
bool SupportDNNLReduceImpl(const NumpyReduceAxesParam& param,
const NDArray& input,
const NDArray& output) {
- int in_ndim = input.shape().ndim();
- int out_size = output.shape().Size();
- int in_size = input.shape().Size();
bool param_supported = true;
if (param.axis.has_value()) {
auto axes = CanonicalizeAndSortAxes(input, param, param.axis.value());
@@ -111,10 +109,9 @@ bool SupportDNNLReduceImpl(const NumpyReduceAxesParam&
param,
}
// initial value not supported by oneDNN
param_supported = param_supported && !param.initial.has_value();
- return param_supported &&
- (input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16) &&
- (output.dtype() == mshadow::kFloat32 || output.dtype() ==
mshadow::kBfloat16) &&
- in_ndim >= 1 && out_size > 0 && in_size > 1;
+ // oneDNN does not support reduction of tensors with size equal to 1
+ return param_supported && input.shape().Size() > 1 &&
+ SupportDNNL<DNNLTypeMode::FloatTypes>(input);
}
void DNNLReduceForwardImpl(const NumpyReduceAxesParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_reshape.cc
b/src/operator/nn/dnnl/dnnl_reshape.cc
index d1270a3a8c..5a2d9c1173 100644
--- a/src/operator/nn/dnnl/dnnl_reshape.cc
+++ b/src/operator/nn/dnnl/dnnl_reshape.cc
@@ -31,11 +31,9 @@
namespace mxnet {
namespace op {
-bool SupportDNNLReshape(const NDArray& input, const NDArray& output) {
- const int input_ndims = input.shape().ndim();
- const int output_ndims = output.shape().ndim();
- return input.shape().Size() > 0 && input_ndims >= 1 && input_ndims <= 6 &&
output_ndims >= 1 &&
- output_ndims <= 6 && IsDNNLType(input.dtype());
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
+bool SupportDNNLReshape(const NDArray& input) {
+ return SupportDNNL(input) && input.shape().Size() != 1;
}
DNNLReshapeFwd::DNNLReshapeFwd(const OpReqType& req, const NDArray& input,
const NDArray& output) {
diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h
b/src/operator/nn/dnnl/dnnl_rnn-inl.h
index 6165dfaeb4..b14a657cad 100644
--- a/src/operator/nn/dnnl/dnnl_rnn-inl.h
+++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h
@@ -558,6 +558,7 @@ class DNNLRnnOp {
const std::vector<NDArray>& outputs);
};
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_rnn.html
inline bool SupportDNNLRnn(const int input_dtype) {
if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_ONEDNN_RNN",
1)) {
return true;
diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc
b/src/operator/nn/dnnl/dnnl_softmax.cc
index 73321a1854..317c5790d0 100644
--- a/src/operator/nn/dnnl/dnnl_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax.cc
@@ -29,23 +29,13 @@
namespace mxnet {
namespace op {
-bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data, const
NDArray& output) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data) {
const int ndim = data.shape().ndim();
- const int in_size = data.shape().Size();
- const int in_dtype = data.dtype();
- const int out_dtype = output.dtype();
- const int axis = CheckAxis(param.axis, ndim);
-
- if (param.temperature.has_value() && param.temperature.value() == 0.0) {
- return false;
- }
-
- if (in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim
- 1)) {
- return false;
- }
-
- // Supports ndim up to 6
- return (ndim >= 1 && ndim <= 6 && in_size != 0);
+ const int out_dtype = param.dtype.has_value() ? param.dtype.value() :
data.dtype();
+ return !(param.temperature.has_value() && param.temperature.value() != 0.0)
&&
+ CheckAxis(param.axis, ndim) == (ndim - 1) &&
SupportDNNL<DNNLTypeMode::NoInt32>(data) &&
+ out_dtype == data.dtype();
}
void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_softmax_output.cc
b/src/operator/nn/dnnl/dnnl_softmax_output.cc
index ba79effc5b..2f1e0c9161 100644
--- a/src/operator/nn/dnnl/dnnl_softmax_output.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax_output.cc
@@ -90,9 +90,9 @@ static DNNLSoftmaxOutputFwd& GetSoftmaxOutputForward(const
SoftmaxOutputParam& p
return it->second;
}
-// This is only used for forward. For backward ,need double check
compatibility
-bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param) {
- return param.multi_output ? false : true;
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param, const NDArray&
input) {
+ return SupportDNNL(input) && !param.multi_output;
}
void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_split.cc
b/src/operator/nn/dnnl/dnnl_split.cc
index e13b45a259..ff2bec3af2 100644
--- a/src/operator/nn/dnnl/dnnl_split.cc
+++ b/src/operator/nn/dnnl/dnnl_split.cc
@@ -29,12 +29,6 @@
namespace mxnet {
namespace op {
-bool SupportDNNLSplit(const NDArray& input) {
- static const std::set<int> supported_dtypes = {
- mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt32, mshadow::kInt8,
mshadow::kUint8};
- return supported_dtypes.count(input.dtype());
-}
-
void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
diff --git a/src/operator/nn/dnnl/dnnl_stack.cc
b/src/operator/nn/dnnl/dnnl_stack.cc
index 981a4f87c6..3aca884121 100644
--- a/src/operator/nn/dnnl/dnnl_stack.cc
+++ b/src/operator/nn/dnnl/dnnl_stack.cc
@@ -32,27 +32,9 @@
namespace mxnet {
namespace op {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
- if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() !=
mshadow::kBfloat16) {
- return false;
- }
-
- int src_dtype = inputs[0].dtype();
- for (const auto& arr : inputs) {
- if (arr.dtype() != src_dtype) {
- return false;
- }
- // Do not support zero-size tensors.
- if (arr.shape().Size() == 0) {
- return false;
- }
-
- int ndim = arr.shape().ndim();
- if (ndim <= 0) {
- return false;
- }
- }
- return true;
+ return SupportDNNL<DNNLTypeMode::FloatTypes,
DNNLTensorsDtypes::AllSame>(inputs);
}
void DNNLStackForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_transpose-inl.h
b/src/operator/nn/dnnl/dnnl_transpose-inl.h
index aa6e071da7..ef3601a01a 100644
--- a/src/operator/nn/dnnl/dnnl_transpose-inl.h
+++ b/src/operator/nn/dnnl/dnnl_transpose-inl.h
@@ -33,8 +33,6 @@
namespace mxnet {
namespace op {
-bool SupportDNNLTranspose(const NDArray& data);
-
class DNNLTransposeFwd {
public:
std::shared_ptr<dnnl::memory> data_;
diff --git a/src/operator/nn/dnnl/dnnl_transpose.cc
b/src/operator/nn/dnnl/dnnl_transpose.cc
index 342a79682e..2894fc76f2 100644
--- a/src/operator/nn/dnnl/dnnl_transpose.cc
+++ b/src/operator/nn/dnnl/dnnl_transpose.cc
@@ -31,16 +31,6 @@
namespace mxnet {
namespace op {
-bool SupportDNNLTranspose(const NDArray& data) {
- auto data_ndim = data.shape().ndim();
-
- if (data_ndim > 4 || data_ndim == 0 || data.shape().Size() == 0 ||
- !(data.dtype() == mshadow::kFloat32 || data.dtype() ==
mshadow::kBfloat16))
- return false;
-
- return true;
-}
-
typedef ParamOpSign<NumpyTransposeParam> DNNLTransposeSignature;
DNNLTransposeFwd::DNNLTransposeFwd(const NumpyTransposeParam& param, const
NDArray& data) {
diff --git a/src/operator/nn/dnnl/dnnl_where.cc
b/src/operator/nn/dnnl/dnnl_where.cc
index 7d5ca4ad6b..8e225b471e 100644
--- a/src/operator/nn/dnnl/dnnl_where.cc
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -32,16 +32,9 @@
namespace mxnet {
namespace op {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
bool SupportDNNLWhere(const std::vector<NDArray>& inputs) {
- static const std::set<int> supported_dtypes = {
- mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8};
- for (int i = 0; i < inputs.size(); ++i) {
- if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size()
<= 0 ||
- inputs[i].shape().ndim() <= 0) {
- return false;
- }
- }
- return true;
+ return SupportDNNL<DNNLTypeMode::NoInt32, DNNLTensorsDtypes::Mixed>(inputs);
}
void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/fully_connected.cc
b/src/operator/nn/fully_connected.cc
index 61e7ca5ea9..63c0be518f 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -31,11 +31,12 @@
namespace mxnet {
namespace op {
+#if MXNET_USE_ONEDNN == 1
+// Support for
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_inner_product.html
bool SupportDNNLFC(const NDArray& input) {
- int ndim = input.shape().ndim();
- return (input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16) &&
- (ndim >= 1 && ndim <= 4) && input.storage_type() == kDefaultStorage;
+ return SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(input);
}
+#endif
static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index e63f09c71f..9e36a28a7a 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -40,7 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
- if (SupportDNNLLogSoftmax(param, inputs[0], outputs[0])) {
+ if (SupportDNNLLogSoftmax(param, inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLLogSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>;
@@ -56,7 +56,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
- if (SupportDNNLLogSoftmax(param, inputs[1], outputs[0])) {
+ if (SupportDNNLLogSoftmax(param, inputs[1])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLLogSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::left,
mxnet_op::log_softmax_bwd>;
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index c121be2720..f67456ab81 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -107,7 +107,7 @@ void LRNComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNL(inputs[0])) {
+ if (SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[0])) {
// We only need to test one output array.
DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
DNNLRun(DNNLLRNForward, attrs, ctx, inputs[0], req[0], outputs[0]);
@@ -124,7 +124,7 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNL(inputs[0])) {
+ if (SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[0])) {
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLRun(DNNLLRNBackward, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(LRNGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/masked_softmax.cc
b/src/operator/nn/masked_softmax.cc
index ee8c902833..fa17ec9f9f 100644
--- a/src/operator/nn/masked_softmax.cc
+++ b/src/operator/nn/masked_softmax.cc
@@ -43,7 +43,7 @@ static void MaskedSoftmaxComputeExCPU(const nnvm::NodeAttrs&
attrs,
if (inputs[0].shape().Size() == 0U)
return;
const MaskedSoftmaxParam& param =
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
- if (SupportDNNLMaskedSoftmax(param, inputs, outputs[0])) {
+ if (SupportDNNLMaskedSoftmax(param, inputs)) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLMaskedSoftmaxForward, attrs, ctx, inputs, req, outputs);
@@ -61,7 +61,6 @@ inline static bool MaskedSoftmaxStorageType(const
nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
- const MaskedSoftmaxParam& param =
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 0fb4e338cb..e1e696d7fc 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -41,7 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
- if (SupportDNNLSoftmax(param, inputs[0], outputs[0])) {
+ if (SupportDNNLSoftmax(param, inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
@@ -57,7 +57,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
- if (SupportDNNLSoftmax(param, inputs[1], outputs[0])) {
+ if (SupportDNNLSoftmax(param, inputs[1])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>;
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.h
b/src/operator/numpy/np_broadcast_reduce_op_value.h
index b438a92872..e38fc32db9 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.h
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.h
@@ -199,7 +199,6 @@ static void DNNLReduceEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
- const NumpyReduceAxesParam& param =
nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
if (SupportDNNLReduce<NumpyReduceAxesParam>(attrs, inputs[0], outputs[0])) {
DNNLRun(DNNLReduceForward<NumpyReduceAxesParam, reduction_alg>,
diff --git a/src/operator/numpy/np_dot_forward.cc
b/src/operator/numpy/np_dot_forward.cc
index 870486022c..0253a16792 100644
--- a/src/operator/numpy/np_dot_forward.cc
+++ b/src/operator/numpy/np_dot_forward.cc
@@ -110,7 +110,7 @@ static void NumpyDotComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNLDot(inputs, outputs[0])) {
+ if (SupportDNNLDot(inputs)) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLDotForward<true>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/numpy/np_matrix_op.cc
b/src/operator/numpy/np_matrix_op.cc
index 946527562a..7a3404df42 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -59,7 +59,7 @@ static void NumpyTransposeComputeExCPU(const nnvm::NodeAttrs&
attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
- if (SupportDNNLTranspose(inputs[0]) && req[0] == kWriteTo) {
+ if (SupportDNNL(inputs[0]) && req[0] == kWriteTo) {
DNNLRun(DNNLTransposeForward<NumpyTransposeParam>, attrs, ctx, inputs[0],
req[0], outputs[0]);
return;
}
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
index 0b14a96777..ab6083586d 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
@@ -44,7 +44,7 @@ static void DNNLQuantizedFlattenForward(const
nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNLReshape(inputs[0], outputs[0])) {
+ if (SupportDNNL(inputs[0])) {
DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req,
outputs);
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
index 344bda3f5e..57b987dfc6 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
@@ -37,7 +37,7 @@ static void DNNLQuantizedReshapeForward(const
nnvm::NodeAttrs& attrs,
CHECK(inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() ==
mshadow::kInt8)
<< "dnnl_quantized_reshape op only supports uint8 and int8 as input
type";
- if (SupportDNNLReshape(inputs[0], outputs[0])) {
+ if (SupportDNNLReshape(inputs[0])) {
OpReqType reqType;
if (inputs[0].GetDNNLData()->get_data_handle() !=
outputs[0].GetDNNLData()->get_data_handle())
reqType = kWriteTo;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
b/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
index 0a8e2aed1f..6be64c3254 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
@@ -40,13 +40,9 @@ inline static bool QuantizedTransposeStorageType(const
nnvm::NodeAttrs& attrs,
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
bool SupportDNNLQuantizedTranspose(const NDArray& data) {
- auto data_ndim = data.shape().ndim();
-
- if (data_ndim > 4 || data_ndim == 0 || data.shape().Size() == 0)
- return false;
-
- return true;
+ return SupportDNNL<DNNLTypeMode::ByteTypes>(data);
}
typedef void (*TransposeFallbackFunAny)(const nnvm::NodeAttrs&,
const OpContext&,
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 52a9888ea8..cb0715f508 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -153,7 +153,7 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
const SoftmaxOutputParam& param =
nnvm::get<SoftmaxOutputParam>(attrs.parsed);
- if (SupportDNNL(inputs[0]) && !ctx.is_train &&
SupportDNNLSoftmaxOutput(param)) {
+ if (SupportDNNLSoftmaxOutput(param, inputs[0]) && !ctx.is_train) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLSoftmaxOutputForward, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req,
outputs);
diff --git a/src/operator/subgraph/dnnl/dnnl_conv_property.h
b/src/operator/subgraph/dnnl/dnnl_conv_property.h
index 3d814c0598..7d747994b3 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_conv_property.h
@@ -114,7 +114,7 @@ class SgDNNLConvSelector : public SubgraphSelector {
default:
if ((!disable_conv_act_) && node_name == "Activation") {
const ActivationParam& param =
nnvm::get<ActivationParam>(new_node.attrs.parsed);
- if ((quantize_ && SupportQuantizedDNNLAct(param)) ||
+ if ((quantize_ && SupportDNNLQuantizedAct(param)) ||
(!quantize_ && SupportDNNLAct(param))) {
matched_list_.push_back(&new_node);
// not support conv+relu+sum yet.
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index 902444354f..5b567ee652 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -94,7 +94,7 @@ class SgDNNLFCSelector : public SubgraphSelector {
// Currently, For INT8 FC fusion, only supports
relu/bounded_relu(clip)/abs.
if (new_node.op() == Op::Get("Activation")) {
const ActivationParam& param =
nnvm::get<ActivationParam>(new_node.attrs.parsed);
- if ((quantized_ && SupportQuantizedDNNLAct(param)) ||
+ if ((quantized_ && SupportDNNLQuantizedAct(param)) ||
(!quantized_ && SupportDNNLAct(param))) {
matched_list_.push_back(&new_node);
status_ = kSuccess;
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index 0d62cac808..60b55fb5f4 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -124,7 +124,7 @@ void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNLDot(inputs, outputs[DotOut::out])) {
+ if (SupportDNNLDot(inputs)) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLDotForward<false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
@@ -138,7 +138,7 @@ static void BatchDotComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNLBatchDot(inputs, outputs[DotOut::out])) {
+ if (SupportDNNLBatchDot(inputs)) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLBatchDotForward<false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc
b/src/operator/tensor/elemwise_binary_op_basic.cc
index ad29575cb3..2f03f90672 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -30,11 +30,12 @@
namespace mxnet {
namespace op {
-bool SupportDNNLSum(const NDArray& input) {
- int ndim = input.shape().ndim();
- return (input.dtype() == mshadow::kFloat32 || input.dtype() ==
mshadow::kBfloat16) &&
- (ndim >= 1 && ndim <= 4) && input.storage_type() == kDefaultStorage;
+#if MXNET_USE_ONEDNN == 1
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLSum(const std::vector<NDArray>& inputs) {
+ return SupportDNNL(inputs[0]) && SupportDNNL(inputs[1]);
}
+#endif
static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -44,7 +45,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_ONEDNN == 1
- if (SupportDNNLSum(inputs[0]) && SupportDNNLSum(inputs[1])) {
+ if (SupportDNNLSum(inputs) && common::ContainsOnlyStorage(inputs,
kDefaultStorage)) {
DNNLRun(DNNLSumForward, attrs, ctx, inputs, req, outputs);
return;
} else if (inputs[0].storage_type() == kDefaultStorage &&
diff --git a/src/operator/tensor/elemwise_unary_op.h
b/src/operator/tensor/elemwise_unary_op.h
index bc6011a97d..b675acf9f9 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -511,7 +511,7 @@ inline void EltwiseComputeExCPU(const nnvm::NodeAttrs&
attrs,
const std::vector<mxnet::NDArray>& outputs) {
auto fallBackFunction =
computeMixed ? UnaryOp::ComputeMixedType<cpu, OP> :
UnaryOp::Compute<cpu, OP>;
- if (SupportDNNLEltwise(inputs[0], outputs[0])) {
+ if (SupportDNNLEltwise(inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(
DNNLEltwiseForward<DNNLAlgorithm<OP>::value>, attrs, ctx, inputs[0],
req[0], outputs[0]);
diff --git a/src/operator/tensor/matrix_op-inl.h
b/src/operator/tensor/matrix_op-inl.h
index 00bfe9bd51..2831789fe4 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -677,6 +677,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
}
// Currently DNNL only supports step = 1 or step has no value
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
inline bool SupportDNNLSlice(const SliceParam& param) {
if (param.step.ndim() == 0U)
return true;
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 15f131cfcd..303011f058 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -127,7 +127,7 @@ void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
// DNNL support the data type or the shape. Then convert
// it to the output format and shape
- if (SupportDNNLReshape(inputs[0], outputs[0])) {
+ if (SupportDNNLReshape(inputs[0])) {
DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req,
outputs);
@@ -236,7 +236,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in DNNL format and
// DNNL support the data type or the shape. Then convert
// it to the output format and shape
- if (SupportDNNLReshape(inputs[0], outputs[0])) {
+ if (SupportDNNLReshape(inputs[0])) {
DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req,
outputs);
@@ -317,7 +317,7 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs&
attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
- if (SupportDNNLTranspose(inputs[0]) && req[0] == kWriteTo) {
+ if (SupportDNNL(inputs[0]) && req[0] == kWriteTo) {
DNNLRun(DNNLTransposeForward<TransposeParam>, attrs, ctx, inputs[0],
req[0], outputs[0]);
return;
}
@@ -422,7 +422,7 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in DNNL format and
// DNNL support the data type or the shape. Then convert
// it to the output format and shape
- if (SupportDNNLReshape(inputs[0], outputs[0])) {
+ if (SupportDNNLReshape(inputs[0])) {
DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req,
outputs);
@@ -473,6 +473,12 @@ will return a new array with shape ``(2,1,3,4)``.
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(ExpandDimParam::__FIELDS__());
+#if MXNET_USE_ONEDNN == 1
+bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const
NDArray& output) {
+ return SupportDNNLSlice(param) && SupportDNNL(input) && SupportDNNL(output);
+}
+#endif
+
void SliceExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
@@ -486,7 +492,7 @@ void SliceExCPU(const nnvm::NodeAttrs& attrs,
SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
#if MXNET_USE_ONEDNN == 1
} else if (in_stype == kDefaultStorage) {
- if (SupportDNNL(inputs[0])) {
+ if (SupportDNNLSlice(param, inputs[0], outputs[0])) {
DNNLRun(DNNLSlice, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
@@ -1185,7 +1191,7 @@ static void SplitForwardEx(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(!inputs.empty());
- if (SupportDNNLSplit(inputs[0])) {
+ if (SupportDNNL(inputs[0])) {
DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
DNNLRun(DNNLSplitForward, attrs, op_ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(SplitOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);