This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 87d69fa  Masked softmax (& log-softmax) (#19460)
87d69fa is described below

commit 87d69faa8fd98a7565e2e3bb16aa4225ad614e4b
Author: Moises Hernandez <[email protected]>
AuthorDate: Thu Dec 3 20:12:25 2020 -0800

    Masked softmax (& log-softmax) (#19460)
    
    * Add masked-softmax operator. Includes scaling factor and optional 
normalization
    
    * Backwards and Log-Softmax
    
    * fix lint
    
    * Including mask broadcasting
    
    * change name masked_log_softmax
    
    * Include CPU implementation
    
    * Mask as boolean dtype, and vectorize loading
    
    * Custom BWD node removing mask gradient
    
    * Remove OType (use DType) and fix type inferring
    
    * fix lint
    
    * update AMP list
    
    * compare all grads in masked_log_softmax
    
    * remove unused variable
---
 python/mxnet/amp/lists/symbol_bf16.py  |   2 +
 python/mxnet/amp/lists/symbol_fp16.py  |   2 +
 src/operator/nn/log_softmax.cc         |  57 +++
 src/operator/nn/log_softmax.cu         |   7 +
 src/operator/nn/softmax-inl.h          | 742 +++++++++++++++++++++++++++++++++
 src/operator/nn/softmax.cc             |  61 +++
 src/operator/nn/softmax.cu             |   5 +
 tests/python/unittest/test_operator.py |  88 ++++
 8 files changed, 964 insertions(+)

diff --git a/python/mxnet/amp/lists/symbol_bf16.py 
b/python/mxnet/amp/lists/symbol_bf16.py
index 306635a..e7f14fa 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -478,6 +478,8 @@ FP32_FUNCS = [
     'softmax',
     'Softmax',
     'log_softmax',
+    'masked_softmax',
+    'masked_log_softmax',
     'InstanceNorm',
     'LayerNorm',
     'GroupNorm',
diff --git a/python/mxnet/amp/lists/symbol_fp16.py 
b/python/mxnet/amp/lists/symbol_fp16.py
index 4b54448..7242a70 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -576,6 +576,8 @@ FP32_FUNCS = [
     # Neural network
     'softmax',
     'log_softmax',
+    'masked_softmax',
+    'masked_log_softmax',
     'InstanceNorm',
     'LayerNorm',
     'GroupNorm',
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 28ae8cf..6aae7e9 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -159,5 +159,62 @@ NNVM_REGISTER_OP(_backward_log_softmax)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
                                                         
mxnet_op::log_softmax_bwd>);
 
+NNVM_REGISTER_OP(masked_log_softmax)
+.add_alias("_npx_masked_log_softmax")
+.describe(R"code(Computes the masked log softmax of the input.
+This is equivalent to computing masked softmax followed by log.)code")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::string>{"data", "mask"};
+  })
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, 
mxnet_op::log_softmax_fwd>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+    auto data_grad = MakeNode("_backward_masked_log_softmax", n->attrs.name + 
"_backward_data",
+                              {ograds[0], n->inputs[1], nnvm::NodeEntry(n, 0, 
0)},
+                              &n->attrs.dict, &n);
+    auto mask_grad = MakeNode("zeros_like", n->attrs.name + "_backward_mask",
+                              {n->inputs[1]}, nullptr, &n);
+    std::vector<nnvm::NodeEntry> ret;
+    ret.emplace_back(data_grad);
+    ret.emplace_back(mask_grad);
+    return ret;
+  })
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxOpType)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxOpShape)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("mask", "NDArray-or-Symbol", "Mask to apply.")
+.add_arguments(MaskedSoftmaxParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_masked_log_softmax)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::string>{"ograd", "mask", "output"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxGradOpType)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", 
MaskedSoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxGradCompute<cpu, 
mshadow_op::left,
+                                                              
mxnet_op::log_softmax_bwd>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 8bff277..2a54cd3 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -35,5 +35,12 @@ NNVM_REGISTER_OP(_backward_log_softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow_op::left,
                                                         
mxnet_op::log_softmax_bwd>);
 
