zheng-da commented on a change in pull request #14090: Add GPU version of 
contrib.boolean_mask
URL: https://github.com/apache/incubator-mxnet/pull/14090#discussion_r255791426
 
 

 ##########
 File path: src/operator/contrib/boolean_mask.cc
 ##########
 @@ -75,9 +74,118 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& 
attrs,
   return true;
 }
 
+struct BooleanMaskForwardCPUKernel {
+  template<typename DType>
+  static void Map(int i,
+                  DType* out,
+                  const DType* data,
+                  const int32_t* idx,
+                  const size_t col_size) {
+    // i is row id already
+    int32_t prev = (i == 0) ? 0 : idx[i - 1];
+    int32_t curr = idx[i];
+    if (prev != curr) {
+      std::memcpy(out + prev * col_size, data + i * col_size, col_size * 
sizeof(DType));
+    }
+  }
+};
+
+struct BooleanMaskBackwardCPUKernel {
+  template<typename DType>
+  static void Map(int i,
+                  DType* igrad,
+                  const DType* ograd,
+                  const int32_t* idx,
+                  const size_t col_size) {
+    // i is row id already
+    int32_t prev = (i == 0) ? 0 : idx[i - 1];
+    int32_t curr = idx[i];
+    if (prev != curr) {
+      std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * 
sizeof(DType));
+    }
+  }
+};
+
+template<>
+inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<NDArray> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &outputs) {
+  // TODO(@junrushao1994): This implementation is a proof-of-concept,
+  // hence very slow actually. Performance should be improved in the future.
 
 Review comment:
   please remove the comment here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to