yzhliu closed pull request #9931: Add axes support to Dropout for variational dropout in NLP URL: https://github.com/apache/incubator-mxnet/pull/9931
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index cff35a3cef7..b57ab45891e 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout-inl.h * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_ @@ -37,6 +37,7 @@ #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../random/sampler.h" +#include "../tensor/elemwise_binary_broadcast_op.h" #if defined(USE_MKL) && defined(_OPENMP) #include <omp.h> @@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways}; namespace mxnet { namespace op { +const int MAX_DIM = 5; + struct DropoutParam : public dmlc::Parameter<DropoutParam> { float p; int mode; + TShape axes; DMLC_DECLARE_PARAMETER(DropoutParam) { DMLC_DECLARE_FIELD(p).set_default(0.5) .set_range(0, 1) @@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> { .add_enum("always", dropout::kAlways) .set_default(dropout::kTraining) .describe("Whether to only turn on dropout during training or to also turn on for inference."); + DMLC_DECLARE_FIELD(axes).set_default(TShape()) + .describe("Axes for variational dropout kernel."); } }; // struct DropoutParam @@ -205,10 +211,25 @@ class DropoutOp { }); } }; + struct BernoulliKernel { + /*! \brief Bernoulli kernel for generating mask */ + MSHADOW_XINLINE static void Map(int id, + RandGenerator<xpu, DType> gen, + const int N, + const int step, + DType *mask_out, + const real_t pkeep) { + RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { + const real_t rand_num = static_cast<real_t>(genImpl.uniform()); + mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep); + }); + } + }; void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode); + this->axes_ = param.axes; } void Forward(const OpContext &ctx, @@ -225,14 +246,46 @@ class DropoutOp { if (ctx.is_train || this->mode_ == dropout::kAlways) { RandGenerator<xpu, DType> *pgen = ctx.requested[0].get_parallel_random<xpu, DType>(); CHECK_NOTNULL(pgen); - if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { + if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { const TBlob &mask = out_data[dropout::kMask]; CHECK(req[dropout::kOut] != kAddTo); - LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(), + if (this->axes_.ndim() == 0) { + // standard case for dropout + LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(), out.dptr<DType>(), mask.dptr<DType>(), in_data[dropout::kData].dptr<DType>(), this->pkeep_); + return; + } + // initialize the mask + LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(), + mask.dptr<DType>(), + this->pkeep_); + // broadcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + mask.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( + s, out.Size(), out.dptr<DType>(), in_data[dropout::kData].dptr<DType>(), + mask.dptr<DType>()); + }); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); + mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); + mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, + mshadow_op::mul>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], + lstride, rstride, oshape, + in_data[dropout::kData].dptr<DType>(), + mask.dptr<DType>(), out.dptr<DType>()); + }); + } } } else { const TBlob& data = in_data[dropout::kData]; @@ -257,15 +310,40 @@ class DropoutOp { using namespace mshadow::expr; Stream<xpu> *s = ctx.get_stream<xpu>(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { + if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - CHECK_EQ(grad.Size(), mask.Size()); - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( - s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>()); - }); + if (this->axes_.ndim() == 0) { + // standard case for dropout + CHECK_EQ(grad.Size(), mask.Size()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( + s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>()); + }); + return; + } + // broardcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(grad.shape_, + mask.shape_, gdata.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( + s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>()); + }); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); + mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); + mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, + mshadow_op::mul>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>()); + }); + } } } else { const TBlob& gdata = in_grad[dropout::kData]; @@ -286,6 +364,7 @@ class DropoutOp { real_t pkeep_; /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; + TShape axes_; }; // class DropoutOp template<typename xpu> diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index dd5f1e58fbe..3021e0105b4 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h" @@ -93,10 +93,14 @@ Example:: std::vector<TShape> *in_shape, std::vector<TShape> *out_shape){ using namespace mshadow; CHECK_EQ(in_shape->size(), 1U); - const TShape &dshape = in_shape->at(0); + const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed); + TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); + for (index_t i = 0; i < param.axes.ndim(); ++i) { + dshape[param.axes[i]] = 1; + } out_shape->push_back(dshape); return true; }) diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index e655278822a..832490b08f1 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h" ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services