+NNVM_REGISTER_OP(masked_log_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, 
mxnet_op::log_softmax_fwd>);
+
+NNVM_REGISTER_OP(_backward_masked_log_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, 
mshadow_op::left,
+                                                              
mxnet_op::log_softmax_bwd>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 7806eaf..b53b8a4 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -36,6 +36,8 @@
 #include "../tensor/broadcast_reduce_op.h"
 #include "../../common/cuda/utils.h"
 
+using mshadow::red::limits::MinValue;
+
 namespace mxnet {
 namespace op {
 namespace mxnet_op {
@@ -161,6 +163,49 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out, 
IType *length,
   }
 }
 
+struct masked_softmax_where_scale {
+  template<typename DType, int ndim>
+  MSHADOW_XINLINE static void Map(index_t id, DType* out, const bool* cond,
+                                  const DType* x, const double y,
+                                  Shape<ndim> data_shape, Shape<ndim> 
mask_shape,
+                                  const double scale) {
+    index_t mask_pos = 0;
+    index_t stride = 1;
+    for (index_t i = ndim-1, j = id; i >=0; --i) {
+      auto tmp = j / data_shape[i];
+      if (mask_shape[i] != 1) {
+        mask_pos += (j - tmp  * mask_shape[i]) * stride;
+      }
+      stride *= mask_shape[i];
+      j = tmp;
+    }
+    KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] / 
static_cast<DType>(scale) :
+                                                       static_cast<DType>(y)));
+  }
+};
+
+template<typename OP, bool negate, typename AType, typename DType, int ndim>
+inline void MaskedSoftmax(Stream<cpu> *s, DType *in, DType *out, bool *mask,
+                          Shape<ndim> data_shape, Shape<ndim> mask_shape,
+                          int axis, const double scale,
+                          const double temperature, bool normalize,
+                          const OpContext& ctx) {
+  Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, 
DType>(
+      Shape1(data_shape.Size()), s);
+  DType* masked_scaled_input = TBlob(workspace).dptr<DType>();
+
+  double neg = MinValue<DType>();
+  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), 
masked_scaled_input,
+                                                  mask, in, neg, data_shape, 
mask_shape,
+                                                  scale);
+  int* max_lenghts = nullptr;
+  Softmax<OP, negate, AType, DType>(s, masked_scaled_input, out, max_lenghts,
+                                    data_shape, axis, temperature);
+  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), out,
+                                                  mask, out, 0.0, data_shape, 
mask_shape,
+                                                  1.0);
+}
+
 struct softmax_bwd {
   template<typename DType, typename AType>
   MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
@@ -258,6 +303,28 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType 
*ograd,
   }
 }
 
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int 
ndim,
+         typename DType>
+inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
+                              DType *igrad, bool *mask, Shape<ndim> data_shape,
+                              Shape<ndim> mask_shape, int axis,
+                              const double scale, const double temperature,
+                              const OpContext& ctx) {
+  Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, 
DType>(
+    Shape1(data_shape.Size()), s);
+  DType* masked_ograd = TBlob(workspace).dptr<DType>();
+  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), 
masked_ograd,
+                                                  mask, ograd, 0.0, 
data_shape, mask_shape,
+                                                  1.0);
+  int* max_lenghts = nullptr;
+  SoftmaxGrad<OP1, OP2, Req, negate, AType, DType, DType, int, ndim>(
+      s, out, masked_ograd, igrad,
+      max_lenghts, data_shape,
+      axis, temperature);
+  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), igrad,
+                                                  mask, igrad, 0.0, 
data_shape, mask_shape,
+                                                  scale);
+}
 
 #ifdef __CUDACC__
 template<int x_bits, typename OP, bool negate, typename AType, int ndim,
@@ -397,6 +464,202 @@ __global__ void softmax_stride1_compute_kernel(const 
DType *in, OType *out, ITyp
   }
 }
 
