zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r162501065
 
 

 ##########
 File path: src/operator/tensor/matrix_op.cc
 ##########
 @@ -180,6 +182,51 @@ If the argument `reverse` is set to 1, then the special 
values are inferred from
 .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.")
 .add_arguments(ReshapeParam::__FIELDS__());
 
+static void FlattenEx(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+#if MXNET_USE_MKLDNN == 1
+  const auto in_stype = inputs[0].storage_type();
+  const auto out_stype = outputs[0].storage_type();
+  if (inputs[0].IsMKLDNNData()) {
+    MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]);
+    // If the output is a special MKLDNN layout and the number of dimensions
+    // is larger than 2, we should use the default layout.
+    if (outputs[0].IsMKLDNNData() && inputs[0].shape().ndim() > 2)
+      const_cast<NDArray &>(outputs[0]).Reorder2Default();
+    return;
+  } else {
+    // This happens if inputs are supposed to be in MKLDNN format
+    // but MKLDNN doesn't support the data type or the shape. We're
+    // forced to convert it to the default format.
+    FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
 
 Review comment:
   it's defined in src/operator/nn/mkldnn/mkldnn_base-inl.h

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to