This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b1f2f44 [MXNET-560] Add temperature parameter in Softmax operator
(#11466)
b1f2f44 is described below
commit b1f2f44118887debc380d43711257b7df099678a
Author: Lin Yuan <[email protected]>
AuthorDate: Thu Jul 19 12:56:06 2018 -0700
[MXNET-560] Add temperature parameter in Softmax operator (#11466)
* Add temperature parameter in softmax operator and add a unit test
* Optimize runtime when temperature is set to default 1.0
* Add temperature parameter in softmax operator and add a unit test
---
CONTRIBUTORS.md | 1 +
cpp-package/scripts/OpWrapperGenerator.py | 1 +
src/operator/contrib/ctc_loss-inl.h | 4 +-
src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +
src/operator/nn/mkldnn/mkldnn_softmax.cc | 9 ++++
src/operator/nn/softmax-inl.h | 82 ++++++++++++++++++++++---------
src/operator/nn/softmax.cc | 7 ++-
tests/python/unittest/test_operator.py | 16 +++++-
8 files changed, 94 insertions(+), 28 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index c28214d..b04e4a3 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -172,6 +172,7 @@ List of Contributors
* [Thomas Delteil](https://github.com/ThomasDelteil)
* [Jesse Brizzi](https://github.com/jessebrizzi)
* [Hang Zhang](http://hangzh.com)
+* [Lin Yuan](https://github.com/apeforest)
* [Kou Ding](https://github.com/chinakook)
* [Istvan Fehervari](https://github.com/ifeherva)
* [Aaron Markham](https://github.com/aaronmarkham)
diff --git a/cpp-package/scripts/OpWrapperGenerator.py
b/cpp-package/scripts/OpWrapperGenerator.py
index 8facde1..1b5f8b5 100644
--- a/cpp-package/scripts/OpWrapperGenerator.py
+++ b/cpp-package/scripts/OpWrapperGenerator.py
@@ -95,6 +95,7 @@ class Arg:
'int or None':'dmlc::optional<int>',\
'long':'int64_t',\
'double':'double',\
+ 'double or None':'dmlc::optional<double>',\
'Shape or None':'dmlc::optional<Shape>',\
'string':'const std::string&'}
name = ''
diff --git a/src/operator/contrib/ctc_loss-inl.h
b/src/operator/contrib/ctc_loss-inl.h
index ef58c51..0e7b63e 100644
--- a/src/operator/contrib/ctc_loss-inl.h
+++ b/src/operator/contrib/ctc_loss-inl.h
@@ -409,7 +409,7 @@ class CTCLossOp : public Operator {
// since the input is activation before softmax and cudnn ctc takes softmax
// apply softmax to inputs first.
- mxnet_op::Softmax<mxnet_op::softmax_fwd>(s, data.dptr_, prob.dptr_,
data.shape_, 2);
+ mxnet_op::Softmax<mxnet_op::softmax_fwd>(s, data.dptr_, prob.dptr_,
data.shape_, 2, 1.0);
CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_,
prob_desc_,
@@ -427,7 +427,7 @@ class CTCLossOp : public Operator {
if (req_grad) {
mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd>(s,
- prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2);
+ prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
}
}
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index f77d113..bbfb873 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -146,9 +146,11 @@ namespace op {
struct ActivationParam;
struct ConvolutionParam;
struct DeconvolutionParam;
+struct SoftmaxParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray
&input);
+bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
}
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc
b/src/operator/nn/mkldnn/mkldnn_softmax.cc
index acfa358..7268ed3 100644
--- a/src/operator/nn/mkldnn/mkldnn_softmax.cc
+++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc
@@ -32,6 +32,15 @@
namespace mxnet {
namespace op {
+bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) {
+ // MKLDNN does not support temperature argument in their softmax function
+ // now. Need update this once they start to support it.
+ if (param.temperature.has_value()) {
+ return false;
+ }
+ return true;
+}
+
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data) {
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 080bc08..64b436e 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -53,7 +53,7 @@ struct log_softmax_fwd {
template<typename OP, typename DType, int ndim>
inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
- Shape<ndim> shape, int axis) {
+ Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
@@ -71,12 +71,25 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
}
DType sum = DType(0);
- for (index_t j = 0; j < M; ++j) {
- sum += std::exp(in[base + j*sa] - mmax);
- }
+ // By default temperature is 1.0, and only in reinforcement training
+ // users would set it to other values.
+ // Adding a branch here to save the CPU 'divide-by-1' computation at
runtime
+ if (temperature == 1.0) {
+ for (index_t j = 0; j < M; ++j) {
+ sum += std::exp(in[base + j*sa] - mmax);
+ }
+
+ for (index_t j = 0; j < M; ++j) {
+ out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum);
+ }
+ } else {
+ for (index_t j = 0; j < M; ++j) {
+ sum += std::exp((in[base + j*sa] - mmax)/temperature);
+ }
- for (index_t j = 0; j < M; ++j) {
- out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum);
+ for (index_t j = 0; j < M; ++j) {
+ out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum);
+ }
}
}
}
@@ -100,7 +113,8 @@ struct log_softmax_bwd {
template<typename OP1, typename OP2, typename DType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
- DType *igrad, Shape<ndim> shape, int axis) {
+ DType *igrad, Shape<ndim> shape, int axis,
+ const DType temperature) {
index_t M = shape[axis];
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
@@ -117,8 +131,17 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType
*ograd,
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
}
- for (index_t j = 0; j < M; ++j) {
- igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
+ // By default temperature is 1.0, and only in reinforcement training
+ // users would set it to other values.
+ // Adding a branch here to save the CPU 'divide-by-1' computation at
runtime
+ if (temperature == 1.0) {
+ for (index_t j = 0; j < M; ++j) {
+ igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa],
sum);
+ }
+ } else {
+ for (index_t j = 0; j < M; ++j) {
+ igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa],
sum)/temperature;
+ }
}
}
}
@@ -127,7 +150,8 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType
*ograd,
#ifdef __CUDACC__
template<int x_bits, typename OP, typename DType, int ndim>
__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int
axis,
- Shape<ndim> sshape, Shape<ndim> stride)
{
+ Shape<ndim> sshape, Shape<ndim> stride,
+ const double temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ DType smem[x_size];
index_t sa = stride[axis];
@@ -146,7 +170,8 @@ __global__ void softmax_compute_kernel(DType *in, DType
*out, index_t M, int axi
red::sum::SetInitValue(smem[x]);
for (index_t i = x; i < M; i += x_size) {
- red::sum::Reduce(smem[x], static_cast<DType>(expf(in[base + i*sa] -
smax)));
+ red::sum::Reduce(smem[x], static_cast<DType>(expf((in[base + i*sa] - smax)/
+ static_cast<DType>(temperature))));
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
@@ -155,13 +180,13 @@ __global__ void softmax_compute_kernel(DType *in, DType
*out, index_t M, int axi
__syncthreads();
for (index_t i = x; i < M; i += x_size) {
- out[base + i*sa] = OP::Map(in[base + i*sa] - smax, ssum);
+ out[base + i*sa] = OP::Map((in[base + i*sa] -
smax)/static_cast<DType>(temperature), ssum);
}
}
template<typename OP, typename DType, int ndim>
inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
- Shape<ndim> shape, int axis) {
+ Shape<ndim> shape, int axis, const double temperature) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
@@ -172,7 +197,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
softmax_compute_kernel<x_bits, OP, DType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, M, axis, sshape, stride);
+ in, out, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
}
@@ -180,7 +205,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
template<int x_bits, typename OP1, typename OP2, typename DType, int ndim>
__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim>
sshape,
- Shape<ndim> stride) {
+ Shape<ndim> stride, const double
temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ DType smem[x_size];
index_t sa = stride[axis];
@@ -198,14 +223,16 @@ __global__ void softmax_gradient_kernel(DType *out, DType
*ograd, DType *igrad,
__syncthreads();
for (index_t i = x; i < M; i += x_size) {
- igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+ igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/
+ static_cast<DType>(temperature);
}
}
template<typename OP1, typename OP2, typename DType, int ndim>
inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
- DType *igrad, Shape<ndim> shape, int axis) {
+ DType *igrad, Shape<ndim> shape, int axis,
+ const double temperature) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
@@ -216,7 +243,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType
*ograd,
softmax_gradient_kernel<x_bits, OP1, OP2, DType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
- out, ograd, igrad, M, axis, sshape, stride);
+ out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
}
#endif
@@ -226,9 +253,12 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType
*ograd,
struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
int axis;
+ dmlc::optional<double> temperature;
DMLC_DECLARE_PARAMETER(SoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis along which to compute softmax.");
+ DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
+ .describe("Temperature parameter in softmax");
}
};
@@ -243,14 +273,18 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
if (shape.ndim() == 2) {
Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<DType>(), shape.get<2>(), axis);
+ outputs[0].dptr<DType>(), shape.get<2>(), axis,
+ static_cast<DType>(temperature));
} else {
Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<DType>(), shape.get<3>(), axis);
+ outputs[0].dptr<DType>(), shape.get<3>(), axis,
+ static_cast<DType>(temperature));
}
});
}
@@ -267,16 +301,20 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
+ const double temperature = param.temperature.has_value() ?
+ param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<2>(), axis);
+ shape.get<2>(), axis,
+ static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<3>(), axis);
+ shape.get<3>(), axis,
+ static_cast<DType>(temperature));
}
});
}
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index e9b104f..e855608 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -39,7 +39,8 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// It seems MKLDNN softmax doesn't support training.
- if (SupportMKLDNN(inputs[0]) && !ctx.is_train) {
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ if (SupportMKLDNN(inputs[0]) && !ctx.is_train &&
SupportMKLDNNSoftmax(param)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
@@ -77,10 +78,12 @@ MXNET_OPERATOR_REGISTER_UNARY(softmax)
The resulting array contains elements in the range (0,1) and the elements
along the given axis sum up to 1.
.. math::
- softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}}
+ softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}}
for :math:`j = 1, ..., K`
+t is the temperature parameter in softmax function. By default, t equals 1.0
+
Example::
x = [[ 1. 1. 1.]
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 814266a..c870709 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -267,11 +267,11 @@ def test_rnnrelu_dropout():
out = exe.forward(is_train=True)
out[0].wait_to_read()
-def np_softmax(x, axis=-1):
+def np_softmax(x, axis=-1, temperature=1.0):
# fix for old numpy on Travis not supporting keepdims
# x = x - np.max(x, axis=-1, keepdims=True)
x = x - np.max(x, axis=axis, keepdims=True)
- x = np.exp(x)
+ x = np.exp(x/temperature)
# x /= np.sum(x, axis=-1, keepdims=True)
x /= np.sum(x, axis=axis, keepdims=True)
return x
@@ -4357,6 +4357,18 @@ def test_new_softmax():
check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)])
check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3)
+@with_seed()
+def test_softmax_with_temperature():
+ for ndim in range(1, 5):
+ shape = np.random.randint(1, 5, size=ndim)
+ data = np.random.uniform(-2, 2, size=shape)
+ for temp in range(1, 11):
+ sym = mx.sym.softmax(axis=0, temperature=temp)
+ expected_fwd = np_softmax(data, axis=0, temperature=temp)
+ expected_bwd = np.zeros(shape)
+ check_symbolic_forward(sym, [data], [expected_fwd], rtol=0.05,
atol=1e-3)
+ check_symbolic_backward(sym, [data], [np.ones(shape)],
[expected_bwd], rtol=0.05, atol=1e-3)
+ check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3)
@with_seed()
def test_log_softmax():