This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f7f62c  [operator] add threshold for mish (#20339)
4f7f62c is described below

commit 4f7f62cf60b47265b142adec7fd348fcaf534406
Author: herewj <[email protected]>
AuthorDate: Wed Jun 16 06:21:17 2021 +0800

    [operator] add threshold for mish (#20339)
    
    * add threshold for mish
    
    Signed-off-by: Adnios <[email protected]>
    
    * sanity
    
    Signed-off-by: Adnios <[email protected]>
    
    * fix op::softrelu error
    
    * use op::softrelu
    
    Signed-off-by: Adnios <[email protected]>
    
    * try to fix error "op::softrelu and op::tanh"
    
    Signed-off-by: Adnios <[email protected]>
    
    * back to
    
    * move mish after tanh
---
 src/common/cuda/rtc/backward_functions-inl.h | 17 ++++++++---------
 src/common/cuda/rtc/forward_functions-inl.h  | 14 +++++---------
 src/operator/mshadow_op.h                    | 28 ++++++++++++++++++++++++----
 3 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/src/common/cuda/rtc/backward_functions-inl.h 
b/src/common/cuda/rtc/backward_functions-inl.h
index cacb8be..85135ae 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -52,15 +52,6 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
 
 template <typename DType, typename DTypeGrad>
 __device__ inline mixed_type<DTypeGrad, DType>
-backward_mish(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  const auto softrelu = op::log(1 + exp(v));
-  const auto tanh = op::tanh(softrelu);
-  return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
 backward_softrelu(const DTypeGrad grad, const DType val) {
   const mixed_type<DTypeGrad, DType> v = val;
   return grad * sigmoid(v);
@@ -214,6 +205,14 @@ backward_arctanh(const DTypeGrad grad, const DType val) {
 
 template <typename DType, typename DTypeGrad>
 __device__ inline mixed_type<DTypeGrad, DType>
+backward_mish(const DTypeGrad grad, const DType val) {
+  const auto softrelu = op::softrelu(val);
+  const auto tanh_sr = op::tanh(softrelu);
+  return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr));
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
 backward_sqrt(const DTypeGrad grad, const DType out) {
   return 0.5 * grad / out;
 }
diff --git a/src/common/cuda/rtc/forward_functions-inl.h 
b/src/common/cuda/rtc/forward_functions-inl.h
index b353e92..7a886a0 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -695,15 +695,6 @@ __device__ inline DType log_sigmoid(const DType val) {
 }
 
 template <typename DType>
-__device__ inline DType mish(const DType val) {
-  if (type_util::has_double_or_integral<DType>::value) {
-    return val * ::tanh(::log(1 + ::exp(val)));
-  } else {
-    return val * ::tanhf(logf(1 + expf(val)));
-  }
-}
-
-template <typename DType>
 __device__ inline DType softrelu(const DType val) {
   // Avoid overflow of exp for large inputs.
   // The threshold 20 is chosen such that softrelu(a) = a
@@ -780,6 +771,11 @@ DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
 DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
 DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)
 
+template <typename DType>
+__device__ inline DType mish(const DType val) {
+  return val * op::tanh(op::softrelu(val));
+}
+
 // sqrt
 
 DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index e28dc86..611ddbc 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + 
math::exp(-a))));
 
 MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
 
-MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
+struct mish : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // reference softrelu
+    auto softrelu = math::log1p(math::exp(a));
+    if (a > DType(20.0f)) {
+      softrelu = a;
+    }
+    return DType(a * math::tanh(softrelu));
+  }
+};
 
-MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
-                        a * (1.0f / (1.0f + math::exp(-a))) *
-                        (1.0f - math::sqr(math::tanh(math::log(1.0f + 
math::exp(a))))));
+struct mish_grad : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // Note: the input(a) is x(not y)
+    auto softrelu = math::log1p(math::exp(a));
+    if (a > DType(20.0f)) {
+      softrelu = a;
+    }
+    auto tanh_sr = math::tanh(softrelu);
+    auto sr_grad = 1.0f / (1.0f + math::exp(-a));
+    return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr));
+  }
+};
 
 MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
 

Reply via email to