leezu closed pull request #11223: Allow specifying AdaGrad initial accumulator 
value
URL: https://github.com/apache/incubator-mxnet/pull/11223
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc904fb1..e7727b7e586 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1091,14 +1091,20 @@ class AdaGrad(Optimizer):
     ----------
     eps: float, optional
         Small value to avoid division by 0.
+    initial_accumulator_value: float, default 0
+        The Adagrad state is initially set to this value.
 
     """
-    def __init__(self, eps=1e-7, **kwargs):
+    def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
         super(AdaGrad, self).__init__(**kwargs)
         self.float_stable_eps = eps
+        self.initial_accumulator_value = initial_accumulator_value
 
     def create_state(self, index, weight):
-        return zeros(weight.shape, weight.context, stype=weight.stype)  # 
history
+        history = zeros(weight.shape, weight.context, stype=weight.stype)
+        if self.initial_accumulator_value:
+            history[:] = self.initial_accumulator_value
+        return history
 
     def update(self, index, weight, grad, state):
         assert(isinstance(weight, NDArray))
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index fba10fb522a..cd516738130 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import itertools
 import numpy as np
 import mxnet as mx
 import mxnet.lr_scheduler as lr_scheduler
@@ -991,12 +992,16 @@ class PyAdaGrad(mx.optimizer.Optimizer):
         Small value to avoid division by 0.
 
     """
-    def __init__(self, eps=1e-7, **kwargs):
+    def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
         super(PyAdaGrad, self).__init__(**kwargs)
         self.float_stable_eps = eps
+        self.initial_accumulator_value = initial_accumulator_value
 
     def create_state(self, index, weight):
-        return mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+        history = mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+        if self.initial_accumulator_value:
+            history[:] = self.initial_accumulator_value
+        return history
 
     def update(self, index, weight, grad, state):
         self._update_count(index)
@@ -1020,21 +1025,21 @@ def test_adagrad():
     cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
     rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
     wd_options = [{}, {'wd': 0.0}]
-    for dtype in [np.float32]:
-        for eps_option in eps_options:
-            for cg_option in cg_options:
-                for rg_option in rg_options:
-                    for wd_option in wd_options:
-                        kwarg = {}
-                        kwarg.update(eps_option)
-                        kwarg.update(cg_option)
-                        kwarg.update(rg_option)
-                        kwarg.update(wd_option)
-                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 
dtype)
-                        if wd_option.get('wd', 0.0) == 0.0:
-                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype,
-                                              w_stype='row_sparse', 
g_stype='row_sparse')
+    acc_options = [{}, {'initial_accumulator_value': 1.0}]
 
+    for dtype in [np.float32]:
+        for eps_option, cg_option, rg_option, wd_option, acc_option in 
itertools.product(
+                eps_options, cg_options, rg_options, wd_options, acc_options):
+            kwarg = {}
+            kwarg.update(eps_option)
+            kwarg.update(cg_option)
+            kwarg.update(rg_option)
+            kwarg.update(wd_option)
+            kwarg.update(acc_option)
+            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+            if wd_option.get('wd', 0.0) == 0.0:
+                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+                                  w_stype='row_sparse', g_stype='row_sparse')
 
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to