samskalicky commented on a change in pull request #12430: [MXNET-882] Support 
for N-d arrays added to diag op.
URL: https://github.com/apache/incubator-mxnet/pull/12430#discussion_r215712026
 
 

 ##########
 File path: src/operator/tensor/diag_op-inl.h
 ##########
 @@ -159,21 +274,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[0].shape_;
   const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-  if (ishape.ndim() == 2) {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
-      });
-    });
-  } else {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
-      });
-    });
-  }
+  DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, 
out_data.Size(), param, s, req);
 
 Review comment:
   I like that you put all the diag setup work in the DiagOpProcess function, 
but now that leaves DiagOpForward as a function that doesnt actually do 
anything. It seems like an unnecessary additional function call to call 
DiagOpProcess. What do you think about inlining DiagOpProcess inside of 
DiagOpForward? Or is there some benefit to doing it this way that im missing?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to