ptrendx commented on a change in pull request #13346: Aggregate SGD
URL: https://github.com/apache/incubator-mxnet/pull/13346#discussion_r248392864
##########
File path: src/operator/optimizer_op.cc
##########
@@ -313,6 +315,209 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}
+NNVM_REGISTER_OP(multi_sgd_update)
+.describe(R"code(Update function for Stochastic Gradient Descent (SDG)
optimizer.
+
+It updates the weights using::
+
+ weight = weight - learning_rate * (gradient + wd * weight)
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights * 2);
+ })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights);
+ })
+.set_attr_parser(ParamParser<MultiSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 2>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
+ std::vector<std::string> ret;
+ for (uint32_t i = 0; i < num_args; ++i) {
+ ret.push_back(std::string("weight_") + std::to_string(i));
+ ret.push_back(std::string("grad_") + std::to_string(i));
+ }
+ return ret;
+ })
+.set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, type_identity, 2>)
+.add_argument("data", "NDArray-or-Symbol[]", "Weights")
+.add_arguments(MultiSGDParam::__FIELDS__());
+
+NNVM_REGISTER_OP(multi_sgd_mom_update)
+.describe(R"code(Momentum update function for Stochastic Gradient Descent
(SGD) optimizer.
+
+Momentum update has better convergence rates on neural networks.
Mathematically it looks
+like below:
+
+.. math::
+
+ v_1 = \alpha * \nabla J(W_0)\\
+ v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
+ W_t = W_{t-1} + v_t
+
+It updates the weights using::
+
+ v = momentum * v - learning_rate * gradient
+ weight += v
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at
each epoch.
+
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and
weight's storage
+type is the same as momentum's storage type,
+only the row slices whose indices appear in grad.indices are updated (for both
weight and momentum)::
+
+ for row in gradient.indices:
+ v[row] = momentum[row] * v[row] - learning_rate * gradient[row]
+ weight[row] += v[row]
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights * 3);
+ })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights);
+ })
+.set_attr_parser(ParamParser<MultiSGDMomParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 3>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
+ std::vector<std::string> ret;
+ for (uint32_t i = 0; i < num_args; ++i) {
+ ret.push_back(std::string("weight_") + std::to_string(i));
+ ret.push_back(std::string("grad_") + std::to_string(i));
+ ret.push_back(std::string("mom_") + std::to_string(i));
+ }
+ return ret;
+ })
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ std::vector<uint32_t> ret;
+ const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
+ for (int i = 0; i < param.num_weights; ++i) {
+ ret.push_back(i * 3 + 2);
+ }
+ return ret;
+ })
+.set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, type_identity, 3>)
+.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum")
+.add_arguments(MultiSGDMomParam::__FIELDS__());
+
+NNVM_REGISTER_OP(multi_mp_sgd_update)
+.describe(R"code(Update function for multi-precision Stochastic Gradient
Descent (SDG) optimizer.
+
+It updates the weights using::
+
+ weight = weight - learning_rate * (gradient + wd * weight)
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights * 3);
+ })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_weights);
+ })
+.set_attr_parser(ParamParser<MultiSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 3>)
+.set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDParam,
3, 1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
+ std::vector<std::string> ret;
+ for (uint32_t i = 0; i < num_args; ++i) {
+ ret.push_back(std::string("weight_") + std::to_string(i));
+ ret.push_back(std::string("grad_") + std::to_string(i));
+ ret.push_back(std::string("weight32_") + std::to_string(i));
+ }
+ return ret;
+ })
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ std::vector<uint32_t> ret;
+ const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+ for (int i = 0; i < param.num_weights; ++i) {
+ ret.push_back(i * 3 + 2);
+ }
+ return ret;
+ })
+.set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, single_precision, 3>)
+.add_argument("data", "NDArray-or-Symbol[]", "Weights")
+.add_arguments(MultiSGDParam::__FIELDS__());
+
+NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
+.describe(R"code(Momentum update function for multi-precision Stochastic
Gradient Descent (SGD) optimizer.
+
+Momentum update has better convergence rates on neural networks.
Mathematically it looks
+like below:
+
+.. math::
+
+ v_1 = \alpha * \nabla J(W_0)\\
+ v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
+ W_t = W_{t-1} + v_t
+
+It updates the weights using::
+
+ v = momentum * v - learning_rate * gradient
+ weight += v
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at
each epoch.
+
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and
weight's storage
Review comment:
Done
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services