+template<int ndim>
+MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const 
Shape<ndim>& data_shape,
+  const Shape<ndim>& mask_shape, int axis, index_t* stride_axis) {
+  index_t ret = 0;
+  index_t stride = 1;
+  *stride_axis = 1;
+  #pragma unroll
+  for (index_t i = ndim-1, j = idx; i >=0; --i) {
+    auto tmp = j / data_shape[i];
+    if (i != axis && mask_shape[i] != 1) {
+      ret += (j - tmp  * mask_shape[i]) * stride;
+      if (i > axis)
+        *stride_axis *= mask_shape[i];
+    }
+    stride *= mask_shape[i];
+    j = tmp;
+  }
+  return ret;
+}
+
+template<bool normalize, int x_bits, typename OP, bool negate, typename AType,
+         int ndim, typename DType>
+__global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
+                                      index_t M, int axis, Shape<ndim> sshape,
+                                      Shape<ndim> stride, Shape<ndim> 
mask_shape,
+                                      const double scale, const double 
temperature) {
+  extern __shared__ double shared[];
+  AType* smem = reinterpret_cast<AType*>(shared);  // x_size
+
+  const unsigned x_size = 1 << x_bits;
+  index_t sa = stride[axis];
+  index_t base = unravel_dot(blockIdx.x, sshape, stride);
+  index_t sa_mask = 0;
+  index_t base_mask = get_mask_position(blockIdx.x, sshape, mask_shape, axis, 
&sa_mask);
+  bool bcst_mask_axis = (mask_shape[axis] == 1);
+  index_t x = threadIdx.x;
+
+  DType smax = 0.0;
+  if (normalize) {
+    red::maximum::SetInitValue(smem[x]);
+    for (index_t i = x; i < M; i += x_size) {
+      bool mask_value = bcst_mask_axis ? in_mask[base_mask] : 
in_mask[base_mask + i*sa_mask];
+      if (mask_value)
+        smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+    }
+    __syncthreads();
+    cuda::Reduce1D<red::maximum, x_bits>(smem);
+    __syncthreads();
+    smax = smem[0] / scale;
+    __syncthreads();
+  }
+
+  red::sum::SetInitValue(smem[x]);
+  DType val;
+  for (index_t i = x; i < M; i += x_size) {
+    bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask 
+ i*sa_mask];
+    if (mask_value) {
+      val = (negate ? -in[base + i*sa]:in[base + i*sa]) / scale;
+      smem[x] += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
+    }
+  }
+  __syncthreads();
+  cuda::Reduce1D<red::sum, x_bits>(smem);
+  __syncthreads();
+  AType ssum = smem[0];
+  __syncthreads();
+
+  for (index_t i = x; i < M; i += x_size) {
+    val = (negate ? -in[base + i*sa] : in[base + i*sa]) / scale;
+    bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask 
+ i*sa_mask];
+    out[base + i*sa] =
+      mask_value ? DType(OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum)) :
+                             DType(0.0f);
+  }
+}
+
+template<bool normalize, typename OP, bool negate, typename AType, typename 
LType,
+         typename LTypeMask, typename DType, int ndim>
+__global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, 
bool *in_mask,
+                                              const index_t M, int axis, 
Shape<ndim> sshape,
+                                              Shape<ndim> mask_shape, const 
double scale,
+                                              const double temperature, const 
int rows_per_block,
+                                              const index_t total_rows,
+                                              const size_t size_input_shared,
+                                              const size_t size_mask_shared) {
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  const int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+  const int row_length_mask =  entries_per_load > 0 ? M / 
entries_per_load_mask : 0;
+  extern __shared__ double shared[];
+  LType* persistent_storage = reinterpret_cast<LType*>(shared);
+  // rows_per_block * M (DType), aligned to double
+  LTypeMask* mask_shared = 
reinterpret_cast<LTypeMask*>(&shared[size_input_shared]);
+  // rows_per_block * M (bool), aligned to double
+  AType* scratch = reinterpret_cast<AType*>(&shared[size_input_shared + 
size_mask_shared]);
+  // softmax_threads_per_block
+
+  const int warp_size = 32;
+  const int threads_per_row = softmax_threads_per_block / rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int my_row = blockIdx.x * rows_per_block + my_local_row;
+  if (my_row >= total_rows) return;
+  const int my_id = threadIdx.x % threads_per_row;
+  size_t base = my_row * row_length;
+  index_t pos_mask = 0;
+  index_t stride = mask_shape[axis];
+  #pragma unroll
+  for (index_t i = axis-1, j = my_row; i >=0; --i) {
+    auto tmp = j / sshape[i];
+    if (mask_shape[i]  != 1) {
+      pos_mask += (j - tmp  * mask_shape[i]) * stride;
+      stride *= mask_shape[i];
+    }
+    j = tmp;
+  }
+
+  const LType* in_aligned = reinterpret_cast<const LType*>(in);
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
+  }
+  const LTypeMask* in_mask_aligned = reinterpret_cast<const 
LTypeMask*>(&in_mask[pos_mask]);
+  for (index_t i = my_id; i < row_length_mask; i += threads_per_row) {
+    mask_shared[my_local_row * row_length_mask + i] = (mask_shape[axis] > 1) ?
+                                                      in_mask_aligned[i] :
+                                                      in_mask_aligned[0];
+  }
+  DType* row = reinterpret_cast<DType*>(persistent_storage + my_local_row * 
row_length);
+  bool* row_mask = reinterpret_cast<bool*>(mask_shared + my_local_row * 
row_length_mask);
+  __syncthreads();
+
+  DType smax = 0.0;
+  if (normalize) {
+    DType my_max_value;
+    red::maximum::SetInitValue(my_max_value);
+    for (index_t i = my_id; i < M; i += threads_per_row) {
+      if (row_mask[i])
+        my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+    }
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x 
+ size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+                                   [](AType x, AType y) { return ::max(x, y); 
});
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - threadIdx.x % threads_per_row]  / scale;
+    __syncthreads();
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    if (row_mask[i]) {
+      const DType val = (negate ? -row[i] : row[i]) / scale;
+      my_sum += static_cast<AType>(expf((val - smax) / 
static_cast<AType>(temperature)));
+    }
+  }
+  scratch[threadIdx.x] = my_sum;
+  __syncthreads();
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] += scratch[threadIdx.x + size];
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return x + y;});
+    scratch[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+
+  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+  __syncthreads();
+
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    const DType val = (negate ? -row[i] : row[i]) / scale;
+    row[i] = row_mask[i] ? DType(OP::Map((val - 
smax)/static_cast<DType>(temperature), ssum)) :
+                           DType(0.0f);
+  }
+  __syncthreads();
+
+  LType* out_aligned = reinterpret_cast<LType*>(out);
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
+  }
+}
+
 template<typename OP, bool negate, typename AType, typename DType, typename 
OType,
          typename IType, int ndim>
 inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
@@ -436,6 +699,84 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, 
IType *length,
   }
 }
 
