This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4f7f62c [operator] add threshold for mish (#20339)
4f7f62c is described below
commit 4f7f62cf60b47265b142adec7fd348fcaf534406
Author: herewj <[email protected]>
AuthorDate: Wed Jun 16 06:21:17 2021 +0800
[operator] add threshold for mish (#20339)
* add threshold for mish
Signed-off-by: Adnios <[email protected]>
* sanity
Signed-off-by: Adnios <[email protected]>
* fix op::softrelu error
* use op::softrelu
Signed-off-by: Adnios <[email protected]>
* try to fix error "op::softrelu and op::tanh"
Signed-off-by: Adnios <[email protected]>
* back to
* move mish after tanh
---
src/common/cuda/rtc/backward_functions-inl.h | 17 ++++++++---------
src/common/cuda/rtc/forward_functions-inl.h | 14 +++++---------
src/operator/mshadow_op.h | 28 ++++++++++++++++++++++++----
3 files changed, 37 insertions(+), 22 deletions(-)
diff --git a/src/common/cuda/rtc/backward_functions-inl.h
b/src/common/cuda/rtc/backward_functions-inl.h
index cacb8be..85135ae 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -52,15 +52,6 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
-backward_mish(const DTypeGrad grad, const DType val) {
- const mixed_type<DTypeGrad, DType> v = val;
- const auto softrelu = op::log(1 + exp(v));
- const auto tanh = op::tanh(softrelu);
- return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
backward_softrelu(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
return grad * sigmoid(v);
@@ -214,6 +205,14 @@ backward_arctanh(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
+backward_mish(const DTypeGrad grad, const DType val) {
+ const auto softrelu = op::softrelu(val);
+ const auto tanh_sr = op::tanh(softrelu);
+ return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr));
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
backward_sqrt(const DTypeGrad grad, const DType out) {
return 0.5 * grad / out;
}
diff --git a/src/common/cuda/rtc/forward_functions-inl.h
b/src/common/cuda/rtc/forward_functions-inl.h
index b353e92..7a886a0 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -695,15 +695,6 @@ __device__ inline DType log_sigmoid(const DType val) {
}
template <typename DType>
-__device__ inline DType mish(const DType val) {
- if (type_util::has_double_or_integral<DType>::value) {
- return val * ::tanh(::log(1 + ::exp(val)));
- } else {
- return val * ::tanhf(logf(1 + expf(val)));
- }
-}
-
-template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
// The threshold 20 is chosen such that softrelu(a) = a
@@ -780,6 +771,11 @@ DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)
+template <typename DType>
+__device__ inline DType mish(const DType val) {
+ return val * op::tanh(op::softrelu(val));
+}
+
// sqrt
DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index e28dc86..611ddbc 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f +
math::exp(-a))));
MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
-MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
+struct mish : public mxnet_op::tunable {
+ template<typename DType>
+ MSHADOW_XINLINE static DType Map(DType a) {
+ // reference softrelu
+ auto softrelu = math::log1p(math::exp(a));
+ if (a > DType(20.0f)) {
+ softrelu = a;
+ }
+ return DType(a * math::tanh(softrelu));
+ }
+};
-MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
- a * (1.0f / (1.0f + math::exp(-a))) *
- (1.0f - math::sqr(math::tanh(math::log(1.0f +
math::exp(a))))));
+struct mish_grad : public mxnet_op::tunable {
+ template<typename DType>
+ MSHADOW_XINLINE static DType Map(DType a) {
+ // Note: the input(a) is x(not y)
+ auto softrelu = math::log1p(math::exp(a));
+ if (a > DType(20.0f)) {
+ softrelu = a;
+ }
+ auto tanh_sr = math::tanh(softrelu);
+ auto sr_grad = 1.0f / (1.0f + math::exp(-a));
+ return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr));
+ }
+};
MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));