D-Roberts commented on a change in pull request #13467: [MXNET-1235] Add a test
for AdaMax optimizer
URL: https://github.com/apache/incubator-mxnet/pull/13467#discussion_r237684888
##########
File path: tests/python/unittest/test_optimizer.py
##########
@@ -613,6 +612,88 @@ def test_adam():
dtype, w_stype='default',
g_stype='row_sparse',
rtol=1e-4, atol=2e-5)
+
+# AdaMax
+class PyAdamax(mx.optimizer.Optimizer):
+ """The python reference of AdaMax optimizer.
+
+ This class implements the AdaMax optimizer, a variant of Adam based on the
infinity norm,
+ available at http://arxiv.org/abs/1412.6980 Section 7.
+
+ The optimizer updates the weight by::
+ grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+ m = beta1 * m_t + (1 - beta1) * grad
+ u = maximum(beta2 * u, abs(grad))
+ weight -= lr / (1 - beta1**t) * m / u
+
+ This optimizer accepts the following parameters in addition to those
accepted
+ by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ beta1 : float, optional
+ Exponential decay rate for the first moment estimates.
+ beta2 : float, optional
+ Exponential decay rate for the second moment estimates.
+ """
+ def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs):
+ super(PyAdamax, self).__init__(learning_rate=learning_rate, **kwargs)
+ self.beta1 = beta1
+ self.beta2 = beta2
+
+ def create_state(self, index, weight):
+ return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype),
# mean
+ mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype))
# variance
+
+ def update(self, index, weight, grad, state):
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+
+ t = self._index_update_count[index]
+ lr /= (1. - self.beta1**t)
+
+ # preprocess grad
+ grad = grad * self.rescale_grad + wd * weight
+ if self.clip_gradient is not None:
+ grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+
+ # update m_t and u_t
+ m_t, u_t = state
+ m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
+ u_t[:] = mx.nd.maximum(self.beta2 * u_t, mx.nd.abs(grad))
+
+ # update weight
+ weight[:] -= lr * m_t / u_t
+
+
+@with_seed()
+def test_adamax():
+ opt1 = PyAdamax
+ opt2 = mx.optimizer.Adamax
+ shape = (3, 4, 5)
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+ mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
+ for dtype in [np.float16, np.float32, np.float64]:
+ for cg_option in cg_options:
+ for rg_option in rg_options:
+ for wd_option in wd_options:
+ for mp_option in mp_options:
+ kwarg = {}
+ kwarg.update(cg_option)
+ kwarg.update(rg_option)
+ kwarg.update(wd_option)
+ kwarg.update(mp_option)
+ if (dtype == np.float16 and
+ ('multi_precision' not in kwarg or
+ not kwarg['multi_precision'])):
+ continue
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
dtype,
+ rtol=1e-4, atol=2e-5)
Review comment:
done
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services