RafLit commented on code in PR #21034:
URL: https://github.com/apache/incubator-mxnet/pull/21034#discussion_r903408675
##########
src/operator/operator_tune.cc:
##########
@@ -277,6 +277,7 @@
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu);
// NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad);
// NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gelu);
// NOLINT()
Review Comment:
Maybe change the name to gelu_erf to keep it consistent?
##########
src/operator/mshadow_op.h:
##########
@@ -617,9 +617,25 @@ MXNET_UNARY_MATH_OP(gelu,
(1.0f + math::erf(static_cast<float>(a) / SQRT_2))));
MXNET_BINARY_MATH_OP_NC(gelu_grad,
- DType(0.5f * (1.0f + math::erf(static_cast<float>(a) /
SQRT_2) +
- static_cast<float>(a) *
- erf_grad::Map(static_cast<float>(a)
/ SQRT_2) / SQRT_2)));
+ DType(static_cast<float>(b) / static_cast<float>(a) +
+ 0.5f * static_cast<float>(a) *
+ erf_grad::Map(static_cast<float>(a) /
SQRT_2) / SQRT_2));
+
+MXNET_UNARY_MATH_OP(gelu_tanh,
+ DType(0.5f * static_cast<float>(a) *
+ (1.0f + math::tanh(math::sqrt(2.0f / PI) *
+ (static_cast<float>(a) +
+ 0.044715 *
math::pow(static_cast<float>(a), 3))))));
+
+MXNET_BINARY_MATH_OP_NC(
+ gelu_tanh_grad,
+ DType(static_cast<float>(b) *
+ (1.0f / static_cast<float>(a) +
+ (1.0f -
+ math::tanh(math::sqrt(2.0f / PI) *
+ (static_cast<float>(a) + 0.044715 *
math::pow(static_cast<float>(a), 3))) *
+ (math::sqrt(2.0f / PI) *
+ (1.0f + 0.134145 * math::pow(static_cast<float>(a), 2)))))));
Review Comment:
It would be cleaner to define 0.044715 and 0.134145 as constants.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]