This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 117c509  Fix bug for Dropout with axes, also adding unit test (#10030)
117c509 is described below

commit 117c5095fb57b5e9bae36209e133626311d2b815
Author: Hang Zhang <8041160+zhanghang1...@users.noreply.github.com>
AuthorDate: Wed Mar 7 17:44:24 2018 -0800

    Fix bug for Dropout with axes, also adding unit test (#10030)
    
    * fix bug
    
    * add test for dropout with axes
---
 src/operator/nn/dropout-inl.h          |  2 +-
 tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h
index b57ab45..1af4798 100644
--- a/src/operator/nn/dropout-inl.h
+++ b/src/operator/nn/dropout-inl.h
@@ -259,7 +259,7 @@ class DropoutOp {
             return;
           }
           // initialize the mask
-          LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(),
+          LaunchRNG<BernoulliKernel, xpu>(s, pgen, mask.Size(),
                                           mask.dptr<DType>(),
                                           this->pkeep_);
           // broadcast mul
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 1ee14b6..91b8faa 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4645,6 +4645,27 @@ def test_dropout():
             exe.backward([mx.nd.ones(shape)], is_train=False)
             assert (exe.grad_arrays[0].asnumpy() == 
exe.outputs[0].asnumpy()).all()
 
+    def get_slice(x, axis, idx):
+        ix = ()
+        for i in range(x.ndim):
+            if i == axis:
+                ix += (idx,)
+            else:
+                ix += (slice(None, None, None),)
+        return x[ix]
+
+    def check_dropout_axes(ratio, shape, axes):
+        compactshape = list(shape)
+        for axis in axes:
+            compactshape[axis] = 1
+        compactx = mx.random.uniform(shape=tuple(compactshape))
+        broadcastx = compactx.broadcast_to(shape)
+        dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes)
+        for axis in axes:
+            target = get_slice(dropouty, axis, 0).asnumpy()
+            for i in range(1, shape[axis]):
+                assert(get_slice(dropouty, axis, i).asnumpy() == target).all()
+
     shape = (100, 100)
     check_dropout_ratio(0.5, shape)
     check_dropout_ratio(0.0, shape)
@@ -4652,6 +4673,21 @@ def test_dropout():
     check_dropout_ratio(0.75, shape)
     check_dropout_ratio(0.25, shape)
 
+    nshape = (10, 10, 10, 10)
+    check_dropout_axes(0.25, nshape, axes = (0,))
+    check_dropout_axes(0.25, nshape, axes = (1,))
+    check_dropout_axes(0.25, nshape, axes = (2,))
+    check_dropout_axes(0.25, nshape, axes = (3,))
+    check_dropout_axes(0.25, nshape, axes = (0, 1))
+    check_dropout_axes(0.25, nshape, axes = (0, 2))
+    check_dropout_axes(0.25, nshape, axes = (0, 3))
+    check_dropout_axes(0.25, nshape, axes = (1, 2))
+    check_dropout_axes(0.25, nshape, axes = (1, 3))
+    check_dropout_axes(0.25, nshape, axes = (2, 3))
+    check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
+    check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
+    check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
+
 
 @with_seed()
 def test_scatter_gather_nd():

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to