This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fc9e70b  Improve sparse adagrad update (#9651)
fc9e70b is described below

commit fc9e70bf2d349ad4c6cb65ff3f0958e23a7410bf
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Sat Mar 3 14:12:23 2018 +0800

    Improve sparse adagrad update (#9651)
    
    * fix adagrad
    
    * add test
    
    * fix lint
    
    * CR comments
    
    * remove raise in python
    
    * enhance unit test
    
    * revert wd changes
    
    * revert dense op changes
---
 docs/api/python/ndarray/sparse.md       |   1 +
 python/mxnet/optimizer.py               |  44 +++-------
 src/operator/optimizer_op-inl.h         | 146 ++++++++++++++++++++++++++++++++
 src/operator/optimizer_op.cc            |  32 +++++++
 src/operator/optimizer_op.cu            |   3 +
 tests/python/unittest/test_optimizer.py |  68 +++++++++++++++
 6 files changed, 263 insertions(+), 31 deletions(-)

diff --git a/docs/api/python/ndarray/sparse.md 
b/docs/api/python/ndarray/sparse.md
index a7aaa1f..df33570 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -484,6 +484,7 @@ We summarize the interface for each class in the following 
sections.
     sgd_mom_update
     adam_update
     ftrl_update
+    adagrad_update
 ```
 
 ### More
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 065c08c..6589e77 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -27,8 +27,6 @@ from .ndarray import (NDArray, zeros, clip, sqrt, cast, 
maximum, abs as NDabs)
 from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, 
rmspropalex_update,
                       mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, 
ftml_update,
                       signsgd_update, signum_update)
-from .ndarray import _internal
-from .ndarray import op
 from .ndarray import sparse
 from .random import normal
 
@@ -1073,6 +1071,10 @@ class AdaGrad(Optimizer):
     This optimizer accepts the following parameters in addition to those 
accepted
     by :class:`.Optimizer`.
 
+    See Also
+    ----------
+    :meth:`mxnet.ndarray.sparse.adagrad_update`.
+
     Parameters
     ----------
     eps: float, optional
@@ -1093,39 +1095,19 @@ class AdaGrad(Optimizer):
         lr = self._get_lr(index)
         wd = self._get_wd(index)
 
-        is_sparse = True if weight.stype == 'row_sparse' and grad.stype == 
'row_sparse' else False
-
-        if is_sparse is True:
-            grad_indices_count = len(grad.indices)
-
-        grad = grad * self.rescale_grad
-
-        if is_sparse is True:
-            grad_indices = grad.indices
-            # Make sure that the scalar multiply still has a sparse result
-            assert grad_indices_count == len(grad_indices)
-
-        if self.clip_gradient is not None:
-            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+        is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
         history = state
-        save_history_stype = history.stype
 
         if is_sparse:
-            history[:] = sparse.elemwise_add(sparse.square(grad),
-                                             sparse.retain(history, 
grad_indices))
-            history_indices = history.indices
-            assert len(history_indices) == grad_indices_count
-            adjusted_add = _internal._scatter_plus_scalar(history, 
self.float_stable_eps)
-            srt = op.sqrt(adjusted_add)
-            div = _internal._scatter_elemwise_div(grad, srt)
-            retained_weight = sparse.retain(weight, grad.indices)
-            to_add = sparse.elemwise_add(div, 
_internal._mul_scalar(retained_weight, float(wd)))
-            assert len(to_add.indices) == grad_indices_count
-            weight[:] = sparse.elemwise_add(weight, 
_internal._mul_scalar(to_add, float(-lr)))
-            state[:] = history
-            assert state.stype == save_history_stype
-            assert len(history_indices) == grad_indices_count
+            kwargs = {'epsilon': self.float_stable_eps,
+                      'rescale_grad': self.rescale_grad}
+            if self.clip_gradient:
+                kwargs['clip_gradient'] = self.clip_gradient
+            sparse.adagrad_update(weight, grad, history, out=weight, lr=lr, 
wd=wd, **kwargs)
         else:
+            grad = grad * self.rescale_grad
+            if self.clip_gradient is not None:
+                grad = clip(grad, -self.clip_gradient, self.clip_gradient)
             history[:] += square(grad)
             div = grad / sqrt(history + self.float_stable_eps)
             weight[:] += (div + weight * wd) * -lr
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 89d27e1..55d2156 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1483,7 +1483,153 @@ inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
     });
 }
 
+struct AdagradParam : public dmlc::Parameter<AdagradParam> {
+  float lr;
+  float epsilon;
+  float rescale_grad;
+  float clip_gradient;
+  float wd;
+  DMLC_DECLARE_PARAMETER(AdagradParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(epsilon)
+    .set_default(1.0e-7)
+    .describe("epsilon");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("weight decay");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+  }
+};
+
+inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
+                               const int dev_mask,
+                               DispatchMode* dispatch_mode,
+                               std::vector<int>* in_attrs,
+                               std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+  bool dispatched = false;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) 
&&
+      common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
+      param.wd == 0.0f) {
+    // rsp, rsp, rsp -> rsp with wd = 0.0
+    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  return dispatched;
+}
+
+
+struct AdagradDnsRspDnsKernel {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
+    DType* state_data, const DType* weight_data, const IType* grad_idx,
+    const DType* grad_data, const DType clip_gradient, const DType epsilon,
+    const DType lr, const DType rescale_grad) {
+    using nnvm::dim_t;
+    using namespace mshadow_op;
+    const dim_t data_i = grad_idx[i] * row_length;
+    const dim_t grad_i = i * row_length;
+    for (dim_t j = 0; j < row_length; j++) {
+      const dim_t data_j = data_i + j;
+      const dim_t grad_j = grad_i + j;
+      DType grad_rescaled = grad_data[grad_j] * rescale_grad;
+      if (clip_gradient >= 0.0f) {
+        grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+      }
+      const DType grad_squared = grad_rescaled * grad_rescaled;
+      state_data[data_j] += grad_squared;
+      const DType div = grad_rescaled / square_root::Map(state_data[data_j] + 
epsilon);
+      // No need to use KERNEL_ASSIGN, as we already checked req is 
kWriteInplace
+      out_data[data_j] = weight_data[data_j] - div * lr;
+    }
+  }
+};
 
+template<typename xpu>
+void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param,
+                                const OpContext& ctx,
+                                const TBlob& weight,
+                                const NDArray& grad,
+                                const TBlob& state,
+                                const OpReqType& req,
+                                TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  CHECK_EQ(param.wd, 0.0f)
+    << "sparse adagrad_update does not support wd.";
+  if (req == kNullOp || !grad.storage_initialized()) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
adagrad_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(state.shape_.Size(), 0);
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      const DType* weight_data = weight.dptr<DType>();
+      const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+      const DType* grad_val = grad.data().dptr<DType>();
+      DType* state_data = state.dptr<DType>();
+      DType* out_data = out->dptr<DType>();
+      const nnvm::dim_t nnr = grad.storage_shape()[0];
+      const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+      Kernel<AdagradDnsRspDnsKernel, xpu>::Launch(s, nnr, row_length,
+        out_data, state_data, weight_data, grad_idx, grad_val,
+        static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.epsilon),
+        static_cast<DType>(param.lr), static_cast<DType>(param.rescale_grad));
+    });
+  });
+}
+
+template<typename xpu>
+inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
+                                       const OpContext& ctx,
+                                       const NDArray& weight,
+                                       const NDArray& grad,
+                                       const NDArray& state,
+                                       const OpReqType& req,
+                                       NDArray *out) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  // fill history with zero values
+  if (!state.storage_initialized()) {
+    NDArray state_zeros = state;
+    FillDnsZerosRspImpl(s, &state_zeros);
+  }
+  TBlob out_blob = out->data();
+  // reuse dns rsp implementation when storage_shape == shape
+  AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                  state.data(), req, &out_blob);
+}
+
+template<typename xpu>
+inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
+                            const OpContext &ctx,
+                            const std::vector<NDArray> &inputs,
+                            const std::vector<OpReqType> &req,
+                            const std::vector<NDArray> &outputs) {
+  using namespace mxnet_op;
+  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
+      common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
+    NDArray out = outputs[0];
+    AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], 
inputs[2], req[0], &out);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 136769a..741092a 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -38,6 +38,7 @@ DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
 DMLC_REGISTER_PARAMETER(SignSGDParam);
 DMLC_REGISTER_PARAMETER(SignumParam);
+DMLC_REGISTER_PARAMETER(AdagradParam);
 
 NNVM_REGISTER_OP(signsgd_update)
 .describe(R"code(Update function for SignSGD optimizer.
@@ -536,5 +537,36 @@ only the row slices whose indices appear in grad.indices 
are updated (for w, z a
 .add_argument("n", "NDArray-or-Symbol", "Square of grad")
 .add_arguments(FtrlParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_sparse_adagrad_update)
+.describe(R"code(Update function for AdaGrad optimizer.
+
+Referenced from *Adaptive Subgradient Methods for Online Learning and 
Stochastic Optimization*,
+and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
+
+Updates are applied by::
+
+    rescaled_grad = clip(grad * rescale_grad, clip_gradient)
+    history = history + square(rescaled_grad)
+    w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
+
+Note that non-zero values for the weight decay option are not supported.
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AdagradParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", AdagradStorageType)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2};
+  })
+.set_attr<FComputeEx>("FComputeEx<cpu>", AdagradUpdateEx<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("history", "NDArray-or-Symbol", "History")
+.add_arguments(AdagradParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 1bd6117..c49af68 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -200,5 +200,8 @@ NNVM_REGISTER_OP(ftrl_update)
 .set_attr<FCompute>("FCompute<gpu>", FtrlUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", FtrlUpdateEx<gpu>);
 
+NNVM_REGISTER_OP(_sparse_adagrad_update)
+.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 159c9ba..f71e2c8 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -963,6 +963,74 @@ def test_nadam():
             optimizer='nadam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
 
+# AdaGrad
+class PyAdaGrad(mx.optimizer.Optimizer):
+    """The python reference of AdaGrad optimizer.
+
+    This class implements the AdaGrad optimizer described in *Adaptive 
Subgradient
+    Methods for Online Learning and Stochastic Optimization*, and available at
+    http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
+
+    Updates are applied by::
+
+        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        history = history + square(rescaled_grad)
+        w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
+
+    This optimizer accepts the following parameters in addition to those 
accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    eps: float, optional
+        Small value to avoid division by 0.
+
+    """
+    def __init__(self, eps=1e-7, **kwargs):
+        super(PyAdaGrad, self).__init__(**kwargs)
+        self.float_stable_eps = eps
+
+    def create_state(self, index, weight):
+        return mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        history = state
+        grad = grad * self.rescale_grad
+        if self.clip_gradient is not None:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+        history[:] += mx.nd.square(grad)
+        div = grad / mx.nd.sqrt(history + self.float_stable_eps)
+        weight[:] += (div + weight * wd) * -lr
+
+def test_adagrad():
+    mx.random.seed(0)
+    opt1 = PyAdaGrad
+    opt2 = mx.optimizer.AdaGrad
+    shape = (3, 4, 5)
+    eps_options = [{}, {'eps': 1e-8}]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.0}]
+    for dtype in [np.float32]:
+        for eps_option in eps_options:
+            for cg_option in cg_options:
+                for rg_option in rg_options:
+                    for wd_option in wd_options:
+                        kwarg = {}
+                        kwarg.update(eps_option)
+                        kwarg.update(cg_option)
+                        kwarg.update(rg_option)
+                        kwarg.update(wd_option)
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 
dtype)
+                        if wd_option.get('wd', 0.0) == 0.0:
+                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype,
+                                              w_stype='row_sparse', 
g_stype='row_sparse')
+
+
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to