This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 85d3ef3  Lamb optimizer update (#16715)
85d3ef3 is described below

commit 85d3ef3a40da20a4aac3030950aa0f37f8cb89c5
Author: Rohit Kumar Srivastava <[email protected]>
AuthorDate: Sat Nov 23 22:19:07 2019 -0800

    Lamb optimizer update (#16715)
    
    * initial commit lamb optimizer
    
    * fixing base lamb optimizer
    
    * adding API doc for Lamb Phase 1 and 2
---
 python/mxnet/optimizer/optimizer.py     |  52 ++++++++-
 src/operator/optimizer_op-inl.h         | 186 ++++++++++++++++++++++++++++++++
 src/operator/optimizer_op.cc            |  81 ++++++++++++++
 src/operator/optimizer_op.cu            |   7 ++
 tests/python/unittest/test_optimizer.py |  73 +++++++++++++
 5 files changed, 397 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/optimizer/optimizer.py 
b/python/mxnet/optimizer/optimizer.py
index b7311b2..00d130b 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -34,14 +34,14 @@ from ..ndarray import (sgd_update, sgd_mom_update, 
adam_update, rmsprop_update,
                        multi_sgd_update, multi_sgd_mom_update, 
multi_mp_sgd_update,
                        multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
                        preloaded_multi_sgd_mom_update, 
preloaded_multi_mp_sgd_update,
-                       preloaded_multi_mp_sgd_mom_update)
+                       preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, 
lamb_update_phase2)
 from ..ndarray import sparse
 from ..random import normal
 from ..util import is_np_array
 
 __all__ = [
     'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 
'LBSGD',
-    'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
+    'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 
'LAMB',
     'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
 ]
 
@@ -1244,6 +1244,54 @@ class LBSGD(Optimizer):
             kwargs = {}
             sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
 
+
+@register
+class LAMB(Optimizer):
+    """LAMB Optimizer.
+    """
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, 
epsilon=1e-6,
+                 lower_bound=None, upper_bound=None, bias_correction=True, 
**kwargs):
+        super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        self.bias_correction = bias_correction
+
+
+    def create_state(self, index, weight):
+        stype = weight.stype
+        dtype = weight.dtype
+        return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
+                zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
+
+    def update(self, index, weight, grad, state):
+        assert(isinstance(weight, NDArray))
+        assert(isinstance(grad, NDArray))
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        t = self._index_update_count[index]
+
+        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': 
self.epsilon,
+                  'bias_correction': self.bias_correction, 't': t,
+                  'rescale_grad': self.rescale_grad}
+        mean, var = state
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
+        g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
+
+        kwargs = {}
+        if self.lower_bound:
+            kwargs['lower_bound'] = self.lower_bound
+        if self.upper_bound:
+            kwargs['upper_bound'] = self.upper_bound
+        r_1 = weight.norm()
+        r_2 = g.norm()
+        lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
+
+
 # pylint: enable=line-too-long
 @register
 class DCASGD(Optimizer):
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index c211d32..698f797 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+struct LambUpdatePhaseOneParam : public 
dmlc::Parameter<LambUpdatePhaseOneParam> {
+    float beta1;
+    float beta2;
+    float epsilon;
+    float t;
+    bool bias_correction;
+    float wd;
+    float rescale_grad;
+    float clip_gradient;
+    DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
+      DMLC_DECLARE_FIELD(beta1)
+      .set_default(0.9f)
+      .describe("The decay rate for the 1st moment estimates.");
+      DMLC_DECLARE_FIELD(beta2)
+      .set_default(0.999f)
+      .describe("The decay rate for the 2nd moment estimates.");
+      DMLC_DECLARE_FIELD(epsilon)
+      .set_default(1e-6f)
+      .describe("A small constant for numerical stability.");
+      DMLC_DECLARE_FIELD(t)
+      .describe("Index update count.");
+      DMLC_DECLARE_FIELD(bias_correction)
+      .set_default(true)
+      .describe("Whether to use bias correction.");
+      DMLC_DECLARE_FIELD(wd)
+      .describe("Weight decay augments the objective function with a "
+                "regularization term that penalizes large weights. "
+                "The penalty scales with the square of the magnitude of each 
weight.");
+      DMLC_DECLARE_FIELD(rescale_grad)
+      .set_default(1.0f)
+      .describe("Rescale gradient to grad = rescale_grad*grad.");
+      DMLC_DECLARE_FIELD(clip_gradient)
+      .set_default(-1.0f)
+      .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] 
"
+                "If clip_gradient <= 0, gradient clipping is turned off. "
+                "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    }
+};
+
+struct LambUpdatePhaseTwoParam : public 
dmlc::Parameter<LambUpdatePhaseTwoParam> {
+    float lr;
+    float lower_bound;
+    float upper_bound;
+    DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
+      DMLC_DECLARE_FIELD(lr)
+      .describe("Learning rate");
+      DMLC_DECLARE_FIELD(lower_bound)
+      .set_default(-1.0f)
+      .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower 
limit is not set");
+      DMLC_DECLARE_FIELD(upper_bound)
+      .set_default(-1.0f)
+      .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper 
limit is not set");
+    }
+};
+
+struct LambUpdatePhaseOneKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data,
+    DType* mean_data, DType* var_data, const DType* weight_data, const DType* 
grad_data,
+    const DType clip_gradient, const DType rescale_grad,
+    const DType beta1, const DType beta2, const DType wd,
+    const DType epsilon, const DType t,
+    bool bias_correction, const OpReqType req) {
+    using namespace mshadow_op;
+
+    DType grad_rescaled = grad_data[i] * rescale_grad;
+    if (clip_gradient >= 0.f) {
+      grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+    }
+
+    mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
+    var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * 
grad_rescaled;
+
+    DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * 
weight_data[i];
+
+    if (bias_correction) {
+      DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
+      DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
+      g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * 
weight_data[i];
+    }
+    KERNEL_ASSIGN(out_data[i], req, g);
+  }
+};
+
+template<typename xpu>
+inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  const LambUpdatePhaseOneParam& param = 
nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+
+  Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
+    out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
+    static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.rescale_grad),
+    static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
+    static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
+    static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), 
req[0]);
+  });
+}
+
+inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector* in_attrs,
+                            mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 4U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
+
+  mxnet::TShape& weight_shape = in_attrs->at(0);
+  mxnet::TShape& g_shape = in_attrs->at(1);
+  CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
+           << "total no. of dimensions for weights and g must match";
+  for (int i=0; i < weight_shape.ndim(); ++i) {
+    CHECK_EQ(weight_shape[i], g_shape[i])
+           << "weight and g dimension size mismatch at " << i << "-th index";
+  }
+  mxnet::TShape& r1_shape = in_attrs->at(2);
+  mxnet::TShape& r2_shape = in_attrs->at(3);
+  CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
+  CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
+  for (int i=0; i < expected_out.ndim(); ++i) {
+    expected_out[i] = weight_shape[i];
+  }
+
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
+  return shape_is_known(expected_out);
+}
+
+struct LambUpdatePhaseTwoKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data,
+    const DType* weight_data, const DType* g,
+    const DType* r1, const DType* r2,
+    DType lr, const DType lower_bound,
+    const DType upper_bound, const OpReqType req) {
+    using namespace mshadow_op;
+
+    DType new_r1 = r1[0];
+    if (lower_bound >= 0) {
+      new_r1 = maximum::Map(new_r1, lower_bound);
+    }
+    if (upper_bound >= 0) {
+      new_r1 = minimum::Map(new_r1, upper_bound);
+    }
+    if (new_r1 == 0.0f || r2[0] == 0.0f) {
+      lr = lr * 1.0f;
+    } else {
+      lr = lr * new_r1 / r2[0];
+    }
+
+    KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
+  }
+};
+
+template<typename xpu>
+inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  const LambUpdatePhaseTwoParam& param = 
nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+
+  Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
+    out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
+    static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
+    static_cast<DType>(param.upper_bound), req[0]);
+  });
+}
+
 // This RMSProp code follows the version in
 // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
 // by Alex Graves, 2013.
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 0141086..9cf3277 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -43,6 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
 DMLC_REGISTER_PARAMETER(SignSGDParam);
 DMLC_REGISTER_PARAMETER(SignumParam);
 DMLC_REGISTER_PARAMETER(AdagradParam);
+DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
+DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);
 
 NNVM_REGISTER_OP(signsgd_update)
 .describe(R"code(Update function for SignSGD optimizer.
@@ -921,5 +923,84 @@ Note that non-zero values for the weight decay option are 
not supported.
 .add_argument("history", "NDArray-or-Symbol", "History")
 .add_arguments(AdagradParam::__FIELDS__());
 
+NNVM_REGISTER_OP(lamb_update_phase1)
+.describe(R"code(Phase I of lamb update it performs the following operations 
and returns g:.
+
+Link to paper: https://arxiv.org/pdf/1904.00962.pdf
+
+.. math::
+    \begin{gather*}
+    grad = grad * rescale_grad
+    if (grad < -clip_gradient)
+    then
+         grad = -clip_gradient
+    if (grad > clip_gradient)
+    then
+         grad = clip_gradient
+
+    mean = beta1 * mean + (1 - beta1) * grad;
+    variance = beta2 * variance + (1. - beta2) * grad ^ 2;
+
+    if (bias_correction)
+    then
+         mean_hat = mean / (1. - beta1^t);
+         var_hat = var / (1 - beta2^t);
+         g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
+    else
+         g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i];
+    \end{gather*}
+
+)code" ADD_FILELINE)
+.set_num_inputs(4)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
+.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2, 3};
+  })
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
+.add_argument("var", "NDArray-or-Symbol", "Moving variance")
+.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
+
+NNVM_REGISTER_OP(lamb_update_phase2)
+.describe(R"code(Phase II of lamb update it performs the following operations 
and updates grad.
+
+Link to paper: https://arxiv.org/pdf/1904.00962.pdf
+
+.. math::
+    \begin{gather*}
+    if (lower_bound >= 0)
+    then
+         r1 = max(r1, lower_bound)
+    if (upper_bound >= 0)
+    then
+         r1 = max(r1, upper_bound)
+
+    if (r1 == 0 or r2 == 0)
+    then
+         lr = lr
+    else
+         lr = lr * (r1/r2)
+    weight = weight - lr * g
+    \end{gather*}
+
+)code" ADD_FILELINE)
+.set_num_inputs(4)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
+.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
+.add_argument("r1", "NDArray-or-Symbol", "r1")
+.add_argument("r2", "NDArray-or-Symbol", "r2")
+.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 2c72462..a602b64 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -277,5 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
 NNVM_REGISTER_OP(_sparse_adagrad_update)
 .set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
 
