ptrendx commented on a change in pull request #13346: Aggregate SGD
URL: https://github.com/apache/incubator-mxnet/pull/13346#discussion_r248392897
 
 

 ##########
 File path: src/operator/optimizer_op.cc
 ##########
 @@ -313,6 +315,209 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
+NNVM_REGISTER_OP(multi_sgd_update)
+.describe(R"code(Update function for Stochastic Gradient Descent (SDG) 
optimizer.
+
+It updates the weights using::
+
+ weight = weight - learning_rate * (gradient + wd * weight)
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_weights * 2);
+  })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_weights);
+  })
+.set_attr_parser(ParamParser<MultiSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 2>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i) {
+      ret.push_back(std::string("weight_") + std::to_string(i));
+      ret.push_back(std::string("grad_") + std::to_string(i));
+    }
+    return ret;
+  })
+.set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, type_identity, 2>)
+.add_argument("data", "NDArray-or-Symbol[]", "Weights")
+.add_arguments(MultiSGDParam::__FIELDS__());
+
+NNVM_REGISTER_OP(multi_sgd_mom_update)
+.describe(R"code(Momentum update function for Stochastic Gradient Descent 
(SGD) optimizer.
+
+Momentum update has better convergence rates on neural networks. 
Mathematically it looks
+like below:
+
+.. math::
+
+  v_1 = \alpha * \nabla J(W_0)\\
+  v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
+  W_t = W_{t-1} + v_t
+
+It updates the weights using::
+
+  v = momentum * v - learning_rate * gradient
+  weight += v
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at 
each epoch.
+
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and 
weight's storage
 
 Review comment:
   Done

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to