+template<typename OP, bool negate, typename AType, typename DType,
+         typename OType, int ndim>
+inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
+                          Shape<ndim> data_shape, Shape<ndim> mask_shape,
+                          int axis, const double scale, const double 
temperature,
+                          bool normalize, const OpContext& ctx) {
+  const int x_bits = 7;
+  const int x_size = 1 << x_bits;
+  index_t M = data_shape[axis];
+  if (M == 0 || data_shape.Size() == 0) return;
+  index_t N = data_shape.Size() / M;
+  Shape<ndim> stride = calc_stride(data_shape);
+  Shape<ndim> sshape = data_shape;
+  sshape[axis] = 1;
+
+  const size_t DSize = sizeof(DType);
+  // Using max of 20 kB of shared memory for InputData in the optimized case
+  const size_t max_opt_M = 20 * 1024 / DSize;
+  if (stride[axis] == 1 &&
+      static_cast<size_t>(M) <= max_opt_M &&
+      std::is_same<DType, OType>::value) {
+    int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
+    int ltype_mask = mxnet::common::cuda::get_load_type(mask_shape[axis] * 
sizeof(bool));
+    MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+      CHECK_LE(sizeof(DType), sizeof(LType));
+      MXNET_LOAD_TYPE_SWITCH(ltype_mask, LTypeMask, {
+        CHECK_LE(sizeof(bool), sizeof(LTypeMask));
+        int rows_per_block = mxnet::common::cuda::
+                               get_rows_per_block(M *
+                                                  sizeof(DType) / 
sizeof(LType),
+                                                  softmax_threads_per_block);
+        // calculate amount shared memory (slots aligned to double)
+        int entries_per_load = entries_per_load = sizeof(LType)/sizeof(DType);
+        int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+        size_t size_input_shared = entries_per_load > 0 ?
+                                 rows_per_block * M / entries_per_load : 0;
+        size_t size_mask_shared = entries_per_load_mask > 0 ?
+                                  rows_per_block * M / entries_per_load_mask : 
0;
+        size_input_shared = ((size_input_shared * sizeof(LType) + 
sizeof(double) - 1) /
+                             sizeof(double));
+        size_mask_shared = ((size_mask_shared * sizeof(LTypeMask) + 
sizeof(double) - 1) /
+                            sizeof(double));
+        size_t amount_shared = size_input_shared * sizeof(double) +
+                               size_mask_shared * sizeof(double) +
+                               softmax_threads_per_block * sizeof(AType);
+
+        int nblocks = (N + rows_per_block - 1) / rows_per_block;
+        if (normalize) {
+          masked_softmax_stride1_kernel<true, OP, negate, AType, LType, 
LTypeMask>
+            <<<nblocks, softmax_threads_per_block, amount_shared,
+               mshadow::Stream<gpu>::GetStream(s)>>>(
+              in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+              rows_per_block, N, size_input_shared, size_mask_shared);
+        } else {
+          masked_softmax_stride1_kernel<false, OP, negate, AType, LType, 
LTypeMask>
+            <<<nblocks, softmax_threads_per_block, amount_shared,
+               mshadow::Stream<gpu>::GetStream(s)>>>(
+              in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+              rows_per_block, N, size_input_shared, size_mask_shared);
+        }
+      });
+    });
+    MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_kernel);
+  } else {
+    size_t amount_shared = x_size * sizeof(AType);
+    if (normalize) {
+      masked_softmax_kernel<true, x_bits, OP, negate, AType, ndim>
+        <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
+          in, out, mask, M, axis, sshape, stride, mask_shape, scale, 
temperature);
+    } else {
+      masked_softmax_kernel<false, x_bits, OP, negate, AType, ndim>
+        <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
+          in, out, mask, M, axis, sshape, stride, mask_shape, scale, 
temperature);
+    }
+    MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_kernel);
+  }
+}
+
 template<typename OP1, typename OP2, int Req, bool negate, typename AType, 
