yuxiangw commented on a change in pull request #9220: Signum optimizer
URL: https://github.com/apache/incubator-mxnet/pull/9220#discussion_r159722035
 
 

 ##########
 File path: src/operator/optimizer_op.cc
 ##########
 @@ -35,6 +35,74 @@ DMLC_REGISTER_PARAMETER(AdamParam);
 DMLC_REGISTER_PARAMETER(RMSPropParam);
 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
+DMLC_REGISTER_PARAMETER(SignSGDParam);
+DMLC_REGISTER_PARAMETER(SignumParam);
+
+NNVM_REGISTER_OP(signsgd_update)
+// MXNET_ADD_SPARSE_OP_ALIAS(signsgd_update)
+.describe(R"code(Update function for SignSGDoptimizer.
+It updates the weights using::
+
+ weight = weight - learning_rate * sign(gradient)
+
+
+** Sparse matrix not supported for this optimizer yet.
+
+If weight is of ``row_sparse`` storage type,
+only the row slices whose indices appear in grad.indices are updated::
+
+ for row in gradient.indices:
+     weight[row] = weight[row] - learning_rate * gradient[row]
+
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1, 
false, true, false>)
+.set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
+// .set_attr<FComputeEx>("FComputeEx<cpu>", SignSGDUpdateEx<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_arguments(SignSGDParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(signum_update)
+// MXNET_ADD_SPARSE_OP_ALIAS(signum_update)
+.describe(R"code(SIGN momentUM (Signum) optimizer.
+
+ weight = weight - learning_rate * sign(momentum)
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at 
each epoch.
+
+** Sparse matrix not supported for this optimizer yet.
+
+If weight and momentum are both of ``row_sparse`` storage type,
+only the row slices whose indices appear in grad.indices are updated (for both 
weight and momentum)::
+
+  for row in gradient.indices:
+      v[row] = momentum[row] * v[row] - learning_rate * gradient[row]
+      weight[row] += v[row]
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignumParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, 
false, true, false>)
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to