apeforest commented on a change in pull request #16735: Use single-bit for mask 
in dropout operator
URL: https://github.com/apache/incubator-mxnet/pull/16735#discussion_r360796360
 
 

 ##########
 File path: src/operator/nn/dropout-inl.h
 ##########
 @@ -187,28 +188,101 @@ class DropoutOp {
                                     const index_t N,
                                     const index_t step,
                                     DType *dropout_out,
-                                    DType *mask_out,
+                                    uint8_t *mask_out,
                                     const DType *input_data,
                                     const real_t pkeep) {
       RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
         const real_t rand_num = static_cast<real_t>(genImpl.uniform());
-        mask_out[i] = mshadow_op::threshold_eq::Map<real_t>(rand_num, pkeep) * 
(1.0f / pkeep);
-        dropout_out[i] = input_data[i] * mask_out[i];
-      });
+        // mask_out is set per bit position
+        // therefore bitwise shift need to be performed here
+        auto maskIdx = i / 8;
+        auto maskOffset = i % 8;
+        bool maskVal = mshadow_op::threshold_eq::Map<real_t>(rand_num, pkeep);
+        if (maskVal) {
+          // set bit
+          mask_out[maskIdx] |= 1U << maskOffset;
+        } else {
+          // clear bit
+          mask_out[maskIdx] &= ~(1U << maskOffset);
+        }
+
+        // TODO (lnyuan): seems we can set dropout to zero if maskVal is False
+        // however doing this would break one unit test when pkeep is 0, 
expecting nan
+        // not sure why
+        dropout_out[i] = maskVal * input_data[i] * (1.0f / pkeep);
+      })
+    }
+  };
+
+  struct DropoutBackwardKernel {
+    MSHADOW_XINLINE static void Map(index_t i,
+                                    OpReqType req,
+                                    DType *igrad,
+                                    DType *ograd,
+                                    const uint8_t *mask,
+                                    const real_t pkeep) {
+      auto maskIdx = i / 8;
+      uint8_t maskOffset = i % 8;
+      bool maskVal = (mask[maskIdx] >> maskOffset) & 1U;
+      KERNEL_ASSIGN(igrad[i], req, maskVal * ograd[i] * (1 / pkeep));
     }
   };
+
   struct BernoulliKernel {
     /*! \brief Bernoulli kernel for generating mask */
     MSHADOW_XINLINE static void Map(index_t id,
                                     RandGenerator<xpu, DType> gen,
                                     const index_t N,
                                     const index_t step,
-                                    DType *mask_out,
+                                    DType *dropout_out,
+                                    uint8_t *mask_out,
                                     const real_t pkeep) {
       RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
         const real_t rand_num = static_cast<real_t>(genImpl.uniform());
-        mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * 
(1.0f / pkeep);
-      });
+        // mask_out is set per bit position
+        // therefore bitwise shift need to be performed here
+        auto maskIdx = i / 8;
 
 Review comment:
   After more careful checking, actually the race condition will not happen in 
this implementation. The reason is the `step` variable passed to the 
`Kernel<OP, xpu>::Launch()` is always a multiple of 8. And therefore, one 
thread will sequentially process a multiple of 8 items in the 
`RNG_KERNEL_LOOP`. 
   
   Nonetheless, this logic is not clearly seen from this implementation and has 
dependency on RandGenerator<xpu>::kMinNumRandomPerThread (currently is 64). I 
will try to make it more explicit in the code for better readability as well as 
bug-prone.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to