mozga-intel commented on a change in pull request #20621:
URL: https://github.com/apache/incubator-mxnet/pull/20621#discussion_r722715010



##########
File path: src/operator/tensor/matrix_op.cc
##########
@@ -930,7 +930,38 @@ NNVM_REGISTER_OP(_backward_reverse)
                                 })
     .set_attr<FCompute>("FCompute<cpu>", ReverseOpForward<cpu>);
 
+#if MXNET_USE_ONEDNN == 1
+static void StackForwardEx(const nnvm::NodeAttrs& attrs,
+                           const OpContext& op_ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
+  CHECK(!inputs.empty());
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp)

Review comment:
       To be consistent with the rest part of this file, please add the 
brackets for a given if()

##########
File path: src/operator/nn/mkldnn/mkldnn_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_stack.cc
+ */
+
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_concat-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;

Review comment:
       The same here.

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)

Review comment:
       {}

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)

Review comment:
       {} - to be consistent.

##########
File path: src/operator/nn/mkldnn/mkldnn_concat.cc
##########
@@ -65,7 +65,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs,
   TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   const int num_in_data    = param.num_args;
-  const int concat_dim     = param.dim;
+  const int concat_dim     = CheckAxis(param.dim, out_data[0].shape().ndim());

Review comment:
       The name of this function tells me that I will check an axis. But, it 
looks like that I want to get its value. How about renaming this function?

##########
File path: src/operator/nn/dnnl/dnnl_concat-inl.h
##########
@@ -52,13 +52,18 @@ class DNNLConcatFwd {
 
 static DNNLConcatFwd& GetConcatForward(int concat_dim,
                                        const std::vector<NDArray>& in_data,
-                                       const std::vector<dnnl::memory::desc>& 
data_md) {
+                                       const std::vector<dnnl::memory::desc>& 
data_md,
+                                       int cache_dim = -1) {

Review comment:
       What does the cache_dim tell me?

##########
File path: src/operator/nn/mkldnn/mkldnn_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_stack.cc
+ */
+
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_concat-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)

Review comment:
       Please use {} for a given "if"

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();
+  const int src_dtype     = in_data[0].dtype();
+  const int dst_dtype     = out_data[0].dtype();
+  int leading             = 1;

Review comment:
       Does the name of this variable is complete? Might be better to add 
leading/trailing_XXXXX?

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)

Review comment:
       {} - to be consistent.

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();
+  const int src_dtype     = in_data[0].dtype();
+  const int dst_dtype     = out_data[0].dtype();
+  int leading             = 1;
+  int trailing            = 1;
+
+  for (int i = 0; i < axis; ++i) {
+    leading *= oshape[i];
+  }
+  for (int i = axis + 1; i < oshape.ndim(); ++i) {
+    trailing *= oshape[i];
+  }
+  int mid = oshape[axis];
+
+  std::vector<dnnl::memory::desc> data_md;
+  std::vector<dnnl::memory> data_mem;
+  dnnl::memory::desc in_md(
+      {leading, 1, trailing}, get_dnnl_type(src_dtype), 
dnnl::memory::format_tag::abc);
+  dnnl::memory::desc out_md(
+      {leading, mid, trailing}, get_dnnl_type(dst_dtype), 
dnnl::memory::format_tag::any);
+
+  const int num_in_data = in_data.size();
+  data_md.reserve(num_in_data);
+  data_mem.reserve(num_in_data);
+
+  MSHADOW_TYPE_SWITCH(src_dtype, DType, {
+    for (int i = 0; i < num_in_data; i++) {
+      NDArray tmp = in_data[i].IsDNNLData() ? in_data[i].Reorder2Default() : 
in_data[i];
+      dnnl::memory tmp_mem(in_md, CpuEngine::Get()->get_engine(), 
tmp.data().dptr<DType>());
+      data_mem.emplace_back(tmp_mem);
+      data_md.emplace_back(in_md);
+    }
+  });
+
+  auto& fwd = GetConcatForward(1, in_data, data_md, axis);

Review comment:
       Please assign "1" to a variable to improve the readability of this one.

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();

Review comment:
       Does the shape of this out_data might be volatile? ~ [if no]: please use 
the type explicitly.

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();
+  const int src_dtype     = in_data[0].dtype();
+  const int dst_dtype     = out_data[0].dtype();
+  int leading             = 1;
+  int trailing            = 1;
+
+  for (int i = 0; i < axis; ++i) {
+    leading *= oshape[i];
+  }
+  for (int i = axis + 1; i < oshape.ndim(); ++i) {
+    trailing *= oshape[i];
+  }
+  int mid = oshape[axis];

Review comment:
       Does the mid point at the middle of the axis? And please place it at the 
beginning of this function: line 67.

