This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ef2be51265 Refactor SupportDNNL functions (#21032)
ef2be51265 is described below

commit ef2be5126592bb73bf3e42c29722b93b89b250dc
Author: AdamGrabowski <[email protected]>
AuthorDate: Thu Jun 23 09:41:21 2022 +0200

    Refactor SupportDNNL functions (#21032)
    
    * Refactor SupportDNNL functions
    
    * FixSupportDNNL for quantized reshape
    
    * Update src/operator/nn/dnnl/dnnl_softmax_output.cc
    
    Co-authored-by: bgawrych <[email protected]>
    
    * Update src/operator/nn/dnnl/dnnl_reduce.cc
    
    Co-authored-by: bgawrych <[email protected]>
---
 src/ndarray/ndarray.cc                             |   4 +-
 src/operator/contrib/adaptive_avg_pooling.cc       |  22 ++--
 src/operator/contrib/batch_norm_relu.cc            |  19 +--
 src/operator/nn/batch_norm.cc                      |  19 +--
 src/operator/nn/concat.cc                          |  11 +-
 src/operator/nn/dnnl/dnnl_act.cc                   |  17 +--
 src/operator/nn/dnnl/dnnl_base-inl.h               | 142 ++++++++++++++-------
 src/operator/nn/dnnl/dnnl_batch_dot.cc             |   9 +-
 src/operator/nn/dnnl/dnnl_binary.cc                |  11 +-
 src/operator/nn/dnnl/dnnl_convolution.cc           |   3 +-
 src/operator/nn/dnnl/dnnl_deconvolution.cc         |   4 +-
 src/operator/nn/dnnl/dnnl_dot.cc                   |  11 +-
 src/operator/nn/dnnl/dnnl_eltwise.cc               |  10 +-
 src/operator/nn/dnnl/dnnl_layer_norm.cc            |  12 +-
 src/operator/nn/dnnl/dnnl_log_softmax.cc           |  16 +--
 src/operator/nn/dnnl/dnnl_masked_softmax.cc        |   7 +-
 src/operator/nn/dnnl/dnnl_pooling-inl.h            |  14 +-
 src/operator/nn/dnnl/dnnl_reduce.cc                |  11 +-
 src/operator/nn/dnnl/dnnl_reshape.cc               |   8 +-
 src/operator/nn/dnnl/dnnl_rnn-inl.h                |   1 +
 src/operator/nn/dnnl/dnnl_softmax.cc               |  22 +---
 src/operator/nn/dnnl/dnnl_softmax_output.cc        |   6 +-
 src/operator/nn/dnnl/dnnl_split.cc                 |   6 -
 src/operator/nn/dnnl/dnnl_stack.cc                 |  22 +---
 src/operator/nn/dnnl/dnnl_transpose-inl.h          |   2 -
 src/operator/nn/dnnl/dnnl_transpose.cc             |  10 --
 src/operator/nn/dnnl/dnnl_where.cc                 |  11 +-
 src/operator/nn/fully_connected.cc                 |   7 +-
 src/operator/nn/log_softmax.cc                     |   4 +-
 src/operator/nn/lrn.cc                             |   4 +-
 src/operator/nn/masked_softmax.cc                  |   3 +-
 src/operator/nn/softmax.cc                         |   4 +-
 src/operator/numpy/np_broadcast_reduce_op_value.h  |   1 -
 src/operator/numpy/np_dot_forward.cc               |   2 +-
 src/operator/numpy/np_matrix_op.cc                 |   2 +-
 .../quantization/dnnl/dnnl_quantized_flatten.cc    |   2 +-
 .../quantization/dnnl/dnnl_quantized_reshape.cc    |   2 +-
 .../quantization/dnnl/dnnl_quantized_transpose.cc  |   8 +-
 src/operator/softmax_output.cc                     |   2 +-
 src/operator/subgraph/dnnl/dnnl_conv_property.h    |   2 +-
 src/operator/subgraph/dnnl/dnnl_fc_property.h      |   2 +-
 src/operator/tensor/dot.cc                         |   4 +-
 src/operator/tensor/elemwise_binary_op_basic.cc    |  11 +-
 src/operator/tensor/elemwise_unary_op.h            |   2 +-
 src/operator/tensor/matrix_op-inl.h                |   1 +
 src/operator/tensor/matrix_op.cc                   |  18 ++-
 46 files changed, 230 insertions(+), 281 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 605e705ae7..902880fb1d 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1326,8 +1326,8 @@ inline void CopyFromToDnsImpl(const NDArray& from, const 
NDArray& to, RunContext
     TBlob tmp = to.data();
     ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp, from.ctx(), to.ctx(), 
ctx);
 #if MXNET_USE_ONEDNN == 1
-  } else if (SupportDNNL(from.dtype(), from.shape()) && 
SupportDNNL(to.dtype(), to.shape()) &&
-             from.ctx().dev_mask() == cpu::kDevMask && to.ctx().dev_mask() == 
cpu::kDevMask) {
+  } else if (SupportDNNL(from) && SupportDNNL(to) && from.ctx().dev_mask() == 
cpu::kDevMask &&
+             to.ctx().dev_mask() == cpu::kDevMask) {
     // If we copy data directly, we need to make sure both NDArrays are 
supported
     // by DNNL.
     auto from_mem = from.GetDNNLData();
diff --git a/src/operator/contrib/adaptive_avg_pooling.cc 
b/src/operator/contrib/adaptive_avg_pooling.cc
index c2614d9456..69d43e0060 100644
--- a/src/operator/contrib/adaptive_avg_pooling.cc
+++ b/src/operator/contrib/adaptive_avg_pooling.cc
@@ -172,10 +172,11 @@ void AdaptiveAvgPoolUpdateOutput(mshadow::Stream<cpu>* s,
 }
 
 #if MXNET_USE_ONEDNN == 1
-bool SupportDNNLAveragePooling(const NDArray& in_data, const NDArray& 
out_data) {
-  for (int64_t idx = 2; idx < in_data.shape().ndim(); ++idx) {
-    const int s1 = in_data.shape()[idx];
-    const int s2 = out_data.shape()[idx];
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_pooling.html
+bool SupportDNNLAveragePooling(const NDArray& input, const NDArray& output) {
+  for (int64_t idx = 2; idx < input.shape().ndim(); ++idx) {
+    const int s1 = input.shape()[idx];
+    const int s2 = output.shape()[idx];
     if (s2 == 0) {
       return false;
     }
@@ -183,17 +184,18 @@ bool SupportDNNLAveragePooling(const NDArray& in_data, 
const NDArray& out_data)
       return false;
     }
   }
-  const int IH         = in_data.shape()[2];
-  const int IW         = in_data.shape()[3];
-  const int OH         = out_data.shape()[2];
-  const int OW         = out_data.shape()[3];
+  const int IH         = input.shape()[2];
+  const int IW         = input.shape()[3];
+  const int OH         = output.shape()[2];
+  const int OW         = output.shape()[3];
   const int strides_H  = ((IH << 1) / OH) - (IH / OH);
   const int strides_W  = ((IW << 1) / OW) - (IW / OW);
   const int kernel_H   = DIV_ROUND_UP((IH << 1) / OH, 1) - (IH / OH);
   const int kernel_W   = DIV_ROUND_UP((IW << 1) / OW, 1) - (IW / OW);
   const int pad_l_top  = (strides_H * (OH - 1) + kernel_H - IH) / 2;
   const int pad_l_left = (strides_W * (OW - 1) + kernel_W - IW) / 2;
-  return pad_l_top == 0 && pad_l_left == 0;
+
+  return SupportDNNL<3, 5, DNNLTypeMode::AllTypes>(input) && pad_l_top == 0 && 
pad_l_left == 0;
 }
 
 void AdaptiveAvgPoolOpBackwardExCPU(const nnvm::NodeAttrs& attrs,
@@ -236,7 +238,7 @@ void AdaptiveAvgPoolComputeExCPU(const nnvm::NodeAttrs& 
attrs,
   oneDNN doesn't support adaptive pooling.
   Fallback is needed when padding is not equal 0;
   */
-  if (SupportDNNL(inputs[0]) && SupportDNNLAveragePooling(inputs[0], 
outputs[0])) {
+  if (SupportDNNLAveragePooling(inputs[0], outputs[0])) {
     DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
     DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/contrib/batch_norm_relu.cc 
b/src/operator/contrib/batch_norm_relu.cc
index 3d0cbb1493..a0f158f42b 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -130,16 +130,9 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& 
attrs,
 }
 
 #if MXNET_USE_ONEDNN == 1
-static inline bool SupportDNNLBNReLU(const NDArray& input, const 
BatchNormParam& param) {
-  if (mxnet::op::batchnorm::disable_mkl)
-    return false;
-  const mxnet::TShape shape = input.shape();
-  const int ndim            = shape.ndim();
-  if (ndim == 0 || shape.Size() == 0)
-    return false;
-  const int dtype = input.dtype();
-  return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
-         SupportStorageDNNL(input.storage_type());
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_batch_normalization.html
+static inline bool SupportDNNLBNReLU(const NDArray& input) {
+  return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(input) && 
!mxnet::op::batchnorm::disable_mkl;
 }
 
 void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -148,8 +141,7 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 5U);
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
-  if (SupportDNNLBNReLU(inputs[0], param)) {
+  if (SupportDNNLBNReLU(inputs[0])) {
     CHECK_GT(outputs.size(), 3U);
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
@@ -165,8 +157,7 @@ void BatchNormWithReLUGradComputeExCPU(const 
nnvm::NodeAttrs& attrs,
                                        const std::vector<NDArray>& inputs,
                                        const std::vector<OpReqType>& req,
                                        const std::vector<NDArray>& outputs) {
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
-  if (SupportDNNLBNReLU(inputs[0], param)) {
+  if (SupportDNNLBNReLU(inputs[0])) {
     CHECK_EQ(inputs.size(), 9U);
     DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx, 
inputs, req, outputs);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index fa422f156c..dc09ebeb22 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -462,16 +462,9 @@ static bool BNChangeLayout(nnvm::NodeAttrs* attrs,
 }
 
 #if MXNET_USE_ONEDNN == 1
-static inline bool SupportDNNLBN(const NDArray& input, const BatchNormParam& 
param) {
-  if (mxnet::op::batchnorm::disable_mkl)
-    return false;
-  const mxnet::TShape shape = input.shape();
-  const int ndim            = shape.ndim();
-  if (ndim == 0 || shape.Size() == 0)
-    return false;
-  const int dtype = input.dtype();
-  return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
-         SupportStorageDNNL(input.storage_type());
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_batch_normalization.html
+static inline bool SupportDNNLBN(const NDArray& input) {
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(input) && 
!mxnet::op::batchnorm::disable_mkl;
 }
 
 void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -480,8 +473,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
                            const std::vector<OpReqType>& req,
                            const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 5U);
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
-  if (SupportDNNLBN(inputs[0], param)) {
+  if (SupportDNNLBN(inputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
       DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx, 
inputs, req, outputs);
@@ -497,8 +489,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                const std::vector<NDArray>& inputs,
                                const std::vector<OpReqType>& req,
                                const std::vector<NDArray>& outputs) {
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
-  if (SupportDNNLBN(inputs[0], param)) {
+  if (SupportDNNLBN(inputs[0])) {
     DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx, 
inputs, req, outputs);
     DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 70b8aeb9f8..6fd1dc18e4 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -248,24 +248,21 @@ inline static bool BackwardConcatStorageType(const 
nnvm::NodeAttrs& attrs,
   return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, 
wanted_mode);
 }
 #if MXNET_USE_ONEDNN == 1
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_concat.html
 bool SupportDNNLConcat(const std::vector<NDArray>& arrs) {
   for (auto& arr : arrs) {
-    if (arr.IsView())
-      return false;
-    if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == 
mshadow::kBfloat16))
-      return false;
-    // Do not support zero-size tensors.
-    if (arr.shape().Size() == 0 || arr.shape().ndim() == 0)
+    if (arr.IsView() || !SupportDNNL<2, 12, AllTypes>(arr))
       return false;
     int ndim             = arr.shape().ndim();
     const int dnnl_ndims = arr.GetDNNLData()->get_desc().data.ndims;
-    if ((ndim != 2 && ndim != 4) || ndim != dnnl_ndims) {
+    if (ndim != dnnl_ndims) {
       return false;
     }
   }
   return true;
 }
 #endif  // MXNET_USE_ONEDNN == 1
+
 static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
                                const OpContext& op_ctx,
                                const std::vector<NDArray>& inputs,
diff --git a/src/operator/nn/dnnl/dnnl_act.cc b/src/operator/nn/dnnl/dnnl_act.cc
index f3ee79a3f6..e1fb642209 100644
--- a/src/operator/nn/dnnl/dnnl_act.cc
+++ b/src/operator/nn/dnnl/dnnl_act.cc
@@ -48,12 +48,9 @@ bool SupportDNNLAct(const ActivationParam& param) {
          param.act_type == activation::kSoftReLU || param.act_type == 
activation::kTanh;
 }
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
 bool SupportDNNLAct(const ActivationParam& param, const NDArray& input) {
-  // DNNL Activation supports 1d, 2d, 3d, 4d and 5d data layout
-  if ((input.shape().ndim() < 1) || (input.shape().ndim() > 5) ||
-      !(input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16))
-    return false;
-  return SupportDNNLAct(param);
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(input) && SupportDNNLAct(param);
 }
 
 bool SupportDNNLLeakyRelu(const LeakyReLUParam& param) {
@@ -61,15 +58,13 @@ bool SupportDNNLLeakyRelu(const LeakyReLUParam& param) {
          param.act_type == leakyrelu::kGELU;
 }
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
 bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input) {
-  // DNNL Activation supports 1d, 2d, 3d, 4d and 5d data layout
-  if ((input.shape().ndim() < 1) || (input.shape().ndim() > 5) ||
-      !(input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16))
-    return false;
-  return SupportDNNLLeakyRelu(param);
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(input) && 
SupportDNNLLeakyRelu(param);
 }
 
-bool SupportQuantizedDNNLAct(const ActivationParam& param) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLQuantizedAct(const ActivationParam& param) {
   // Although it is the same as SupportDNNLAct i left it here, so when new 
activations
   // will be introduced it will be easier to handle.
   return SupportDNNLAct(param);
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h 
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 8161d62c24..e8cf9c44b3 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -35,6 +35,8 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
+#include <map>
+#include <set>
 
 #include "dnnl.hpp"
 #include "mxnet/graph_attr_types.h"
@@ -114,36 +116,84 @@ struct data_type_enum<uint8_t> {
   enum { type = static_cast<unsigned int>(dnnl::memory::data_type::u8) };
 };
 
-static inline bool SupportDNNLArray(int dtype, const mxnet::TShape& shape) {
-  int ndim     = shape.ndim();
-  bool support = ndim == 1 || ndim == 2 || ndim == 4;
-  support      = support &&
-            (dtype == mshadow::kFloat32 || dtype == mshadow::kInt32 || dtype 
== mshadow::kInt8 ||
-             dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16);
-  return support;
+// SupportDNNL variant that is used to check tensor's dimensions.
+template <int MinNdim, int MaxNdim>
+static inline bool SupportDNNLShape(const mxnet::TShape& shape) {
+  const int ndim = shape.ndim();
+  return ndim >= MinNdim && ndim <= MaxNdim && shape.Size() != 0;
 }
 
-static inline bool SupportStorageDNNL(int stype) {
-  return stype == kDefaultStorage;
+// SupportDNNL variant that is used to check defined type combinations.
+// If any other combination is necessary new variant should be implemented.
+enum DNNLTypeMode { AllTypes, NoInt32, IntTypes, FloatTypes, ByteTypes };
+
+// Mapping of DNNLTypeMode variant into set of desired type combinations.
+// clang-format off
+static std::map<DNNLTypeMode, std::set<int>> DNNLTypeModeMap = {
+    {DNNLTypeMode::AllTypes,
+      {mshadow::kBfloat16, mshadow::kFloat32, mshadow::kInt32, mshadow::kInt8, 
mshadow::kUint8}},
+    {DNNLTypeMode::NoInt32,
+      {mshadow::kBfloat16, mshadow::kFloat32, mshadow::kInt8, 
mshadow::kUint8}},
+    {DNNLTypeMode::IntTypes,
+      {mshadow::kInt32, mshadow::kInt8, mshadow::kUint8}},
+    {DNNLTypeMode::FloatTypes,
+      {mshadow::kBfloat16, mshadow::kFloat32}},
+    {DNNLTypeMode::ByteTypes,
+      {mshadow::kInt8, mshadow::kUint8}}};
+// clang-format on
+
+template <DNNLTypeMode TypeMode>
+inline bool SupportDNNLType(int dtype) {
+  return DNNLTypeModeMap[TypeMode].count(dtype);
 }
 
-static inline bool SupportDNNL(int dtype, const mxnet::TShape& shape) {
-  int ndim = shape.ndim();
-  if (ndim == 0 || shape.Size() == 0) {
-    // DNNL currently does not support 0-dim Tensor and 0-size Tensor
-    return false;
-  }
-  return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
-         (ndim == 1 || ndim == 2 || ndim == 4);
+// SupportDNNL variants:
+// Default variant used to check widest dimension possible
+// and all possible data types for given tensor.
+static inline bool SupportDNNL(const NDArray& tensor) {
+  return SupportDNNLType<AllTypes>(tensor.dtype()) && SupportDNNLShape<1, 
12>(tensor.shape());
 }
 
-static inline bool IsDNNLType(int dtype) {
-  return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == 
mshadow::kUint8 ||
-         dtype == mshadow::kBfloat16;
+// Variant checking default shape (1,12) but gives possiblity to check
+// any implemented type combination.
+template <DNNLTypeMode TypeMode>
+static inline bool SupportDNNL(const NDArray& tensor) {
+  return SupportDNNLType<TypeMode>(tensor.dtype()) && SupportDNNLShape<1, 
12>(tensor.shape());
 }
 
-static inline bool SupportDNNL(const NDArray& input) {
-  return SupportDNNL(input.dtype(), input.shape()) && 
SupportStorageDNNL(input.storage_type());
+// Variant with possiblity to check arbitrary shapes and type combination.
+template <int MinNdim, int MaxNdim, DNNLTypeMode TypeMode>
+static inline bool SupportDNNL(const NDArray& tensor) {
+  return SupportDNNLType<TypeMode>(tensor.dtype()) &&
+         SupportDNNLShape<MinNdim, MaxNdim>(tensor.shape());
+}
+
+// Variant checking multiple inputs at the same time
+// with possibility to check their types independantly
+// or check if those types are the same.
+enum DNNLTensorsDtypes { AllSame = 0, Mixed = 1 };
+
+template <int MinNdim, int MaxNdim, DNNLTypeMode TypeMode, DNNLTensorsDtypes 
MixedTensors>
+static inline bool SupportDNNL(const std::vector<NDArray>& inputs) {
+  int dtype = MixedTensors ? -1 : inputs[0].dtype();
+  if (!SupportDNNLType<TypeMode>(dtype)) {
+    return false;
+  }
+  for (NDArray input : inputs) {
+    if (dtype == -1) {
+      if (!SupportDNNL<MinNdim, MaxNdim, TypeMode>(input))
+        return false;
+    } else {
+      if (input.dtype() != dtype && !SupportDNNLShape<MinNdim, 
MaxNdim>(input.shape()))
+        return false;
+    }
+  }
+  return true;
+}
+
+template <DNNLTypeMode TypeMode, DNNLTensorsDtypes MixedTensors>
+static inline bool SupportDNNL(const std::vector<NDArray>& inputs) {
+  return SupportDNNL<1, 12, TypeMode, MixedTensors>(inputs);
 }
 
 static inline bool SupportDNNLQuantize(const int out_type) {
@@ -180,37 +230,39 @@ void* AlignMem(void* mem, size_t size, size_t alignment, 
size_t* space);
 
 namespace op {
 struct ActivationParam;
-struct LeakyReLUParam;
 struct ConvolutionParam;
 struct DeconvolutionParam;
-struct SoftmaxParam;
+struct LayerNormParam;
+struct LeakyReLUParam;
 struct MaskedSoftmaxParam;
-struct SoftmaxOutputParam;
 struct ReshapeParam;
-struct LayerNormParam;
+struct SliceParam;
+struct SoftmaxParam;
+struct SoftmaxOutputParam;
 bool SupportDNNLAct(const ActivationParam& param);
 bool SupportDNNLAct(const ActivationParam& param, const NDArray& input);
-bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
-bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
-bool SupportQuantizedDNNLAct(const ActivationParam& param);
+bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
+bool SupportDNNLConcat(const std::vector<NDArray>& arrs);
 bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input);
 bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input);
-bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input, const 
NDArray& output);
-bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input, 
const NDArray& output);
-bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param,
-                              const std::vector<NDArray>& input,
-                              const NDArray& output);
-bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
-bool SupportDNNLTranspose(const NDArray& data);
-bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output);
-bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& 
output);
+bool SupportDNNLDot(const std::vector<NDArray>& inputs);
+bool SupportDNNLEltwise(const NDArray& input);
+bool SupportDNNLFC(const NDArray& input);
 bool SupportDNNLLayerNorm(const LayerNormParam& param, const 
std::vector<NDArray>& inputs);
-bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
-bool SupportDNNLSplit(const NDArray& input);
-bool SupportDNNLStack(const std::vector<NDArray>& inputs);
-bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
-bool SupportDNNLEltwise(const NDArray& input, const NDArray& output);
+bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
+bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
+bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input);
+bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const 
std::vector<NDArray>& input);
 bool SupportDNNLPower(const NDArray& input);
+bool SupportDNNLQuantizedAct(const ActivationParam& param);
+bool SupportDNNLReshape(const NDArray& input);
+bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const 
NDArray& output);
+bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input);
+bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param, const NDArray& 
input);
+bool SupportDNNLStack(const std::vector<NDArray>& inputs);
+bool SupportDNNLSum(const std::vector<NDArray>& inputs);
+
 void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const 
dnnl::memory& out);
 }  // namespace op
 
@@ -445,9 +497,9 @@ class TmpMemMgr {
 
 typedef std::unordered_map<int, dnnl::memory> dnnl_args_map_t;
 class DNNLStream {
-  std::vector<std::pair<dnnl::primitive, dnnl_args_map_t> > net_prim_args;
+  std::vector<std::pair<dnnl::primitive, dnnl_args_map_t>> net_prim_args;
   // Here we hold all memory related to the operators in the stream.
-  std::vector<std::shared_ptr<const dnnl::memory> > mem_holder;
+  std::vector<std::shared_ptr<const dnnl::memory>> mem_holder;
   dnnl::stream s;
 
  public:
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot.cc 
b/src/operator/nn/dnnl/dnnl_batch_dot.cc
index 4534feafd1..a40d55621c 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_batch_dot.cc
@@ -32,11 +32,10 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(DNNLDotParam);
 
-bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& 
output) {
-  return inputs[DotIn::lhs].shape().Size() != 0 && 
inputs[DotIn::rhs].shape().Size() != 0 &&
-         output.shape().Size() != 0 &&
-         (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
-          inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_matmul.html
+bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs) {
+  return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::lhs]) &&
+         SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::rhs]);
 }
 
 DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_binary.cc 
b/src/operator/nn/dnnl/dnnl_binary.cc
index 2ec74f41a9..75c4805fb7 100644
--- a/src/operator/nn/dnnl/dnnl_binary.cc
+++ b/src/operator/nn/dnnl/dnnl_binary.cc
@@ -63,15 +63,10 @@ void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& 
inputs,
   DNNLStream::Get()->Submit();
 }
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
 bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
-  auto dtype_0 = inputs[0].dtype();
-  auto dtype_1 = inputs[1].dtype();
-  auto ndim_0  = inputs[0].shape().ndim();
-  auto ndim_1  = inputs[1].shape().ndim();
-  return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
-         inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
-         (dtype_0 == mshadow::kFloat32 || dtype_0 == mshadow::kBfloat16) &&
-         (dtype_1 == mshadow::kFloat32 || dtype_1 == mshadow::kBfloat16);
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[1]) &&
+         SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0]);
 }
 
 }  // namespace op
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc 
b/src/operator/nn/dnnl/dnnl_convolution.cc
index ca6effb791..a38a9377fe 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -36,10 +36,11 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(DNNLConvParam);
 
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_convolution.html
 bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input) {
   if (params.kernel.ndim() > 3 || params.kernel.ndim() == 0)
     return false;
-  return IsDNNLType(input.dtype()) && input.shape().ndim() >= 3 && 
input.shape().ndim() <= 5;
+  return SupportDNNL<3, 5, DNNLTypeMode::AllTypes>(input);
 }
 
 std::shared_ptr<dnnl::convolution_forward::primitive_desc> GetConvFwdImpl(
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc 
b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 79e3229963..f0252a7384 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -29,10 +29,10 @@
 namespace mxnet {
 namespace op {
 
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_convolution.html
 bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input) 
{
   return params.kernel.ndim() >= 1 && params.kernel.ndim() <= 3 &&
-         input.shape().ndim() == (params.kernel.ndim() + 2) &&
-         (input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16);
+         SupportDNNL<3, 5, DNNLTypeMode::FloatTypes>(input);
 }
 
 DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
index 39786462e7..30468ce22a 100644
--- a/src/operator/nn/dnnl/dnnl_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -31,15 +31,14 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output) 
{
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_matmul.html
+bool SupportDNNLDot(const std::vector<NDArray>& inputs) {
 #if MXNET_USE_BLAS_MKL == 1
   return false;
 #endif
-  return inputs[DotIn::lhs].shape().Size() > 1 && 
inputs[DotIn::rhs].shape().Size() > 1 &&
-         inputs[DotIn::lhs].shape().ndim() > 0 && 
inputs[DotIn::rhs].shape().ndim() > 0 &&
-         output.shape().Size() != 0 && output.shape().ndim() > 0 && 
output.shape().ndim() <= 12 &&
-         (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
-          inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+  // Remove cases where ndim of inputs is equal to 1, because output will be 
scalar in this case
+  return SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::lhs]) &&
+         SupportDNNL<2, 12, DNNLTypeMode::FloatTypes>(inputs[DotIn::rhs]);
 }
 
 DNNLDotFwd& DNNLDotFwd::GetCached(const DotParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_eltwise.cc 
b/src/operator/nn/dnnl/dnnl_eltwise.cc
index eb51beb526..2fc2edae10 100644
--- a/src/operator/nn/dnnl/dnnl_eltwise.cc
+++ b/src/operator/nn/dnnl/dnnl_eltwise.cc
@@ -28,13 +28,9 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLEltwise(const NDArray& input, const NDArray& output) {
-  auto checkTensor = [](const NDArray& tensor) {
-    return (tensor.dtype() == mshadow::kFloat32 || tensor.dtype() == 
mshadow::kBfloat16) &&
-           tensor.shape().ndim() > 0 && tensor.shape().ndim() <= 12 && 
tensor.shape().Size() > 0 &&
-           SupportStorageDNNL(tensor.storage_type());
-  };
-  return checkTensor(input) && checkTensor(output);
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLEltwise(const NDArray& input) {
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(input);
 }
 
 DNNLEltwiseFwd& DNNLEltwiseFwd::GetCached(const NDArray& input,
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm.cc 
b/src/operator/nn/dnnl/dnnl_layer_norm.cc
index 65eb0a3e59..66c3c44895 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm.cc
+++ b/src/operator/nn/dnnl/dnnl_layer_norm.cc
@@ -29,6 +29,7 @@
 namespace mxnet {
 namespace op {
 
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_layer_normalization.html
 bool SupportDNNLLayerNorm(const LayerNormParam& param, const 
std::vector<NDArray>& inputs) {
   const mxnet::TShape& shape = inputs[layernorm::kData].shape();
 
@@ -40,13 +41,10 @@ bool SupportDNNLLayerNorm(const LayerNormParam& param, 
const std::vector<NDArray
     return shape.Size() / shape[0] >= shapeLimit && shape[0] >= shapeLimit;
   };
 
-  return (ShapeBetterForDNNL(shape) &&
-          (GetRealAxis(param.axis, shape.ndim()) == shape.ndim() - 1) && 
(shape.ndim() >= 2) &&
-          (shape.ndim() <= 5) &&
-          (inputs[layernorm::kData].dtype() == mshadow::kFloat32 ||
-           inputs[layernorm::kData].dtype() == mshadow::kBfloat16) &&
-          inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
-          inputs[layernorm::kBeta].dtype() == mshadow::kFloat32);
+  return (ShapeBetterForDNNL(shape) && GetRealAxis(param.axis, shape.ndim()) 
== shape.ndim() - 1) &&
+         SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[layernorm::kData]) 
&&
+         inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
+         inputs[layernorm::kBeta].dtype() == mshadow::kFloat32;
 }
 
 void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc 
b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index 1559ee347b..e694363836 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -51,21 +51,15 @@ static dnnl::logsoftmax_backward::primitive_desc 
GetLogSoftmaxBwdPd(
   return dnnl::logsoftmax_backward::primitive_desc(desc, cpu_engine, 
hint_fwd_pd);
 }
 
-bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& data, 
const NDArray& output) {
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_logsoftmax.html
+bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& data) {
   const int ndim      = data.shape().ndim();
-  const int in_dtype  = data.dtype();
-  const int out_dtype = output.dtype();
-  const int axis      = CheckAxis(param.axis, ndim);
+  const int out_dtype = param.dtype.has_value() ? param.dtype.value() : 
data.dtype();
   // DNNL does not support temperature argument in their log_softmax function
   // now. Need update this once they start to support it.
   // Currently, DNNL shows bad performance when log_softmax is not performed 
on the last dimension
-  if (data.shape().Size() == 0 || data.shape().ndim() == 0 || 
param.temperature.has_value() ||
-      in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim 
- 1)) {
-    return false;
-  }
-
-  // only supports ndim = 1, 2, 3, 4 for now
-  return (ndim >= 1 && ndim <= 4);
+  return !param.temperature.has_value() && CheckAxis(param.axis, ndim) == 
(ndim - 1) &&
+         SupportDNNL<1, 4, DNNLTypeMode::FloatTypes>(data) && out_dtype == 
data.dtype();
 }
 
 class DNNLLogSoftmaxFwd {
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax.cc 
b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
index a2a9c5835d..d47b818a59 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
@@ -28,16 +28,15 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param,
-                              const std::vector<NDArray>& inputs,
-                              const NDArray& output) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const 
std::vector<NDArray>& inputs) {
   CHECK_EQ(inputs.size(), 2);
   const auto mask = inputs[1];
   SoftmaxParam softmax_param;
   softmax_param.axis        = param.axis;
   softmax_param.dtype       = inputs[0].dtype();
   softmax_param.temperature = param.temperature;
-  return mask.dtype() == mshadow::kBool && SupportDNNLSoftmax(softmax_param, 
inputs[0], output);
+  return mask.dtype() == mshadow::kBool && SupportDNNLSoftmax(softmax_param, 
inputs[0]);
 }
 
 inline static dnnl::memory::dims GetOneDNNDims(const NDArray& arr) {
diff --git a/src/operator/nn/dnnl/dnnl_pooling-inl.h 
b/src/operator/nn/dnnl/dnnl_pooling-inl.h
index 77ad9baf74..6efc8210da 100644
--- a/src/operator/nn/dnnl/dnnl_pooling-inl.h
+++ b/src/operator/nn/dnnl/dnnl_pooling-inl.h
@@ -127,25 +127,19 @@ inline bool SupportDNNLPooling(const PoolingParam& param) 
{
            param.layout.value() == mshadow::kNCDHW));
 }
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_pooling.html
 inline bool SupportDNNLPooling(const PoolingParam& param, const NDArray& 
input) {
-  const auto dshape = input.shape();
-  const auto ndim   = dshape.ndim();
-  const auto dtype  = input.dtype();
-
-  if (!(SupportStorageDNNL(input.storage_type()) && (ndim == 3 || ndim == 4 || 
ndim == 5) &&
-        (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
-    return false;
-
-  if (!SupportDNNLPooling(param))
+  if (!SupportDNNL<3, 5, DNNLTypeMode::FloatTypes>(input) || 
!SupportDNNLPooling(param))
     return false;
 
   if (param.pooling_convention == pool_enum::kValid) {
     return true;
   } else {
+    const auto dshape = input.shape();
     if (param.pool_type == pool_enum::kAvgPooling) {
       // dnnl works differently when padding is asymmetric, so let's skip this 
case.
       bool is_symmetric = true;
-      switch (ndim) {
+      switch (dshape.ndim()) {
         case 5:
           is_symmetric =
               is_symmetric &&
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc 
b/src/operator/nn/dnnl/dnnl_reduce.cc
index a9be0af295..36aeec15b3 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -83,12 +83,10 @@ mxnet::Tuple<int> CanonicalizeAndSortAxes(const NDArray& 
input,
   return axes;
 }
 
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reduction.html
 bool SupportDNNLReduceImpl(const NumpyReduceAxesParam& param,
                            const NDArray& input,
                            const NDArray& output) {
-  int in_ndim          = input.shape().ndim();
-  int out_size         = output.shape().Size();
-  int in_size          = input.shape().Size();
   bool param_supported = true;
   if (param.axis.has_value()) {
     auto axes    = CanonicalizeAndSortAxes(input, param, param.axis.value());
@@ -111,10 +109,9 @@ bool SupportDNNLReduceImpl(const NumpyReduceAxesParam& 
param,
   }
   // initial value not supported by oneDNN
   param_supported = param_supported && !param.initial.has_value();
-  return param_supported &&
-         (input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16) &&
-         (output.dtype() == mshadow::kFloat32 || output.dtype() == 
mshadow::kBfloat16) &&
-         in_ndim >= 1 && out_size > 0 && in_size > 1;
+  // oneDNN does not support reduction of tensors with size equal to 1
+  return param_supported && input.shape().Size() > 1 &&
+         SupportDNNL<DNNLTypeMode::FloatTypes>(input);
 }
 
 void DNNLReduceForwardImpl(const NumpyReduceAxesParam& param,
diff --git a/src/operator/nn/dnnl/dnnl_reshape.cc 
b/src/operator/nn/dnnl/dnnl_reshape.cc
index d1270a3a8c..5a2d9c1173 100644
--- a/src/operator/nn/dnnl/dnnl_reshape.cc
+++ b/src/operator/nn/dnnl/dnnl_reshape.cc
@@ -31,11 +31,9 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLReshape(const NDArray& input, const NDArray& output) {
-  const int input_ndims  = input.shape().ndim();
-  const int output_ndims = output.shape().ndim();
-  return input.shape().Size() > 0 && input_ndims >= 1 && input_ndims <= 6 && 
output_ndims >= 1 &&
-         output_ndims <= 6 && IsDNNLType(input.dtype());
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
+bool SupportDNNLReshape(const NDArray& input) {
+  return SupportDNNL(input) && input.shape().Size() != 1;
 }
 
 DNNLReshapeFwd::DNNLReshapeFwd(const OpReqType& req, const NDArray& input, 
const NDArray& output) {
diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h 
b/src/operator/nn/dnnl/dnnl_rnn-inl.h
index 6165dfaeb4..b14a657cad 100644
--- a/src/operator/nn/dnnl/dnnl_rnn-inl.h
+++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h
@@ -558,6 +558,7 @@ class DNNLRnnOp {
             const std::vector<NDArray>& outputs);
 };
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_rnn.html
 inline bool SupportDNNLRnn(const int input_dtype) {
   if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_ONEDNN_RNN", 
1)) {
     return true;
diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc 
b/src/operator/nn/dnnl/dnnl_softmax.cc
index 73321a1854..317c5790d0 100644
--- a/src/operator/nn/dnnl/dnnl_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax.cc
@@ -29,23 +29,13 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data, const 
NDArray& output) {
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data) {
   const int ndim      = data.shape().ndim();
-  const int in_size   = data.shape().Size();
-  const int in_dtype  = data.dtype();
-  const int out_dtype = output.dtype();
-  const int axis      = CheckAxis(param.axis, ndim);
-
-  if (param.temperature.has_value() && param.temperature.value() == 0.0) {
-    return false;
-  }
-
-  if (in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim 
- 1)) {
-    return false;
-  }
-
-  // Supports ndim up to 6
-  return (ndim >= 1 && ndim <= 6 && in_size != 0);
+  const int out_dtype = param.dtype.has_value() ? param.dtype.value() : 
data.dtype();
+  return !(param.temperature.has_value() && param.temperature.value() != 0.0) 
&&
+         CheckAxis(param.axis, ndim) == (ndim - 1) && 
SupportDNNL<DNNLTypeMode::NoInt32>(data) &&
+         out_dtype == data.dtype();
 }
 
 void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_softmax_output.cc 
b/src/operator/nn/dnnl/dnnl_softmax_output.cc
index ba79effc5b..2f1e0c9161 100644
--- a/src/operator/nn/dnnl/dnnl_softmax_output.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax_output.cc
@@ -90,9 +90,9 @@ static DNNLSoftmaxOutputFwd& GetSoftmaxOutputForward(const 
SoftmaxOutputParam& p
   return it->second;
 }
 
-//  This is only used for forward. For backward ,need double check 
compatibility
-bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param) {
-  return param.multi_output ? false : true;
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
+bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param, const NDArray& 
input) {
+  return SupportDNNL(input) && !param.multi_output;
 }
 
 void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_split.cc 
b/src/operator/nn/dnnl/dnnl_split.cc
index e13b45a259..ff2bec3af2 100644
--- a/src/operator/nn/dnnl/dnnl_split.cc
+++ b/src/operator/nn/dnnl/dnnl_split.cc
@@ -29,12 +29,6 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLSplit(const NDArray& input) {
-  static const std::set<int> supported_dtypes = {
-      mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt32, mshadow::kInt8, 
mshadow::kUint8};
-  return supported_dtypes.count(input.dtype());
-}
-
 void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
                       const OpContext& ctx,
                       const std::vector<NDArray>& inputs,
diff --git a/src/operator/nn/dnnl/dnnl_stack.cc 
b/src/operator/nn/dnnl/dnnl_stack.cc
index 981a4f87c6..3aca884121 100644
--- a/src/operator/nn/dnnl/dnnl_stack.cc
+++ b/src/operator/nn/dnnl/dnnl_stack.cc
@@ -32,27 +32,9 @@
 namespace mxnet {
 namespace op {
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
 bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
-  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16) {
-    return false;
-  }
-
-  int src_dtype = inputs[0].dtype();
-  for (const auto& arr : inputs) {
-    if (arr.dtype() != src_dtype) {
-      return false;
-    }
-    // Do not support zero-size tensors.
-    if (arr.shape().Size() == 0) {
-      return false;
-    }
-
-    int ndim = arr.shape().ndim();
-    if (ndim <= 0) {
-      return false;
-    }
-  }
-  return true;
+  return SupportDNNL<DNNLTypeMode::FloatTypes, 
DNNLTensorsDtypes::AllSame>(inputs);
 }
 
 void DNNLStackForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/dnnl/dnnl_transpose-inl.h 
b/src/operator/nn/dnnl/dnnl_transpose-inl.h
index aa6e071da7..ef3601a01a 100644
--- a/src/operator/nn/dnnl/dnnl_transpose-inl.h
+++ b/src/operator/nn/dnnl/dnnl_transpose-inl.h
@@ -33,8 +33,6 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLTranspose(const NDArray& data);
-
 class DNNLTransposeFwd {
  public:
   std::shared_ptr<dnnl::memory> data_;
diff --git a/src/operator/nn/dnnl/dnnl_transpose.cc 
b/src/operator/nn/dnnl/dnnl_transpose.cc
index 342a79682e..2894fc76f2 100644
--- a/src/operator/nn/dnnl/dnnl_transpose.cc
+++ b/src/operator/nn/dnnl/dnnl_transpose.cc
@@ -31,16 +31,6 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLTranspose(const NDArray& data) {
-  auto data_ndim = data.shape().ndim();
-
-  if (data_ndim > 4 || data_ndim == 0 || data.shape().Size() == 0 ||
-      !(data.dtype() == mshadow::kFloat32 || data.dtype() == 
mshadow::kBfloat16))
-    return false;
-
-  return true;
-}
-
 typedef ParamOpSign<NumpyTransposeParam> DNNLTransposeSignature;
 
 DNNLTransposeFwd::DNNLTransposeFwd(const NumpyTransposeParam& param, const 
NDArray& data) {
diff --git a/src/operator/nn/dnnl/dnnl_where.cc 
b/src/operator/nn/dnnl/dnnl_where.cc
index 7d5ca4ad6b..8e225b471e 100644
--- a/src/operator/nn/dnnl/dnnl_where.cc
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -32,16 +32,9 @@
 namespace mxnet {
 namespace op {
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
 bool SupportDNNLWhere(const std::vector<NDArray>& inputs) {
-  static const std::set<int> supported_dtypes = {
-      mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8};
-  for (int i = 0; i < inputs.size(); ++i) {
-    if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size() 
<= 0 ||
-        inputs[i].shape().ndim() <= 0) {
-      return false;
-    }
-  }
-  return true;
+  return SupportDNNL<DNNLTypeMode::NoInt32, DNNLTensorsDtypes::Mixed>(inputs);
 }
 
 void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/fully_connected.cc 
b/src/operator/nn/fully_connected.cc
index 61e7ca5ea9..63c0be518f 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -31,11 +31,12 @@
 namespace mxnet {
 namespace op {
 
+#if MXNET_USE_ONEDNN == 1
+// Support for 
https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_inner_product.html
 bool SupportDNNLFC(const NDArray& input) {
-  int ndim = input.shape().ndim();
-  return (input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16) &&
-         (ndim >= 1 && ndim <= 4) && input.storage_type() == kDefaultStorage;
+  return SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(input);
 }
+#endif
 
 static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
                                 mxnet::ShapeVector* in_shape,
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index e63f09c71f..9e36a28a7a 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -40,7 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
-  if (SupportDNNLLogSoftmax(param, inputs[0], outputs[0])) {
+  if (SupportDNNLLogSoftmax(param, inputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLLogSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
     auto fn = SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>;
@@ -56,7 +56,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                        const std::vector<OpReqType>& req,
                                        const std::vector<NDArray>& outputs) {
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
-  if (SupportDNNLLogSoftmax(param, inputs[1], outputs[0])) {
+  if (SupportDNNLLogSoftmax(param, inputs[1])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLLogSoftmaxBackward, attrs, ctx, inputs, req, outputs);
     auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::left, 
mxnet_op::log_softmax_bwd>;
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index c121be2720..f67456ab81 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -107,7 +107,7 @@ void LRNComputeExCPU(const nnvm::NodeAttrs& attrs,
                      const std::vector<NDArray>& inputs,
                      const std::vector<OpReqType>& req,
                      const std::vector<NDArray>& outputs) {
-  if (SupportDNNL(inputs[0])) {
+  if (SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[0])) {
     // We only need to test one output array.
     DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
     DNNLRun(DNNLLRNForward, attrs, ctx, inputs[0], req[0], outputs[0]);
@@ -124,7 +124,7 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  if (SupportDNNL(inputs[0])) {
+  if (SupportDNNL<2, 5, DNNLTypeMode::FloatTypes>(inputs[0])) {
     DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     DNNLRun(DNNLLRNBackward, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(LRNGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/masked_softmax.cc 
b/src/operator/nn/masked_softmax.cc
index ee8c902833..fa17ec9f9f 100644
--- a/src/operator/nn/masked_softmax.cc
+++ b/src/operator/nn/masked_softmax.cc
@@ -43,7 +43,7 @@ static void MaskedSoftmaxComputeExCPU(const nnvm::NodeAttrs& 
attrs,
   if (inputs[0].shape().Size() == 0U)
     return;
   const MaskedSoftmaxParam& param = 
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
-  if (SupportDNNLMaskedSoftmax(param, inputs, outputs[0])) {
+  if (SupportDNNLMaskedSoftmax(param, inputs)) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
 
     DNNLRun(DNNLMaskedSoftmaxForward, attrs, ctx, inputs, req, outputs);
@@ -61,7 +61,6 @@ inline static bool MaskedSoftmaxStorageType(const 
nnvm::NodeAttrs& attrs,
                                             DispatchMode* dispatch_mode,
                                             std::vector<int>* in_attrs,
                                             std::vector<int>* out_attrs) {
-  const MaskedSoftmaxParam& param = 
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 2U);
   CHECK_EQ(out_attrs->size(), 1U);
 
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 0fb4e338cb..e1e696d7fc 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -41,7 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
-  if (SupportDNNLSoftmax(param, inputs[0], outputs[0])) {
+  if (SupportDNNLSoftmax(param, inputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
     auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
@@ -57,7 +57,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                     const std::vector<OpReqType>& req,
                                     const std::vector<NDArray>& outputs) {
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
-  if (SupportDNNLSoftmax(param, inputs[1], outputs[0])) {
+  if (SupportDNNLSoftmax(param, inputs[1])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLSoftmaxBackward, attrs, ctx, inputs, req, outputs);
     auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::mul, 
mxnet_op::softmax_bwd>;
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.h 
b/src/operator/numpy/np_broadcast_reduce_op_value.h
index b438a92872..e38fc32db9 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.h
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.h
@@ -199,7 +199,6 @@ static void DNNLReduceEx(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
-  const NumpyReduceAxesParam& param = 
nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
 
   if (SupportDNNLReduce<NumpyReduceAxesParam>(attrs, inputs[0], outputs[0])) {
     DNNLRun(DNNLReduceForward<NumpyReduceAxesParam, reduction_alg>,
diff --git a/src/operator/numpy/np_dot_forward.cc 
b/src/operator/numpy/np_dot_forward.cc
index 870486022c..0253a16792 100644
--- a/src/operator/numpy/np_dot_forward.cc
+++ b/src/operator/numpy/np_dot_forward.cc
@@ -110,7 +110,7 @@ static void NumpyDotComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                  const std::vector<NDArray>& inputs,
                                  const std::vector<OpReqType>& req,
                                  const std::vector<NDArray>& outputs) {
-  if (SupportDNNLDot(inputs, outputs[0])) {
+  if (SupportDNNLDot(inputs)) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLDotForward<true>, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/numpy/np_matrix_op.cc 
b/src/operator/numpy/np_matrix_op.cc
index 946527562a..7a3404df42 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -59,7 +59,7 @@ static void NumpyTransposeComputeExCPU(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
 
-  if (SupportDNNLTranspose(inputs[0]) && req[0] == kWriteTo) {
+  if (SupportDNNL(inputs[0]) && req[0] == kWriteTo) {
     DNNLRun(DNNLTransposeForward<NumpyTransposeParam>, attrs, ctx, inputs[0], 
req[0], outputs[0]);
     return;
   }
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc 
b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
index 0b14a96777..ab6083586d 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
@@ -44,7 +44,7 @@ static void DNNLQuantizedFlattenForward(const 
nnvm::NodeAttrs& attrs,
                                         const std::vector<NDArray>& inputs,
                                         const std::vector<OpReqType>& req,
                                         const std::vector<NDArray>& outputs) {
-  if (SupportDNNLReshape(inputs[0], outputs[0])) {
+  if (SupportDNNL(inputs[0])) {
     DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
   } else {
     FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc 
b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
index 344bda3f5e..57b987dfc6 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
@@ -37,7 +37,7 @@ static void DNNLQuantizedReshapeForward(const 
nnvm::NodeAttrs& attrs,
   CHECK(inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == 
mshadow::kInt8)
       << "dnnl_quantized_reshape op only supports uint8 and int8 as input 
type";
 
-  if (SupportDNNLReshape(inputs[0], outputs[0])) {
+  if (SupportDNNLReshape(inputs[0])) {
     OpReqType reqType;
     if (inputs[0].GetDNNLData()->get_data_handle() != 
outputs[0].GetDNNLData()->get_data_handle())
       reqType = kWriteTo;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc 
b/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
index 0a8e2aed1f..6be64c3254 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_transpose.cc
@@ -40,13 +40,9 @@ inline static bool QuantizedTransposeStorageType(const 
nnvm::NodeAttrs& attrs,
   return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, 
out_attrs);
 }
 
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
 bool SupportDNNLQuantizedTranspose(const NDArray& data) {
-  auto data_ndim = data.shape().ndim();
-
-  if (data_ndim > 4 || data_ndim == 0 || data.shape().Size() == 0)
-    return false;
-
-  return true;
+  return SupportDNNL<DNNLTypeMode::ByteTypes>(data);
 }
 typedef void (*TransposeFallbackFunAny)(const nnvm::NodeAttrs&,
                                         const OpContext&,
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 52a9888ea8..cb0715f508 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -153,7 +153,7 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs& attrs,
                                const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 2U);
   const SoftmaxOutputParam& param = 
nnvm::get<SoftmaxOutputParam>(attrs.parsed);
-  if (SupportDNNL(inputs[0]) && !ctx.is_train && 
SupportDNNLSoftmaxOutput(param)) {
+  if (SupportDNNLSoftmaxOutput(param, inputs[0]) && !ctx.is_train) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLSoftmaxOutputForward, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
diff --git a/src/operator/subgraph/dnnl/dnnl_conv_property.h 
b/src/operator/subgraph/dnnl/dnnl_conv_property.h
index 3d814c0598..7d747994b3 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_conv_property.h
@@ -114,7 +114,7 @@ class SgDNNLConvSelector : public SubgraphSelector {
       default:
         if ((!disable_conv_act_) && node_name == "Activation") {
           const ActivationParam& param = 
nnvm::get<ActivationParam>(new_node.attrs.parsed);
-          if ((quantize_ && SupportQuantizedDNNLAct(param)) ||
+          if ((quantize_ && SupportDNNLQuantizedAct(param)) ||
               (!quantize_ && SupportDNNLAct(param))) {
             matched_list_.push_back(&new_node);
             // not support conv+relu+sum yet.
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h 
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index 902444354f..5b567ee652 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -94,7 +94,7 @@ class SgDNNLFCSelector : public SubgraphSelector {
         // Currently, For INT8 FC fusion, only supports 
relu/bounded_relu(clip)/abs.
         if (new_node.op() == Op::Get("Activation")) {
           const ActivationParam& param = 
nnvm::get<ActivationParam>(new_node.attrs.parsed);
-          if ((quantized_ && SupportQuantizedDNNLAct(param)) ||
+          if ((quantized_ && SupportDNNLQuantizedAct(param)) ||
               (!quantized_ && SupportDNNLAct(param))) {
             matched_list_.push_back(&new_node);
             status_ = kSuccess;
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index 0d62cac808..60b55fb5f4 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -124,7 +124,7 @@ void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
                       const std::vector<NDArray>& inputs,
                       const std::vector<OpReqType>& req,
                       const std::vector<NDArray>& outputs) {
-  if (SupportDNNLDot(inputs, outputs[DotOut::out])) {
+  if (SupportDNNLDot(inputs)) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLDotForward<false>, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
@@ -138,7 +138,7 @@ static void BatchDotComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                  const std::vector<NDArray>& inputs,
                                  const std::vector<OpReqType>& req,
                                  const std::vector<NDArray>& outputs) {
-  if (SupportDNNLBatchDot(inputs, outputs[DotOut::out])) {
+  if (SupportDNNLBatchDot(inputs)) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLBatchDotForward<false>, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc 
b/src/operator/tensor/elemwise_binary_op_basic.cc
index ad29575cb3..2f03f90672 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -30,11 +30,12 @@
 namespace mxnet {
 namespace op {
 
-bool SupportDNNLSum(const NDArray& input) {
-  int ndim = input.shape().ndim();
-  return (input.dtype() == mshadow::kFloat32 || input.dtype() == 
mshadow::kBfloat16) &&
-         (ndim >= 1 && ndim <= 4) && input.storage_type() == kDefaultStorage;
+#if MXNET_USE_ONEDNN == 1
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_eltwise.html
+bool SupportDNNLSum(const std::vector<NDArray>& inputs) {
+  return SupportDNNL(inputs[0]) && SupportDNNL(inputs[1]);
 }
+#endif
 
 static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
                           const OpContext& ctx,
@@ -44,7 +45,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
 #if MXNET_USE_ONEDNN == 1
-  if (SupportDNNLSum(inputs[0]) && SupportDNNLSum(inputs[1])) {
+  if (SupportDNNLSum(inputs) && common::ContainsOnlyStorage(inputs, 
kDefaultStorage)) {
     DNNLRun(DNNLSumForward, attrs, ctx, inputs, req, outputs);
     return;
   } else if (inputs[0].storage_type() == kDefaultStorage &&
diff --git a/src/operator/tensor/elemwise_unary_op.h 
b/src/operator/tensor/elemwise_unary_op.h
index bc6011a97d..b675acf9f9 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -511,7 +511,7 @@ inline void EltwiseComputeExCPU(const nnvm::NodeAttrs& 
attrs,
                                 const std::vector<mxnet::NDArray>& outputs) {
   auto fallBackFunction =
       computeMixed ? UnaryOp::ComputeMixedType<cpu, OP> : 
UnaryOp::Compute<cpu, OP>;
-  if (SupportDNNLEltwise(inputs[0], outputs[0])) {
+  if (SupportDNNLEltwise(inputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(
         DNNLEltwiseForward<DNNLAlgorithm<OP>::value>, attrs, ctx, inputs[0], 
req[0], outputs[0]);
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 00bfe9bd51..2831789fe4 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -677,6 +677,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
 }
 
 // Currently DNNL only supports step = 1 or step has no value
+// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
 inline bool SupportDNNLSlice(const SliceParam& param) {
   if (param.step.ndim() == 0U)
     return true;
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 15f131cfcd..303011f058 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -127,7 +127,7 @@ void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
   // DNNL support the data type or the shape. Then convert
   // it to the output format and shape
 
-  if (SupportDNNLReshape(inputs[0], outputs[0])) {
+  if (SupportDNNLReshape(inputs[0])) {
     DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
   } else {
     FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
@@ -236,7 +236,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
   // If inputs are supposed to be in DNNL format and
   // DNNL support the data type or the shape. Then convert
   // it to the output format and shape
-  if (SupportDNNLReshape(inputs[0], outputs[0])) {
+  if (SupportDNNLReshape(inputs[0])) {
     DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
   } else {
     FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
@@ -317,7 +317,7 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
 
-  if (SupportDNNLTranspose(inputs[0]) && req[0] == kWriteTo) {
+  if (SupportDNNL(inputs[0]) && req[0] == kWriteTo) {
     DNNLRun(DNNLTransposeForward<TransposeParam>, attrs, ctx, inputs[0], 
req[0], outputs[0]);
     return;
   }
@@ -422,7 +422,7 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
   // If inputs are supposed to be in DNNL format and
   // DNNL support the data type or the shape. Then convert
   // it to the output format and shape
-  if (SupportDNNLReshape(inputs[0], outputs[0])) {
+  if (SupportDNNLReshape(inputs[0])) {
     DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
   } else {
     FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
@@ -473,6 +473,12 @@ will return a new array with shape ``(2,1,3,4)``.
     .add_argument("data", "NDArray-or-Symbol", "Source input")
     .add_arguments(ExpandDimParam::__FIELDS__());
 
+#if MXNET_USE_ONEDNN == 1
+bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const 
NDArray& output) {
+  return SupportDNNLSlice(param) && SupportDNNL(input) && SupportDNNL(output);
+}
+#endif
+
 void SliceExCPU(const nnvm::NodeAttrs& attrs,
                 const OpContext& ctx,
                 const std::vector<NDArray>& inputs,
@@ -486,7 +492,7 @@ void SliceExCPU(const nnvm::NodeAttrs& attrs,
     SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
 #if MXNET_USE_ONEDNN == 1
   } else if (in_stype == kDefaultStorage) {
-    if (SupportDNNL(inputs[0])) {
+    if (SupportDNNLSlice(param, inputs[0], outputs[0])) {
       DNNLRun(DNNLSlice, attrs, ctx, inputs[0], req[0], outputs[0]);
     } else {
       FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
@@ -1185,7 +1191,7 @@ static void SplitForwardEx(const nnvm::NodeAttrs& attrs,
                            const std::vector<OpReqType>& req,
                            const std::vector<NDArray>& outputs) {
   CHECK(!inputs.empty());
-  if (SupportDNNLSplit(inputs[0])) {
+  if (SupportDNNL(inputs[0])) {
     DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLSplitForward, attrs, op_ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(SplitOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);

Reply via email to