typename LType,
   typename DType, typename OType, typename IType>
 __global__ void softmax_stride1_grad_kernel(const OType *out, const OType 
*ograd,
@@ -550,6 +891,154 @@ __global__ void softmax_grad_kernel(OType *out, OType 
*ograd, DType *igrad,
   }
 }
 
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, 
typename LType,
+  typename LTypeMask, typename DType, typename OType, int ndim>
+__global__ void masked_softmax_stride1_grad_kernel(const OType *out, const 
OType *ograd,
+                                                   DType *igrad, const bool 
*in_mask,
+                                                   const index_t M, int axis,
+                                                   Shape<ndim> sshape,
+                                                   Shape<ndim> mask_shape,
+                                                   const double scale,
+                                                   const double temperature,
+                                                   const int rows_per_block,
+                                                   const index_t total_rows,
+                                                   const size_t 
size_input_shared,
+                                                   const size_t 
size_mask_shared) {
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  const int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+  const int row_length_mask =  entries_per_load > 0 ? M / 
entries_per_load_mask : 0;
+  extern __shared__ double shared[];
+  LType* persistent_storage = reinterpret_cast<LType*>(shared);
+  // 2 * rows_per_block * M (DType), aligned to double
+  LTypeMask* mask_shared = 
reinterpret_cast<LTypeMask*>(&shared[size_input_shared]);
+  // rows_per_block * M (bool), aligned to double
+  AType* scratch = reinterpret_cast<AType*>(&shared[size_input_shared + 
size_mask_shared]);
+  // softmax_threads_per_block
+
+  const int warp_size = 32;
+  const int threads_per_row = softmax_threads_per_block / rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int my_row = blockIdx.x * rows_per_block + my_local_row;
+  if (my_row >= total_rows) return;
+  const int my_id = threadIdx.x % threads_per_row;
+  size_t base = my_row * row_length;
+  index_t pos_mask = 0;
+  index_t stride = mask_shape[axis];
+  #pragma unroll
+  for (index_t i = axis - 1, j = my_row; i >=0; --i) {
+    auto tmp = j / sshape[i];
+    if (mask_shape[i]  != 1) {
+      pos_mask += (j - tmp  * mask_shape[i]) * stride;
+      stride *= mask_shape[i];
+    }
+    j = tmp;
+  }
+
+  const LType* out_aligned = reinterpret_cast<const LType*>(out);
+  const LType* ograd_aligned = reinterpret_cast<const LType*>(ograd);
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base + 
i];
+    persistent_storage[my_local_row * row_length * 2 + row_length + i] = 
ograd_aligned[base + i];
+  }
+  const LTypeMask* in_mask_aligned = reinterpret_cast<const 
LTypeMask*>(&in_mask[pos_mask]);
+  for (index_t i = my_id; i < row_length_mask; i += threads_per_row) {
+    mask_shared[my_local_row * row_length_mask + i] = (mask_shape[axis] > 1) ?
+                                                      in_mask_aligned[i] :
+                                                      in_mask_aligned[0];
+  }
+  DType* row = reinterpret_cast<DType *>(persistent_storage + my_local_row * 
row_length * 2);
+  bool* row_mask = reinterpret_cast<bool*>(mask_shared + my_local_row * 
row_length_mask);
+  __syncthreads();
+
+  AType my_sum_value;
+  red::sum::SetInitValue(my_sum_value);
+
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    if (row_mask[i])
+      my_sum_value += OP1::Map(row[i + M], row[i]);
+  }
+  scratch[threadIdx.x] = my_sum_value;
+  __syncthreads();
+  for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+    if (my_id < size) {
+      scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x + 
size];
+    }
+    __syncthreads();
+  }
+  if (my_id < warp_size) {
+    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+                                 [](AType x, AType y) { return x + y; });
+    scratch[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+  __syncthreads();
+
+  AType temperature_scale = static_cast<AType>(temperature) *
+                            static_cast<AType>(scale);
+  for (index_t i = my_id; i < M; i += threads_per_row) {
+    const DType val =
+      negate ?
+      -OP2::Map(row[i + M], row[i], ssum):
+      OP2::Map(row[i + M], row[i], ssum);
+    row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature_scale)) :
+                           DType(0.0f);
+    if (Req == kAddTo) {
+      row[i] += igrad[my_row * M + i];
+    }
+  }
+  __syncthreads();
+
+  LType* igrad_aligned = reinterpret_cast<LType*>(igrad);
+
+  for (index_t i = my_id; i < row_length; i += threads_per_row) {
+    igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2 
+ i];
+  }
+}
+
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, 
typename AType, int ndim,
+         typename DType, typename OType>
+__global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType 
*igrad,
+                                           const bool *in_mask, index_t M, int 
axis,
+                                           Shape<ndim> sshape, Shape<ndim> 
stride,
+                                           Shape<ndim> mask_shape,
+                                           const double scale, const double 
temperature) {
+  const unsigned x_size = 1 << x_bits;
+  __shared__ AType smem[x_size];
+  index_t sa = stride[axis];
+  index_t base = unravel_dot(blockIdx.x, sshape, stride);
+  index_t sa_mask = 0;
+  index_t base_mask = get_mask_position(blockIdx.x, sshape, mask_shape, axis, 
&sa_mask);
+  bool bcst_mask_axis = (mask_shape[axis] == 1);
+  index_t x = threadIdx.x;
+
+  red::sum::SetInitValue(smem[x]);
+  for (index_t i = x; i < M; i += x_size) {
+    bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask 
+ i*sa_mask];
+    if (mask_value)
+      smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
+  }
+  __syncthreads();
+  cuda::Reduce1D<red::sum, x_bits>(smem);
+  __syncthreads();
+  AType ssum = smem[0];
+  __syncthreads();
+
+  DType final_result;
+  AType temperature_scale = static_cast<AType>(temperature) *
+                            static_cast<AType>(scale);
+  for (index_t i = x; i < M; i += x_size) {
+    bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask 
+ i*sa_mask];
+    final_result =
+      negate ?
+      -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum):
+      OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+    final_result = mask_value ? final_result / 
static_cast<DType>(temperature_scale) : DType(0.0f);
+    KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
+  }
+}
+
 template<typename OP1, typename OP2, int Req, bool negate, typename AType, int 
ndim,
          typename DType, typename OType, typename IType>
 inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
@@ -591,6 +1080,70 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType 
*ograd,
     MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel);
   }
 }