##########
File path: src/operator/tensor/matrix_op.cc
##########
@@ -930,7 +930,38 @@ NNVM_REGISTER_OP(_backward_reverse)
                                 })
     .set_attr<FCompute>("FCompute<cpu>", ReverseOpForward<cpu>);
 
+#if MXNET_USE_ONEDNN == 1
+static void StackForwardEx(const nnvm::NodeAttrs& attrs,
+                           const OpContext& op_ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
+  CHECK(!inputs.empty());
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp)
+    return;
+
+  if (SupportDNNLStack(inputs)) {
+    DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    DNNLRun(DNNLStackForward, attrs, op_ctx, inputs, req, outputs);
+    DNNL_OPCHECK_RUN(StackOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
+  } else {
+    FallBackCompute(StackOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
+  }
+}
+
+inline static bool StackInferStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_attrs,
+                                         std::vector<int>* out_attrs) {
+  return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, 
out_attrs);

Review comment:
       ```suggestion
     return DNNLStorageType(attrs, dev_mask, /*support_dnnl*/ true, 
dispatch_mode, in_attrs, out_attrs);
   ```

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();
+  const int src_dtype     = in_data[0].dtype();
+  const int dst_dtype     = out_data[0].dtype();
+  int leading             = 1;
+  int trailing            = 1;
+
+  for (int i = 0; i < axis; ++i) {
+    leading *= oshape[i];
+  }
+  for (int i = axis + 1; i < oshape.ndim(); ++i) {
+    trailing *= oshape[i];
+  }
+  int mid = oshape[axis];
+
+  std::vector<dnnl::memory::desc> data_md;
+  std::vector<dnnl::memory> data_mem;
+  dnnl::memory::desc in_md(
+      {leading, 1, trailing}, get_dnnl_type(src_dtype), 
dnnl::memory::format_tag::abc);

Review comment:
       1 - > please add the name of it. "mid_in" or... 

##########
File path: src/operator/nn/dnnl/dnnl_concat-inl.h
##########
@@ -52,13 +52,18 @@ class DNNLConcatFwd {
 
 static DNNLConcatFwd& GetConcatForward(int concat_dim,
                                        const std::vector<NDArray>& in_data,
-                                       const std::vector<dnnl::memory::desc>& 
data_md) {
+                                       const std::vector<dnnl::memory::desc>& 
data_md,
+                                       int cache_dim = -1) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> 
fwds;
 #else
   static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, 
OpHash> fwds;
 #endif
+  if (cache_dim == -1) {
+    cache_dim = concat_dim;

Review comment:
       Could you please tell me what is an advantage of adding the same value 
here (twofold: cache_dim if cache_dim == 1)?

##########
File path: src/operator/tensor/matrix_op.cc
##########
@@ -930,7 +930,38 @@ NNVM_REGISTER_OP(_backward_reverse)
                                 })
     .set_attr<FCompute>("FCompute<cpu>", ReverseOpForward<cpu>);
 
+#if MXNET_USE_ONEDNN == 1
+static void StackForwardEx(const nnvm::NodeAttrs& attrs,
+                           const OpContext& op_ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
+  CHECK(!inputs.empty());
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp)
+    return;
+
+  if (SupportDNNLStack(inputs)) {
+    DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);

Review comment:
       Here is better to say what the false means. 
   ```suggestion
       DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, 
outputs);
   ```

##########
File path: src/operator/nn/mkldnn/mkldnn_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_stack.cc
+ */
+
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_concat-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {
+    if (arr.dtype() != src_dtype) {
+      return false;
+    }
+    // DO not support zero-size tensors.
+    if (arr.shape().Size() == 0)
+      return false;
+    int ndim = arr.shape().ndim();
+    if (ndim <= 0)
+      return false;
+  }
+  return true;
+}
+
+void MKLDNNStackForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<NDArray>& in_data,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<NDArray>& out_data) {
+  TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
+
+  const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+  const int axis          = CheckAxis(param.axis, out_data[0].shape().ndim());
+  const auto oshape       = out_data[0].shape();
+  const int src_dtype     = in_data[0].dtype();
+  const int dst_dtype     = out_data[0].dtype();
+  int leading             = 1;

Review comment:
       What does the leading/trailing name symbolize?

##########
File path: src/operator/nn/dnnl/dnnl_stack.cc
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_stack.cc
+ */
+
+#include "./dnnl_base-inl.h"
+#include "./dnnl_concat-inl.h"
+#include "./dnnl_ops-inl.h"
+
+#include "../../tensor/matrix_op-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
+  if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != 
mshadow::kBfloat16)
+    return false;
+
+  int src_dtype = inputs[0].dtype();
+  for (auto& arr : inputs) {

Review comment:
       const auto&




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to