This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 494c29e  [BUGFIX] fix ELU function will appear nan when calculating 
the gradient (#14673)
494c29e is described below

commit 494c29e88cbb0ebf96bd8bb83a9a738b9f4d67e5
Author: 夏鲁豫 <[email protected]>
AuthorDate: Tue Apr 23 13:58:42 2019 +0800

    [BUGFIX] fix ELU function will appear nan when calculating the gradient 
(#14673)
    
    * fix ELU
    
    * fix
    
    * fix
    
    * fix
    
    * fix
    
    * fix
---
 python/mxnet/gluon/nn/activations.py | 3 ++-
 tests/python/unittest/test_gluon.py  | 2 +-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/gluon/nn/activations.py 
b/python/mxnet/gluon/nn/activations.py
index c7dc831..8c51b0a 100644
--- a/python/mxnet/gluon/nn/activations.py
+++ b/python/mxnet/gluon/nn/activations.py
@@ -153,12 +153,13 @@ class ELU(HybridBlock):
     Outputs:
         - **out**: output tensor with the same shape as `data`.
     """
+
     def __init__(self, alpha=1.0, **kwargs):
         super(ELU, self).__init__(**kwargs)
         self._alpha = alpha
 
     def hybrid_forward(self, F, x):
-        return F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))
+        return F.LeakyReLU(x, act_type='elu', slope=self._alpha)
 
 
 class SELU(HybridBlock):
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 8c60ef6..efa04f4 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1180,7 +1180,7 @@ def test_activations():
     elu = mx.gluon.nn.ELU()
     def elu_test(x):
         def elu(x):
-            return 1.0 * (mx.nd.exp(x) - 1) if x < 0 else x
+            return mx.nd.expm1(x) if x <= 0.0 else x
         return [elu(x_i) for x_i in x]
 
     for test_point, ref_point in zip(elu_test(point_to_validate), 
elu(point_to_validate)):

Reply via email to