This is an automated email from the ASF dual-hosted git repository.
samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 468e21d Fix SoftReLU fused operator numerical stability (#17849)
(#19391)
468e21d is described below
commit 468e21d2a85d76ad7de44dcaefd9439de7d15d4f
Author: Manu Seth <[email protected]>
AuthorDate: Fri Oct 23 00:01:19 2020 -0700
Fix SoftReLU fused operator numerical stability (#17849) (#19391)
* fix numerically unstable fused softrelu op
* implement test for softrelu numerical stability
Co-authored-by: RuRo <[email protected]>
---
src/operator/fusion/fused_op-inl.h | 5 ++++-
tests/python/gpu/test_fusion.py | 3 +++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/src/operator/fusion/fused_op-inl.h
b/src/operator/fusion/fused_op-inl.h
index c838d85..0b10f82 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) {
template <typename DType>
__device__ inline DType softrelu(const DType val) {
- return logf(1 + expf(val));
+ // Avoid overflow of exp for large inputs.
+ // The threshold 20 is chosen such that softrelu(a) = a
+ // for a > 20 using floating precision.
+ return val > 20 ? val : logf(1 + expf(val));
}
template <typename DType>
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 1bbf598..1febf8d 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -138,6 +138,9 @@ def check_unary_ops():
for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
announce_check("Activation(act_type='{}')".format(act_type))
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
+ if act_type == 'softrelu':
+ # Check that softrelu implementation doesn't overflow on large
inputs
+ check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000
* arr)
# Cast requires dtype
for dtype in ['float16', 'float32', 'float64', 'int32']: