eric-haibin-lin closed pull request #10664: [MXNET-358] support dense weight & 
sparse grad for adam, sgd and sgd_momentum
URL: https://github.com/apache/incubator-mxnet/pull/10664
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 2f7c51bff8d..8e845b2aaa8 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -433,7 +433,7 @@ def _get_wd(self, index):
 class SGD(Optimizer):
     """The SGD optimizer with momentum and weight decay.
 
-    If the storage types of weight and grad are both ``row_sparse``, and 
``lazy_update`` is True, \
+    If the storage types of grad is ``row_sparse`` and ``lazy_update`` is 
True, \
     **lazy updates** are applied by::
 
         for row in grad.indices:
@@ -493,8 +493,8 @@ def create_state_multi_precision(self, index, weight):
 
     def create_state(self, index, weight):
         momentum = None
-        stype = weight.stype if self.lazy_update else 'default'
         if self.momentum != 0.0:
+            stype = weight.stype if self.lazy_update else 'default'
             momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=stype)
         return momentum
 
@@ -514,7 +514,7 @@ def _update_impl(self, index, weight, grad, state, 
multi_precision=False):
         if not multi_precision:
             if state is not None:
                 sgd_mom_update(weight, grad, state, out=weight,
-                               lr=lr, wd=wd, **kwargs)
+                               lazy_update=self.lazy_update, lr=lr, wd=wd, 
**kwargs)
             else:
                 sgd_update(weight, grad, out=weight, 
lazy_update=self.lazy_update,
                            lr=lr, wd=wd, **kwargs)
@@ -985,7 +985,7 @@ class Adam(Optimizer):
     This class implements the optimizer described in *Adam: A Method for
     Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
 
-    If the storage types of weight and grad are both ``row_sparse``, and 
``lazy_update`` is True, \
+    If the storage types of grad is ``row_sparse``, and ``lazy_update`` is 
True, \
     **lazy updates** are applied by::
 
         for row in grad.indices:
@@ -1058,7 +1058,7 @@ def update(self, index, weight, grad, state):
 
         mean, var = state
         adam_update(weight, grad, mean, var, out=weight,
-                    lr=lr, wd=wd, **kwargs)
+                    lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
 
 @register
 class AdaGrad(Optimizer):
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index a629ba5eed8..0a9cd08db81 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -471,14 +471,18 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
   attrs->parsed = std::move(param);
 }
 
-#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param)                          
    \
-  {                                                                            
    \
-    CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func                    
    \
-          << " for RowSparse " << param << " is only implemented for "         
    \
-          << "RowSparse " << param << " with all rows containing non-zeros. "  
    \
-          << "Expects " << param << ".values.shape[0] (" << 
rsp.storage_shape()[0] \
-          << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ").";      
    \
+inline void CheckAllRowsPresent(const NDArray& arr, const std::string& func,
+                                const std::string& param) {
+  if (arr.storage_type() == kRowSparseStorage) {
+    CHECK(arr.storage_shape()[0] == arr.shape()[0]) << func
+          << " for RowSparse " << param << " is only implemented for "
+          << "RowSparse " << param << " with all rows containing non-zeros. "
+          << "Expects " << param << ".data.shape[0] (" << 
arr.storage_shape()[0]
+          << ") == " << param << ".shape[0] (" << arr.shape()[0] << ").";
+  } else {
+    CHECK(arr.storage_type() == kDefaultStorage);
   }
+}
 
 inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
                                const OpContext &ctx,
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index dfc7bef977d..28b382c92fb 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -42,6 +42,18 @@
 
 namespace mxnet {
 namespace op {
+
+/*
+ * \brief log message for optimizers with lazy update.
+ */
+inline void LogLazyUpdate() {
+  common::LogOnce("Optimizer with lazy_update = True detected. "
+                  "Be aware that lazy update with row_sparse gradient is 
different from "
+                  "standard update, and may lead to different empirical 
results. See "
+                  
"https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
+                  "for more details.");
+}
+
 struct SGDParam : public dmlc::Parameter<SGDParam> {
   float lr;
   float wd;
@@ -66,7 +78,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
     DMLC_DECLARE_FIELD(lazy_update)
     .set_default(true)
-    .describe("If true, lazy updates are applied.");
+    .describe("If true, lazy updates are applied if gradient's stype is 
row_sparse.");
   }
 };
 
@@ -167,6 +179,10 @@ struct SGDDnsRspKernel<req, cpu> {
   }
 };
 
+/*
+ * \brief SGD update implementation for dense weight and row_sparse grad.
+ *        Both standard update and lazy update are supported.
+ */
 template<typename xpu>
 inline void SGDUpdateDnsRspImpl(const SGDParam& param,
                                 const OpContext &ctx,
@@ -190,6 +206,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
         DType* weight_data = weight.dptr<DType>();
         float wd = param.wd;
+        // apply standard weight decay if not lazy update
         if (!param.lazy_update) {
           Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, 
weight.Size(),
             weight_data, weight_data, static_cast<DType>(1 - param.lr * 
param.wd));
@@ -214,14 +231,18 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
   });
 }
 
+/*
+ * \brief SGD update implementation for row_sparse grad.
+ *        Both standard update and lazy update are supported.
+ */
 template<typename xpu>
-inline void SGDUpdateRspRspImpl(const SGDParam& param,
-                                const OpContext& ctx,
-                                const NDArray& weight,
-                                const NDArray& grad,
-                                const OpReqType& req,
-                                NDArray *out) {
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
+inline void SGDUpdateRspImpl(const SGDParam& param,
+                             const OpContext& ctx,
+                             const NDArray& weight,
+                             const NDArray& grad,
+                             const OpReqType& req,
+                             NDArray *out) {
+  CheckAllRowsPresent(weight, "SGDUpdate", "weights");
   // reuse dns rsp implementation when storage_shape == shape
   TBlob out_blob = out->data();
   SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
@@ -233,15 +254,15 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
                         const std::vector<NDArray> &inputs,
                         const std::vector<OpReqType> &req,
                         const std::vector<NDArray> &outputs) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace mshadow_op;
   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
-  auto out_stype = outputs[0].storage_type();
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
+  const auto w_stype = inputs[0].storage_type();
+  const auto g_stype = inputs[1].storage_type();
+  const auto o_stype = outputs[0].storage_type();
+  if (o_stype == w_stype && g_stype == kRowSparseStorage &&
+      (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
     NDArray out = outputs[0];
-    SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
+    // std update and lazy update with rsp grad
+    SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -253,6 +274,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
   float wd;
   float rescale_grad;
   float clip_gradient;
+  bool lazy_update;
   DMLC_DECLARE_PARAMETER(SGDMomParam) {
     DMLC_DECLARE_FIELD(lr)
     .describe("Learning rate");
@@ -272,6 +294,10 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
               "If clip_gradient <= 0, gradient clipping is turned off. "
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(lazy_update)
+    .set_default(true)
+    .describe("If true, lazy updates are applied if gradient's stype is 
row_sparse "
+              "and both weight and momentum have the same stype");
   }
 };
 
@@ -478,14 +504,17 @@ struct SGDMomDnsRspDnsKernel<req, gpu> {
   }
 };
 
+/*
+ * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
+ */
 template<typename xpu>
-inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
-                                      const OpContext& ctx,
-                                      const TBlob& weight,
-                                      const NDArray& grad,
-                                      const TBlob& mom,
-                                      const OpReqType& req,
-                                      TBlob *out) {
+inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
+                                          const OpContext& ctx,
+                                          const TBlob& weight,
+                                          const NDArray& grad,
+                                          const TBlob& mom,
+                                          const OpReqType& req,
+                                          TBlob *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
   Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -518,69 +547,78 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& 
param,
   });
 }
 
+/*
+ * \brief sgd momentum lazy update for row_sparse grad.
+ */
 template<typename xpu>
-inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
-                                      const OpContext& ctx,
-                                      const NDArray& weight,
-                                      const NDArray& grad,
-                                      const NDArray& mom,
-                                      const OpReqType& req,
-                                      NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
+                                    const OpContext& ctx,
+                                    const NDArray& weight,
+                                    const NDArray& grad,
+                                    const NDArray& mom,
+                                    const OpReqType& req,
+                                    NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
-  // fill mom with zero values in order to reuse the sgd mom dns impl
-  if (!mom.storage_initialized()) {
+  // fill mom with zero values (if it's in rsp storage)
+  // in order to reuse the sgd mom dns impl
+  if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
     NDArray mom_zeros = mom;
     FillDnsZerosRspImpl(s, &mom_zeros);
   }
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
-  SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
-                                 mom.data(), req, &out_blob);
+  SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                     mom.data(), req, &out_blob);
 }
 
 /*!
- * \brief Storge type inference function in optimizer.
- * \param n_rsp     The number of inputs that should be of row_sparse storage 
type
- *                  if kFComputeEx is dispatched
- * \param n_rsp_dns The number of inputs that should be of row_sparse or 
default storage type
- *                  if kFComputeEx is dispatched
+ * \brief Storge type inference function for optimizers which support both
+ *        lazy update and standard update, with states (e.g. 2nd order moment)
+ * \param num_states The number of states that could be row_sparse or dense
  */
-template<int n_rsp, int n_rsp_dns>
+template<size_t num_states, typename ParamType>
 inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
                               const int dev_mask,
                               DispatchMode* dispatch_mode,
                               std::vector<int>* in_attrs,
                               std::vector<int>* out_attrs) {
   using namespace common;
-  CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+  const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
+  // weight, grad, state 0, state 1, ... -> weight
+  CHECK_EQ(in_attrs->size(), 2 + num_states);
   CHECK_EQ(out_attrs->size(), 1U);
+  const int weight_stype = in_attrs->at(0);
+  const int grad_stype = in_attrs->at(1);
+  const int state_stype = in_attrs->at(2);
+  // the storage type of all states should be the same
+  for (size_t i = 3; i <  2 + num_states; i++) {
+    CHECK_EQ(state_stype, in_attrs->at(i))
+      << "Inconsistent storage types detected in state " << i;
+  }
   bool dispatched = false;
   if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
     // dns, ... -> dns
     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
   }
-  const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + 
n_rsp);
-  const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, 
in_attrs->end());
-  if (!dispatched && ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) &&
-      (ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
-       ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
-    // rsp, ..., rsp/dns, ... -> rsp
-    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+      state_stype == weight_stype) {
+    // weight and state share stype, grad's stype = rsp
+    dispatched = storage_type_assign(out_attrs, 
static_cast<NDArrayStorageType>(weight_stype),
                                      dispatch_mode, DispatchMode::kFComputeEx);
     // warn users if lazy_update is turned on
-    if (dispatched && ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage)) {
-      LogOnce("Optimizer with lazy_update = True detected. "
-      "Be aware that lazy update is different from standard update, "
-      "and may lead to different empirical results. See "
-      
"https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
-      "for more details.");
-    }
+    if (dispatched && param.lazy_update) LogLazyUpdate();
+  }
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) {
+    // weight,  grad, state, ...  -> weight
+    // rsp,     rsp,  dns,   ...  -> rsp, standard update
+    dispatched = storage_type_assign(out_attrs, 
static_cast<NDArrayStorageType>(weight_stype),
+                                     dispatch_mode, DispatchMode::kFComputeEx);
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -588,10 +626,16 @@ inline bool StdOptStorageType(const nnvm::NodeAttrs& 
attrs,
   return dispatched;
 }
 
+/*
+ * \brief kernel for standard momentum update for dense weight, sparse grad 
and dense state.
+ */
 template<int req, typename xpu>
 struct SGDMomStdDnsRspDnsKernel;
 
 
+/*
+ * \brief standard momentum update for dense weight, row_sparse grad and dense 
states.
+ */
 template<typename xpu>
 void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
                                   const OpContext& ctx,
@@ -601,19 +645,28 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& 
param,
                                   const OpReqType& req,
                                   TBlob *out);
 
+/*
+ * \brief standard momentum update for row_sparse grad.
+ *        both row_sparse and dense weight are supported.
+ */
 template<typename xpu>
-inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
-                                         const OpContext& ctx,
-                                         const NDArray& weight,
-                                         const NDArray& grad,
-                                         const NDArray& mom,
-                                         const OpReqType& req,
-                                         NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,
+                                   const OpContext& ctx,
+                                   const NDArray& weight,
+                                   const NDArray& grad,
+                                   const NDArray& mom,
+                                   const OpReqType& req,
+                                   NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  // fill mom with zero values (if it's in rsp storage)
+  // in order to reuse the sgd mom dns impl
+  if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
+    NDArray mom_zeros = mom;
+    FillDnsZerosRspImpl(s, &mom_zeros);
+  }
   TBlob out_blob = out->data();
   SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
                                     mom.data(), req, &out_blob);
@@ -630,16 +683,25 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
   auto &weight = inputs[0];
   auto &grad = inputs[1];
   auto &mom = inputs[2];
+  const auto w_stype = weight.storage_type();
+  const auto m_stype = mom.storage_type();
   const auto out_stype = outputs[0].storage_type();
   NDArray out = outputs[0];
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
-    SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
-  } else if (weight.storage_type() == kRowSparseStorage &&
-             grad.storage_type() == kRowSparseStorage &&
-             mom.storage_type() == kDefaultStorage &&
-             out_stype == kRowSparseStorage) {
-    SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
+  const bool valid_weight = w_stype == kDefaultStorage || w_stype == 
kRowSparseStorage;
+  const bool valid_grad = grad.storage_type() == kRowSparseStorage;
+  const bool lazy_update = param.lazy_update;
+  CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+  if (valid_weight && valid_grad && m_stype == w_stype) {
+    if (lazy_update) {
+      // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+      SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
+    } else {
+      // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+      SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+    }
+  } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == 
kDefaultStorage) {
+    // rsp weight, rsp grad, dns state -> std update
+    SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -742,6 +804,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
   float wd;
   float rescale_grad;
   float clip_gradient;
+  bool lazy_update;
   DMLC_DECLARE_PARAMETER(AdamParam) {
     DMLC_DECLARE_FIELD(lr)
     .describe("Learning rate");
@@ -767,6 +830,10 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
               "If clip_gradient <= 0, gradient clipping is turned off. "
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(lazy_update)
+    .set_default(true)
+    .describe("If true, lazy updates are applied if gradient's stype is 
row_sparse "
+              "and all of w, m and v have the same stype");
   }
 };
 
@@ -876,15 +943,18 @@ struct AdamDnsRspDnsKernel<req, gpu> {
   }
 };
 
+/*
+ * \brief lazy adam update for dense weight, dense states and rsp grad.
+ */
 template<typename xpu>
-inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
-                                    const OpContext& ctx,
-                                    const TBlob& weight,
-                                    const NDArray& grad,
-                                    const TBlob& mean,
-                                    const TBlob& var,
-                                    const OpReqType& req,
-                                    TBlob *out) {
+inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param,
+                                        const OpContext& ctx,
+                                        const TBlob& weight,
+                                        const NDArray& grad,
+                                        const TBlob& mean,
+                                        const TBlob& var,
+                                        const OpReqType& req,
+                                        TBlob *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
   Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -920,39 +990,47 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& 
param,
   });
 }
 
+/*
+ * \brief lazy adam update for both row_sparse and dense weight.
+ *        grad is expected to be row_sparse.
+ */
 template<typename xpu>
-inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
-                                    const OpContext& ctx,
-                                    const NDArray& weight,
-                                    const NDArray& grad,
-                                    const NDArray& mean,
-                                    const NDArray& var,
-                                    const OpReqType& req,
-                                    NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void AdamLazyUpdateRspImpl(const AdamParam& param,
+                                  const OpContext& ctx,
+                                  const NDArray& weight,
+                                  const NDArray& grad,
+                                  const NDArray& mean,
+                                  const NDArray& var,
+                                  const OpReqType& req,
+                                  NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdamUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill mean and variance with zero values in order to reuse the sgd mom dns 
impl
-  if (!mean.storage_initialized()) {
+  if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) 
{
     NDArray mean_zeros = mean;
     FillDnsZerosRspImpl(s, &mean_zeros);
   }
-  if (!var.storage_initialized()) {
+  if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) {
     NDArray var_zeros = var;
     FillDnsZerosRspImpl(s, &var_zeros);
   }
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
-  AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
-                               var.data(), req, &out_blob);
+  AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, 
mean.data(),
+                                   var.data(), req, &out_blob);
 }
 
+/*
+ * \brief kernel for standard adam update for dense weight, row_sparse grad 
and dense states.
+ */
 template<int req, typename xpu>
 struct AdamStdDnsRspDnsKernel;
 
+/*
+ * \brief standard adam update for dense weight, row_sparse grad and dense 
states.
+ */
 template<typename xpu>
 void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
                                 const OpContext& ctx,
@@ -963,18 +1041,22 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
                                 const OpReqType& req,
                                 TBlob *out);
 
+/*
+ * \brief standard adam update for both row_sparse and dense weight.
+ *        states are expected to be dense, while grad is expected to be 
row_sparse.
+ */
 template<typename xpu>
-inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
-                                       const OpContext& ctx,
-                                       const NDArray& weight,
-                                       const NDArray& grad,
-                                       const NDArray& mean,
-                                       const NDArray& var,
-                                       const OpReqType& req,
-                                       NDArray *out) {
+inline void AdamStdUpdateRspImpl(const AdamParam& param,
+                                 const OpContext& ctx,
+                                 const NDArray& weight,
+                                 const NDArray& grad,
+                                 const NDArray& mean,
+                                 const NDArray& var,
+                                 const OpReqType& req,
+                                 NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdamStdUpdate", "weights");
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
   AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
@@ -988,21 +1070,30 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
                          const std::vector<OpReqType> &req,
                          const std::vector<NDArray> &outputs) {
   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
-  const auto weight_stype = inputs[0].storage_type();
-  const auto grad_stype = inputs[1].storage_type();
-  const auto mean_stype = inputs[2].storage_type();
-  const auto var_stype = inputs[3].storage_type();
+  const auto w_stype = inputs[0].storage_type();
+  const auto g_stype = inputs[1].storage_type();
+  const auto m_stype = inputs[2].storage_type();
+  const auto v_stype = inputs[3].storage_type();
   const auto out_stype = outputs[0].storage_type();
   NDArray out = outputs[0];
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
-     AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+  const bool valid_weight = w_stype == kDefaultStorage || w_stype == 
kRowSparseStorage;
+  CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+  CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype";
+  if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) {
+     if (param.lazy_update) {
+       // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+       AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
                                   inputs[3], req[0], &out);
-  } else if (weight_stype == kRowSparseStorage && grad_stype == 
kRowSparseStorage &&
-             mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
-             out_stype == kRowSparseStorage) {
-     AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], 
inputs[2],
-                                     inputs[3], req[0], &out);
+     } else {
+       // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+       AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                                 inputs[3], req[0], &out);
+     }
+  } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage &&
+             m_stype == kDefaultStorage) {
+     // rsp, rsp, dns, dns -> rsp, standard update
+     AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                               inputs[3], req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -1361,7 +1452,7 @@ inline void FtrlUpdateRspRspRspImpl(const FtrlParam& 
param,
   using namespace mshadow::expr;
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "FtrlUpdate", "weights");
+  CheckAllRowsPresent(weight, "FtrlUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill z and n with zero values in order to reuse the sgd mom dns impl
   if (!z.storage_initialized()) {
@@ -1690,7 +1781,7 @@ inline void AdagradUpdateRspRspRspImpl(const 
AdagradParam& param,
   using namespace mshadow;
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdagradUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill history with zero values
   if (!state.storage_initialized()) {
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index fe0be9d442f..cc7770b2e4c 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -132,6 +132,10 @@ struct SGDMomStdDnsRspDnsKernel<req, cpu> {
   }
 };
 
+/*
+ * \brief standard momentum update for dense weight on cpu.
+ *        state is expected to be dense, while grad is expected to be 
row_sparse.
+ */
 template<>
 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
                                        const OpContext& ctx,
@@ -152,12 +156,12 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& 
param,
     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
         DType* weight_data = weight.dptr<DType>();
-        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
-        DType* grad_val = grad.data().dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
         DType* mom_data = mom.dptr<DType>();
         DType* out_data = out->dptr<DType>();
-        nnvm::dim_t num_rows = weight.shape_[0];
-        auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+        const nnvm::dim_t num_rows = weight.shape_[0];
+        const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
         Tensor<cpu, 1, char> workspace = ctx.requested[0]
           .get_space_typed<cpu, 1, char>(Shape1(num_rows * 
sizeof(nnvm::dim_t)), s);
 
@@ -275,6 +279,40 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& 
param,
   });
 }
 
+/*!
+ * \brief Storge type inference function for SGD.
+ */
+inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
+                           const int dev_mask,
+                           DispatchMode* dispatch_mode,
+                           std::vector<int>* in_attrs,
+                           std::vector<int>* out_attrs) {
+  using namespace common;
+  const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int weight_stype = in_attrs->at(0);
+  const int grad_stype = in_attrs->at(1);
+  bool dispatched = false;
+  if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    // dns, ... -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) {
+    // grad's stype = rsp
+    dispatched = storage_type_assign(out_attrs, 
static_cast<NDArrayStorageType>(weight_stype),
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+    // warn users if lazy_update is turned on
+    if (dispatched && param.lazy_update) LogLazyUpdate();
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
@@ -282,13 +320,13 @@ MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
 
 It updates the weights using::
 
- weight = weight - learning_rate * gradient
+ weight = weight - learning_rate * (gradient + wd * weight)
 
-If weight is of ``row_sparse`` storage type,
+However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is 
True,
 only the row slices whose indices appear in grad.indices are updated::
 
  for row in gradient.indices:
-     weight[row] = weight[row] - learning_rate * gradient[row]
+     weight[row] = weight[row] - learning_rate * (gradient[row] + wd * 
weight[row])
 
 )code" ADD_FILELINE)
 .set_num_inputs(2)
@@ -296,7 +334,7 @@ only the row slices whose indices appear in grad.indices 
are updated::
 .set_attr_parser(ParamParser<SGDParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1, 
false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", SGDStorageType)
 .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
@@ -305,7 +343,7 @@ only the row slices whose indices appear in grad.indices 
are updated::
 
 NNVM_REGISTER_OP(sgd_mom_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
-.describe(R"code(Momentum update function for Stochastic Gradient Descent 
(SDG) optimizer.
+.describe(R"code(Momentum update function for Stochastic Gradient Descent 
(SGD) optimizer.
 
 Momentum update has better convergence rates on neural networks. 
Mathematically it looks
 like below:
@@ -323,10 +361,8 @@ It updates the weights using::
 
 Where the parameter ``momentum`` is the decay rate of momentum estimates at 
each epoch.
 
-If weight and grad are both of ``row_sparse`` storage type and momentum is of 
``default`` storage type,
-standard update is applied.
-
-If weight, grad and momentum are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and 
weight's storage
+type is the same as momentum's storage type,
 only the row slices whose indices appear in grad.indices are updated (for both 
weight and momentum)::
 
   for row in gradient.indices:
@@ -339,7 +375,7 @@ only the row slices whose indices appear in grad.indices 
are updated (for both w
 .set_attr_parser(ParamParser<SGDMomParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, 
SGDMomParam>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2};
@@ -420,7 +456,7 @@ available at 
http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
 .add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
 .add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
 .add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
-.add_arguments(AdamParam::__FIELDS__());
+.add_arguments(FTMLParam::__FIELDS__());
 
 NNVM_REGISTER_OP(adam_update)
 MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
@@ -443,7 +479,8 @@ It updates the weights using::
  v = beta2*v + (1-beta2)*(grad**2)
  w += - learning_rate * m / (sqrt(v) + epsilon)
 
-If w, m and v are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and 
the storage
+type of weight is the same as those of m and v,
 only the row slices whose indices appear in grad.indices are updated (for w, m 
and v)::
 
  for row in grad.indices:
@@ -461,7 +498,7 @@ only the row slices whose indices appear in grad.indices 
are updated (for w, m a
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 
AdamParam>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index d1dc31a31c5..90762f7620f 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -204,7 +204,6 @@ def update(self, index, weight, grad, state):
     def update_multi_precision(self, index, weight, grad, state):
         self.update(index, weight, grad, state)
 
-@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. 
Tracked at https://github.com/apache/incubator-mxnet/issues/9000";)
 @with_seed()
 def test_sgd():
     opt1 = PySGD
@@ -233,16 +232,9 @@ def test_sgd():
                                 continue
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype)
                             # test operator fallback on cpu
-                            if (default_context() == mx.cpu()):
-                                compare_optimizer(opt1(**kwarg), 
opt2(**kwarg), shape, dtype,
-                                                  g_stype='row_sparse')
-                                if dtype != np.float16:
-                                    compare_optimizer(opt1(**kwarg), 
opt2(**kwarg), shape[:2],
-                                                      dtype, w_stype='csr', 
g_stype='csr')
-    # test optimizer with a big shape
-    big_shape = (54686454, 1)
-    kwarg = {'momentum': 0.9, 'wd': 0.05}
-    compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)
+                            if dtype != np.float16:
+                                compare_optimizer(opt1(**kwarg), 
opt2(**kwarg), shape[:2],
+                                                  dtype, w_stype='csr', 
g_stype='csr')
 
 class PySparseSGD(mx.optimizer.Optimizer):
     """python reference implemenation of sgd"""
@@ -337,9 +329,11 @@ def test_sparse_sgd():
                             kwarg.update(mp_option)
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype,
                                               w_stype='row_sparse', 
g_stype='row_sparse')
+                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype,
+                                              w_stype='default', 
g_stype='row_sparse')
 
 
-@with_seed(0)
+@with_seed()
 def test_std_sparse_sgd():
     opt1 = PySGD
     opt2 = mx.optimizer.SGD
@@ -360,6 +354,8 @@ def test_std_sparse_sgd():
                         kwarg.update(wd_option)
                         compare_optimizer(opt1(**kwarg), 
opt2(lazy_update=False, **kwarg), shape, dtype,
                                           w_stype='row_sparse', 
g_stype='row_sparse')
+                        compare_optimizer(opt1(**kwarg), 
opt2(lazy_update=False, **kwarg), shape, dtype,
+                                          w_stype='default', 
g_stype='row_sparse')
 
 
 class PyNAG(PySGD):
@@ -543,7 +539,7 @@ def test_ftml():
 class PyAdam(mx.optimizer.Optimizer):
     """python reference implemenation of adam"""
     def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, 
epsilon=1e-8,
-                 decay_factor=(1 - 1e-8), lazy_update=False, **kwargs):
+                 decay_factor=(1 - 1e-8), lazy_update=True, **kwargs):
         super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
         self.beta1 = beta1
         self.beta2 = beta2
@@ -594,7 +590,7 @@ def update(self, index, weight, grad, state):
         for row in range(num_rows):
             # check row slices of all zeros
             all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), 
np.zeros_like(grad[row].asnumpy()))
-            # skip zeros during sparse update
+            # skip zeros during lazy update
             if all_zeros and self.lazy_update:
                 continue
             grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
@@ -635,15 +631,21 @@ def test_adam():
                                     not kwarg['multi_precision'])):
                             continue
                         # atol 2e-5 needed to pass with seed 1248389097
-                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 
dtype,
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), 
opt2(**kwarg), shape, dtype,
                                           rtol=1e-4, atol=2e-5)
                         # atol 2e-5 needed to pass with seed 781809840
-                        compare_optimizer(opt1(lazy_update=True, **kwarg), 
opt2(**kwarg), shape,
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
                                           dtype, w_stype='row_sparse', 
g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
-                        compare_optimizer(opt1(**kwarg), 
opt2(lazy_update=False, **kwarg), shape,
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), 
opt2(lazy_update=False, **kwarg), shape,
                                           dtype, w_stype='row_sparse', 
g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
+                                          dtype, w_stype='default', 
g_stype='row_sparse',
+                                          rtol=1e-4, atol=2e-5)
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), 
opt2(lazy_update=False, **kwarg), shape,
+                                          dtype, w_stype='default', 
g_stype='row_sparse',
+                                          rtol=1e-4, atol=2e-5)
 
 # Signum
 class PySignum(mx.optimizer.Optimizer):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to