ptrendx commented on a change in pull request #20339:
URL: https://github.com/apache/incubator-mxnet/pull/20339#discussion_r650117882



##########
File path: src/common/cuda/rtc/backward_functions-inl.h
##########
@@ -53,10 +53,9 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
 template <typename DType, typename DTypeGrad>
 __device__ inline mixed_type<DTypeGrad, DType>
 backward_mish(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  const auto softrelu = op::log(1 + exp(v));
-  const auto tanh = op::tanh(softrelu);
-  return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
+  const auto softrelu = (val > 20) ? val : op::log(1 + op::exp(val));

Review comment:
       Why not just use `op::softrelu(val)` here instead of replicating that 
logic?

##########
File path: src/common/cuda/rtc/forward_functions-inl.h
##########
@@ -697,9 +697,11 @@ __device__ inline DType log_sigmoid(const DType val) {
 template <typename DType>
 __device__ inline DType mish(const DType val) {
   if (type_util::has_double_or_integral<DType>::value) {
-    return val * ::tanh(::log(1 + ::exp(val)));
+    const auto softrelu = (val > 20) ? val : ::log(1 + ::exp(val));
+    return val * ::tanh(softrelu);
   } else {
-    return val * ::tanhf(logf(1 + expf(val)));
+    const auto softrelu = (val > 20) ? val : logf(1 + expf(val));
+    return val * ::tanhf(softrelu);

Review comment:
       Similarly here. This function body could just be
   ```
   return val * op::tanh(op::softrelu(val));
   ```
   right?

##########
File path: src/operator/mshadow_op.h
##########
@@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + 
math::exp(-a))));
 
 MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
 
-MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
+struct mish : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // reference softrelu
+    auto softrelu = math::log1p(math::exp(a));

Review comment:
       And similarly here you should be able to use
   ```
   auto softrelu = math::softrelu(a)
   ```
   right? This way the logic stays in 1 place and is easier to maintain.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to