access2rohit commented on a change in pull request #16885: Multi Precision Lamb
Update operator
URL: https://github.com/apache/incubator-mxnet/pull/16885#discussion_r352933498
##########
File path: src/operator/optimizer_op-inl.h
##########
@@ -1749,6 +1749,164 @@ inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs&
attrs,
});
}
+template<int n_in, int n_out, int total_in>
+inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator "
<< attrs.name;
+ CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator "
<< attrs.name;
+ for (int i = 0; i < n_in; ++i) {
+ TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16);
+ }
+ for (int i = n_in; i < total_in; ++i) {
+ TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
+ }
+ for (int i = 0; i < n_out; ++i) {
+ TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32);
+ }
+ return true;
+}
+
+struct MPLambUpdatePhaseOneKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, float* out_data,
+ float* mean_data, float* var_data, const DType* weight_data,
+ const DType* grad_data, const float* weight32_data,
+ const float clip_gradient, const float rescale_grad,
+ const float beta1_t, const float beta1,
+ const float beta2_t, const float beta2,
+ const float wd, const float epsilon, const int t,
+ bool bias_correction, const OpReqType req) {
+ using namespace mshadow_op;
+
+ float grad_rescaled = grad_data[i] * rescale_grad;
+ if (clip_gradient >= 0.f) {
+ grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+ }
+
+ mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
+ var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled *
grad_rescaled;
+
+ float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd *
weight32_data[i];
+
+ if (bias_correction) {
+ float mean_hat = mean_data[i] / (1. - beta1_t);
+ float var_hat = var_data[i] / (1 - beta2_t);
+ g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd *
weight32_data[i];
+ }
+ KERNEL_ASSIGN(out_data[i], req, g);
+ }
+};
+
+template<typename xpu>
+inline void MPLambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ using namespace mxnet_op;
+ const LambUpdatePhaseOneParam& param =
nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ float beta1_t = std::pow(param.beta1, param.t);
+ float beta2_t = std::pow(param.beta2, param.t);
+ Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
+
+ Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
+ out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
+ param.clip_gradient, param.rescale_grad, beta1_t, param.beta1, beta2_t,
param.beta2,
+ param.wd, param.epsilon, static_cast<int>(param.t), param.bias_correction,
req[0]);
Review comment:
will remove static_cast, since `t` in `LambUpdatePhaseOneParam` is already
changed to `int` in open PR:
https://github.com/apache/incubator-mxnet/pull/16903
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services