leezu closed pull request #11223: Allow specifying AdaGrad initial accumulator
value
URL: https://github.com/apache/incubator-mxnet/pull/11223
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc904fb1..e7727b7e586 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1091,14 +1091,20 @@ class AdaGrad(Optimizer):
----------
eps: float, optional
Small value to avoid division by 0.
+ initial_accumulator_value: float, default 0
+ The Adagrad state is initially set to this value.
"""
- def __init__(self, eps=1e-7, **kwargs):
+ def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
super(AdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps
+ self.initial_accumulator_value = initial_accumulator_value
def create_state(self, index, weight):
- return zeros(weight.shape, weight.context, stype=weight.stype) #
history
+ history = zeros(weight.shape, weight.context, stype=weight.stype)
+ if self.initial_accumulator_value:
+ history[:] = self.initial_accumulator_value
+ return history
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
diff --git a/tests/python/unittest/test_optimizer.py
b/tests/python/unittest/test_optimizer.py
index fba10fb522a..cd516738130 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import itertools
import numpy as np
import mxnet as mx
import mxnet.lr_scheduler as lr_scheduler
@@ -991,12 +992,16 @@ class PyAdaGrad(mx.optimizer.Optimizer):
Small value to avoid division by 0.
"""
- def __init__(self, eps=1e-7, **kwargs):
+ def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
super(PyAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps
+ self.initial_accumulator_value = initial_accumulator_value
def create_state(self, index, weight):
- return mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+ history = mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+ if self.initial_accumulator_value:
+ history[:] = self.initial_accumulator_value
+ return history
def update(self, index, weight, grad, state):
self._update_count(index)
@@ -1020,21 +1025,21 @@ def test_adagrad():
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.0}]
- for dtype in [np.float32]:
- for eps_option in eps_options:
- for cg_option in cg_options:
- for rg_option in rg_options:
- for wd_option in wd_options:
- kwarg = {}
- kwarg.update(eps_option)
- kwarg.update(cg_option)
- kwarg.update(rg_option)
- kwarg.update(wd_option)
- compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
dtype)
- if wd_option.get('wd', 0.0) == 0.0:
- compare_optimizer(opt1(**kwarg), opt2(**kwarg),
shape, dtype,
- w_stype='row_sparse',
g_stype='row_sparse')
+ acc_options = [{}, {'initial_accumulator_value': 1.0}]
+ for dtype in [np.float32]:
+ for eps_option, cg_option, rg_option, wd_option, acc_option in
itertools.product(
+ eps_options, cg_options, rg_options, wd_options, acc_options):
+ kwarg = {}
+ kwarg.update(eps_option)
+ kwarg.update(cg_option)
+ kwarg.update(rg_option)
+ kwarg.update(wd_option)
+ kwarg.update(acc_option)
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+ if wd_option.get('wd', 0.0) == 0.0:
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+ w_stype='row_sparse', g_stype='row_sparse')
if __name__ == '__main__':
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services