eric-haibin-lin closed pull request #10664: [MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum URL: https://github.com/apache/incubator-mxnet/pull/10664
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 2f7c51bff8d..8e845b2aaa8 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -433,7 +433,7 @@ def _get_wd(self, index): class SGD(Optimizer): """The SGD optimizer with momentum and weight decay. - If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \ + If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \ **lazy updates** are applied by:: for row in grad.indices: @@ -493,8 +493,8 @@ def create_state_multi_precision(self, index, weight): def create_state(self, index, weight): momentum = None - stype = weight.stype if self.lazy_update else 'default' if self.momentum != 0.0: + stype = weight.stype if self.lazy_update else 'default' momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) return momentum @@ -514,7 +514,7 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False): if not multi_precision: if state is not None: sgd_mom_update(weight, grad, state, out=weight, - lr=lr, wd=wd, **kwargs) + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) else: sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) @@ -985,7 +985,7 @@ class Adam(Optimizer): This class implements the optimizer described in *Adam: A Method for Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980. - If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \ + If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \ **lazy updates** are applied by:: for row in grad.indices: @@ -1058,7 +1058,7 @@ def update(self, index, weight, grad, state): mean, var = state adam_update(weight, grad, mean, var, out=weight, - lr=lr, wd=wd, **kwargs) + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) @register class AdaGrad(Optimizer): diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a629ba5eed8..0a9cd08db81 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -471,14 +471,18 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } -#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \ - { \ - CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \ - << " for RowSparse " << param << " is only implemented for " \ - << "RowSparse " << param << " with all rows containing non-zeros. " \ - << "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \ - << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \ +inline void CheckAllRowsPresent(const NDArray& arr, const std::string& func, + const std::string& param) { + if (arr.storage_type() == kRowSparseStorage) { + CHECK(arr.storage_shape()[0] == arr.shape()[0]) << func + << " for RowSparse " << param << " is only implemented for " + << "RowSparse " << param << " with all rows containing non-zeros. " + << "Expects " << param << ".data.shape[0] (" << arr.storage_shape()[0] + << ") == " << param << ".shape[0] (" << arr.shape()[0] << ")."; + } else { + CHECK(arr.storage_type() == kDefaultStorage); } +} inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index dfc7bef977d..28b382c92fb 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -42,6 +42,18 @@ namespace mxnet { namespace op { + +/* + * \brief log message for optimizers with lazy update. + */ +inline void LogLazyUpdate() { + common::LogOnce("Optimizer with lazy_update = True detected. " + "Be aware that lazy update with row_sparse gradient is different from " + "standard update, and may lead to different empirical results. See " + "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html " + "for more details."); +} + struct SGDParam : public dmlc::Parameter<SGDParam> { float lr; float wd; @@ -66,7 +78,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> { "grad = max(min(grad, clip_gradient), -clip_gradient)."); DMLC_DECLARE_FIELD(lazy_update) .set_default(true) - .describe("If true, lazy updates are applied."); + .describe("If true, lazy updates are applied if gradient's stype is row_sparse."); } }; @@ -167,6 +179,10 @@ struct SGDDnsRspKernel<req, cpu> { } }; +/* + * \brief SGD update implementation for dense weight and row_sparse grad. + * Both standard update and lazy update are supported. + */ template<typename xpu> inline void SGDUpdateDnsRspImpl(const SGDParam& param, const OpContext &ctx, @@ -190,6 +206,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, MXNET_ASSIGN_REQ_SWITCH(req, req_type, { DType* weight_data = weight.dptr<DType>(); float wd = param.wd; + // apply standard weight decay if not lazy update if (!param.lazy_update) { Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(), weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd)); @@ -214,14 +231,18 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, }); } +/* + * \brief SGD update implementation for row_sparse grad. + * Both standard update and lazy update are supported. + */ template<typename xpu> -inline void SGDUpdateRspRspImpl(const SGDParam& param, - const OpContext& ctx, - const NDArray& weight, - const NDArray& grad, - const OpReqType& req, - NDArray *out) { - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); +inline void SGDUpdateRspImpl(const SGDParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const OpReqType& req, + NDArray *out) { + CheckAllRowsPresent(weight, "SGDUpdate", "weights"); // reuse dns rsp implementation when storage_shape == shape TBlob out_blob = out->data(); SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob); @@ -233,15 +254,15 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace mshadow_op; const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); - auto out_stype = outputs[0].storage_type(); - if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && - out_stype == kRowSparseStorage) { + const auto w_stype = inputs[0].storage_type(); + const auto g_stype = inputs[1].storage_type(); + const auto o_stype = outputs[0].storage_type(); + if (o_stype == w_stype && g_stype == kRowSparseStorage && + (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) { NDArray out = outputs[0]; - SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out); + // std update and lazy update with rsp grad + SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -253,6 +274,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> { float wd; float rescale_grad; float clip_gradient; + bool lazy_update; DMLC_DECLARE_PARAMETER(SGDMomParam) { DMLC_DECLARE_FIELD(lr) .describe("Learning rate"); @@ -272,6 +294,10 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> { .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " "If clip_gradient <= 0, gradient clipping is turned off. " "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(lazy_update) + .set_default(true) + .describe("If true, lazy updates are applied if gradient's stype is row_sparse " + "and both weight and momentum have the same stype"); } }; @@ -478,14 +504,17 @@ struct SGDMomDnsRspDnsKernel<req, gpu> { } }; +/* + * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state. + */ template<typename xpu> -inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, - const OpContext& ctx, - const TBlob& weight, - const NDArray& grad, - const TBlob& mom, - const OpReqType& req, - TBlob *out) { +inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { using namespace mxnet_op; using namespace rowsparse; Stream<xpu>* s = ctx.get_stream<xpu>(); @@ -518,69 +547,78 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } +/* + * \brief sgd momentum lazy update for row_sparse grad. + */ template<typename xpu> -inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, - const OpContext& ctx, - const NDArray& weight, - const NDArray& grad, - const NDArray& mom, - const OpReqType& req, - NDArray *out) { - using namespace mshadow; - using namespace mshadow::expr; +inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + CheckAllRowsPresent(weight, "SGDMomUpdate", "weights"); Stream<xpu>* s = ctx.get_stream<xpu>(); - // fill mom with zero values in order to reuse the sgd mom dns impl - if (!mom.storage_initialized()) { + // fill mom with zero values (if it's in rsp storage) + // in order to reuse the sgd mom dns impl + if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) { NDArray mom_zeros = mom; FillDnsZerosRspImpl(s, &mom_zeros); } TBlob out_blob = out->data(); // reuse dns rsp implementation when storage_shape == shape - SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, - mom.data(), req, &out_blob); + SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, + mom.data(), req, &out_blob); } /*! - * \brief Storge type inference function in optimizer. - * \param n_rsp The number of inputs that should be of row_sparse storage type - * if kFComputeEx is dispatched - * \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type - * if kFComputeEx is dispatched + * \brief Storge type inference function for optimizers which support both + * lazy update and standard update, with states (e.g. 2nd order moment) + * \param num_states The number of states that could be row_sparse or dense */ -template<int n_rsp, int n_rsp_dns> +template<size_t num_states, typename ParamType> inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, std::vector<int>* in_attrs, std::vector<int>* out_attrs) { using namespace common; - CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns)); + const ParamType& param = nnvm::get<ParamType>(attrs.parsed); + // weight, grad, state 0, state 1, ... -> weight + CHECK_EQ(in_attrs->size(), 2 + num_states); CHECK_EQ(out_attrs->size(), 1U); + const int weight_stype = in_attrs->at(0); + const int grad_stype = in_attrs->at(1); + const int state_stype = in_attrs->at(2); + // the storage type of all states should be the same + for (size_t i = 3; i < 2 + num_states; i++) { + CHECK_EQ(state_stype, in_attrs->at(i)) + << "Inconsistent storage types detected in state " << i; + } bool dispatched = false; if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { // dns, ... -> dns dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } - const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp); - const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end()); - if (!dispatched && ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) && - (ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) || - ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) { - // rsp, ..., rsp/dns, ... -> rsp - dispatched = storage_type_assign(out_attrs, kRowSparseStorage, + if (!dispatched && grad_stype == kRowSparseStorage && + (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && + state_stype == weight_stype) { + // weight and state share stype, grad's stype = rsp + dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode, DispatchMode::kFComputeEx); // warn users if lazy_update is turned on - if (dispatched && ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage)) { - LogOnce("Optimizer with lazy_update = True detected. " - "Be aware that lazy update is different from standard update, " - "and may lead to different empirical results. See " - "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html " - "for more details."); - } + if (dispatched && param.lazy_update) LogLazyUpdate(); + } + if (!dispatched && grad_stype == kRowSparseStorage && + weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) { + // weight, grad, state, ... -> weight + // rsp, rsp, dns, ... -> rsp, standard update + dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype), + dispatch_mode, DispatchMode::kFComputeEx); } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); @@ -588,10 +626,16 @@ inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +/* + * \brief kernel for standard momentum update for dense weight, sparse grad and dense state. + */ template<int req, typename xpu> struct SGDMomStdDnsRspDnsKernel; +/* + * \brief standard momentum update for dense weight, row_sparse grad and dense states. + */ template<typename xpu> void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, const OpContext& ctx, @@ -601,19 +645,28 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, const OpReqType& req, TBlob *out); +/* + * \brief standard momentum update for row_sparse grad. + * both row_sparse and dense weight are supported. + */ template<typename xpu> -inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param, - const OpContext& ctx, - const NDArray& weight, - const NDArray& grad, - const NDArray& mom, - const OpReqType& req, - NDArray *out) { - using namespace mshadow; - using namespace mshadow::expr; +inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + CheckAllRowsPresent(weight, "SGDMomUpdate", "weights"); + Stream<xpu>* s = ctx.get_stream<xpu>(); + // fill mom with zero values (if it's in rsp storage) + // in order to reuse the sgd mom dns impl + if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) { + NDArray mom_zeros = mom; + FillDnsZerosRspImpl(s, &mom_zeros); + } TBlob out_blob = out->data(); SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mom.data(), req, &out_blob); @@ -630,16 +683,25 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, auto &weight = inputs[0]; auto &grad = inputs[1]; auto &mom = inputs[2]; + const auto w_stype = weight.storage_type(); + const auto m_stype = mom.storage_type(); const auto out_stype = outputs[0].storage_type(); NDArray out = outputs[0]; - if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && - out_stype == kRowSparseStorage) { - SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); - } else if (weight.storage_type() == kRowSparseStorage && - grad.storage_type() == kRowSparseStorage && - mom.storage_type() == kDefaultStorage && - out_stype == kRowSparseStorage) { - SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); + const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage; + const bool valid_grad = grad.storage_type() == kRowSparseStorage; + const bool lazy_update = param.lazy_update; + CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype"; + if (valid_weight && valid_grad && m_stype == w_stype) { + if (lazy_update) { + // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update + SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); + } else { + // rsp grad && m.stype = w.stype && lazy_update = false -> std update + SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); + } + } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) { + // rsp weight, rsp grad, dns state -> std update + SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -742,6 +804,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> { float wd; float rescale_grad; float clip_gradient; + bool lazy_update; DMLC_DECLARE_PARAMETER(AdamParam) { DMLC_DECLARE_FIELD(lr) .describe("Learning rate"); @@ -767,6 +830,10 @@ struct AdamParam : public dmlc::Parameter<AdamParam> { .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " "If clip_gradient <= 0, gradient clipping is turned off. " "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(lazy_update) + .set_default(true) + .describe("If true, lazy updates are applied if gradient's stype is row_sparse " + "and all of w, m and v have the same stype"); } }; @@ -876,15 +943,18 @@ struct AdamDnsRspDnsKernel<req, gpu> { } }; +/* + * \brief lazy adam update for dense weight, dense states and rsp grad. + */ template<typename xpu> -inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param, - const OpContext& ctx, - const TBlob& weight, - const NDArray& grad, - const TBlob& mean, - const TBlob& var, - const OpReqType& req, - TBlob *out) { +inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mean, + const TBlob& var, + const OpReqType& req, + TBlob *out) { using namespace mxnet_op; using namespace rowsparse; Stream<xpu>* s = ctx.get_stream<xpu>(); @@ -920,39 +990,47 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param, }); } +/* + * \brief lazy adam update for both row_sparse and dense weight. + * grad is expected to be row_sparse. + */ template<typename xpu> -inline void AdamUpdateRspRspRspImpl(const AdamParam& param, - const OpContext& ctx, - const NDArray& weight, - const NDArray& grad, - const NDArray& mean, - const NDArray& var, - const OpReqType& req, - NDArray *out) { - using namespace mshadow; - using namespace mshadow::expr; +inline void AdamLazyUpdateRspImpl(const AdamParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mean, + const NDArray& var, + const OpReqType& req, + NDArray *out) { using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights"); + CheckAllRowsPresent(weight, "AdamUpdate", "weights"); Stream<xpu>* s = ctx.get_stream<xpu>(); // fill mean and variance with zero values in order to reuse the sgd mom dns impl - if (!mean.storage_initialized()) { + if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) { NDArray mean_zeros = mean; FillDnsZerosRspImpl(s, &mean_zeros); } - if (!var.storage_initialized()) { + if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) { NDArray var_zeros = var; FillDnsZerosRspImpl(s, &var_zeros); } TBlob out_blob = out->data(); // reuse dns rsp implementation when storage_shape == shape - AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(), - var.data(), req, &out_blob); + AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(), + var.data(), req, &out_blob); } +/* + * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states. + */ template<int req, typename xpu> struct AdamStdDnsRspDnsKernel; +/* + * \brief standard adam update for dense weight, row_sparse grad and dense states. + */ template<typename xpu> void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, const OpContext& ctx, @@ -963,18 +1041,22 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, const OpReqType& req, TBlob *out); +/* + * \brief standard adam update for both row_sparse and dense weight. + * states are expected to be dense, while grad is expected to be row_sparse. + */ template<typename xpu> -inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param, - const OpContext& ctx, - const NDArray& weight, - const NDArray& grad, - const NDArray& mean, - const NDArray& var, - const OpReqType& req, - NDArray *out) { +inline void AdamStdUpdateRspImpl(const AdamParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mean, + const NDArray& var, + const OpReqType& req, + NDArray *out) { using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights"); + CheckAllRowsPresent(weight, "AdamStdUpdate", "weights"); TBlob out_blob = out->data(); // reuse dns rsp implementation when storage_shape == shape AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(), @@ -988,21 +1070,30 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs, const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs) { const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed); - const auto weight_stype = inputs[0].storage_type(); - const auto grad_stype = inputs[1].storage_type(); - const auto mean_stype = inputs[2].storage_type(); - const auto var_stype = inputs[3].storage_type(); + const auto w_stype = inputs[0].storage_type(); + const auto g_stype = inputs[1].storage_type(); + const auto m_stype = inputs[2].storage_type(); + const auto v_stype = inputs[3].storage_type(); const auto out_stype = outputs[0].storage_type(); NDArray out = outputs[0]; - if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && - out_stype == kRowSparseStorage) { - AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], + const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage; + CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype"; + CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype"; + if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) { + if (param.lazy_update) { + // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update + AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], inputs[3], req[0], &out); - } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && - mean_stype == kDefaultStorage && var_stype == kDefaultStorage && - out_stype == kRowSparseStorage) { - AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], - inputs[3], req[0], &out); + } else { + // rsp grad && m.stype = w.stype && lazy_update = false -> std update + AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], + inputs[3], req[0], &out); + } + } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage && + m_stype == kDefaultStorage) { + // rsp, rsp, dns, dns -> rsp, standard update + AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], + inputs[3], req[0], &out); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -1361,7 +1452,7 @@ inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param, using namespace mshadow::expr; using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "FtrlUpdate", "weights"); + CheckAllRowsPresent(weight, "FtrlUpdate", "weights"); Stream<xpu>* s = ctx.get_stream<xpu>(); // fill z and n with zero values in order to reuse the sgd mom dns impl if (!z.storage_initialized()) { @@ -1690,7 +1781,7 @@ inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param, using namespace mshadow; using namespace mxnet_op; using namespace rowsparse; - CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights"); + CheckAllRowsPresent(weight, "AdagradUpdate", "weights"); Stream<xpu>* s = ctx.get_stream<xpu>(); // fill history with zero values if (!state.storage_initialized()) { diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index fe0be9d442f..cc7770b2e4c 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -132,6 +132,10 @@ struct SGDMomStdDnsRspDnsKernel<req, cpu> { } }; +/* + * \brief standard momentum update for dense weight on cpu. + * state is expected to be dense, while grad is expected to be row_sparse. + */ template<> void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param, const OpContext& ctx, @@ -152,12 +156,12 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param, MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { MXNET_ASSIGN_REQ_SWITCH(req, req_type, { DType* weight_data = weight.dptr<DType>(); - IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); - DType* grad_val = grad.data().dptr<DType>(); + const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); + const DType* grad_val = grad.data().dptr<DType>(); DType* mom_data = mom.dptr<DType>(); DType* out_data = out->dptr<DType>(); - nnvm::dim_t num_rows = weight.shape_[0]; - auto row_length = weight.shape_.ProdShape(1, weight.ndim()); + const nnvm::dim_t num_rows = weight.shape_[0]; + const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); Tensor<cpu, 1, char> workspace = ctx.requested[0] .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s); @@ -275,6 +279,40 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param, }); } +/*! + * \brief Storge type inference function for SGD. + */ +inline bool SGDStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + using namespace common; + const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int weight_stype = in_attrs->at(0); + const int grad_stype = in_attrs->at(1); + bool dispatched = false; + if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { + // dns, ... -> dns + dispatched = storage_type_assign(out_attrs, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched && grad_stype == kRowSparseStorage && + (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) { + // grad's stype = rsp + dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype), + dispatch_mode, DispatchMode::kFComputeEx); + // warn users if lazy_update is turned on + if (dispatched && param.lazy_update) LogLazyUpdate(); + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + NNVM_REGISTER_OP(sgd_update) MXNET_ADD_SPARSE_OP_ALIAS(sgd_update) @@ -282,13 +320,13 @@ MXNET_ADD_SPARSE_OP_ALIAS(sgd_update) It updates the weights using:: - weight = weight - learning_rate * gradient + weight = weight - learning_rate * (gradient + wd * weight) -If weight is of ``row_sparse`` storage type, +However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is True, only the row slices whose indices appear in grad.indices are updated:: for row in gradient.indices: - weight[row] = weight[row] - learning_rate * gradient[row] + weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row]) )code" ADD_FILELINE) .set_num_inputs(2) @@ -296,7 +334,7 @@ only the row slices whose indices appear in grad.indices are updated:: .set_attr_parser(ParamParser<SGDParam>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) -.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1, false, true, false>) +.set_attr<FInferStorageType>("FInferStorageType", SGDStorageType) .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>) .set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>) .add_argument("weight", "NDArray-or-Symbol", "Weight") @@ -305,7 +343,7 @@ only the row slices whose indices appear in grad.indices are updated:: NNVM_REGISTER_OP(sgd_mom_update) MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update) -.describe(R"code(Momentum update function for Stochastic Gradient Descent (SDG) optimizer. +.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer. Momentum update has better convergence rates on neural networks. Mathematically it looks like below: @@ -323,10 +361,8 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type, -standard update is applied. - -If weight, grad and momentum are all of ``row_sparse`` storage type, +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage +type is the same as momentum's storage type, only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: for row in gradient.indices: @@ -339,7 +375,7 @@ only the row slices whose indices appear in grad.indices are updated (for both w .set_attr_parser(ParamParser<SGDMomParam>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>) -.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>) +.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, SGDMomParam>) .set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector<uint32_t>{2}; @@ -420,7 +456,7 @@ available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf. .add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``") .add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``") .add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``") -.add_arguments(AdamParam::__FIELDS__()); +.add_arguments(FTMLParam::__FIELDS__()); NNVM_REGISTER_OP(adam_update) MXNET_ADD_SPARSE_OP_ALIAS(adam_update) @@ -443,7 +479,8 @@ It updates the weights using:: v = beta2*v + (1-beta2)*(grad**2) w += - learning_rate * m / (sqrt(v) + epsilon) -If w, m and v are all of ``row_sparse`` storage type, +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage +type of weight is the same as those of m and v, only the row slices whose indices appear in grad.indices are updated (for w, m and v):: for row in grad.indices: @@ -461,7 +498,7 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>) -.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>) +.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, AdamParam>) .set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector<uint32_t>{2, 3}; diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index d1dc31a31c5..90762f7620f 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -204,7 +204,6 @@ def update(self, index, weight, grad, state): def update_multi_precision(self, index, weight, grad, state): self.update(index, weight, grad, state) -@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. Tracked at https://github.com/apache/incubator-mxnet/issues/9000") @with_seed() def test_sgd(): opt1 = PySGD @@ -233,16 +232,9 @@ def test_sgd(): continue compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) # test operator fallback on cpu - if (default_context() == mx.cpu()): - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, - g_stype='row_sparse') - if dtype != np.float16: - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2], - dtype, w_stype='csr', g_stype='csr') - # test optimizer with a big shape - big_shape = (54686454, 1) - kwarg = {'momentum': 0.9, 'wd': 0.05} - compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32) + if dtype != np.float16: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2], + dtype, w_stype='csr', g_stype='csr') class PySparseSGD(mx.optimizer.Optimizer): """python reference implemenation of sgd""" @@ -337,9 +329,11 @@ def test_sparse_sgd(): kwarg.update(mp_option) compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + w_stype='default', g_stype='row_sparse') -@with_seed(0) +@with_seed() def test_std_sparse_sgd(): opt1 = PySGD opt2 = mx.optimizer.SGD @@ -360,6 +354,8 @@ def test_std_sparse_sgd(): kwarg.update(wd_option) compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype, + w_stype='default', g_stype='row_sparse') class PyNAG(PySGD): @@ -543,7 +539,7 @@ def test_ftml(): class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - decay_factor=(1 - 1e-8), lazy_update=False, **kwargs): + decay_factor=(1 - 1e-8), lazy_update=True, **kwargs): super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 @@ -594,7 +590,7 @@ def update(self, index, weight, grad, state): for row in range(num_rows): # check row slices of all zeros all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy())) - # skip zeros during sparse update + # skip zeros during lazy update if all_zeros and self.lazy_update: continue grad[row] = grad[row] * self.rescale_grad + wd * weight[row] @@ -635,15 +631,21 @@ def test_adam(): not kwarg['multi_precision'])): continue # atol 2e-5 needed to pass with seed 1248389097 - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(**kwarg), shape, dtype, rtol=1e-4, atol=2e-5) # atol 2e-5 needed to pass with seed 781809840 - compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape, + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse', rtol=1e-4, atol=2e-5) - compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, + compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse', rtol=1e-4, atol=2e-5) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, + dtype, w_stype='default', g_stype='row_sparse', + rtol=1e-4, atol=2e-5) + compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape, + dtype, w_stype='default', g_stype='row_sparse', + rtol=1e-4, atol=2e-5) # Signum class PySignum(mx.optimizer.Optimizer): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services