cjolivier01 commented on a change in pull request #9882: Add 
force_deterministic option for sparse embedding
URL: https://github.com/apache/incubator-mxnet/pull/9882#discussion_r170700191
 
 

 ##########
 File path: src/operator/tensor/indexing_op.cu
 ##########
 @@ -103,13 +162,125 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(const 
OpContext& ctx,
   }
 }
 
+inline void SparseEmbeddingOpBackwardDeterministicRspImpl(const OpContext& ctx,
+                                                          const TBlob& ograd,
+                                                          const TBlob& data,
+                                                          const OpReqType req,
+                                                          const NDArray& 
output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  dim_t num_rows = output.shape()[0];
+  dim_t row_length = output.shape()[1];
+  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+  if (data_size == 0) {
+    FillZerosRspImpl(s, output);
+    return;
+  }
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
+        // temp resource declarations
+        dim_t* lookup_table = NULL;
 
 Review comment:
   Can this huge chunk of code be pulled out into a template function so that 
it's steppable in the debugger?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to