ptrendx commented on a change in pull request #13346: Aggregate SGD
URL: https://github.com/apache/incubator-mxnet/pull/13346#discussion_r244893032
##########
File path: src/operator/optimizer_op-inl.h
##########
@@ -82,6 +82,301 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
}
};
+struct MultiSGDParam : public dmlc::Parameter<MultiSGDParam> {
+ nnvm::Tuple<float> lrs;
+ nnvm::Tuple<float> wds;
+ float rescale_grad;
+ float clip_gradient;
+ int num_weights;
+ DMLC_DECLARE_PARAMETER(MultiSGDParam) {
+ DMLC_DECLARE_FIELD(lrs)
+ .describe("Learning rates.");
+ DMLC_DECLARE_FIELD(wds)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each
weight.");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Rescale gradient to grad = rescale_grad*grad.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(num_weights)
+ .set_default(1)
+ .describe("Number of updated weights.");
+ }
+};
+
+struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
+ nnvm::Tuple<float> lrs;
+ nnvm::Tuple<float> wds;
+ float momentum;
+ float rescale_grad;
+ float clip_gradient;
+ int num_weights;
+ DMLC_DECLARE_PARAMETER(MultiSGDMomParam) {
+ DMLC_DECLARE_FIELD(lrs)
+ .describe("Learning rates.");
+ DMLC_DECLARE_FIELD(wds)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each
weight.");
+ DMLC_DECLARE_FIELD(momentum)
+ .set_default(0.0f)
+ .describe("The decay rate of momentum estimates at each epoch.");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Rescale gradient to grad = rescale_grad*grad.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(num_weights)
+ .set_default(1)
+ .describe("Number of updated weights.");
+ }
+};
+
+template<typename ParamType, int input_stride>
+inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_attrs,
+ std::vector<TShape> *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
+ CHECK_EQ(out_attrs->size(), param.num_weights);
+
+ bool all_inferred = true;
+ auto& input_shapes = *in_attrs;
+ auto& output_shapes = *out_attrs;
+ // Learning rates
+ CHECK_EQ(param.lrs.ndim(), param.num_weights)
+ << "Number of learning rates is inconsistent with num_weights "
+ << "parameter passed. Expected number of learning rates: "
+ << param.num_weights << ", and got " << param.lrs.ndim();
+ // Weight decays
+ CHECK_EQ(param.wds.ndim(), param.num_weights)
+ << "Number of weight decays is inconsistent with num_weights "
+ << "parameter passed. Expected number of weight decays: "
+ << param.num_weights << ", and got " << param.wds.ndim();
+ // Weights and gradients
+ for (int i = 0; i < param.num_weights; ++i) {
+ std::vector<TShape> input_vec;
+ std::vector<TShape> output_vec({output_shapes[i]});
+ for (int j = 0; j < input_stride; ++j) {
+ input_vec.push_back(input_shapes[i * input_stride + j]);
+ }
+ all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs,
&input_vec, &output_vec);
+ }
+ return all_inferred;
+}
+
+template <typename ParamType, int input_stride, int num_fp32_inputs>
+inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
+ CHECK_EQ(out_attrs->size(), param.num_weights);
+
+ bool all_inferred = true;
+ auto& input_types = *in_attrs;
+ auto& output_types = *out_attrs;
+ // Weights and gradients
+ for (int i = 0; i < param.num_weights; ++i) {
+ std::vector<int> input_vec;
+ std::vector<int> output_vec({output_types[i]});
+ for (int j = 0; j < input_stride - num_fp32_inputs; ++j) {
+ input_vec.push_back(input_types[i * input_stride + j]);
+ }
+ all_inferred = all_inferred &&
+ ElemwiseType<input_stride - num_fp32_inputs, 1>(attrs,
&input_vec, &output_vec);
+ }
+ // master copies of weights
+ for (int i = 0; i < param.num_weights; ++i) {
+ for (int j = 0; j < num_fp32_inputs; ++j) {
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j,
mshadow::kFloat32);
+ }
+ }
+ return all_inferred;
+}
+
+template<typename DType, typename MPDType>
+struct MultiSGDKernelParam {
+ static const int N = 60;
Review comment:
@anirudhacharya This is the reason of 60 - I pass this struct as kernel
parameter, which has a limit of 4 kB.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services