piiswrong commented on a change in pull request #9220: Signum optimizer
URL: https://github.com/apache/incubator-mxnet/pull/9220#discussion_r160506516
##########
File path: python/mxnet/optimizer.py
##########
@@ -534,6 +535,66 @@ def update_multi_precision(self, index, weight, grad,
state):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)
+@register
+class Signum(Optimizer):
+ """The Signum optimizer that takes the sign of gradient or momentum.
+
+ The optimizer updates the weight by:
+
+ rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+ state = momentum * state + (1-momentum)*rescaled_grad
+ weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+ See the original paper at:
https://jeremybernste.in/projects/amazon/signum.pdf
+
+ For details of the update algorithm see
+ :class:`~mxnet.ndarray.signsgd_update` and
:class:`~mxnet.ndarray.signum_update`.
+
+ This optimizer accepts the following parameters in addition to those
accepted
+ by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ momentum : float, optional
+ The momentum value.
+ wd_lh : float, optional
+ The amount of decoupled weight decay regularization.
+ """
+ def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
+ super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
+ self.momentum = momentum
+ self.wd_lh = wd_lh
+
+ def create_state(self, index, weight):
+ momentum = None
+ if self.momentum != 0.0:
+ momentum = zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=weight.stype)
+ return momentum
+
+ def _update_impl(self, index, weight, grad, state):
+ assert(isinstance(weight, NDArray))
+ assert(isinstance(grad, NDArray))
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+
+ kwargs = {'rescale_grad': self.rescale_grad}
+ if self.momentum > 0:
+ kwargs['momentum'] = self.momentum
+ if self.clip_gradient:
+ kwargs['clip_gradient'] = self.clip_gradient
+ if self.wd_lh:
+ kwargs['wd_lh'] = self.wd_lh
+
+ if state is not None:
+ signum_update(weight, grad, state, out=weight,
Review comment:
call these signum_momentum_update and signum_update to be consistent with
others
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services