sandeep-krishnamurthy closed pull request #13467: [MXNET-1235] Add a test for 
AdaMax optimizer
URL: https://github.com/apache/incubator-mxnet/pull/13467
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 334b7d4c0fd..b03dcdcfba4 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import numpy as np
+import itertools
 import mxnet as mx
 import mxnet.lr_scheduler as lr_scheduler
 from mxnet import gluon
@@ -501,7 +502,6 @@ def test_ftml():
 
 
 # ADAM
-
 class PyAdam(mx.optimizer.Optimizer):
     """python reference implemenation of adam"""
     def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, 
epsilon=1e-8,
@@ -613,6 +613,80 @@ def test_adam():
                                           dtype, w_stype='default', 
g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
 
+
+# AdaMax
+class PyAdamax(mx.optimizer.Optimizer):
+    """The python reference of AdaMax optimizer.
+
+    This class implements the AdaMax optimizer, one variant of Adam based on 
the infinity norm,
+    available at http://arxiv.org/abs/1412.6980 Section 7.
+
+    The optimizer updates the weight by::
+        grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        m = beta1 * m_t + (1 - beta1) * grad
+        u = maximum(beta2 * u, abs(grad))
+        weight -= lr / (1 - beta1**t) * m / u
+
+    This optimizer accepts the following parameters in addition to those 
accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    beta1 : float, optional
+        Exponential decay rate for the first moment estimates.
+    beta2 : float, optional
+        Exponential decay rate for the second moment estimates.
+    """
+    def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs):
+        super(PyAdamax, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+
+    def create_state(self, index, weight):
+        return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), 
 # mean
+                mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) 
 # variance
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        t = self._index_update_count[index]
+        lr /= (1. - self.beta1**t)
+
+        # preprocess grad
+        grad = grad * self.rescale_grad + wd * weight
+        if self.clip_gradient is not None:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+
+        # update m_t and u_t
+        m_t, u_t = state
+        m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
+        u_t[:] = mx.nd.maximum(self.beta2 * u_t, mx.nd.abs(grad))
+
+        # update weight
+        weight[:] -= lr * m_t / u_t
+
+
+@with_seed()
+def test_adamax():
+    opt1 = PyAdamax
+    opt2 = mx.optimizer.Adamax
+    shape = (3, 4, 5)
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
+    for dtype in [np.float16, np.float32, np.float64]:
+        for params in itertools.product(cg_options, rg_options, wd_options, 
mp_options):
+            kwarg = {k: v for param in params for k, v in param.items()}
+            if (dtype == np.float16 and
+                    ('multi_precision' not in kwarg or
+                    not kwarg['multi_precision'])):
+                continue
+            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+
+
 # Signum
 class PySignum(mx.optimizer.Optimizer):
     """The python reference of Signum optimizer.


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to