+NNVM_REGISTER_OP(lamb_update_phase1)
+.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);
+
+NNVM_REGISTER_OP(lamb_update_phase2)
+.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);
+
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index f9adf63..4dbf251 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -425,6 +425,79 @@ def test_nag():
                 continue
             compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, 
rtol=1e-3, atol=1e-4)
 
+
+# LAMB optimizer
+class PyLAMB(mx.optimizer.Optimizer):
+    """
+       Python reference implementation of LAMB optimizer.
+    """
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, 
epsilon=1e-6,
+                 lower_bound=None, upper_bound=None, bias_correction=True, 
**kwargs):
+        super(PyLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        self.bias_correction = bias_correction
+
+    def create_state(self, index, weight):
+        stype = weight.stype
+        return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=stype),
+                mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=stype))
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        t = self._index_update_count[index]
+
+        grad *= self.rescale_grad
+        if self.clip_gradient is not None:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+
+        mean, var = state
+        mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
+        var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad)
+
+        mean_hat = mean
+        var_hat = var
+        r1 = weight.norm()
+        if self.lower_bound:
+            r1 = mx.nd.maximum(r1, self.lower_bound)
+        if self.upper_bound:
+            r1 = mx.nd.minimum(r1, self.upper_bound)
+        if self.bias_correction:
+            mean_hat = mean / (1. - mx.nd.power(self.beta1, t))
+            var_hat = var / (1. - mx.nd.power(self.beta2, t))
+
+        g = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight
+        r2 = g.norm()
+        # calculate lamb_trust_ratio
+        r = 1. if r1 == 0. or r2 == 0. else r1 / r2
+        lr *= r
+        # update weight
+        weight[:] -= lr * g
+
+    def update_multi_precision(self, index, weight, grad, state):
+        self.update(index, weight, grad, state)
+
+@with_seed()
+def test_lamb():
+    opt1 = PyLAMB
+    opt2 = mx.optimizer.LAMB
+    shape = (3, 4, 5)
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    bc_options = [{}, {'bias_correction': False}, {'bias_correction': True}]
+    lb_options = [{}, {'lower_bound': None}, {'lower_bound': 1e-3}]
+    ub_options = [{}, {'upper_bound': None}, {'upper_bound': 10}]
+    for params in itertools.product(cg_options, rg_options, wd_options, 
bc_options, lb_options, ub_options):
+        kwarg = {k: v for param in params for k, v in param.items()}
+        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)
+
+
 #SGLD
 class PySGLD(mx.optimizer.Optimizer):
     """python reference implementation of SGLD"""

Reply via email to