This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d7b39f4 Add support for kAddTo in softmax backward (#11836)
d7b39f4 is described below
commit d7b39f4c055a46c1b819b0dff4ae468c1983e016
Author: Hao Jin <[email protected]>
AuthorDate: Mon Aug 20 22:38:08 2018 -0700
Add support for kAddTo in softmax backward (#11836)
---
src/operator/contrib/ctc_loss-inl.h | 2 +-
src/operator/nn/softmax-inl.h | 44 ++++++++++++++++++----------------
tests/python/unittest/test_operator.py | 19 +++++++++------
3 files changed, 37 insertions(+), 28 deletions(-)
diff --git a/src/operator/contrib/ctc_loss-inl.h
b/src/operator/contrib/ctc_loss-inl.h
index 0e7b63e..72209ae 100644
--- a/src/operator/contrib/ctc_loss-inl.h
+++ b/src/operator/contrib/ctc_loss-inl.h
@@ -426,7 +426,7 @@ class CTCLossOp : public Operator {
workspace_bytes));
if (req_grad) {
- mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd>(s,
+ mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd,
kWriteTo>(s,
prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
}
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 64b436e..4a19db7 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -111,7 +111,7 @@ struct log_softmax_bwd {
};
-template<typename OP1, typename OP2, typename DType, int ndim>
+template<typename OP1, typename OP2, int Req, typename DType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
@@ -134,13 +134,16 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType
*ograd,
// By default temperature is 1.0, and only in reinforcement training
// users would set it to other values.
// Adding a branch here to save the CPU 'divide-by-1' computation at
runtime
+ DType final_result;
if (temperature == 1.0) {
for (index_t j = 0; j < M; ++j) {
- igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa],
sum);
+ final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
+ KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
}
} else {
for (index_t j = 0; j < M; ++j) {
- igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa],
sum)/temperature;
+ final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) /
temperature;
+ KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
}
}
}
@@ -202,7 +205,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
}
-template<int x_bits, typename OP1, typename OP2, typename DType, int ndim>
+template<int x_bits, typename OP1, typename OP2, int Req, typename DType, int
ndim>
__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim>
sshape,
Shape<ndim> stride, const double
temperature) {
@@ -222,14 +225,16 @@ __global__ void softmax_gradient_kernel(DType *out, DType
*ograd, DType *igrad,
DType ssum = smem[0];
__syncthreads();
+ DType final_result;
for (index_t i = x; i < M; i += x_size) {
- igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/
- static_cast<DType>(temperature);
+ final_result =
+ OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) /
static_cast<DType>(temperature);
+ KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
}
}
-template<typename OP1, typename OP2, typename DType, int ndim>
+template<typename OP1, typename OP2, int Req, typename DType, int ndim>
inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
@@ -241,7 +246,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType
*ograd,
Shape<ndim> sshape = shape;
sshape[axis] = 1;
- softmax_gradient_kernel<x_bits, OP1, OP2, DType, ndim>
+ softmax_gradient_kernel<x_bits, OP1, OP2, Req, DType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
@@ -298,24 +303,23 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
- CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
- if (shape.ndim() == 2) {
- SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
- inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<2>(), axis,
- static_cast<DType>(temperature));
- } else {
- SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
- inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<3>(), axis,
- static_cast<DType>(temperature));
- }
+ MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+ if (shape.ndim() == 2) {
+ SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(),
inputs[1].dptr<DType>(),
+ inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(),
+ shape.get<2>(), axis,
static_cast<DType>(temperature));
+ } else {
+ SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(),
inputs[1].dptr<DType>(),
+ inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(),
+ shape.get<3>(), axis,
static_cast<DType>(temperature));
+ }
+ });
});
}
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 0ff9a10..e1e5c9e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4522,13 +4522,18 @@ def test_where():
@with_seed()
def test_new_softmax():
for ndim in range(1, 5):
- for _ in range(5):
- shape = np.random.randint(1, 5, size=ndim)
- axis = np.random.randint(-ndim, ndim)
- data = np.random.uniform(-2, 2, size=shape)
- sym = mx.sym.softmax(axis=axis)
- check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)])
- check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3)
+ shape = np.random.randint(1, 5, size=ndim)
+ axis = np.random.randint(-ndim, ndim)
+ data = np.random.uniform(-2, 2, size=shape)
+ sym = mx.sym.softmax(axis=axis)
+ expected_fwd = np_softmax(data, axis=axis)
+ expected_bwd = np.zeros(shape)
+ check_symbolic_forward(sym, [data], [expected_fwd])
+ for req in ['null', 'add', 'write']:
+ check_symbolic_backward(sym, [data],
[np.ones(expected_fwd.shape)], [expected_bwd],
+ rtol=1e-2, atol=1e-3, grad_req=req)
+ check_numeric_gradient(sym, [data], rtol=1e-2, atol=1e-3)
+
@with_seed()
def test_softmax_with_temperature():