apeforest commented on a change in pull request #14779: [WIP] Fully connected, higher order grad URL: https://github.com/apache/incubator-mxnet/pull/14779#discussion_r305545160
########## File path: src/operator/nn/fully_connected-inl.h ########## @@ -249,6 +285,99 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs, } } + + +/// +// Inputs are: +// o_x_grad : head gradient for x_grad +// o_w_grad : head gradient for w_grad +// o_b_grad : if param.no_bias is false +// o_y : head gradient of y +// +// outputs are: +// o_y_grad : gradient of o_y +// x_grad_grad : o_y * o_w_grad +// w_grad_grad : o_y.T * o_x_grad +// +// For implementation details see this PR: https://github.com/apache/incubator-mxnet/pull/14779 + +/** + * Second order gradient for Fully Connected + * x_grad_grad = o_y * o_w_grad + * w_grad_grad = o_y.T * o_x_grad + * + * @tparam xpu + * @tparam DType + * @param attrs + * @param ctx + * @param inputs + * @param req + * @param outputs + */ +template<typename xpu, typename DType> +void FullyConnectedGradGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace std; + using namespace fullc; + Stream<xpu> *stream = ctx.get_stream<xpu>(); + const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed); + const size_t num_inputs = param.no_bias ? 3U : 4U; + CHECK_EQ(inputs.size(), num_inputs); // o_x_grad, o_w_grad, o_y + CHECK_EQ(outputs.size(), 3U); + CHECK_EQ(req.size(), 3U); + + // inputs + Tensor<xpu, 2, DType> o_x_grad; + Tensor<xpu, 2, DType> o_w_grad; + Tensor<xpu, 2, DType> o_y; + Tensor<xpu, 2, DType> o_b_grad; + + // outputs + Tensor<xpu, 2, DType> o_y_grad; + Tensor<xpu, 2, DType> x_grad_grad; + Tensor<xpu, 2, DType> w_grad_grad; + size_t o_y_idx = std::numeric_limits<size_t>::max(); + if (param.no_bias) + o_y_idx = k_o_y; + else + o_y_idx = k_o_y_bias; + if (!param.flatten) { + o_x_grad = FlattenAs2DHead<xpu, DType>(inputs[k_o_x_grad], ctx); + o_w_grad = inputs[k_o_w_grad].get<xpu, 2, DType>(stream); + o_y = FlattenAs2DHead<xpu, DType>(inputs[o_y_idx], ctx); + } else { + o_x_grad = FlattenAs2DTail<xpu, DType>(inputs[k_o_x_grad], ctx); + o_w_grad = FlattenAs2DTail<xpu, DType>(inputs[k_o_w_grad], ctx); + o_y = inputs[o_y_idx].get<xpu, 2, DType>(stream); + o_y_grad = outputs[k_o_y_grad].get<xpu, 2, DType>(stream); + x_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[k_x_grad_grad], ctx); + w_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[k_w_grad_grad], ctx); + } + linalg_gemm(o_y, o_w_grad, x_grad_grad, false, false, stream); + linalg_gemm(o_y, o_x_grad, w_grad_grad, true, false, stream); + if (! param.no_bias) { + // TODO(larroy) + } +} + + +template<typename xpu> +void FullyConnectedGradGradDtypeDispatch( Review comment: nit: DType? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
