TaoLv commented on a change in pull request #17170: add mkldnn softmax backward 
URL: https://github.com/apache/incubator-mxnet/pull/17170#discussion_r361568985
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_softmax.cc
 ##########
 @@ -86,6 +98,33 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
   stream->Submit();
 }
 
+void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs,
+                          const OpContext &ctx,
+                          const std::vector<NDArray> &in_data,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray> &out_data) {
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(in_data.size(), 2U);
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
+  auto diff_mem = in_data[0].GetMKLDNNData();
+  auto data_mem = in_data[1].GetMKLDNNData();
+  auto fwd_pd = GetSoftmaxFwdPd(ctx.is_train, axis, *data_mem);
+  auto bwd_pd = GetSoftmaxBwdPd(*diff_mem, *data_mem, axis, fwd_pd);
+
+  auto out_mem = CreateMKLDNNMem(out_data[0], bwd_pd.diff_src_desc(), req[0]);
+  MKLDNNStream *stream = MKLDNNStream::Get();
+  mkldnn_args_map_t args = {
+    { MKLDNN_ARG_DST, *data_mem },
+    { MKLDNN_ARG_DIFF_DST, *diff_mem },
+    { MKLDNN_ARG_DIFF_SRC, *out_mem.second },
+  };
+
+  stream->RegisterPrimArgs(bwd_pd, args);
 
 Review comment:
   I mean here you need give a primitive not a primitive descriptor. Please 
check the definition of RegisterPrimArgs.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to