eric-haibin-lin commented on a change in pull request #16885: Multi Precision
Lamb Update operator
URL: https://github.com/apache/incubator-mxnet/pull/16885#discussion_r352905436
##########
File path: src/operator/optimizer_op-inl.h
##########
@@ -1749,6 +1749,161 @@ inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs&
attrs,
});
}
+template<int n_in, int n_out, int total_in>
+inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator "
<< attrs.name;
+ CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator "
<< attrs.name;
+ for (int i = 0; i < n_in; ++i) {
+ TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16);
+ }
+ for (int i = n_in; i < total_in; ++i) {
+ TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
+ }
+ for (int i = 0; i < n_out; ++i) {
+ TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32);
+ }
+ return true;
+}
+
+struct MPLambUpdatePhaseOneKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, float* out_data,
+ float* mean_data, float* var_data, const DType* weight_data,
+ const DType* grad_data, const float* weight32_data,
+ const float clip_gradient, const float rescale_grad,
+ const float beta1, const float beta2, const float wd,
+ const float epsilon, const float t,
+ bool bias_correction, const OpReqType req) {
+ using namespace mshadow_op;
+
+ float grad_rescaled = grad_data[i] * rescale_grad;
+ if (clip_gradient >= 0.f) {
+ grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+ }
+
+ mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
+ var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled *
grad_rescaled;
+
+ float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd *
weight32_data[i];
+
+ if (bias_correction) {
+ float mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
Review comment:
Are we not using std pow with integer t?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services