This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 117c509 Fix bug for Dropout with axes, also adding unit test (#10030) 117c509 is described below commit 117c5095fb57b5e9bae36209e133626311d2b815 Author: Hang Zhang <8041160+zhanghang1...@users.noreply.github.com> AuthorDate: Wed Mar 7 17:44:24 2018 -0800 Fix bug for Dropout with axes, also adding unit test (#10030) * fix bug * add test for dropout with axes --- src/operator/nn/dropout-inl.h | 2 +- tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index b57ab45..1af4798 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -259,7 +259,7 @@ class DropoutOp { return; } // initialize the mask - LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(), + LaunchRNG<BernoulliKernel, xpu>(s, pgen, mask.Size(), mask.dptr<DType>(), this->pkeep_); // broadcast mul diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1ee14b6..91b8faa 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4645,6 +4645,27 @@ def test_dropout(): exe.backward([mx.nd.ones(shape)], is_train=False) assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all() + def get_slice(x, axis, idx): + ix = () + for i in range(x.ndim): + if i == axis: + ix += (idx,) + else: + ix += (slice(None, None, None),) + return x[ix] + + def check_dropout_axes(ratio, shape, axes): + compactshape = list(shape) + for axis in axes: + compactshape[axis] = 1 + compactx = mx.random.uniform(shape=tuple(compactshape)) + broadcastx = compactx.broadcast_to(shape) + dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes) + for axis in axes: + target = get_slice(dropouty, axis, 0).asnumpy() + for i in range(1, shape[axis]): + assert(get_slice(dropouty, axis, i).asnumpy() == target).all() + shape = (100, 100) check_dropout_ratio(0.5, shape) check_dropout_ratio(0.0, shape) @@ -4652,6 +4673,21 @@ def test_dropout(): check_dropout_ratio(0.75, shape) check_dropout_ratio(0.25, shape) + nshape = (10, 10, 10, 10) + check_dropout_axes(0.25, nshape, axes = (0,)) + check_dropout_axes(0.25, nshape, axes = (1,)) + check_dropout_axes(0.25, nshape, axes = (2,)) + check_dropout_axes(0.25, nshape, axes = (3,)) + check_dropout_axes(0.25, nshape, axes = (0, 1)) + check_dropout_axes(0.25, nshape, axes = (0, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2)) + check_dropout_axes(0.25, nshape, axes = (1, 3)) + check_dropout_axes(0.25, nshape, axes = (2, 3)) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + @with_seed() def test_scatter_gather_nd(): -- To stop receiving notification emails like this one, please contact zhash...@apache.org.