eric-haibin-lin commented on a change in pull request #14568: NAG Optimizer
with multi-precision support
URL: https://github.com/apache/incubator-mxnet/pull/14568#discussion_r275141006
##########
File path: src/operator/optimizer_op.cc
##########
@@ -705,6 +707,98 @@ only the row slices whose indices appear in grad.indices
are updated (for w, m a
.add_arguments(AdamParam::__FIELDS__());
+NNVM_REGISTER_OP(nag_update)
+MXNET_ADD_SPARSE_OP_ALIAS(nag_update)
+.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG)
optimizer.
+It updates the weights using the following formula,
+
+weight = weight - (lr * (grad + wd * weight))
+
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NAGParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FCompute>("FCompute<cpu>", NAGUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_arguments(NAGParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(nag_mom_update)
+MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update)
+.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG)
optimizer.
+It updates the weights using the following formula,
+
+.. math::
+ v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
+ W_t = W_{t-1} - v_t
+
+Where
+:math:`\eta` is the learning rate of the optimizer
+:math:`\gamma` is the decay rate of the momentum estimate
+:math:`\v_t` is the update vector at time step `t`
+:math:`\W_t` is the weight vector at time step `t`
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NAGMomParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<uint32_t>{2};
+ })
+.set_attr<FCompute>("FCompute<cpu>", NAGMomUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mom", "NDArray-or-Symbol", "Momentum")
+.add_arguments(NAGMomParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(mp_nag_update)
+MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update)
Review comment:
remove MXNET_ADD_SPARSE_OP_ALIAS
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services