This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 349803c Multi-precision AdamW update op (#14171)
349803c is described below
commit 349803ce9e737248ef8eb97914fcd87d9a5d75d8
Author: Haibin Lin <[email protected]>
AuthorDate: Tue Feb 19 16:02:00 2019 -0800
Multi-precision AdamW update op (#14171)
* mp adamw update
* Softmax fp16 (#201)
* softmax for fp16 with fp32 accumulator
* return AType in kernel
* add dtype
* kernel
* adamw with nan check
* add doc
* Revert "Softmax fp16 (#201)"
This reverts commit 5869e0ae832437c839bb4ccbcc434971bf5c3486.
* add test
* more test for fp16
* skip update for rescale = 0
---
src/operator/contrib/adamw-inl.h | 165 ++++++++++++++++++------
src/operator/contrib/adamw.cc | 76 ++++++++++-
src/operator/contrib/adamw.cu | 27 +++-
tests/python/unittest/test_contrib_optimizer.py | 84 ++++++++++++
4 files changed, 310 insertions(+), 42 deletions(-)
diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h
index 3d76b33..66bd4f3 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adamw-inl.h
@@ -33,6 +33,7 @@
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
+#include <cmath>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
@@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float epsilon;
float wd;
float eta;
- float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(AdamWParam) {
DMLC_DECLARE_FIELD(lr)
@@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
"The penalty scales with the square of the magnitude of each
weight.");
DMLC_DECLARE_FIELD(eta)
.describe("Learning rate schedule multiplier");
- DMLC_DECLARE_FIELD(rescale_grad)
- .set_default(1.0f)
- .describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
@@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
}
};
+// rescale_grad is a reserved argument at position -1. Example:
+// n_in = 2: weight, grad (fp16)
+// n_out = 1: weight (fp16)
+// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
+template<int n_in, int n_out, int total_in>
+inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_attrs,
+ std::vector<TShape> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator "
<< attrs.name;
+ CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator "
<< attrs.name;
+ // rescale_grad.shape = (1,)
+ SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1));
+ return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string,
n_in, n_out>(
+ attrs, in_attrs, out_attrs, TShape());
+}
+
+// rescale_grad is a reserved argument at position -1. Example:
+// n_in = 2: weight, grad (fp16)
+// n_out = 1: weight (fp16)
+// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
+template<int n_in, int n_out, int total_in>
+inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator "
<< attrs.name;
+ CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator "
<< attrs.name;
+ for (int i = n_in; i < total_in; ++i) {
+ TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
+ }
+ return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in,
n_out>(
+ attrs, in_attrs, out_attrs, -1);
+}
+
+template<int req>
+struct MPAdamWKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
+ float* var_data, const DType* weight_data, const DType* grad_data, float*
weight32,
+ const float param_clip_gradient, const float param_beta1, const float
param_beta2,
+ const float param_eta, const float param_lr, const float param_wd,
+ const float param_rescale_grad, const float param_epsilon) {
+ float w = weight32[i];
+ float mean = mean_data[i];
+ float var = var_data[i];
+ float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
+ if (param_clip_gradient >= 0.0f) {
+ mean = param_beta1 * mean +
+ (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad,
param_clip_gradient);
+ var = param_beta2 * var + (1 - param_beta2) *
+ mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad,
param_clip_gradient));
+ } else {
+ mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad;
+ var = param_beta2 * var + (1 - param_beta2) *
mshadow_op::square::Map(scaled_grad);
+ }
+ mean_data[i] = mean;
+ var_data[i] = var;
+ w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var)
+ param_epsilon)
+ + param_wd * w);
+ weight32[i] = w;
+ KERNEL_ASSIGN(out_data[i], req, w);
+ }
+};
+
+
+template<typename xpu>
+struct MPAdamWUpdate {
+ static inline void Forward(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs,
+ const float rescale_grad) {
+ using namespace mxnet_op;
+ AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
+ Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_,
+ var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
param.clip_gradient, param.beta1,
+ param.beta2, param.eta, param.lr, param.wd, rescale_grad,
param.epsilon);
+ });
+ });
+ }
+};
+
/*
* \brief adam_w update.
*/
template<typename xpu>
-inline void AdamWUpdate(const nnvm::NodeAttrs& attrs,
- const OpContext &ctx,
- const std::vector<TBlob> &inputs,
- const std::vector<OpReqType> &req,
- const std::vector<TBlob> &outputs) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace mshadow_op;
- const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
- Stream<xpu>* s = ctx.get_stream<xpu>();
- MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
- Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
- Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
- Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
- Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
- Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+struct AdamWUpdate {
+ static inline void Forward(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs,
+ const float rescale_grad) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace mshadow_op;
+ const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
- grad = scalar<DType>(param.rescale_grad) * grad;
- if (param.clip_gradient >= 0.0f) {
- mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
- F<clip>(grad, DType(param.clip_gradient));
- var = scalar<DType>(param.beta2)*var +
scalar<DType>(1.f-param.beta2)*F<square>(
- F<clip>(grad, DType(param.clip_gradient)));
- } else {
- mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1)
* grad;
- var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) *
F<square>(grad);
- }
- Assign(out, req[0],
- weight -
- scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
- mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
- (scalar<DType>(param.wd) * weight)));
- });
-}
+ grad = scalar<DType>(rescale_grad) * grad;
+ if (param.clip_gradient >= 0.0f) {
+ mean = scalar<DType>(param.beta1)*mean +
scalar<DType>(1.f-param.beta1) *
+ F<clip>(grad, DType(param.clip_gradient));
+ var = scalar<DType>(param.beta2)*var +
scalar<DType>(1.f-param.beta2)*F<square>(
+ F<clip>(grad, DType(param.clip_gradient)));
+ } else {
+ mean = scalar<DType>(param.beta1)*mean +
scalar<DType>(1.f-param.beta1) * grad;
+ var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)
* F<square>(grad);
+ }
+ Assign(out, req[0],
+ weight -
+ scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
+ mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
+ (scalar<DType>(param.wd) * weight)));
+ });
+ }
+};
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index 94623fe..2fbc397 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -24,12 +24,76 @@
* \author Haibin Lin
*/
#include "./adamw-inl.h"
+#include "../optimizer_op-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(AdamWParam);
+template<template <typename xpu> class F>
+inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ // copy to cpu and check NaN value
+ TBlob scale_blob = inputs[inputs.size() - 1];
+ MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
+ float scalef = static_cast<float>(*scale_blob.dptr<DType>());
+ if (!std::isfinite(scalef) || scalef == 0) return;
+ std::vector<TBlob> inputs_wo_scale;
+ size_t num_in = inputs.size();
+ inputs_wo_scale.reserve(num_in - 1);
+ for (size_t i = 0; i < num_in - 1; i++)
inputs_wo_scale.emplace_back(inputs[i]);
+ F<cpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
+ });
+}
+
+NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+.describe(R"code(Update function for multi-precision AdamW optimizer.
+
+AdamW is seen as a modification of Adam by decoupling the weight decay from the
+optimization steps taken w.r.t. the loss function.
+
+Adam update consists of the following steps, where g represents gradient and
m, v
+are 1st and 2nd order moment estimates (mean and variance).
+
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
+ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd
W_{t-1})
+
+It updates the weights using::
+
+ m = beta1*m + (1-beta1)*grad
+ v = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+
+Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad
is NaN, Inf, or 0,
+the update is skipped.
+)code" ADD_FILELINE)
+.set_num_inputs(6)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
+.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<uint32_t>{2, 3, 4};
+ })
+.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<MPAdamWUpdate>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
+.add_argument("var", "NDArray-or-Symbol", "Moving variance")
+.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
+.add_argument("rescale_grad", "NDArray-or-Symbol",
+ "Rescale gradient to rescale_grad * grad. If NaN, the update is
skipped.")
+.add_arguments(AdamWParam::__FIELDS__());
+
NNVM_REGISTER_OP(_contrib_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a
modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t.
the loss function.
@@ -50,21 +114,25 @@ It updates the weights using::
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad
is NaN, Inf, or 0,
+the update is skipped.
)code" ADD_FILELINE)
-.set_num_inputs(4)
+.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
+.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
-.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<AdamWUpdate>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
+.add_argument("rescale_grad", "NDArray-or-Symbol",
+ "Rescale gradient to rescale_grad * grad. If NaN, the update is
skipped.")
.add_arguments(AdamWParam::__FIELDS__());
} // namespace op
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index b7452f8..e21b83b 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -28,8 +28,33 @@
namespace mxnet {
namespace op {
+template<template <typename xpu> class F>
+inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ // copy to cpu and check NaN value
+ TBlob scale_blob = inputs[inputs.size() - 1];
+ MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
+ DType scale = 0;
+ CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType),
+ cudaMemcpyDeviceToHost));
+ float scalef = static_cast<float>(scale);
+ if (!std::isfinite(scalef) || scalef == 0) return;
+ std::vector<TBlob> inputs_wo_scale;
+ size_t num_in = inputs.size();
+ inputs_wo_scale.reserve(num_in - 1);
+ for (size_t i = 0; i < num_in - 1; i++)
inputs_wo_scale.emplace_back(inputs[i]);
+ F<gpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
+ });
+}
+
NNVM_REGISTER_OP(_contrib_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);
+
+NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_optimizer.py
b/tests/python/unittest/test_contrib_optimizer.py
index 8ff8a7e..dad7bed 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -94,6 +94,90 @@ def test_group_adagrad():
g_stype='row_sparse',
compare_states=False)
+def test_adamw():
+ shape = (3, 4)
+ weight = mx.nd.random.uniform(shape=shape)
+ weight_ref = weight.copy()
+ grad = mx.nd.random.uniform(shape=shape)
+ m = mx.nd.random.uniform(shape=shape)
+ v = mx.nd.random.uniform(shape=shape)
+ rescale_grad = mx.nd.array([10])
+ eta, lr, wd, epsilon = 1, 1, 0, 1e-8
+ beta1, beta2 = 0.9, 0.999
+ kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
+ 'beta1': beta1, 'beta2': beta2}
+
+ # update is skipped for rescale = 0
+ mx.nd.contrib.adamw_update(weight, grad, m, v,
+ rescale_grad * 0, out=weight, **kwargs)
+ # weight remains unchanged
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+ # update is skipped for rescale = nan
+ mx.nd.contrib.adamw_update(weight, grad, m, v,
+ rescale_grad * np.nan, out=weight, **kwargs)
+ # weight remains unchanged
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+ # update is skipped for rescale = inf
+ mx.nd.contrib.adamw_update(weight, grad, m, v,
+ rescale_grad * np.inf, out=weight, **kwargs)
+ # weight remains unchanged
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+ # multi-precision update is skipped for rescale = nan
+ weight_fp16 = weight.astype('float16')
+ grad_fp16 = grad.astype('float16')
+ weight_fp16_ref = weight_fp16.copy()
+ mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+ rescale_grad * np.nan, out=weight_fp16,
**kwargs)
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(),
weight_fp16.asnumpy())
+
+ # multi-precision update is skipped for rescale = inf
+ mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+ rescale_grad * np.inf, out=weight_fp16,
**kwargs)
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(),
weight_fp16.asnumpy())
+
+ # multi-precision update is skipped for rescale = 0
+ mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+ rescale_grad * 0, out=weight_fp16, **kwargs)
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(),
weight_fp16.asnumpy())
+
+ # reference normal update
+ grad_rescale = rescale_grad * grad
+ m_ref = beta1*m + (1-beta1)*grad_rescale
+ v_ref = beta2*v + (1-beta2)*(grad_rescale**2)
+ weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight
* wd)
+ m_test = m.copy()
+ v_test = v.copy()
+ weight_test = weight.copy()
+ # op normal update
+ mx.nd.contrib.adamw_update(weight_test, grad, m_test, v_test,
+ rescale_grad, out=weight_test, **kwargs)
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(),
weight_test.asnumpy())
+ mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_test.asnumpy())
+ mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_test.asnumpy())
+
+ # reference normal multi-precision update
+ m_fp32 = m.copy()
+ v_fp32 = v.copy()
+ weight_fp32 = weight.copy()
+ grad_rescale = rescale_grad * grad_fp16.astype('float32')
+ m_ref = beta1*m_fp32 + (1-beta1)*grad_rescale
+ v_ref = beta2*v_fp32 + (1-beta2)*(grad_rescale**2)
+ weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight
* wd)
+ weight_fp16_ref = weight_ref.astype('float16')
+ # op normal multi-precision update
+ mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m_fp32, v_fp32,
weight_fp32,
+ rescale_grad, out=weight_fp16, **kwargs)
+ mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_fp32.asnumpy())
+ mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_fp32.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(),
weight_fp32.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(),
weight_fp16.asnumpy())
+
if __name__ == '__main__':
import nose