samskalicky commented on a change in pull request #12430: [MXNET-882] Support
for N-d arrays added to diag op.
URL: https://github.com/apache/incubator-mxnet/pull/12430#discussion_r215712026
##########
File path: src/operator/tensor/diag_op-inl.h
##########
@@ -159,21 +274,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
const TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
- if (ishape.ndim() == 2) {
- MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
- MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
- Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
- in_data.dptr<DType>(), Shape2(ishape[0],
ishape[1]), param.k.value());
- });
- });
- } else {
- MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
- MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
- Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
- in_data.dptr<DType>(), Shape2(oshape[0],
oshape[1]), param.k.value());
- });
- });
- }
+ DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape,
out_data.Size(), param, s, req);
Review comment:
I like that you put all the diag setup work in the DiagOpProcess function,
but now that leaves DiagOpForward as a function that doesnt actually do
anything. It seems like an unnecessary additional function call to call
DiagOpProcess. What do you think about inlining DiagOpProcess inside of
DiagOpForward? Or is there some benefit to doing it this way that im missing?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services