yzhliu closed pull request #9931: Add axes support to Dropout for variational 
dropout in NLP
URL: https://github.com/apache/incubator-mxnet/pull/9931
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h
index cff35a3cef7..b57ab45891e 100644
--- a/src/operator/nn/dropout-inl.h
+++ b/src/operator/nn/dropout-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file dropout-inl.h
  * \brief
- * \author Bing Xu, Da Zheng
+ * \author Bing Xu, Da Zheng, Hang Zhang
 */
 
 #ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_
@@ -37,6 +37,7 @@
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
 #include "../random/sampler.h"
+#include "../tensor/elemwise_binary_broadcast_op.h"
 
 #if defined(USE_MKL) && defined(_OPENMP)
 #include <omp.h>
@@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways};
 namespace mxnet {
 namespace op {
 
+const int MAX_DIM = 5;
+
 struct DropoutParam : public dmlc::Parameter<DropoutParam> {
   float p;
   int mode;
+  TShape axes;
   DMLC_DECLARE_PARAMETER(DropoutParam) {
     DMLC_DECLARE_FIELD(p).set_default(0.5)
     .set_range(0, 1)
@@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
     .add_enum("always", dropout::kAlways)
     .set_default(dropout::kTraining)
     .describe("Whether to only turn on dropout during training or to also turn 
on for inference.");
+    DMLC_DECLARE_FIELD(axes).set_default(TShape())
+    .describe("Axes for variational dropout kernel.");
   }
 };  // struct DropoutParam
 
@@ -205,10 +211,25 @@ class DropoutOp {
       });
     }
   };
+  struct BernoulliKernel {
+    /*! \brief Bernoulli kernel for generating mask */
+    MSHADOW_XINLINE static void Map(int id,
+                                    RandGenerator<xpu, DType> gen,
+                                    const int N,
+                                    const int step,
+                                    DType *mask_out,
+                                    const real_t pkeep) {
+      RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
+        const real_t rand_num = static_cast<real_t>(genImpl.uniform());
+        mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * 
(1.0f / pkeep);
+      });
+    }
+  };
 
   void Init(const DropoutParam &param) {
     this->pkeep_ = 1.0f - param.p;
     this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
+    this->axes_ = param.axes;
   }
 
   void Forward(const OpContext &ctx,
@@ -225,14 +246,46 @@ class DropoutOp {
       if (ctx.is_train || this->mode_ == dropout::kAlways) {
         RandGenerator<xpu, DType> *pgen = 
ctx.requested[0].get_parallel_random<xpu, DType>();
         CHECK_NOTNULL(pgen);
-        if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
+        if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, 
in_data, out_data)) {
           const TBlob &mask = out_data[dropout::kMask];
           CHECK(req[dropout::kOut] != kAddTo);
-          LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
+          if (this->axes_.ndim() == 0) {
+            // standard case for dropout
+            LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
                                         out.dptr<DType>(),
                                         mask.dptr<DType>(),
                                         in_data[dropout::kData].dptr<DType>(),
                                         this->pkeep_);
+            return;
+          }
+          // initialize the mask
+          LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(),
+                                          mask.dptr<DType>(),
+                                          this->pkeep_);
+          // broadcast mul
+          TShape new_lshape, new_rshape, new_oshape;
+          int ndim = 
BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_,
+                                                 mask.shape_, out.shape_,
+                                                 &new_lshape, &new_rshape, 
&new_oshape);
+          if (!ndim) {
+            MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, {
+              mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, 
xpu>::Launch(
+                s, out.Size(), out.dptr<DType>(), 
in_data[dropout::kData].dptr<DType>(),
+                mask.dptr<DType>());
+            });
+          } else {
+            BROADCAST_NDIM_SWITCH(ndim, NDim, {
+              mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+              mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+              mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+              mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
+                               mshadow_op::mul>, xpu>::
+              template LaunchEx(s, new_oshape.Size(), req[dropout::kOut],
+              lstride, rstride, oshape,
+              in_data[dropout::kData].dptr<DType>(),
+              mask.dptr<DType>(), out.dptr<DType>());
+            });
+          }
         }
       } else {
         const TBlob& data = in_data[dropout::kData];
@@ -257,15 +310,40 @@ class DropoutOp {
     using namespace mshadow::expr;
     Stream<xpu> *s = ctx.get_stream<xpu>();
     if (ctx.is_train || mode_ == dropout::kAlways) {
-      if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
+      if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, 
out_data, out_grad)) {
         const TBlob &gdata = in_grad[dropout::kData];
         const TBlob &grad = out_grad[dropout::kOut];
         const TBlob &mask = out_data[dropout::kMask];
-        CHECK_EQ(grad.Size(), mask.Size());
-        MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, 
xpu>::Launch(
-            s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), 
mask.dptr<DType>());
-        });
+        if (this->axes_.ndim() == 0) {
+          // standard case for dropout
+          CHECK_EQ(grad.Size(), mask.Size());
+          MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, 
xpu>::Launch(
+              s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), 
mask.dptr<DType>());
+          });
+          return;
+        }
+        // broardcast mul
+        TShape new_lshape, new_rshape, new_oshape;
+        int ndim = BinaryBroadcastShapeCompact(grad.shape_,
+                                               mask.shape_, gdata.shape_,
+                                               &new_lshape, &new_rshape, 
&new_oshape);
+        if (!ndim) {
+          MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, 
xpu>::Launch(
+              s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), 
mask.dptr<DType>());
+          });
+        } else {
+          BROADCAST_NDIM_SWITCH(ndim, NDim, {
+            mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+            mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+            mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+            mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
+                             mshadow_op::mul>, xpu>::
+            template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+            grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>());
+          });
+        }
       }
     } else {
       const TBlob& gdata = in_grad[dropout::kData];
@@ -286,6 +364,7 @@ class DropoutOp {
   real_t pkeep_;
   /*! \brief Dropout mode */
   dropout::DropoutOpMode mode_;
+  TShape axes_;
 };  // class DropoutOp
 
 template<typename xpu>
diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc
index dd5f1e58fbe..3021e0105b4 100644
--- a/src/operator/nn/dropout.cc
+++ b/src/operator/nn/dropout.cc
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file dropout.cc
  * \brief
- * \author Bing Xu, Da Zheng
+ * \author Bing Xu, Da Zheng, Hang Zhang
 */
 
 #include "./dropout-inl.h"
@@ -93,10 +93,14 @@ Example::
       std::vector<TShape> *in_shape, std::vector<TShape> *out_shape){
   using namespace mshadow;
   CHECK_EQ(in_shape->size(), 1U);
-  const TShape &dshape = in_shape->at(0);
+  const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
+  TShape dshape(in_shape->at(0));
   if (dshape.ndim() == 0) return false;
   out_shape->clear();
   out_shape->push_back(dshape);
+  for (index_t i = 0; i < param.axes.ndim(); ++i) {
+    dshape[param.axes[i]] = 1;
+  }
   out_shape->push_back(dshape);
   return true;
 })
diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu
index e655278822a..832490b08f1 100644
--- a/src/operator/nn/dropout.cu
+++ b/src/operator/nn/dropout.cu
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file dropout.cc
  * \brief
- * \author Bing Xu, Da Zheng
+ * \author Bing Xu, Da Zheng, Hang Zhang
 */
 
 #include "./dropout-inl.h"


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to