eric-haibin-lin closed pull request #9651: Improve sparse adagrad update
URL: https://github.com/apache/incubator-mxnet/pull/9651
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/sparse.md 
b/docs/api/python/ndarray/sparse.md
index a7aaa1fd41d..df335702aba 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -484,6 +484,7 @@ We summarize the interface for each class in the following 
sections.
     sgd_mom_update
     adam_update
     ftrl_update
+    adagrad_update
 ```
 
 ### More
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 065c08cee4e..6589e77e453 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -27,8 +27,6 @@
 from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, 
rmspropalex_update,
                       mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, 
ftml_update,
                       signsgd_update, signum_update)
-from .ndarray import _internal
-from .ndarray import op
 from .ndarray import sparse
 from .random import normal
 
@@ -1073,6 +1071,10 @@ class AdaGrad(Optimizer):
     This optimizer accepts the following parameters in addition to those 
accepted
     by :class:`.Optimizer`.
 
+    See Also
+    ----------
+    :meth:`mxnet.ndarray.sparse.adagrad_update`.
+
     Parameters
     ----------
     eps: float, optional
@@ -1093,39 +1095,19 @@ def update(self, index, weight, grad, state):
         lr = self._get_lr(index)
         wd = self._get_wd(index)
 
-        is_sparse = True if weight.stype == 'row_sparse' and grad.stype == 
'row_sparse' else False
-
-        if is_sparse is True:
-            grad_indices_count = len(grad.indices)
-
-        grad = grad * self.rescale_grad
-
-        if is_sparse is True:
-            grad_indices = grad.indices
-            # Make sure that the scalar multiply still has a sparse result
-            assert grad_indices_count == len(grad_indices)
-
-        if self.clip_gradient is not None:
-            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+        is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
         history = state
-        save_history_stype = history.stype
 
         if is_sparse:
-            history[:] = sparse.elemwise_add(sparse.square(grad),
-                                             sparse.retain(history, 
grad_indices))
-            history_indices = history.indices
-            assert len(history_indices) == grad_indices_count
-            adjusted_add = _internal._scatter_plus_scalar(history, 
self.float_stable_eps)
-            srt = op.sqrt(adjusted_add)
-            div = _internal._scatter_elemwise_div(grad, srt)
-            retained_weight = sparse.retain(weight, grad.indices)
-            to_add = sparse.elemwise_add(div, 
_internal._mul_scalar(retained_weight, float(wd)))
-            assert len(to_add.indices) == grad_indices_count
-            weight[:] = sparse.elemwise_add(weight, 
_internal._mul_scalar(to_add, float(-lr)))
-            state[:] = history
-            assert state.stype == save_history_stype
-            assert len(history_indices) == grad_indices_count
+            kwargs = {'epsilon': self.float_stable_eps,
+                      'rescale_grad': self.rescale_grad}
+            if self.clip_gradient:
+                kwargs['clip_gradient'] = self.clip_gradient
+            sparse.adagrad_update(weight, grad, history, out=weight, lr=lr, 
wd=wd, **kwargs)
         else:
+            grad = grad * self.rescale_grad
+            if self.clip_gradient is not None:
+                grad = clip(grad, -self.clip_gradient, self.clip_gradient)
             history[:] += square(grad)
             div = grad / sqrt(history + self.float_stable_eps)
             weight[:] += (div + weight * wd) * -lr
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 89d27e17ec6..55d215602ee 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1483,7 +1483,153 @@ inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
     });
 }
 
+struct AdagradParam : public dmlc::Parameter<AdagradParam> {
+  float lr;
+  float epsilon;
+  float rescale_grad;
+  float clip_gradient;
+  float wd;
+  DMLC_DECLARE_PARAMETER(AdagradParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(epsilon)
+    .set_default(1.0e-7)
+    .describe("epsilon");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("weight decay");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+  }
+};
+
+inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
+                               const int dev_mask,
+                               DispatchMode* dispatch_mode,
+                               std::vector<int>* in_attrs,
+                               std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+  bool dispatched = false;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) 
&&
+      common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
+      param.wd == 0.0f) {
+    // rsp, rsp, rsp -> rsp with wd = 0.0
+    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  return dispatched;
+}
+
+
+struct AdagradDnsRspDnsKernel {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
+    DType* state_data, const DType* weight_data, const IType* grad_idx,
+    const DType* grad_data, const DType clip_gradient, const DType epsilon,
+    const DType lr, const DType rescale_grad) {
+    using nnvm::dim_t;
+    using namespace mshadow_op;
+    const dim_t data_i = grad_idx[i] * row_length;
+    const dim_t grad_i = i * row_length;
+    for (dim_t j = 0; j < row_length; j++) {
+      const dim_t data_j = data_i + j;
+      const dim_t grad_j = grad_i + j;
+      DType grad_rescaled = grad_data[grad_j] * rescale_grad;
+      if (clip_gradient >= 0.0f) {
+        grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+      }
+      const DType grad_squared = grad_rescaled * grad_rescaled;
+      state_data[data_j] += grad_squared;
+      const DType div = grad_rescaled / square_root::Map(state_data[data_j] + 
epsilon);
+      // No need to use KERNEL_ASSIGN, as we already checked req is 
kWriteInplace
+      out_data[data_j] = weight_data[data_j] - div * lr;
+    }
+  }
+};
 
+template<typename xpu>
+void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param,
+                                const OpContext& ctx,
+                                const TBlob& weight,
+                                const NDArray& grad,
+                                const TBlob& state,
+                                const OpReqType& req,
+                                TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  CHECK_EQ(param.wd, 0.0f)
+    << "sparse adagrad_update does not support wd.";
+  if (req == kNullOp || !grad.storage_initialized()) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
adagrad_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(state.shape_.Size(), 0);
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      const DType* weight_data = weight.dptr<DType>();
+      const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+      const DType* grad_val = grad.data().dptr<DType>();
+      DType* state_data = state.dptr<DType>();
+      DType* out_data = out->dptr<DType>();
+      const nnvm::dim_t nnr = grad.storage_shape()[0];
+      const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+      Kernel<AdagradDnsRspDnsKernel, xpu>::Launch(s, nnr, row_length,
+        out_data, state_data, weight_data, grad_idx, grad_val,
+        static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.epsilon),
+        static_cast<DType>(param.lr), static_cast<DType>(param.rescale_grad));
+    });
+  });
+}
+
+template<typename xpu>
+inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
+                                       const OpContext& ctx,
+                                       const NDArray& weight,
+                                       const NDArray& grad,
+                                       const NDArray& state,
+                                       const OpReqType& req,
+                                       NDArray *out) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  // fill history with zero values
+  if (!state.storage_initialized()) {
+    NDArray state_zeros = state;
+    FillDnsZerosRspImpl(s, &state_zeros);
+  }
+  TBlob out_blob = out->data();
+  // reuse dns rsp implementation when storage_shape == shape
+  AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                  state.data(), req, &out_blob);
+}
+
+template<typename xpu>
+inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
+                            const OpContext &ctx,
+                            const std::vector<NDArray> &inputs,
+                            const std::vector<OpReqType> &req,
+                            const std::vector<NDArray> &outputs) {
+  using namespace mxnet_op;
+  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
+      common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
+    NDArray out = outputs[0];
+    AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], 
inputs[2], req[0], &out);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 136769a1bf0..741092ad784 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -38,6 +38,7 @@ DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
 DMLC_REGISTER_PARAMETER(SignSGDParam);
 DMLC_REGISTER_PARAMETER(SignumParam);
+DMLC_REGISTER_PARAMETER(AdagradParam);
 
 NNVM_REGISTER_OP(signsgd_update)
 .describe(R"code(Update function for SignSGD optimizer.
@@ -536,5 +537,36 @@ only the row slices whose indices appear in grad.indices 
are updated (for w, z a
 .add_argument("n", "NDArray-or-Symbol", "Square of grad")
 .add_arguments(FtrlParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_sparse_adagrad_update)
+.describe(R"code(Update function for AdaGrad optimizer.
+
+Referenced from *Adaptive Subgradient Methods for Online Learning and 
Stochastic Optimization*,
+and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
+
+Updates are applied by::
+
+    rescaled_grad = clip(grad * rescale_grad, clip_gradient)
+    history = history + square(rescaled_grad)
+    w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
+
+Note that non-zero values for the weight decay option are not supported.
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AdagradParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", AdagradStorageType)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2};
+  })
+.set_attr<FComputeEx>("FComputeEx<cpu>", AdagradUpdateEx<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("history", "NDArray-or-Symbol", "History")
+.add_arguments(AdagradParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 1bd6117432b..c49af68a5f6 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -200,5 +200,8 @@ NNVM_REGISTER_OP(ftrl_update)
 .set_attr<FCompute>("FCompute<gpu>", FtrlUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", FtrlUpdateEx<gpu>);
 
+NNVM_REGISTER_OP(_sparse_adagrad_update)
+.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 159c9bac89d..f71e2c81e27 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -963,6 +963,74 @@ def get_net(num_hidden, flatten=True):
             optimizer='nadam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
 
+# AdaGrad
+class PyAdaGrad(mx.optimizer.Optimizer):
+    """The python reference of AdaGrad optimizer.
+
+    This class implements the AdaGrad optimizer described in *Adaptive 
Subgradient
+    Methods for Online Learning and Stochastic Optimization*, and available at
+    http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
+
+    Updates are applied by::
+
+        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        history = history + square(rescaled_grad)
+        w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
+
+    This optimizer accepts the following parameters in addition to those 
accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    eps: float, optional
+        Small value to avoid division by 0.
+
+    """
+    def __init__(self, eps=1e-7, **kwargs):
+        super(PyAdaGrad, self).__init__(**kwargs)
+        self.float_stable_eps = eps
+
+    def create_state(self, index, weight):
+        return mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        history = state
+        grad = grad * self.rescale_grad
+        if self.clip_gradient is not None:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+        history[:] += mx.nd.square(grad)
+        div = grad / mx.nd.sqrt(history + self.float_stable_eps)
+        weight[:] += (div + weight * wd) * -lr
+
+def test_adagrad():
+    mx.random.seed(0)
+    opt1 = PyAdaGrad
+    opt2 = mx.optimizer.AdaGrad
+    shape = (3, 4, 5)
+    eps_options = [{}, {'eps': 1e-8}]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.0}]
+    for dtype in [np.float32]:
+        for eps_option in eps_options:
+            for cg_option in cg_options:
+                for rg_option in rg_options:
+                    for wd_option in wd_options:
+                        kwarg = {}
+                        kwarg.update(eps_option)
+                        kwarg.update(cg_option)
+                        kwarg.update(rg_option)
+                        kwarg.update(wd_option)
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 
dtype)
+                        if wd_option.get('wd', 0.0) == 0.0:
+                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype,
+                                              w_stype='row_sparse', 
g_stype='row_sparse')
+
+
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to