sxjscience commented on a change in pull request #19460:
URL: https://github.com/apache/incubator-mxnet/pull/19460#discussion_r530527126
##########
File path: src/operator/nn/softmax-inl.h
##########
@@ -907,6 +1632,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu, typename OP1, typename OP2, bool negate = false>
+void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ // set zeros in mask gradient
+ if (req[1] != kNullOp) {
+ mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
+ ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<bool>());
+ }
+ if (req[0] == kNullOp) return;
+ const MaskedSoftmaxParam& param =
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
+ int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double scale = param.scale_factor.has_value() ?
+ param.scale_factor.value() : 1.0;
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
+
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+ MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+ MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
+ OType* ograd_ptr = inputs[0].dptr<OType>();
+ OType* out_ptr = inputs[3].dptr<OType>();
+ bool* mask_ptr = inputs[2].dptr<bool>();
+ DType* grad_data = outputs[0].dptr<DType>();
+ if (safe_acc) {
+ MaskedSoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), out_ptr,
+ ograd_ptr, grad_data, mask_ptr,
+ inputs[1].shape_.get<ndim>(), inputs[2].shape_.get<ndim>(),
+ axis, static_cast<DType>(scale),
+ static_cast<DType>(temperature), ctx);
+ } else {
+ MaskedSoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ ctx.get_stream<xpu>(), out_ptr,
+ ograd_ptr, grad_data, mask_ptr,
+ inputs[1].shape_.get<ndim>(), inputs[2].shape_.get<ndim>(),
Review comment:
`inputs[1]` is only used for providing the shape so it can be safely
removed.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]