cjolivier01 commented on a change in pull request #9882: Add force_deterministic option for sparse embedding URL: https://github.com/apache/incubator-mxnet/pull/9882#discussion_r170700191
########## File path: src/operator/tensor/indexing_op.cu ########## @@ -103,13 +162,125 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx, } } +inline void SparseEmbeddingOpBackwardDeterministicRspImpl(const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const OpReqType req, + const NDArray& output) { + using namespace mshadow; + using namespace mxnet_op; + using namespace expr; + using namespace rowsparse; + using nnvm::dim_t; + if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support " + << "weight gradient calculation with req != write"; + + Stream<gpu> *s = ctx.get_stream<gpu>(); + dim_t num_rows = output.shape()[0]; + dim_t row_length = output.shape()[1]; + dim_t data_size = static_cast<dim_t>(data.shape_.Size()); + if (data_size == 0) { + FillZerosRspImpl(s, output); + return; + } + + MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, { + // temp resource declarations + dim_t* lookup_table = NULL; Review comment: Can this huge chunk of code be pulled out into a template function so that it's steppable in the debugger? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services