+
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int 
ndim,
+         typename DType, typename OType>
+inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
+                              DType *igrad, bool *mask, Shape<ndim> data_shape,
+                              Shape<ndim> mask_shape, int axis,
+                              const double scale, const double temperature,
+                              const OpContext& ctx) {
+  const int x_bits = 7;
+  const int x_size = 1 << x_bits;
+  index_t M = data_shape[axis];
+  if (M == 0 || data_shape.Size() == 0) return;
+  index_t N = data_shape.Size() / M;
+  Shape<ndim> stride = calc_stride(data_shape);
+  Shape<ndim> sshape = data_shape;
+  sshape[axis] = 1;
+
+  const size_t DSize = sizeof(DType);
+    // Using max of 20 kB of shared memory for InputData in the optimized case
+  const size_t max_opt_M = 20 * 1024 / DSize;
+  if (stride[axis] == 1 &&
+      static_cast<size_t>(M) <= max_opt_M &&
+      std::is_same<DType, OType>::value) {
+    int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
+    int ltype_mask = mxnet::common::cuda::get_load_type(mask_shape[axis] * 
sizeof(bool));
+    MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+      CHECK_LE(sizeof(DType), sizeof(LType));
+      MXNET_LOAD_TYPE_SWITCH(ltype_mask, LTypeMask, {
+        CHECK_LE(sizeof(bool), sizeof(LTypeMask));
+        int rows_per_block = mxnet::common::cuda::
+                               get_rows_per_block(M *
+                                                  sizeof(DType) / 
sizeof(LType),
+                                                  softmax_threads_per_block);
+        // calculate amount shared memory (slots aligned to double)
+        int entries_per_load = entries_per_load = sizeof(LType)/sizeof(DType);
+        int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+        size_t size_input_shared = entries_per_load > 0 ?
+                                 rows_per_block * M / entries_per_load : 0;
+        size_t size_mask_shared = entries_per_load_mask > 0 ?
+                                  rows_per_block * M / entries_per_load_mask : 
0;
+        size_input_shared = ((2 * size_input_shared * sizeof(LType) + 
sizeof(double) - 1) /
+                             sizeof(double));
+        size_mask_shared = ((size_mask_shared * sizeof(LTypeMask) + 
sizeof(double) - 1) /
+                            sizeof(double));
+        size_t amount_shared = size_input_shared * sizeof(double) +
+                               size_mask_shared * sizeof(double) +
+                               softmax_threads_per_block * sizeof(AType);
+
+        int nblocks = (N + rows_per_block - 1) / rows_per_block;
+        masked_softmax_stride1_grad_kernel<OP1, OP2, Req, negate, AType, 
LType, LTypeMask>
+          <<<nblocks, softmax_threads_per_block, amount_shared,
+             mshadow::Stream<gpu>::GetStream(s)>>>(
+            out, ograd, igrad, mask, M, axis, sshape, mask_shape,
+            scale, temperature, rows_per_block, N, size_input_shared, 
size_mask_shared);
+      });
+    });
+    MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_grad_kernel);
+  } else {
+    masked_softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
+      <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+        out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, scale, 
temperature);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_grad_kernel);
+  }
+}
 #endif
 
 }  // namespace mxnet_op
@@ -626,6 +1179,25 @@ struct SoftmaxParam : public 
dmlc::Parameter<SoftmaxParam> {
   }
 };
 
+struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
+  int axis;
+  dmlc::optional<double> scale_factor;
+  dmlc::optional<double> temperature;
+  dmlc::optional<int> dtype;
+  dmlc::optional<bool> normalize;
+  DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
+    DMLC_DECLARE_FIELD(axis).set_default(-1)
+    .describe("The axis along which to compute softmax.");
+    DMLC_DECLARE_FIELD(scale_factor).set_default(dmlc::optional<double>())
+    .describe("Scaling factor applied before softmax");
+    DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
+    .describe("Temperature parameter in softmax");
+    DMLC_DECLARE_FIELD(normalize)
+    .set_default(dmlc::optional<bool>(true))
+    .describe("Whether to normalize input data x: x = x - max(x)");
+  }
+};
+
 static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   return param.dtype.has_value() && param.dtype.value() != -1;
@@ -772,6 +1344,90 @@ struct SoftmaxFGradient {
   }
 };
 
+static inline bool MaskedSoftmaxOpType(const nnvm::NodeAttrs& attrs,
+                                       std::vector<int>* in_attrs,
+                                       std::vector<int>* out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1);
+  CHECK_EQ(in_attrs->size(), 2U);
+
+  std::vector<int> tmp = {in_attrs->at(0)};
+  return ElemwiseType<1, 1>(attrs, &tmp, out_attrs);
+}
+
+static inline bool MaskedSoftmaxOpShape(const nnvm::NodeAttrs& attrs,
+                                        mxnet::ShapeVector *in_shape,
+                                        mxnet::ShapeVector *out_shape) {
+  CHECK_EQ(out_shape->size(), 1U);
+  CHECK_EQ(in_shape->size(), 2U);
+
+  mxnet::TShape& data_shape = (*in_shape)[0];
+  mxnet::TShape& mask_shape = (*in_shape)[1];
+
+  if (!mxnet::ndim_is_known(data_shape) || !mxnet::ndim_is_known(mask_shape)) {
+    return false;
+  }
+  CHECK(data_shape.ndim() == mask_shape.ndim())
+      << "Number of dimensions in data and mask does not match";
+  CHECK(data_shape.ndim() > 0)
+      << "Empty tuple is not allowed";
+
+  for (int i = 0; i < data_shape.ndim(); ++i) {
+    CHECK(data_shape[i] == mask_shape[i] || mask_shape[i] == 1)
+      << "Mask cannot be broadcasted from " << mask_shape << " to " << 
data_shape;
+  }
+  SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(0));
+  SHAPE_ASSIGN_CHECK(*in_shape, 0, out_shape->at(0));
+  return true;
+}
+
+static inline bool MaskedSoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
+                                            mxnet::ShapeVector *in_shape,
+                                            mxnet::ShapeVector *out_shape) {
+  CHECK_EQ(out_shape->size(), 1U);
+  CHECK_EQ(in_shape->size(), 3U);
+
+  mxnet::TShape& ograd_shape = (*in_shape)[0];
+  mxnet::TShape& mask_shape = (*in_shape)[1];
+
+  if (!mxnet::ndim_is_known(ograd_shape) || !mxnet::ndim_is_known(mask_shape)) 
{
+    return false;
+  }
+  CHECK(ograd_shape.ndim() == mask_shape.ndim())
+      << "Number of dimensions in data and mask does not match";
+  CHECK(ograd_shape.ndim() > 0)
+      << "Empty tuple is not allowed";
+
+  for (int i = 0; i < ograd_shape.ndim(); ++i) {
+    CHECK(ograd_shape[i] == mask_shape[i] || mask_shape[i] == 1)
+      << "Mask cannot be broadcasted from " << mask_shape << " to " << 
ograd_shape;
+  }
+
+  SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(0));
+  SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(2));
+  SHAPE_ASSIGN_CHECK(*in_shape, 0, out_shape->at(0));
+  SHAPE_ASSIGN_CHECK(*in_shape, 2, out_shape->at(0));
+  return true;
+}
+
+static inline bool MaskedSoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
+                                           std::vector<int>* in_attrs,
+                                           std::vector<int>* out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1U);
+  CHECK_EQ(in_attrs->size(), 3U);
+  int data_dtype = (*in_attrs)[0];
+  TYPE_ASSIGN_CHECK(*in_attrs, 2, data_dtype);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, data_dtype);
+  data_dtype = (*out_attrs)[0];
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, data_dtype);
+
+  return true;
+}
+
+static inline std::vector<std::pair<int, int> >
+MaskedSoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
+  return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 1}, {3, 0}};
+}
+
 template<typename xpu, typename OP, bool negate = false>
 void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
                     const OpContext& ctx,
