apeforest commented on a change in pull request #14779: Fully connected, higher order grad URL: https://github.com/apache/incubator-mxnet/pull/14779#discussion_r307593378
########## File path: src/operator/nn/fully_connected-inl.h ########## @@ -249,6 +285,114 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs, } } + + +/// +// Inputs are: +// o_x_grad : head gradient for x_grad +// o_w_grad : head gradient for w_grad +// o_b_grad : if param.no_bias is false +// o_y : head gradient of y +// +// outputs are: +// o_y_grad : gradient of o_y +// x_grad_grad : o_y * o_w_grad +// w_grad_grad : o_y.T * o_x_grad +// b_grad_grad: if param.no_bias is false +// +// For implementation details see this PR: https://github.com/apache/incubator-mxnet/pull/14779 + +/** + * Second order gradient for Fully Connected + * x_grad_grad = o_y * o_w_grad + * w_grad_grad = o_y.T * o_x_grad + * + * @tparam xpu + * @tparam DType + * @param attrs + * @param ctx + * @param inputs + * @param req + * @param outputs + */ +template<typename xpu, typename DType> +void FullyConnectedGradGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace std; + using namespace fullc; + Stream<xpu> *stream = ctx.get_stream<xpu>(); + const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed); + const size_t num_inputs = param.no_bias ? 3U : 4U; + // outputs are: o_x_grad, o_w_grad, o_y || o_x_grad, o_w_grad, o_b_grad, o_y + const size_t num_outputs = 3U; + CHECK_EQ(inputs.size(), num_inputs); + CHECK_EQ(outputs.size(), num_outputs); + CHECK_EQ(req.size(), num_outputs); + + // inputs + Tensor<xpu, 2, DType> o_x_grad; + Tensor<xpu, 2, DType> o_w_grad; + Tensor<xpu, 2, DType> o_y; + // unused + // Tensor<xpu, 1, DType> o_b_grad; + + // outputs + Tensor<xpu, 2, DType> o_y_grad; + TBlob o_y_grad_blob = outputs[kOyGrad]; + Tensor<xpu, 2, DType> x_grad_grad; + Tensor<xpu, 2, DType> w_grad_grad; + Tensor<xpu, 1, DType> b_grad_grad; + size_t o_y_idx = std::numeric_limits<size_t>::max(); + if (param.no_bias) + o_y_idx = kOy; + else + o_y_idx = kOyBias; + if (!param.flatten) { + o_x_grad = FlattenAs2DHead<xpu, DType>(inputs[kOxGrad], ctx); + o_w_grad = inputs[kOwGrad].get<xpu, 2, DType>(stream); + o_y = FlattenAs2DHead<xpu, DType>(inputs[o_y_idx], ctx); + x_grad_grad = FlattenAs2DHead<xpu, DType>(outputs[kXGradGrad], ctx); + w_grad_grad = FlattenAs2DHead<xpu, DType>(outputs[kWGradGrad], ctx); + } else { + o_x_grad = FlattenAs2DTail<xpu, DType>(inputs[kOxGrad], ctx); + o_w_grad = FlattenAs2DTail<xpu, DType>(inputs[kOwGrad], ctx); + o_y = inputs[o_y_idx].get<xpu, 2, DType>(stream); + x_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kXGradGrad], ctx); + w_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kWGradGrad], ctx); + } + linalg_gemm(o_y, o_w_grad, x_grad_grad, false, false, stream); + linalg_gemm(o_y, o_x_grad, w_grad_grad, true, false, stream); + // 3rd order not supported + Fill(stream, o_y_grad_blob, kWriteTo, static_cast<DType>(0)); + /* TODO(larroy) bias is not supported yet as there's no bias input to backward. Bias grad grad is Review comment: @sxjscience Could you please review if this is correct? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services