drivanov commented on a change in pull request #16398: Aggregated adamw update
URL: https://github.com/apache/incubator-mxnet/pull/16398#discussion_r332776832
##########
File path: src/operator/contrib/adamw-inl.h
##########
@@ -211,6 +192,314 @@ struct AdamWUpdate {
}
};
+////
+// Multiple gradients in single kernel
+////
+
+struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
+ mxnet::Tuple<float> lrs;
+ mxnet::Tuple<float> wds;
+ mxnet::Tuple<float> etas;
+ float beta1;
+ float beta2;
+ float epsilon;
+ float rescale_grad;
+ float clip_gradient;
+ int num_weights;
+ DMLC_DECLARE_PARAMETER(MultiAdamWParam) {
+ DMLC_DECLARE_FIELD(lrs)
+ .describe("Learning rates");
+ DMLC_DECLARE_FIELD(beta1)
+ .set_default(0.9f)
+ .describe("The decay rate for the 1st moment estimates.");
+ DMLC_DECLARE_FIELD(beta2)
+ .set_default(0.999f)
+ .describe("The decay rate for the 2nd moment estimates.");
+ DMLC_DECLARE_FIELD(epsilon)
+ .set_default(1e-8f)
+ .describe("A small constant for numerical stability.");
+ DMLC_DECLARE_FIELD(wds)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each
weight.");
+ DMLC_DECLARE_FIELD(etas)
+ .describe("Learning rates schedule multiplier");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(num_weights)
+ .set_default(1)
+ .describe("Number of updated weights.");
+ }
+};
+
+
+template<typename ParamType, int input_stride>
+inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1);
+ CHECK_EQ(out_attrs->size(), param.num_weights);
+
+ bool all_inferred = true;
+ auto& input_shapes = *in_attrs;
+ auto& output_shapes = *out_attrs;
+
+ // Learning rates
+ CHECK_EQ(param.lrs.ndim(), param.num_weights)
+ << "Number of learning rates is inconsistent with num_weights "
+ << "parameter passed. Expected number of learning rates: "
+ << param.num_weights << ", and got " << param.lrs.ndim();
+ // Weight decays
+ CHECK_EQ(param.wds.ndim(), param.num_weights)
+ << "Number of weight decays is inconsistent with num_weights "
+ << "parameter passed. Expected number of weight decays: "
+ << param.num_weights << ", and got " << param.wds.ndim();
+ // Learning rates schedule multiplier
+ CHECK_EQ(param.etas.ndim(), param.num_weights)
+ << "Number of learning rates schedule multiplier is inconsistent with
num_weights "
+ << "parameter passed. Expected number of learning rates schedule
multiplier: "
+ << param.num_weights << ", and got " << param.lrs.ndim();
+
+ // Weights, gradients, mean and variance
+ for (int i = 0; i < param.num_weights; ++i) {
+ mxnet::ShapeVector input_vec;
+ mxnet::ShapeVector output_vec({output_shapes[i]});
+ for (int j = 0; j < input_stride; ++j) {
+ input_vec.push_back(input_shapes[i * input_stride + j]);
+ }
+ all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs,
&input_vec, &output_vec);
+ }
+ // rescale_grad.shape = ()
+ SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights*input_stride,
mxnet::TShape());
+ return all_inferred;
+}
+
+template <typename ParamType, int input_stride, int num_fp32_inputs>
+inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1);
+ CHECK_EQ(out_attrs->size(), param.num_weights);
+
+ bool all_inferred = true;
+ auto& input_types = *in_attrs;
+ auto& output_types = *out_attrs;
+
+ // Weights, gradients,
+ for (int i = 0; i < param.num_weights; ++i) {
+ std::vector<int> input_vec;
+ std::vector<int> output_vec({output_types[i]});
+ for (int j = 0; j < input_stride - 2 - num_fp32_inputs; ++j) {
+ input_vec.push_back(input_types[i * input_stride + j]);
+ }
+ all_inferred = all_inferred &&
+ ElemwiseType<input_stride - 2 - num_fp32_inputs, 1>(attrs,
&input_vec, &output_vec);
+ }
+ // mean, var
+ for (int i = 0; i < param.num_weights; ++i) {
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i +2, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i +3, mshadow::kFloat32);
+ }
+
+ // master copies of weights
+ for (int i = 0; i < param.num_weights; ++i) {
+ for (int j = 0; j < num_fp32_inputs; ++j) {
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j,
mshadow::kFloat32);
+ }
+ }
+ // rescale_grad.type = ()
+ TYPE_ASSIGN_CHECK(input_types, param.num_weights*input_stride,
mshadow::kFloat32);
+ return all_inferred;
+}
+
+
+template<typename T>
+class Adam_type_identity {
+ public:
+ using type = T;
+};
+
+
+template<typename T>
+class Adam_single_precision {
+ public:
+ using type = float;
+};
+
+template<typename DType, typename MPDType>
+struct MultiAdamKernelParam {
+ static const int N = 50;
+ int count;
+ size_t max_size;
+ size_t sizes[N];
+ DType* weights[N];
+ DType* grad_data[N];
+ MPDType* mean_data[N];
+ MPDType* var_data[N];
+ MPDType* weights32[N];
+ DType* out_data[N];
+ MPDType clip_gradient;
+ MPDType beta1;
+ MPDType beta2;
+ MPDType etas[N];
+ MPDType lrs[N];
+ MPDType wds[N];
+ MPDType epsilon;
+};
+
+template<typename MPDType, bool has_mixed_precision>
+struct MultiMPAdamWKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam<DType,
MPDType>& param,
+ const OpReqType req, const float
rescale_grad){
+ for (int index = 0; index < param.count; ++index) {
+ if ((size_t)i < param.sizes[index]) {
+ MPDType w = has_mixed_precision ? param.weights32[index][i]:
+ MPDType(param.weights[index][i]);
+ MPDType scaled_grad = static_cast<MPDType>(rescale_grad)*
+ static_cast<MPDType>(param.grad_data[index][i]);
+
+ if (param.clip_gradient >= 0.0f)
+ scaled_grad = mshadow_op::clip::Map(scaled_grad,
param.clip_gradient);
+
+ const auto mean = param.beta1 * (param.mean_data[index][i]-
scaled_grad) + scaled_grad;
+ const auto adj = mshadow_op::square::Map(scaled_grad);
+ const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj;
+
+ param.mean_data[index][i] = mean;
+ param.var_data[index][i] = var;
+ w = w - param.etas[index] * (param.lrs[index] *
+ mean / (mshadow_op::square_root::Map(var) + param.epsilon)
+ + param.wds[index] * w);
+ if (has_mixed_precision)
+ param.weights32[index][i] = w;
+
+ KERNEL_ASSIGN(param.out_data[index][i], req, w);
+ }
+ }
+ }
+};
+
+template<typename xpu,
+ typename DType,
+ typename MPDType,
+ typename ParamType = MultiAdamWParam,
+ int input_stride = 4>
+void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<TBlob> &outputs,
+ MultiAdamKernelParam<DType, MPDType> *pParam) {
+ const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
+ mxnet_op::Stream<xpu>* s = ctx.get_stream<xpu>();
+ pParam->clip_gradient = p.clip_gradient;
+ pParam->beta1 = p.beta1;
+ pParam->beta2 = p.beta2;
+
+ pParam->epsilon = p.epsilon;
+
+ pParam->count = p.num_weights;
+ pParam->max_size = 0;
+ const bool isSame = std::is_same<DType, MPDType>::value;
Review comment:
I will make this change.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services