@@ -836,6 +1492,48 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu, typename OP, bool negate = false>
+void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  CHECK_NE(req[0], kAddTo);
+  const MaskedSoftmaxParam& param = 
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double scale = param.scale_factor.has_value() ?
+    param.scale_factor.value() : 1.0;
+  const double temperature = param.temperature.has_value() ?
+    param.temperature.value() : 1.0;
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+  if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
+    common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for 
masked_softmax with "
+                    "float16 inputs. "
+                    "See https://mxnet.apache.org/api/faq/env_var "
+                    "for more details.");
+  }
+  MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+    MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
+      bool* mask_ptr = inputs[1].dptr<bool>();
+      if (safe_acc) {
+        MaskedSoftmax<OP, negate, AType>(
+          ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+          outputs[0].dptr<DType>(), mask_ptr,
+          inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+          axis, scale, temperature, param.normalize.value(), ctx);
+      } else {
+        MaskedSoftmax<OP, negate, DType>(
+          ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+          outputs[0].dptr<DType>(), mask_ptr,
+          inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+          axis, scale, temperature, param.normalize.value(), ctx);
+      }
+    });
+  });
+}
+
 
 template<typename xpu, typename OP1, typename OP2, bool negate = false>
 void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
@@ -907,6 +1605,50 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu, typename OP1, typename OP2, bool negate = false>
+void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+
+  if (req[0] == kNullOp) return;
+  const MaskedSoftmaxParam& param = 
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double scale = param.scale_factor.has_value() ?
+    param.scale_factor.value() : 1.0;
+  const double temperature = param.temperature.has_value() ?
+    param.temperature.value() : 1.0;
+
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+  MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
+        DType* ograd_ptr = inputs[0].dptr<DType>();
+        DType* out_ptr = inputs[2].dptr<DType>();
+        bool* mask_ptr = inputs[1].dptr<bool>();
+        DType* grad_data = outputs[0].dptr<DType>();
+        if (safe_acc) {
+          MaskedSoftmaxGrad<OP1, OP2, Req, negate, AType>(
+              ctx.get_stream<xpu>(), out_ptr,
+              ograd_ptr, grad_data, mask_ptr,
+              inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+              axis, static_cast<DType>(scale),
+              static_cast<DType>(temperature), ctx);
+        } else {
+          MaskedSoftmaxGrad<OP1, OP2, Req, negate, DType>(
+              ctx.get_stream<xpu>(), out_ptr,
+              ograd_ptr, grad_data, mask_ptr,
+              inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+              axis, static_cast<DType>(scale),
+              static_cast<DType>(temperature), ctx);
+        }
+      });
+    });
+  });
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 9b28b71..cf67853 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -34,6 +34,7 @@
 namespace mxnet {
 namespace op {
 DMLC_REGISTER_PARAMETER(SoftmaxParam);
+DMLC_REGISTER_PARAMETER(MaskedSoftmaxParam);
 
 #if MXNET_USE_MKLDNN == 1
 static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -187,5 +188,65 @@ NNVM_REGISTER_OP(_backward_softmax)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, 
op::mshadow_op::mul,
                                                         
mxnet_op::softmax_bwd>);
+
+NNVM_REGISTER_OP(masked_softmax)
+.add_alias("_npx_masked_softmax")
+.describe(R"code(Applies the softmax function masking elements according to 
the mask provided)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::string>{"data", "mask"};
+  })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"output"};
+  })
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, 
mxnet_op::softmax_fwd>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+    auto data_grad = MakeNode("_backward_masked_softmax", n->attrs.name + 
"_backward_data",
+                              {ograds[0], n->inputs[1], nnvm::NodeEntry(n, 0, 
0)},
+                              &n->attrs.dict, &n);
+    auto mask_grad = MakeNode("zeros_like", n->attrs.name + "_backward_mask",
+                              {n->inputs[1]}, nullptr, &n);
+    std::vector<nnvm::NodeEntry> ret;
+    ret.emplace_back(data_grad);
+    ret.emplace_back(mask_grad);
+    return ret;
+  })
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxOpType)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxOpShape)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("mask", "NDArray-or-Symbol", "Mask to apply.")
+.add_arguments(MaskedSoftmaxParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_masked_softmax)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::string>{"ograd", "mask", "output"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxGradOpType)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", 
MaskedSoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxGradCompute<cpu, 
op::mshadow_op::mul,
+                                                              
mxnet_op::softmax_bwd>);
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index d5762cf..dc8fd99 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -34,6 +34,11 @@ NNVM_REGISTER_OP(softmax)
 NNVM_REGISTER_OP(_backward_softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
                                                         
mxnet_op::softmax_bwd>);
+NNVM_REGISTER_OP(masked_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, 
mxnet_op::softmax_fwd>);
 
