This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 87d69fa Masked softmax (& log-softmax) (#19460)
87d69fa is described below
commit 87d69faa8fd98a7565e2e3bb16aa4225ad614e4b
Author: Moises Hernandez <[email protected]>
AuthorDate: Thu Dec 3 20:12:25 2020 -0800
Masked softmax (& log-softmax) (#19460)
* Add masked-softmax operator. Includes scaling factor and optional
normalization
* Backwards and Log-Softmax
* fix lint
* Including mask broadcasting
* change name masked_log_softmax
* Include CPU implementation
* Mask as boolean dtype, and vectorize loading
* Custom BWD node removing mask gradient
* Remove OType (use DType) and fix type inferring
* fix lint
* update AMP list
* compare all grads in masked_log_softmax
* remove unused variable
---
python/mxnet/amp/lists/symbol_bf16.py | 2 +
python/mxnet/amp/lists/symbol_fp16.py | 2 +
src/operator/nn/log_softmax.cc | 57 +++
src/operator/nn/log_softmax.cu | 7 +
src/operator/nn/softmax-inl.h | 742 +++++++++++++++++++++++++++++++++
src/operator/nn/softmax.cc | 61 +++
src/operator/nn/softmax.cu | 5 +
tests/python/unittest/test_operator.py | 88 ++++
8 files changed, 964 insertions(+)
diff --git a/python/mxnet/amp/lists/symbol_bf16.py
b/python/mxnet/amp/lists/symbol_bf16.py
index 306635a..e7f14fa 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -478,6 +478,8 @@ FP32_FUNCS = [
'softmax',
'Softmax',
'log_softmax',
+ 'masked_softmax',
+ 'masked_log_softmax',
'InstanceNorm',
'LayerNorm',
'GroupNorm',
diff --git a/python/mxnet/amp/lists/symbol_fp16.py
b/python/mxnet/amp/lists/symbol_fp16.py
index 4b54448..7242a70 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -576,6 +576,8 @@ FP32_FUNCS = [
# Neural network
'softmax',
'log_softmax',
+ 'masked_softmax',
+ 'masked_log_softmax',
'InstanceNorm',
'LayerNorm',
'GroupNorm',
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 28ae8cf..6aae7e9 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -159,5 +159,62 @@ NNVM_REGISTER_OP(_backward_log_softmax)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);
+NNVM_REGISTER_OP(masked_log_softmax)
+.add_alias("_npx_masked_log_softmax")
+.describe(R"code(Computes the masked log softmax of the input.
+This is equivalent to computing masked softmax followed by log.)code")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::string>{"data", "mask"};
+ })
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu,
mxnet_op::log_softmax_fwd>)
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ auto data_grad = MakeNode("_backward_masked_log_softmax", n->attrs.name +
"_backward_data",
+ {ograds[0], n->inputs[1], nnvm::NodeEntry(n, 0,
0)},
+ &n->attrs.dict, &n);
+ auto mask_grad = MakeNode("zeros_like", n->attrs.name + "_backward_mask",
+ {n->inputs[1]}, nullptr, &n);
+ std::vector<nnvm::NodeEntry> ret;
+ ret.emplace_back(data_grad);
+ ret.emplace_back(mask_grad);
+ return ret;
+ })
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxOpType)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxOpShape)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("mask", "NDArray-or-Symbol", "Mask to apply.")
+.add_arguments(MaskedSoftmaxParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_masked_log_softmax)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::string>{"ograd", "mask", "output"};
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxGradOpType)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
MaskedSoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxGradCompute<cpu,
mshadow_op::left,
+
mxnet_op::log_softmax_bwd>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 8bff277..2a54cd3 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -35,5 +35,12 @@ NNVM_REGISTER_OP(_backward_log_softmax)
.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);
+NNVM_REGISTER_OP(masked_log_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu,
mxnet_op::log_softmax_fwd>);
+
+NNVM_REGISTER_OP(_backward_masked_log_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu,
mshadow_op::left,
+
mxnet_op::log_softmax_bwd>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 7806eaf..b53b8a4 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -36,6 +36,8 @@
#include "../tensor/broadcast_reduce_op.h"
#include "../../common/cuda/utils.h"
+using mshadow::red::limits::MinValue;
+
namespace mxnet {
namespace op {
namespace mxnet_op {
@@ -161,6 +163,49 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
IType *length,
}
}
+struct masked_softmax_where_scale {
+ template<typename DType, int ndim>
+ MSHADOW_XINLINE static void Map(index_t id, DType* out, const bool* cond,
+ const DType* x, const double y,
+ Shape<ndim> data_shape, Shape<ndim>
mask_shape,
+ const double scale) {
+ index_t mask_pos = 0;
+ index_t stride = 1;
+ for (index_t i = ndim-1, j = id; i >=0; --i) {
+ auto tmp = j / data_shape[i];
+ if (mask_shape[i] != 1) {
+ mask_pos += (j - tmp * mask_shape[i]) * stride;
+ }
+ stride *= mask_shape[i];
+ j = tmp;
+ }
+ KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] /
static_cast<DType>(scale) :
+ static_cast<DType>(y)));
+ }
+};
+
+template<typename OP, bool negate, typename AType, typename DType, int ndim>
+inline void MaskedSoftmax(Stream<cpu> *s, DType *in, DType *out, bool *mask,
+ Shape<ndim> data_shape, Shape<ndim> mask_shape,
+ int axis, const double scale,
+ const double temperature, bool normalize,
+ const OpContext& ctx) {
+ Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1,
DType>(
+ Shape1(data_shape.Size()), s);
+ DType* masked_scaled_input = TBlob(workspace).dptr<DType>();
+
+ double neg = MinValue<DType>();
+ Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(),
masked_scaled_input,
+ mask, in, neg, data_shape,
mask_shape,
+ scale);
+ int* max_lenghts = nullptr;
+ Softmax<OP, negate, AType, DType>(s, masked_scaled_input, out, max_lenghts,
+ data_shape, axis, temperature);
+ Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), out,
+ mask, out, 0.0, data_shape,
mask_shape,
+ 1.0);
+}
+
struct softmax_bwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
@@ -258,6 +303,28 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType
*ograd,
}
}
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int
ndim,
+ typename DType>
+inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
+ DType *igrad, bool *mask, Shape<ndim> data_shape,
+ Shape<ndim> mask_shape, int axis,
+ const double scale, const double temperature,
+ const OpContext& ctx) {
+ Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1,
DType>(
+ Shape1(data_shape.Size()), s);
+ DType* masked_ograd = TBlob(workspace).dptr<DType>();
+ Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(),
masked_ograd,
+ mask, ograd, 0.0,
data_shape, mask_shape,
+ 1.0);
+ int* max_lenghts = nullptr;
+ SoftmaxGrad<OP1, OP2, Req, negate, AType, DType, DType, int, ndim>(
+ s, out, masked_ograd, igrad,
+ max_lenghts, data_shape,
+ axis, temperature);
+ Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), igrad,
+ mask, igrad, 0.0,
data_shape, mask_shape,
+ scale);
+}
#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename AType, int ndim,
@@ -397,6 +464,202 @@ __global__ void softmax_stride1_compute_kernel(const
DType *in, OType *out, ITyp
}
}
+template<int ndim>
+MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const
Shape<ndim>& data_shape,
+ const Shape<ndim>& mask_shape, int axis, index_t* stride_axis) {
+ index_t ret = 0;
+ index_t stride = 1;
+ *stride_axis = 1;
+ #pragma unroll
+ for (index_t i = ndim-1, j = idx; i >=0; --i) {
+ auto tmp = j / data_shape[i];
+ if (i != axis && mask_shape[i] != 1) {
+ ret += (j - tmp * mask_shape[i]) * stride;
+ if (i > axis)
+ *stride_axis *= mask_shape[i];
+ }
+ stride *= mask_shape[i];
+ j = tmp;
+ }
+ return ret;
+}
+
+template<bool normalize, int x_bits, typename OP, bool negate, typename AType,
+ int ndim, typename DType>
+__global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
+ index_t M, int axis, Shape<ndim> sshape,
+ Shape<ndim> stride, Shape<ndim>
mask_shape,
+ const double scale, const double
temperature) {
+ extern __shared__ double shared[];
+ AType* smem = reinterpret_cast<AType*>(shared); // x_size
+
+ const unsigned x_size = 1 << x_bits;
+ index_t sa = stride[axis];
+ index_t base = unravel_dot(blockIdx.x, sshape, stride);
+ index_t sa_mask = 0;
+ index_t base_mask = get_mask_position(blockIdx.x, sshape, mask_shape, axis,
&sa_mask);
+ bool bcst_mask_axis = (mask_shape[axis] == 1);
+ index_t x = threadIdx.x;
+
+ DType smax = 0.0;
+ if (normalize) {
+ red::maximum::SetInitValue(smem[x]);
+ for (index_t i = x; i < M; i += x_size) {
+ bool mask_value = bcst_mask_axis ? in_mask[base_mask] :
in_mask[base_mask + i*sa_mask];
+ if (mask_value)
+ smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::maximum, x_bits>(smem);
+ __syncthreads();
+ smax = smem[0] / scale;
+ __syncthreads();
+ }
+
+ red::sum::SetInitValue(smem[x]);
+ DType val;
+ for (index_t i = x; i < M; i += x_size) {
+ bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask
+ i*sa_mask];
+ if (mask_value) {
+ val = (negate ? -in[base + i*sa]:in[base + i*sa]) / scale;
+ smem[x] += static_cast<AType>(expf((val - smax) /
static_cast<AType>(temperature)));
+ }
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::sum, x_bits>(smem);
+ __syncthreads();
+ AType ssum = smem[0];
+ __syncthreads();
+
+ for (index_t i = x; i < M; i += x_size) {
+ val = (negate ? -in[base + i*sa] : in[base + i*sa]) / scale;
+ bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask
+ i*sa_mask];
+ out[base + i*sa] =
+ mask_value ? DType(OP::Map((val - smax)/static_cast<DType>(temperature),
ssum)) :
+ DType(0.0f);
+ }
+}
+
+template<bool normalize, typename OP, bool negate, typename AType, typename
LType,
+ typename LTypeMask, typename DType, int ndim>
+__global__ void masked_softmax_stride1_kernel(const DType *in, DType *out,
bool *in_mask,
+ const index_t M, int axis,
Shape<ndim> sshape,
+ Shape<ndim> mask_shape, const
double scale,
+ const double temperature, const
int rows_per_block,
+ const index_t total_rows,
+ const size_t size_input_shared,
+ const size_t size_mask_shared) {
+ const int entries_per_load = sizeof(LType)/sizeof(DType);
+ const int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+ const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+ const int row_length_mask = entries_per_load > 0 ? M /
entries_per_load_mask : 0;
+ extern __shared__ double shared[];
+ LType* persistent_storage = reinterpret_cast<LType*>(shared);
+ // rows_per_block * M (DType), aligned to double
+ LTypeMask* mask_shared =
reinterpret_cast<LTypeMask*>(&shared[size_input_shared]);
+ // rows_per_block * M (bool), aligned to double
+ AType* scratch = reinterpret_cast<AType*>(&shared[size_input_shared +
size_mask_shared]);
+ // softmax_threads_per_block
+
+ const int warp_size = 32;
+ const int threads_per_row = softmax_threads_per_block / rows_per_block;
+ const int my_local_row = threadIdx.x / threads_per_row;
+ const int my_row = blockIdx.x * rows_per_block + my_local_row;
+ if (my_row >= total_rows) return;
+ const int my_id = threadIdx.x % threads_per_row;
+ size_t base = my_row * row_length;
+ index_t pos_mask = 0;
+ index_t stride = mask_shape[axis];
+ #pragma unroll
+ for (index_t i = axis-1, j = my_row; i >=0; --i) {
+ auto tmp = j / sshape[i];
+ if (mask_shape[i] != 1) {
+ pos_mask += (j - tmp * mask_shape[i]) * stride;
+ stride *= mask_shape[i];
+ }
+ j = tmp;
+ }
+
+ const LType* in_aligned = reinterpret_cast<const LType*>(in);
+ for (index_t i = my_id; i < row_length; i += threads_per_row) {
+ persistent_storage[my_local_row * row_length + i] = in_aligned[base + i];
+ }
+ const LTypeMask* in_mask_aligned = reinterpret_cast<const
LTypeMask*>(&in_mask[pos_mask]);
+ for (index_t i = my_id; i < row_length_mask; i += threads_per_row) {
+ mask_shared[my_local_row * row_length_mask + i] = (mask_shape[axis] > 1) ?
+ in_mask_aligned[i] :
+ in_mask_aligned[0];
+ }
+ DType* row = reinterpret_cast<DType*>(persistent_storage + my_local_row *
row_length);
+ bool* row_mask = reinterpret_cast<bool*>(mask_shared + my_local_row *
row_length_mask);
+ __syncthreads();
+
+ DType smax = 0.0;
+ if (normalize) {
+ DType my_max_value;
+ red::maximum::SetInitValue(my_max_value);
+ for (index_t i = my_id; i < M; i += threads_per_row) {
+ if (row_mask[i])
+ my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+ }
+ scratch[threadIdx.x] = my_max_value;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x
+ size]);
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+ [](AType x, AType y) { return ::max(x, y);
});
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ smax = scratch[threadIdx.x - threadIdx.x % threads_per_row] / scale;
+ __syncthreads();
+ }
+
+ AType my_sum;
+ red::sum::SetInitValue(my_sum);
+ for (index_t i = my_id; i < M; i += threads_per_row) {
+ if (row_mask[i]) {
+ const DType val = (negate ? -row[i] : row[i]) / scale;
+ my_sum += static_cast<AType>(expf((val - smax) /
static_cast<AType>(temperature)));
+ }
+ }
+ scratch[threadIdx.x] = my_sum;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] += scratch[threadIdx.x + size];
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+ [](AType x, AType y) { return x + y;});
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+
+ AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+ __syncthreads();
+
+ for (index_t i = my_id; i < M; i += threads_per_row) {
+ const DType val = (negate ? -row[i] : row[i]) / scale;
+ row[i] = row_mask[i] ? DType(OP::Map((val -
smax)/static_cast<DType>(temperature), ssum)) :
+ DType(0.0f);
+ }
+ __syncthreads();
+
+ LType* out_aligned = reinterpret_cast<LType*>(out);
+
+ for (index_t i = my_id; i < row_length; i += threads_per_row) {
+ out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
+ }
+}
+
template<typename OP, bool negate, typename AType, typename DType, typename
OType,
typename IType, int ndim>
inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
@@ -436,6 +699,84 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
IType *length,
}
}
+template<typename OP, bool negate, typename AType, typename DType,
+ typename OType, int ndim>
+inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
+ Shape<ndim> data_shape, Shape<ndim> mask_shape,
+ int axis, const double scale, const double
temperature,
+ bool normalize, const OpContext& ctx) {
+ const int x_bits = 7;
+ const int x_size = 1 << x_bits;
+ index_t M = data_shape[axis];
+ if (M == 0 || data_shape.Size() == 0) return;
+ index_t N = data_shape.Size() / M;
+ Shape<ndim> stride = calc_stride(data_shape);
+ Shape<ndim> sshape = data_shape;
+ sshape[axis] = 1;
+
+ const size_t DSize = sizeof(DType);
+ // Using max of 20 kB of shared memory for InputData in the optimized case
+ const size_t max_opt_M = 20 * 1024 / DSize;
+ if (stride[axis] == 1 &&
+ static_cast<size_t>(M) <= max_opt_M &&
+ std::is_same<DType, OType>::value) {
+ int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
+ int ltype_mask = mxnet::common::cuda::get_load_type(mask_shape[axis] *
sizeof(bool));
+ MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+ CHECK_LE(sizeof(DType), sizeof(LType));
+ MXNET_LOAD_TYPE_SWITCH(ltype_mask, LTypeMask, {
+ CHECK_LE(sizeof(bool), sizeof(LTypeMask));
+ int rows_per_block = mxnet::common::cuda::
+ get_rows_per_block(M *
+ sizeof(DType) /
sizeof(LType),
+ softmax_threads_per_block);
+ // calculate amount shared memory (slots aligned to double)
+ int entries_per_load = entries_per_load = sizeof(LType)/sizeof(DType);
+ int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+ size_t size_input_shared = entries_per_load > 0 ?
+ rows_per_block * M / entries_per_load : 0;
+ size_t size_mask_shared = entries_per_load_mask > 0 ?
+ rows_per_block * M / entries_per_load_mask :
0;
+ size_input_shared = ((size_input_shared * sizeof(LType) +
sizeof(double) - 1) /
+ sizeof(double));
+ size_mask_shared = ((size_mask_shared * sizeof(LTypeMask) +
sizeof(double) - 1) /
+ sizeof(double));
+ size_t amount_shared = size_input_shared * sizeof(double) +
+ size_mask_shared * sizeof(double) +
+ softmax_threads_per_block * sizeof(AType);
+
+ int nblocks = (N + rows_per_block - 1) / rows_per_block;
+ if (normalize) {
+ masked_softmax_stride1_kernel<true, OP, negate, AType, LType,
LTypeMask>
+ <<<nblocks, softmax_threads_per_block, amount_shared,
+ mshadow::Stream<gpu>::GetStream(s)>>>(
+ in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+ rows_per_block, N, size_input_shared, size_mask_shared);
+ } else {
+ masked_softmax_stride1_kernel<false, OP, negate, AType, LType,
LTypeMask>
+ <<<nblocks, softmax_threads_per_block, amount_shared,
+ mshadow::Stream<gpu>::GetStream(s)>>>(
+ in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+ rows_per_block, N, size_input_shared, size_mask_shared);
+ }
+ });
+ });
+ MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_kernel);
+ } else {
+ size_t amount_shared = x_size * sizeof(AType);
+ if (normalize) {
+ masked_softmax_kernel<true, x_bits, OP, negate, AType, ndim>
+ <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
+ in, out, mask, M, axis, sshape, stride, mask_shape, scale,
temperature);
+ } else {
+ masked_softmax_kernel<false, x_bits, OP, negate, AType, ndim>
+ <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
+ in, out, mask, M, axis, sshape, stride, mask_shape, scale,
temperature);
+ }
+ MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_kernel);
+ }
+}
+
template<typename OP1, typename OP2, int Req, bool negate, typename AType,
typename LType,
typename DType, typename OType, typename IType>
__global__ void softmax_stride1_grad_kernel(const OType *out, const OType
*ograd,
@@ -550,6 +891,154 @@ __global__ void softmax_grad_kernel(OType *out, OType
*ograd, DType *igrad,
}
}
+template<typename OP1, typename OP2, int Req, bool negate, typename AType,
typename LType,
+ typename LTypeMask, typename DType, typename OType, int ndim>
+__global__ void masked_softmax_stride1_grad_kernel(const OType *out, const
OType *ograd,
+ DType *igrad, const bool
*in_mask,
+ const index_t M, int axis,
+ Shape<ndim> sshape,
+ Shape<ndim> mask_shape,
+ const double scale,
+ const double temperature,
+ const int rows_per_block,
+ const index_t total_rows,
+ const size_t
size_input_shared,
+ const size_t
size_mask_shared) {
+ const int entries_per_load = sizeof(LType)/sizeof(DType);
+ const int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+ const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+ const int row_length_mask = entries_per_load > 0 ? M /
entries_per_load_mask : 0;
+ extern __shared__ double shared[];
+ LType* persistent_storage = reinterpret_cast<LType*>(shared);
+ // 2 * rows_per_block * M (DType), aligned to double
+ LTypeMask* mask_shared =
reinterpret_cast<LTypeMask*>(&shared[size_input_shared]);
+ // rows_per_block * M (bool), aligned to double
+ AType* scratch = reinterpret_cast<AType*>(&shared[size_input_shared +
size_mask_shared]);
+ // softmax_threads_per_block
+
+ const int warp_size = 32;
+ const int threads_per_row = softmax_threads_per_block / rows_per_block;
+ const int my_local_row = threadIdx.x / threads_per_row;
+ const int my_row = blockIdx.x * rows_per_block + my_local_row;
+ if (my_row >= total_rows) return;
+ const int my_id = threadIdx.x % threads_per_row;
+ size_t base = my_row * row_length;
+ index_t pos_mask = 0;
+ index_t stride = mask_shape[axis];
+ #pragma unroll
+ for (index_t i = axis - 1, j = my_row; i >=0; --i) {
+ auto tmp = j / sshape[i];
+ if (mask_shape[i] != 1) {
+ pos_mask += (j - tmp * mask_shape[i]) * stride;
+ stride *= mask_shape[i];
+ }
+ j = tmp;
+ }
+
+ const LType* out_aligned = reinterpret_cast<const LType*>(out);
+ const LType* ograd_aligned = reinterpret_cast<const LType*>(ograd);
+ for (index_t i = my_id; i < row_length; i += threads_per_row) {
+ persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base +
i];
+ persistent_storage[my_local_row * row_length * 2 + row_length + i] =
ograd_aligned[base + i];
+ }
+ const LTypeMask* in_mask_aligned = reinterpret_cast<const
LTypeMask*>(&in_mask[pos_mask]);
+ for (index_t i = my_id; i < row_length_mask; i += threads_per_row) {
+ mask_shared[my_local_row * row_length_mask + i] = (mask_shape[axis] > 1) ?
+ in_mask_aligned[i] :
+ in_mask_aligned[0];
+ }
+ DType* row = reinterpret_cast<DType *>(persistent_storage + my_local_row *
row_length * 2);
+ bool* row_mask = reinterpret_cast<bool*>(mask_shared + my_local_row *
row_length_mask);
+ __syncthreads();
+
+ AType my_sum_value;
+ red::sum::SetInitValue(my_sum_value);
+
+ for (index_t i = my_id; i < M; i += threads_per_row) {
+ if (row_mask[i])
+ my_sum_value += OP1::Map(row[i + M], row[i]);
+ }
+ scratch[threadIdx.x] = my_sum_value;
+ __syncthreads();
+ for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+ if (my_id < size) {
+ scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x +
size];
+ }
+ __syncthreads();
+ }
+ if (my_id < warp_size) {
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
+ [](AType x, AType y) { return x + y; });
+ scratch[threadIdx.x] = my_value;
+ }
+ __syncthreads();
+ AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
+ __syncthreads();
+
+ AType temperature_scale = static_cast<AType>(temperature) *
+ static_cast<AType>(scale);
+ for (index_t i = my_id; i < M; i += threads_per_row) {
+ const DType val =
+ negate ?
+ -OP2::Map(row[i + M], row[i], ssum):
+ OP2::Map(row[i + M], row[i], ssum);
+ row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature_scale)) :
+ DType(0.0f);
+ if (Req == kAddTo) {
+ row[i] += igrad[my_row * M + i];
+ }
+ }
+ __syncthreads();
+
+ LType* igrad_aligned = reinterpret_cast<LType*>(igrad);
+
+ for (index_t i = my_id; i < row_length; i += threads_per_row) {
+ igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2
+ i];
+ }
+}
+
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate,
typename AType, int ndim,
+ typename DType, typename OType>
+__global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType
*igrad,
+ const bool *in_mask, index_t M, int
axis,
+ Shape<ndim> sshape, Shape<ndim>
stride,
+ Shape<ndim> mask_shape,
+ const double scale, const double
temperature) {
+ const unsigned x_size = 1 << x_bits;
+ __shared__ AType smem[x_size];
+ index_t sa = stride[axis];
+ index_t base = unravel_dot(blockIdx.x, sshape, stride);
+ index_t sa_mask = 0;
+ index_t base_mask = get_mask_position(blockIdx.x, sshape, mask_shape, axis,
&sa_mask);
+ bool bcst_mask_axis = (mask_shape[axis] == 1);
+ index_t x = threadIdx.x;
+
+ red::sum::SetInitValue(smem[x]);
+ for (index_t i = x; i < M; i += x_size) {
+ bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask
+ i*sa_mask];
+ if (mask_value)
+ smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::sum, x_bits>(smem);
+ __syncthreads();
+ AType ssum = smem[0];
+ __syncthreads();
+
+ DType final_result;
+ AType temperature_scale = static_cast<AType>(temperature) *
+ static_cast<AType>(scale);
+ for (index_t i = x; i < M; i += x_size) {
+ bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask
+ i*sa_mask];
+ final_result =
+ negate ?
+ -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum):
+ OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+ final_result = mask_value ? final_result /
static_cast<DType>(temperature_scale) : DType(0.0f);
+ KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
+ }
+}
+
template<typename OP1, typename OP2, int Req, bool negate, typename AType, int
ndim,
typename DType, typename OType, typename IType>
inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
@@ -591,6 +1080,70 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType
*ograd,
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel);
}
}
+
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int
ndim,
+ typename DType, typename OType>
+inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
+ DType *igrad, bool *mask, Shape<ndim> data_shape,
+ Shape<ndim> mask_shape, int axis,
+ const double scale, const double temperature,
+ const OpContext& ctx) {
+ const int x_bits = 7;
+ const int x_size = 1 << x_bits;
+ index_t M = data_shape[axis];
+ if (M == 0 || data_shape.Size() == 0) return;
+ index_t N = data_shape.Size() / M;
+ Shape<ndim> stride = calc_stride(data_shape);
+ Shape<ndim> sshape = data_shape;
+ sshape[axis] = 1;
+
+ const size_t DSize = sizeof(DType);
+ // Using max of 20 kB of shared memory for InputData in the optimized case
+ const size_t max_opt_M = 20 * 1024 / DSize;
+ if (stride[axis] == 1 &&
+ static_cast<size_t>(M) <= max_opt_M &&
+ std::is_same<DType, OType>::value) {
+ int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
+ int ltype_mask = mxnet::common::cuda::get_load_type(mask_shape[axis] *
sizeof(bool));
+ MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+ CHECK_LE(sizeof(DType), sizeof(LType));
+ MXNET_LOAD_TYPE_SWITCH(ltype_mask, LTypeMask, {
+ CHECK_LE(sizeof(bool), sizeof(LTypeMask));
+ int rows_per_block = mxnet::common::cuda::
+ get_rows_per_block(M *
+ sizeof(DType) /
sizeof(LType),
+ softmax_threads_per_block);
+ // calculate amount shared memory (slots aligned to double)
+ int entries_per_load = entries_per_load = sizeof(LType)/sizeof(DType);
+ int entries_per_load_mask = sizeof(LTypeMask)/sizeof(bool);
+ size_t size_input_shared = entries_per_load > 0 ?
+ rows_per_block * M / entries_per_load : 0;
+ size_t size_mask_shared = entries_per_load_mask > 0 ?
+ rows_per_block * M / entries_per_load_mask :
0;
+ size_input_shared = ((2 * size_input_shared * sizeof(LType) +
sizeof(double) - 1) /
+ sizeof(double));
+ size_mask_shared = ((size_mask_shared * sizeof(LTypeMask) +
sizeof(double) - 1) /
+ sizeof(double));
+ size_t amount_shared = size_input_shared * sizeof(double) +
+ size_mask_shared * sizeof(double) +
+ softmax_threads_per_block * sizeof(AType);
+
+ int nblocks = (N + rows_per_block - 1) / rows_per_block;
+ masked_softmax_stride1_grad_kernel<OP1, OP2, Req, negate, AType,
LType, LTypeMask>
+ <<<nblocks, softmax_threads_per_block, amount_shared,
+ mshadow::Stream<gpu>::GetStream(s)>>>(
+ out, ograd, igrad, mask, M, axis, sshape, mask_shape,
+ scale, temperature, rows_per_block, N, size_input_shared,
size_mask_shared);
+ });
+ });
+ MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_grad_kernel);
+ } else {
+ masked_softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
+ <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, scale,
temperature);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_grad_kernel);
+ }
+}
#endif
} // namespace mxnet_op
@@ -626,6 +1179,25 @@ struct SoftmaxParam : public
dmlc::Parameter<SoftmaxParam> {
}
};
+struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
+ int axis;
+ dmlc::optional<double> scale_factor;
+ dmlc::optional<double> temperature;
+ dmlc::optional<int> dtype;
+ dmlc::optional<bool> normalize;
+ DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
+ DMLC_DECLARE_FIELD(axis).set_default(-1)
+ .describe("The axis along which to compute softmax.");
+ DMLC_DECLARE_FIELD(scale_factor).set_default(dmlc::optional<double>())
+ .describe("Scaling factor applied before softmax");
+ DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
+ .describe("Temperature parameter in softmax");
+ DMLC_DECLARE_FIELD(normalize)
+ .set_default(dmlc::optional<bool>(true))
+ .describe("Whether to normalize input data x: x = x - max(x)");
+ }
+};
+
static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
return param.dtype.has_value() && param.dtype.value() != -1;
@@ -772,6 +1344,90 @@ struct SoftmaxFGradient {
}
};
+static inline bool MaskedSoftmaxOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(out_attrs->size(), 1);
+ CHECK_EQ(in_attrs->size(), 2U);
+
+ std::vector<int> tmp = {in_attrs->at(0)};
+ return ElemwiseType<1, 1>(attrs, &tmp, out_attrs);
+}
+
+static inline bool MaskedSoftmaxOpShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_shape,
+ mxnet::ShapeVector *out_shape) {
+ CHECK_EQ(out_shape->size(), 1U);
+ CHECK_EQ(in_shape->size(), 2U);
+
+ mxnet::TShape& data_shape = (*in_shape)[0];
+ mxnet::TShape& mask_shape = (*in_shape)[1];
+
+ if (!mxnet::ndim_is_known(data_shape) || !mxnet::ndim_is_known(mask_shape)) {
+ return false;
+ }
+ CHECK(data_shape.ndim() == mask_shape.ndim())
+ << "Number of dimensions in data and mask does not match";
+ CHECK(data_shape.ndim() > 0)
+ << "Empty tuple is not allowed";
+
+ for (int i = 0; i < data_shape.ndim(); ++i) {
+ CHECK(data_shape[i] == mask_shape[i] || mask_shape[i] == 1)
+ << "Mask cannot be broadcasted from " << mask_shape << " to " <<
data_shape;
+ }
+ SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(0));
+ SHAPE_ASSIGN_CHECK(*in_shape, 0, out_shape->at(0));
+ return true;
+}
+
+static inline bool MaskedSoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_shape,
+ mxnet::ShapeVector *out_shape) {
+ CHECK_EQ(out_shape->size(), 1U);
+ CHECK_EQ(in_shape->size(), 3U);
+
+ mxnet::TShape& ograd_shape = (*in_shape)[0];
+ mxnet::TShape& mask_shape = (*in_shape)[1];
+
+ if (!mxnet::ndim_is_known(ograd_shape) || !mxnet::ndim_is_known(mask_shape))
{
+ return false;
+ }
+ CHECK(ograd_shape.ndim() == mask_shape.ndim())
+ << "Number of dimensions in data and mask does not match";
+ CHECK(ograd_shape.ndim() > 0)
+ << "Empty tuple is not allowed";
+
+ for (int i = 0; i < ograd_shape.ndim(); ++i) {
+ CHECK(ograd_shape[i] == mask_shape[i] || mask_shape[i] == 1)
+ << "Mask cannot be broadcasted from " << mask_shape << " to " <<
ograd_shape;
+ }
+
+ SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(0));
+ SHAPE_ASSIGN_CHECK(*out_shape, 0, in_shape->at(2));
+ SHAPE_ASSIGN_CHECK(*in_shape, 0, out_shape->at(0));
+ SHAPE_ASSIGN_CHECK(*in_shape, 2, out_shape->at(0));
+ return true;
+}
+
+static inline bool MaskedSoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(out_attrs->size(), 1U);
+ CHECK_EQ(in_attrs->size(), 3U);
+ int data_dtype = (*in_attrs)[0];
+ TYPE_ASSIGN_CHECK(*in_attrs, 2, data_dtype);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, data_dtype);
+ data_dtype = (*out_attrs)[0];
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, data_dtype);
+
+ return true;
+}
+
+static inline std::vector<std::pair<int, int> >
+MaskedSoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 1}, {3, 0}};
+}
+
template<typename xpu, typename OP, bool negate = false>
void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -836,6 +1492,48 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu, typename OP, bool negate = false>
+void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+ CHECK_NE(req[0], kAddTo);
+ const MaskedSoftmaxParam& param =
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
+ int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double scale = param.scale_factor.has_value() ?
+ param.scale_factor.value() : 1.0;
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+ if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
+ common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for
masked_softmax with "
+ "float16 inputs. "
+ "See https://mxnet.apache.org/api/faq/env_var "
+ "for more details.");
+ }
+ MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+ MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
+ bool* mask_ptr = inputs[1].dptr<bool>();
+ if (safe_acc) {
+ MaskedSoftmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<DType>(), mask_ptr,
+ inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+ axis, scale, temperature, param.normalize.value(), ctx);
+ } else {
+ MaskedSoftmax<OP, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<DType>(), mask_ptr,
+ inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+ axis, scale, temperature, param.normalize.value(), ctx);
+ }
+ });
+ });
+}
+
template<typename xpu, typename OP1, typename OP2, bool negate = false>
void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
@@ -907,6 +1605,50 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu, typename OP1, typename OP2, bool negate = false>
+void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+
+ if (req[0] == kNullOp) return;
+ const MaskedSoftmaxParam& param =
nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
+ int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double scale = param.scale_factor.has_value() ?
+ param.scale_factor.value() : 1.0;
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
+
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+ MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+ MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
+ DType* ograd_ptr = inputs[0].dptr<DType>();
+ DType* out_ptr = inputs[2].dptr<DType>();
+ bool* mask_ptr = inputs[1].dptr<bool>();
+ DType* grad_data = outputs[0].dptr<DType>();
+ if (safe_acc) {
+ MaskedSoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), out_ptr,
+ ograd_ptr, grad_data, mask_ptr,
+ inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+ axis, static_cast<DType>(scale),
+ static_cast<DType>(temperature), ctx);
+ } else {
+ MaskedSoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ ctx.get_stream<xpu>(), out_ptr,
+ ograd_ptr, grad_data, mask_ptr,
+ inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
+ axis, static_cast<DType>(scale),
+ static_cast<DType>(temperature), ctx);
+ }
+ });
+ });
+ });
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 9b28b71..cf67853 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -34,6 +34,7 @@
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SoftmaxParam);
+DMLC_REGISTER_PARAMETER(MaskedSoftmaxParam);
#if MXNET_USE_MKLDNN == 1
static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
@@ -187,5 +188,65 @@ NNVM_REGISTER_OP(_backward_softmax)
#endif
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu,
op::mshadow_op::mul,
mxnet_op::softmax_bwd>);
+
+NNVM_REGISTER_OP(masked_softmax)
+.add_alias("_npx_masked_softmax")
+.describe(R"code(Applies the softmax function masking elements according to
the mask provided)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::string>{"data", "mask"};
+ })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"output"};
+ })
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu,
mxnet_op::softmax_fwd>)
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ auto data_grad = MakeNode("_backward_masked_softmax", n->attrs.name +
"_backward_data",
+ {ograds[0], n->inputs[1], nnvm::NodeEntry(n, 0,
0)},
+ &n->attrs.dict, &n);
+ auto mask_grad = MakeNode("zeros_like", n->attrs.name + "_backward_mask",
+ {n->inputs[1]}, nullptr, &n);
+ std::vector<nnvm::NodeEntry> ret;
+ ret.emplace_back(data_grad);
+ ret.emplace_back(mask_grad);
+ return ret;
+ })
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxOpType)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxOpShape)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("mask", "NDArray-or-Symbol", "Mask to apply.")
+.add_arguments(MaskedSoftmaxParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_masked_softmax)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::string>{"ograd", "mask", "output"};
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", MaskedSoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", MaskedSoftmaxGradOpType)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
MaskedSoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
+.set_attr_parser(ParamParser<MaskedSoftmaxParam>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxGradCompute<cpu,
op::mshadow_op::mul,
+
mxnet_op::softmax_bwd>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index d5762cf..dc8fd99 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -34,6 +34,11 @@ NNVM_REGISTER_OP(softmax)
NNVM_REGISTER_OP(_backward_softmax)
.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu,
op::mshadow_op::mul,
mxnet_op::softmax_bwd>);
+NNVM_REGISTER_OP(masked_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu,
mxnet_op::softmax_fwd>);
+NNVM_REGISTER_OP(_backward_masked_softmax)
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu,
op::mshadow_op::mul,
+
mxnet_op::softmax_bwd>);
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 8a18d5d..31a6aa1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4941,6 +4941,94 @@ def test_softmax_with_length():
[np.zeros(shape), np.zeros(len_shape,
dtype=np.int32)],
rtol=1e-2, atol=2e-3 if dtype == np.float16
else 1e-3, dtype="asnumpy")
+def np_softmax(x, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
+ x = x / scale_factor
+ if normalize:
+ x = x - np.max(x, axis=axis, keepdims=True)
+ x = np.exp(x / temperature)
+ x /= np.sum(x, axis=axis, keepdims=True)
+ return x
+
+def np_masked_softmax(data, mask, axis=-1, scale_factor=1.0, temperature=1.0,
normalize=True):
+ neg = -1e18
+ if data.dtype == np.float16:
+ neg = -1e4
+ temp = np.where(mask, data, neg)
+ result = np_softmax(temp, axis=axis,
+ scale_factor=scale_factor,
+ temperature=temperature,
+ normalize=normalize) * mask
+ return result
+def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0,
temperature=1.0):
+ temp = np.sum(out * grad_out, axis=axis, keepdims=True)
+ result = out * (grad_out - temp) / (temperature * scale_factor)
+ return result
+
[email protected]('dtype', [np.float16, np.float32, np.float64])
[email protected]('axis', [0, -1, -2, -3])
[email protected]('ndims', [3, 4, 5])
[email protected]('n_broadcast_axis', [0, 1, 2])
[email protected]('temperature', [1, 5, 9 ,11])
[email protected]('scale', [1, 2, 7, 12])
[email protected]('normalize', [True])
+def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature,
scale, normalize):
+ n_broadcast_axis = min(n_broadcast_axis, ndims - 1)
+ shape = rand_shape_nd(ndims, dim=10)
+ mx_data = rand_ndarray(shape, dtype=dtype)
+ bcst_dims = []
+ while len(bcst_dims) < n_broadcast_axis:
+ ax = np.random.randint(0, ndims)
+ if ax not in bcst_dims :
+ bcst_dims.append(ax)
+ shape_mask = list(shape)
+ for i in bcst_dims:
+ shape_mask[i] = 1
+
+ np_data = mx_data.asnumpy()
+ np_mask = np.random.randint(0, 2, shape_mask)
+ mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+ mx_grad = rand_ndarray(shape, dtype=dtype)
+ np_grad = mx_grad.asnumpy()
+
+ np_out = np_masked_softmax(np_data, np_mask, axis,
+ scale, temperature, normalize)
+ np_grad_out = np_masked_softmax_grad(np_out, np_grad,
+ axis, scale, temperature)
+ data = mx.sym.Variable("data")
+ mask = mx.sym.Variable("mask")
+ mx_sym = mx.sym.masked_softmax(data=data, mask=mask, scale_factor=scale,
+ temperature=temperature, axis=axis,
+ normalize=normalize)
+ location = {"data": mx_data, "mask": mx_mask}
+ rtol = 1e-2 if dtype == np.float16 else 1e-3
+ atol = 1e-4 if dtype == np.float16 else 1e-5
+ check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol,
+ dtype="asnumpy", equal_nan=True)
+ check_symbolic_backward(mx_sym, location, [mx_grad],
+ [np_grad_out, np.zeros(shape, dtype=np.bool)],
+ rtol=1e-2, atol=2e-3 if dtype == np.float16 else
1e-3,
+ dtype="asnumpy", equal_nan=True)
+
+
[email protected]('dtype', ['float32'])
[email protected]('ndims', [1, 2, 3, 4, 5])
+def test_masked_log_softmax(dtype, ndims):
+ shape = np.random.randint(1, 5, size=ndims)
+ axis = np.random.randint(0, ndims)
+ mx_data = rand_ndarray(shape, dtype=dtype)
+ np_data = mx_data.asnumpy()
+ np_mask = np.random.randint(0, 2, shape)
+ mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+ np_out = np.log(np_masked_softmax(np_data, np_mask, axis)+1e-20) * np_mask
+ data = mx.sym.Variable("data")
+ mask = mx.sym.Variable("mask")
+ mx_sym = mx.sym.masked_log_softmax(data=data, mask=mask, axis=axis-ndims)
+ location = {"data": mx_data, "mask": mx_mask}
+ rtol = 1e-2 if dtype == np.float16 else 1e-3
+ atol = 1e-4 if dtype == np.float16 else 1e-5
+ check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol,
dtype="asnumpy")
+ check_numeric_gradient(mx_sym, location, rtol=1e-1, atol=1e-2)
+
def test_pick():
def test_pick_helper(index_type=np.int32):