This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7740cca [operator] Add Mish Activation Function (#20320)
7740cca is described below
commit 7740ccae6f8efeba87fb0e0f6a4801ba60a9e4c1
Author: herewj <[email protected]>
AuthorDate: Tue Jun 8 11:54:29 2021 -0500
[operator] Add Mish Activation Function (#20320)
* Add Mish Activation Function
Signed-off-by: Adnios <[email protected]>
* fix test_mish
Signed-off-by: Adnios <[email protected]>
* fix test_mish_gpu
Signed-off-by: Adnios <[email protected]>
* fix test_mish_gpu
* add npx.activation(mish) test
Signed-off-by: Adnios <[email protected]>
* add npx.activation(mish) test
Signed-off-by: Adnios <[email protected]>
* backward error test
Signed-off-by: Adnios <[email protected]>
* mish is not support for cudnn for now
Signed-off-by: Adnios <[email protected]>
* mv test_mish
Signed-off-by: Adnios <[email protected]>
* fix backward bug
Signed-off-by: Adnios <[email protected]>
* fix backward bug in cudnn
Signed-off-by: Adnios <[email protected]>
---
.../nd_operations/nn_activation_operators.py | 7 ++++---
benchmark/opperf/rules/default_params.py | 2 +-
python/mxnet/amp/lists/symbol_bf16.py | 1 +
python/mxnet/amp/lists/symbol_fp16.py | 1 +
python/mxnet/ndarray/ndarray.py | 8 ++++++++
python/mxnet/ndarray/numpy_extension/_op.py | 4 +++-
python/mxnet/numpy/multiarray.py | 8 ++++++++
python/mxnet/numpy_extension/_op.py | 4 +++-
python/mxnet/symbol/symbol.py | 8 ++++++++
.../operator/numpy_extension/npx_activation_op.cc | 2 ++
src/common/cuda/rtc/backward_functions-inl.h | 9 +++++++++
src/common/cuda/rtc/forward_functions-inl.h | 9 +++++++++
src/operator/fusion/fused_op-inl.h | 2 ++
src/operator/mshadow_op.h | 6 ++++++
src/operator/nn/activation-inl.h | 13 +++++++++++-
src/operator/nn/activation.cc | 3 +++
src/operator/nn/activation.cu | 10 ++++++++--
src/operator/nn/mkldnn/mkldnn_act.cc | 3 +++
src/operator/operator_tune.cc | 2 ++
src/operator/tensor/elemwise_unary_op_basic.cc | 17 ++++++++++++++++
src/operator/tensor/elemwise_unary_op_basic.cu | 6 ++++++
tests/cpp/operator/activation_perf.cc | 1 +
tests/python/mkl/subgraphs/test_conv_subgraph.py | 4 ++++
tests/python/mkl/subgraphs/test_fc_subgraph.py | 4 ++--
tests/python/unittest/test_numpy_op.py | 23 ++++++++++++++++++++++
tests/python/unittest/test_operator.py | 20 ++++++++++++++++++-
26 files changed, 165 insertions(+), 12 deletions(-)
diff --git a/benchmark/opperf/nd_operations/nn_activation_operators.py
b/benchmark/opperf/nd_operations/nn_activation_operators.py
index 7c59065..33b320a 100644
--- a/benchmark/opperf/nd_operations/nn_activation_operators.py
+++ b/benchmark/opperf/nd_operations/nn_activation_operators.py
@@ -37,9 +37,10 @@ from benchmark.opperf.utils.benchmark_utils import
run_op_benchmarks
8.1 relu
8.2 sigmoid
8.3 log_sigmoid
- 8.4 softrelu
- 8.5 softsign
- 8.6 tanh
+ 8.4 mish
+ 8.5 softrelu
+ 8.6 softsign
+ 8.7 tanh
"""
diff --git a/benchmark/opperf/rules/default_params.py
b/benchmark/opperf/rules/default_params.py
index 9418193..4c90338 100644
--- a/benchmark/opperf/rules/default_params.py
+++ b/benchmark/opperf/rules/default_params.py
@@ -375,7 +375,7 @@ DEFAULT_LABEL_SMCE_LARGE_TENSOR = [(2**32 + 1,)]
# For NN operators
DEFAULT_ACT_TYPE_LR = ['leaky', 'elu', 'selu', 'gelu']
-DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'softrelu',
'softsign', 'tanh']
+DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'mish',
'softrelu', 'softsign', 'tanh']
DEFAULT_LABEL_SOFTMAX = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_LABEL_SOFTMAX_LARGE_TENSOR = [(2**32, 1)]
diff --git a/python/mxnet/amp/lists/symbol_bf16.py
b/python/mxnet/amp/lists/symbol_bf16.py
index b7cb853..2990429 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -293,6 +293,7 @@ FP32_FUNCS = [
'max',
'min',
'min_axis',
+ 'mish',
'mp_sgd_mom_update',
'mp_sgd_update',
'multi_all_finite',
diff --git a/python/mxnet/amp/lists/symbol_fp16.py
b/python/mxnet/amp/lists/symbol_fp16.py
index 0170acf..d886973 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -398,6 +398,7 @@ FP16_FP32_FUNCS = [
'log_sigmoid',
'max',
'min',
+ 'mish',
'mp_lamb_update_phase1',
'mp_lamb_update_phase2',
'mp_nag_mom_update',
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index aa0ab51..cbd0c51 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2267,6 +2267,14 @@ fixed-size items.
"""
return op.softmin(self, *args, **kwargs)
+ def mish(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`mish`.
+
+ The arguments are the same as for :py:func:`mish`, with
+ this array as data.
+ """
+ return op.mish(self, *args, **kwargs)
+
def squeeze(self, axis=None, inplace=False):
"""Remove dimensions with size 1 from this array without altering any
data.
diff --git a/python/mxnet/ndarray/numpy_extension/_op.py
b/python/mxnet/ndarray/numpy_extension/_op.py
index 5d5ca1a..f3d2db2 100644
--- a/python/mxnet/ndarray/numpy_extension/_op.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -217,6 +217,8 @@ def activation(data, act_type='relu', **kwargs):
The following activation functions are supported:
+ - `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
+ - `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) +
exp(-x)}`
@@ -227,7 +229,7 @@ def activation(data, act_type='relu', **kwargs):
----------
data : NDArray
The input array.
- act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
+ act_type : {'log_sigmoid', 'mish', 'relu', 'sigmoid', 'softrelu',
'softsign', 'tanh'}, required
Activation function to be applied.
Returns
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 229725a..5cca1fa 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -2356,6 +2356,14 @@ class ndarray(NDArray): # pylint: disable=invalid-name
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
softmin')
+ def mish(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`mish`.
+
+ The arguments are the same as for :py:func:`mish`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute
mish')
+
def squeeze(self, axis=None): # pylint: disable=arguments-differ
"""Remove single-dimensional entries from the shape of a."""
return squeeze(self, axis=axis)
diff --git a/python/mxnet/numpy_extension/_op.py
b/python/mxnet/numpy_extension/_op.py
index 1d672f4..a84404e 100644
--- a/python/mxnet/numpy_extension/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -202,6 +202,8 @@ def activation(data, act_type='relu', **kwargs):
The following activation functions are supported:
+ - `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
+ - `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) +
exp(-x)}`
@@ -212,7 +214,7 @@ def activation(data, act_type='relu', **kwargs):
----------
data : NDArray
The input array.
- act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
+ act_type : {'log_sigmoid', 'mish', 'relu', 'sigmoid', 'softrelu',
'softsign', 'tanh'}, required
Activation function to be applied.
Returns
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 4962656..34e53da 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2527,6 +2527,14 @@ class Symbol(SymbolBase):
"""
return op.log_sigmoid(self, *args, **kwargs)
+ def mish(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`mish`.
+
+ The arguments are the same as for :py:func:`mish`, with
+ this array as data.
+ """
+ return op.mish(self, *args, **kwargs)
+
def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc
b/src/api/operator/numpy_extension/npx_activation_op.cc
index ad8cc3c..27810fb 100644
--- a/src/api/operator/numpy_extension/npx_activation_op.cc
+++ b/src/api/operator/numpy_extension/npx_activation_op.cc
@@ -36,6 +36,8 @@ inline int String2MXNetActType(const std::string& s) {
return activation::kSigmoid;
} else if (s == "log_sigmoid") {
return activation::kLogSigmoid;
+ } else if (s == "mish") {
+ return activation::kMish;
} else if (s == "tanh") {
return activation::kTanh;
} else if (s == "softrelu") {
diff --git a/src/common/cuda/rtc/backward_functions-inl.h
b/src/common/cuda/rtc/backward_functions-inl.h
index cb1bae8..cacb8be 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -52,6 +52,15 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
+backward_mish(const DTypeGrad grad, const DType val) {
+ const mixed_type<DTypeGrad, DType> v = val;
+ const auto softrelu = op::log(1 + exp(v));
+ const auto tanh = op::tanh(softrelu);
+ return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
backward_softrelu(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
return grad * sigmoid(v);
diff --git a/src/common/cuda/rtc/forward_functions-inl.h
b/src/common/cuda/rtc/forward_functions-inl.h
index 4f87db6..b353e92 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -695,6 +695,15 @@ __device__ inline DType log_sigmoid(const DType val) {
}
template <typename DType>
+__device__ inline DType mish(const DType val) {
+ if (type_util::has_double_or_integral<DType>::value) {
+ return val * ::tanh(::log(1 + ::exp(val)));
+ } else {
+ return val * ::tanhf(logf(1 + expf(val)));
+ }
+}
+
+template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
// The threshold 20 is chosen such that softrelu(a) = a
diff --git a/src/operator/fusion/fused_op-inl.h
b/src/operator/fusion/fused_op-inl.h
index df6d67e..acf5815 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -57,6 +57,7 @@ const std::map<std::string,
std::vector<std::vector<std::string>>> ops_desc = {
{"relu" , {{"op::relu(%)", "_0"}}},
{"sigmoid" , {{"op::sigmoid(%)", "_0"}}},
{"log_sigmoid" , {{"op::log_sigmoid(%)", "_0"}}},
+ {"mish" , {{"op::mish(%)", "_0"}}},
{"softsign" , {{"op::softsign(%)", "_0"}}},
{"exp" , {{"op::exp(%)", "_0"}}},
{"expm1" , {{"op::expm1(%)", "_0"}}},
@@ -137,6 +138,7 @@ const std::map<std::string,
std::vector<std::vector<std::string>>> ops_desc = {
{"_backward_relu" , {{"op::backward_relu(%, %)", "_0",
"_1"}}},
{"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0",
"_1"}}},
{"_backward_log_sigmoid" , {{"op::backward_log_sigmoid(%, %)",
"_0", "_1"}}},
+ {"_backward_mish" , {{"op::backward_mish(%, %)", "_0",
"_1"}}},
{"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0",
"_1"}}},
{"_backward_log" , {{"op::backward_log(%, %)", "_0",
"_1"}}},
{"_backward_log10" , {{"op::backward_log10(%, %)", "_0",
"_1"}}},
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index be49f0f..e28dc86 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -415,6 +415,12 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f +
math::exp(-a))));
MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
+MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
+
+MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
+ a * (1.0f / (1.0f + math::exp(-a))) *
+ (1.0f - math::sqr(math::tanh(math::log(1.0f +
math::exp(a))))));
+
MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index 647debf..7be6cba 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -47,7 +47,7 @@ namespace activation {
enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
-enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kTanh, kSoftReLU,
kSoftSign};
+enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kMish, kTanh, kSoftReLU,
kSoftSign};
// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
@@ -61,6 +61,7 @@ struct ActivationParam : public
dmlc::Parameter<ActivationParam> {
.add_enum("relu", activation::kReLU)
.add_enum("sigmoid", activation::kSigmoid)
.add_enum("log_sigmoid", activation::kLogSigmoid)
+ .add_enum("mish", activation::kMish)
.add_enum("tanh", activation::kTanh)
.add_enum("softrelu", activation::kSoftReLU)
.add_enum("softsign", activation::kSoftSign)
@@ -78,6 +79,8 @@ struct ActivationParam : public
dmlc::Parameter<ActivationParam> {
return "sigmoid";
case activation::kLogSigmoid:
return "log_sigmoid";
+ case activation::kMish:
+ return "mish";
case activation::kTanh:
return "tanh";
case activation::kSoftReLU:
@@ -166,6 +169,10 @@ void ActivationComputeImpl(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
ActivationForward<xpu, mshadow_op::log_sigmoid,
mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
+ case activation::kMish:
+ ActivationForward<xpu, mshadow_op::mish, mshadow_op::mish_grad>(
+ ctx, inputs[0], req[0], outputs[0]);
+ break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], req[0], outputs[0]);
@@ -201,6 +208,10 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs&
attrs, const OpContext &ct
ActivationBackward<xpu, mshadow_op::log_sigmoid,
mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
+ case activation::kMish:
+ ActivationBackward<xpu, mshadow_op::mish, mshadow_op::mish_grad>(
+ ctx, inputs[0], inputs[2], req[0], outputs[0]);
+ break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 12a8084..a81d74e 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -52,6 +52,7 @@ int GradNumInputs(int act_type) {
case kTanh:
case kSigmoid:
case kLogSigmoid:
+ case kMish:
return 3;
default:
CHECK(false) << "missing activation type";
@@ -93,6 +94,7 @@ struct ActivationGrad {
case kTanh:
case kSigmoid:
case kLogSigmoid:
+ case kMish:
heads.push_back(n->inputs[activation::kData]);
break;
default:
@@ -171,6 +173,7 @@ The following activation functions are supported:
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
+- `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) +
exp(-x)}`
- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
- `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index 18962f5..bb16624 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -56,13 +56,16 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;
- // SoftReLU and kSoftSign are both not supported by CUDNN yet
+ // SoftReLU, kSoftSign and Mish are not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu,
mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign,
mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
+ } else if (act_type == activation::kMish) {
+ ActivationForward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(ctx,
+ inputs[0], req[0], outputs[0]);
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
get_cudnn_op<DType>(param).Forward(ctx, inputs[0], req[0], outputs[0]);
@@ -84,10 +87,13 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs&
attrs,
bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0);
- // both SoftReLU and SoftSign not supported by CUDNN yet
+ // SoftReLU, SoftSign and Mish not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
+ } else if (act_type == activation::kMish) {
+ ActivationBackward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(
+ ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
if (do_memory_opt) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc
b/src/operator/nn/mkldnn/mkldnn_act.cc
index a4fe780..43c198f 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -44,6 +44,7 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
return param.act_type == activation::kReLU
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kLogSigmoid
+ || param.act_type == activation::kMish
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}
@@ -86,6 +87,8 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam&
param) {
return mkldnn::algorithm::eltwise_logistic;
case activation::kLogSigmoid:
return mkldnn::algorithm::eltwise_logsigmoid;
+ case activation::kMish:
+ return mkldnn::algorithm::eltwise_mish;
case activation::kTanh:
return mkldnn::algorithm::eltwise_tanh;
case activation::kSoftReLU:
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 5a4b27a..b097787 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -238,6 +238,8 @@
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log_sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_sigmoid_grad); //
NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mish); // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mish_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); //
NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc
b/src/operator/tensor/elemwise_unary_op_basic.cc
index 7a951e2..a739db3 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -166,6 +166,23 @@ The storage type of ``log_sigmoid`` output is always dense
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);
+// mish
+MXNET_OPERATOR_REGISTER_UNARY(mish)
+MXNET_ADD_SPARSE_OP_ALIAS(mish)
+.describe(R"code(Computes mish of x element-wise.
+
+.. math::
+ y = x * tanh(log(1 + exp(x)))
+
+The storage type of ``mish`` output is always dense
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::mish>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mish"});
+
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_mish,
+
unary_bwd<mshadow_op::mish_grad>);
+
DMLC_REGISTER_PARAMETER(HardSigmoidParam);
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu
b/src/operator/tensor/elemwise_unary_op_basic.cu
index e9f52f1..2c2abe2 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -45,6 +45,12 @@ NNVM_REGISTER_OP(log_sigmoid)
NNVM_REGISTER_OP(_backward_log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryRTCCompute{"backward_log_sigmoid"});
+NNVM_REGISTER_OP(mish)
+.set_attr<FCompute>("FCompute<gpu>", UnaryRTCCompute{"mish"});
+
+NNVM_REGISTER_OP(_backward_mish)
+.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryRTCCompute{"backward_mish"});
+
NNVM_REGISTER_OP(hard_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", HardSigmoidForward<gpu>);
diff --git a/tests/cpp/operator/activation_perf.cc
b/tests/cpp/operator/activation_perf.cc
index 61b9626..0dfefe5 100644
--- a/tests/cpp/operator/activation_perf.cc
+++ b/tests/cpp/operator/activation_perf.cc
@@ -44,6 +44,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
"relu",
"sigmoid",
"log_sigmoid",
+ "mish",
"tanh",
"softrelu",
"softsign"
diff --git a/tests/python/mkl/subgraphs/test_conv_subgraph.py
b/tests/python/mkl/subgraphs/test_conv_subgraph.py
index c38c75c..18ebc73 100644
--- a/tests/python/mkl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/mkl/subgraphs/test_conv_subgraph.py
@@ -108,6 +108,7 @@ def test_pos_conv_add2(no_bias, data_shape):
("relu", False), #TODO(bgawrych): investigate
("sigmoid", True),
("log_sigmoid", False),
+ ("mish", False),
("tanh", False), #TODO(bgawrych): investigate
#("softrelu", True), #TODO(bgawrych): bug in oneDNN with AVX
("relu6", False), #TODO(bgawrych): investigate
@@ -149,6 +150,7 @@ def test_pos_conv_act_add(data_shape, alg, quantize,
use_bias):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
+ ("mish", True),
("tanh", True),
("softrelu", True),
("relu6", True),
@@ -186,6 +188,7 @@ def test_pos_conv_bn_act(use_bias, data_shape, alg,
quantize):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
+ ("mish", True),
("tanh", True),
#("softrelu", True), #TODO(bgawrych): failing fusion check - difference in
random single element
("relu6", True),
@@ -293,6 +296,7 @@ def test_pos_concat_scale_align(data_shape, out_type):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
+ ("mish", True),
("tanh", True),
("softrelu", True),
("relu6", True),
diff --git a/tests/python/mkl/subgraphs/test_fc_subgraph.py
b/tests/python/mkl/subgraphs/test_fc_subgraph.py
index 5b4c61d..39c7959 100644
--- a/tests/python/mkl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/mkl/subgraphs/test_fc_subgraph.py
@@ -23,7 +23,7 @@ from mxnet.contrib import quantization
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal_with_err
-fc_post_ops_list=['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu',
'gelu', 'elu', 'leaky',
+fc_post_ops_list=['relu', 'sigmoid', 'log_sigmoid', 'mish', 'tanh',
'softrelu', 'gelu', 'elu', 'leaky',
'square', 'square_root', 'abs', 'exp', 'bounded_relu']
def test_float64_fallback():
@@ -69,7 +69,7 @@ def test_fc_eltwise(data_shape, use_bias, flatten, alg):
def hybrid_forward(self, F, x):
fc_out = self.fc(x)
- if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu']:
+ if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'mish', 'tanh',
'softrelu']:
out = F.Activation(fc_out, act_type=self.alg)
elif self.alg in ['gelu', 'elu', 'leaky']:
out = F.LeakyReLU(fc_out, act_type=self.alg)
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index 1fc7b8e..9f02784 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3450,6 +3450,29 @@ def test_npx_relu():
@use_np
+def test_npx_activation_mish():
+ def np_mish(a):
+ return a * _np.tanh(_np.log1p(_np.exp(a)))
+ def np_mish_grad(a):
+ softrelu = _np.log1p(_np.exp(a))
+ tanh = _np.tanh(softrelu)
+ sigmoid = _np.divide(1.0, (1.0 + _np.exp(-a)))
+ return tanh + a * sigmoid * (1.0 - tanh * tanh)
+
+ shape = (3, 4)
+ A = mx.np.random.uniform(low=-1.0, high=1.0, size=shape)
+ A.attach_grad()
+ np_out = np_mish(A.asnumpy())
+ with mx.autograd.record():
+ B = mx.npx.activation(A, act_type='mish')
+ assert B.shape == np_out.shape
+ assert_almost_equal(B.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ B.backward()
+ np_backward = np_mish_grad(A.asnumpy())
+ assert_almost_equal(A.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+
+
+@use_np
def test_npx_sigmoid():
def np_sigmoid(x):
return _np.divide(1.0, (1.0 + _np.exp(-x)))
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index b2995ec..e013988 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -692,6 +692,24 @@ def test_log_sigmoid():
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
+def test_mish():
+ def fmish(a):
+ return a * np.tanh(np.log1p(np.exp(a)))
+ def fmish_grad(a):
+ softrelu = np.log1p(np.exp(a))
+ tanh = np.tanh(softrelu)
+ sigmoid = np.divide(1.0, (1.0 + np.exp(-a)))
+ return tanh + a * sigmoid * (1.0 - tanh * tanh)
+ shape = (3, 4)
+ x = mx.symbol.Variable("x")
+ y = mx.sym.mish(x)
+ xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
+ ya = fmish(xa)
+ ya_grad = fmish_grad(xa)
+ check_numeric_gradient(y, [xa], numeric_eps=1E-3)
+ check_symbolic_forward(y, [xa], [ya])
+ check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
+
def test_shape_array():
for i in range(1,6):
shape = rand_shape_nd(i)
@@ -8712,7 +8730,7 @@ def test_get_operator_arguments():
assert isinstance(operator_arguments, OperatorArguments)
assert operator_arguments.names == ['data', 'act_type']
assert operator_arguments.types \
- == ['NDArray-or-Symbol', "{'log_sigmoid', 'relu', 'sigmoid',
'softrelu', 'softsign', 'tanh'}, required"]
+ == ['NDArray-or-Symbol', "{'log_sigmoid', 'mish', 'relu', 'sigmoid',
'softrelu', 'softsign', 'tanh'}, required"]
assert operator_arguments.narg == 2