piiswrong commented on a change in pull request #9662: Gluon PReLU, ELU, SELU, 
Swish
URL: https://github.com/apache/incubator-mxnet/pull/9662#discussion_r166086384
 
 

 ##########
 File path: tests/python/unittest/test_gluon.py
 ##########
 @@ -719,6 +719,42 @@ def test_inline():
     assert len_1 == len_2 + 2
 
 
+def test_activations():
+    point_to_validate = mx.nd.array([-0.1, 0.1])
+
+    swish = mx.gluon.nn.Swish()
+    def swish_test(x):
+        return x * mx.nd.sigmoid(x)
+
+    for test_point, ref_point in zip(swish_test(point_to_validate), 
swish(point_to_validate)):
+        assert test_point == ref_point
+
+    elu = mx.gluon.nn.ELU()
+    def elu_test(x):
+        def elu(x):
+            return 1.0 * (mx.nd.exp(x) - 1) if x < 0 else x
+        return [elu(x_i) for x_i in x]
+
+    for test_point, ref_point in zip(elu_test(point_to_validate), 
elu(point_to_validate)):
+        assert test_point == ref_point
+
+    selu = mx.gluon.nn.SELU()
+    def selu_test(x):
+        def selu(x):
+            scale, alpha = 1.0507009873554804934193349852946, 
1.6732632423543772848170429916717
+            return scale * x if x >= 0 else alpha * mx.nd.exp(x) - alpha
+        return [selu(x_i) for x_i in x]
+
+    for test_point, ref_point in zip(selu(point_to_validate), 
selu(point_to_validate)):
+        assert test_point == ref_point
+
+    prelu = mx.gluon.nn.PReLU()
+    prelu.initialize()
+    x = point_to_validate.reshape((1, 1, 2))
 
 Review comment:
   use a different input shape that can catch the infershape problem

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to