+NNVM_REGISTER_OP(_backward_masked_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
+                                                              
mxnet_op::softmax_bwd>);
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 8a18d5d..31a6aa1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4941,6 +4941,94 @@ def test_softmax_with_length():
                                 [np.zeros(shape), np.zeros(len_shape, 
dtype=np.int32)],
                                 rtol=1e-2, atol=2e-3 if dtype == np.float16 
else 1e-3, dtype="asnumpy")
 
+def np_softmax(x, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
+    x = x / scale_factor
+    if normalize:
+        x = x - np.max(x, axis=axis, keepdims=True)
+    x = np.exp(x / temperature)
+    x /= np.sum(x, axis=axis, keepdims=True)
+    return x
+
+def np_masked_softmax(data, mask, axis=-1, scale_factor=1.0, temperature=1.0, 
normalize=True):
+    neg = -1e18
+    if data.dtype == np.float16:
+        neg = -1e4
+    temp = np.where(mask, data, neg)
+    result = np_softmax(temp, axis=axis,
+                        scale_factor=scale_factor,
+                        temperature=temperature,
+                        normalize=normalize) * mask
+    return result
+def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0, 
temperature=1.0):
+    temp = np.sum(out * grad_out, axis=axis, keepdims=True)
+    result = out * (grad_out - temp) / (temperature * scale_factor)
+    return result
+
[email protected]('dtype', [np.float16, np.float32, np.float64])
[email protected]('axis', [0, -1, -2, -3])
[email protected]('ndims', [3, 4, 5])
[email protected]('n_broadcast_axis', [0, 1, 2])
[email protected]('temperature', [1, 5, 9 ,11])
[email protected]('scale', [1, 2, 7, 12])
[email protected]('normalize', [True])
+def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, 
scale, normalize):
+    n_broadcast_axis = min(n_broadcast_axis, ndims - 1)
+    shape = rand_shape_nd(ndims, dim=10)
+    mx_data = rand_ndarray(shape, dtype=dtype)
+    bcst_dims = []
+    while len(bcst_dims) < n_broadcast_axis:
+            ax = np.random.randint(0, ndims)
+            if ax not in bcst_dims :
+                bcst_dims.append(ax)
+    shape_mask = list(shape)
+    for i in bcst_dims:
+        shape_mask[i] = 1
+
+    np_data = mx_data.asnumpy()
+    np_mask = np.random.randint(0, 2, shape_mask)
+    mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+    mx_grad = rand_ndarray(shape, dtype=dtype)
+    np_grad = mx_grad.asnumpy()
+
+    np_out = np_masked_softmax(np_data, np_mask, axis,
+                                scale, temperature, normalize)
+    np_grad_out = np_masked_softmax_grad(np_out, np_grad,
+                                         axis, scale, temperature)
+    data = mx.sym.Variable("data")
+    mask = mx.sym.Variable("mask")
+    mx_sym = mx.sym.masked_softmax(data=data, mask=mask, scale_factor=scale,
+                                   temperature=temperature, axis=axis,
+                                   normalize=normalize)
+    location = {"data": mx_data, "mask": mx_mask}
+    rtol = 1e-2 if dtype == np.float16 else 1e-3
+    atol = 1e-4 if dtype == np.float16 else 1e-5
+    check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol,
+                           dtype="asnumpy", equal_nan=True)
+    check_symbolic_backward(mx_sym, location, [mx_grad],
+                            [np_grad_out, np.zeros(shape, dtype=np.bool)],
+                            rtol=1e-2, atol=2e-3 if dtype == np.float16 else 
1e-3,
+                            dtype="asnumpy", equal_nan=True)
+
+
[email protected]('dtype', ['float32'])
[email protected]('ndims', [1, 2, 3, 4, 5])
+def test_masked_log_softmax(dtype, ndims):
+    shape = np.random.randint(1, 5, size=ndims)
+    axis = np.random.randint(0, ndims)
+    mx_data = rand_ndarray(shape, dtype=dtype)
+    np_data = mx_data.asnumpy()
+    np_mask = np.random.randint(0, 2, shape)
+    mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+    np_out = np.log(np_masked_softmax(np_data, np_mask, axis)+1e-20) * np_mask
+    data = mx.sym.Variable("data")
+    mask = mx.sym.Variable("mask")
+    mx_sym = mx.sym.masked_log_softmax(data=data, mask=mask, axis=axis-ndims)
+    location = {"data": mx_data, "mask": mx_mask}
+    rtol = 1e-2 if dtype == np.float16 else 1e-3
+    atol = 1e-4 if dtype == np.float16 else 1e-5
+    check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, 
dtype="asnumpy")
+    check_numeric_gradient(mx_sym, location, rtol=1e-1, atol=1e-2)
+
 
 def test_pick():
     def test_pick_helper(index_type=np.int32):

Reply via email to