This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 785690c  Added in Large-Batch SGD with a warmup, and a LARS startegy. 
Also add… (#8918)
785690c is described below

commit 785690c0569f265b52c88ff3041849fd7c338d70
Author: Ashok Emani <ashok.em...@intel.com>
AuthorDate: Mon Jan 29 09:16:11 2018 -0800

    Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… 
(#8918)
    
    * Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added 
in a Polynomial Decay learning rate scheduler. Modified the example image fit 
code to allow these options to be selectable.
    
    * Fix pylint issues
    
    * pylint fixes
    
    * remove duplicate num_update
    
    * remove unused count
---
 example/image-classification/common/fit.py | 138 +++++++++++++++------
 python/mxnet/lr_scheduler.py               |  32 +++++
 python/mxnet/optimizer.py                  | 190 +++++++++++++++++++++++++++++
 3 files changed, 324 insertions(+), 36 deletions(-)

diff --git a/example/image-classification/common/fit.py 
b/example/image-classification/common/fit.py
index 2b002c7..d9f96d0 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -15,10 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import mxnet as mx
+""" example train fit utility """
 import logging
 import os
 import time
+import re
+import math
+import mxnet as mx
+
 
 def _get_lr_scheduler(args, kv):
     if 'lr_factor' not in args or args.lr_factor >= 1:
@@ -27,17 +31,26 @@ def _get_lr_scheduler(args, kv):
     if 'dist' in args.kv_store:
         epoch_size /= kv.num_workers
     begin_epoch = args.load_epoch if args.load_epoch else 0
+    if 'pow' in args.lr_step_epochs:
+        lr = args.lr
+        max_up = args.num_epochs * epoch_size
+        pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
+        poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
+        return (lr, poly_sched)
     step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
     lr = args.lr
     for s in step_epochs:
         if begin_epoch >= s:
             lr *= args.lr_factor
     if lr != args.lr:
-        logging.info('Adjust learning rate to %e for epoch %d' %(lr, 
begin_epoch))
+        logging.info('Adjust learning rate to %e for epoch %d',
+                     lr, begin_epoch)
 
-    steps = [epoch_size * (x-begin_epoch) for x in step_epochs if 
x-begin_epoch > 0]
+    steps = [epoch_size * (x - begin_epoch)
+             for x in step_epochs if x - begin_epoch > 0]
     return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, 
factor=args.lr_factor))
 
+
 def _load_model(args, rank=0):
     if 'load_epoch' not in args or args.load_epoch is None:
         return (None, None, None)
@@ -50,6 +63,7 @@ def _load_model(args, rank=0):
     logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
     return (sym, arg_params, aux_params)
 
+
 def _save_model(args, rank=0):
     if args.model_prefix is None:
         return None
@@ -59,6 +73,7 @@ def _save_model(args, rank=0):
     return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else 
"%s-%d" % (
         args.model_prefix, rank))
 
+
 def add_fit_args(parser):
     """
     parser : argparse.ArgumentParser
@@ -68,7 +83,8 @@ def add_fit_args(parser):
     train.add_argument('--network', type=str,
                        help='the neural network to use')
     train.add_argument('--num-layers', type=int,
-                       help='number of layers in the neural network, required 
by some networks such as resnet')
+                       help='number of layers in the neural network, \
+                             required by some networks such as resnet')
     train.add_argument('--gpus', type=str,
                        help='list of gpus to run, e.g. 0 or 0,2,5. empty means 
using cpu')
     train.add_argument('--kv-store', type=str, default='device',
@@ -81,6 +97,8 @@ def add_fit_args(parser):
                        help='the ratio to reduce lr on each step')
     train.add_argument('--lr-step-epochs', type=str,
                        help='the epochs to reduce the lr, e.g. 30,60')
+    train.add_argument('--initializer', type=str, default='default',
+                       help='the initializer type')
     train.add_argument('--optimizer', type=str, default='sgd',
                        help='the optimizer type')
     train.add_argument('--mom', type=float, default=0.9,
@@ -108,8 +126,16 @@ def add_fit_args(parser):
                              takes `2bit` or `none` for now')
     train.add_argument('--gc-threshold', type=float, default=0.5,
                        help='threshold for 2bit gradient compression')
+    # additional parameters for large batch sgd
+    train.add_argument('--macrobatch-size', type=int, default=0,
+                       help='distributed effective batch size')
+    train.add_argument('--warmup-epochs', type=int, default=5,
+                       help='the epochs to ramp-up lr to scaled large-batch 
value')
+    train.add_argument('--warmup-strategy', type=str, default='linear',
+                       help='the ramping-up strategy for large batch sgd')
     return train
 
+
 def fit(args, network, data_loader, **kwargs):
     """
     train a model
@@ -135,14 +161,13 @@ def fit(args, network, data_loader, **kwargs):
         for i, batch in enumerate(train):
             for j in batch.data:
                 j.wait_to_read()
-            if (i+1) % args.disp_batches == 0:
-                logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
-                    i, args.disp_batches*args.batch_size/(time.time()-tic)))
+            if (i + 1) % args.disp_batches == 0:
+                logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
+                             args.disp_batches * args.batch_size / 
(time.time() - tic))
                 tic = time.time()
 
         return
 
-
     # load model
     if 'arg_params' in kwargs and 'aux_params' in kwargs:
         arg_params = kwargs['arg_params']
@@ -156,7 +181,7 @@ def fit(args, network, data_loader, **kwargs):
     checkpoint = _save_model(args, kv.rank)
 
     # devices for training
-    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
+    devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
         mx.gpu(int(i)) for i in args.gpus.split(',')]
 
     # learning rate
@@ -164,14 +189,14 @@ def fit(args, network, data_loader, **kwargs):
 
     # create model
     model = mx.mod.Module(
-        context       = devs,
-        symbol        = network
+        context=devs,
+        symbol=network
     )
 
-    lr_scheduler  = lr_scheduler
+    lr_scheduler = lr_scheduler
     optimizer_params = {
         'learning_rate': lr,
-        'wd' : args.wd,
+        'wd': args.wd,
         'lr_scheduler': lr_scheduler,
         'multi_precision': True}
 
@@ -180,40 +205,81 @@ def fit(args, network, data_loader, **kwargs):
     if args.optimizer in has_momentum:
         optimizer_params['momentum'] = args.mom
 
-    monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 
else None
+    monitor = mx.mon.Monitor(
+        args.monitor, pattern=".*") if args.monitor > 0 else None
 
-    if args.network == 'alexnet':
-        # AlexNet will not converge using Xavier
-        initializer = mx.init.Normal()
-    else:
-        initializer = mx.init.Xavier(
-            rnd_type='gaussian', factor_type="in", magnitude=2)
+    # A limited number of optimizers have a warmup period
+    has_warmup = {'lbsgd', 'lbnag'}
+    if args.optimizer in has_warmup:
+        if 'dist' in args.kv_store:
+            nworkers = kv.num_workers
+        else:
+            nworkers = 1
+        epoch_size = args.num_examples / args.batch_size / nworkers
+        if epoch_size < 1:
+            epoch_size = 1
+        macrobatch_size = args.macrobatch_size
+        if macrobatch_size < args.batch_size * nworkers:
+            macrobatch_size = args.batch_size * nworkers
+        #batch_scale = round(float(macrobatch_size) / args.batch_size / 
nworkers +0.4999)
+        batch_scale = math.ceil(
+            float(macrobatch_size) / args.batch_size / nworkers)
+        optimizer_params['updates_per_epoch'] = epoch_size
+        optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch 
else 0
+        optimizer_params['batch_scale'] = batch_scale
+        optimizer_params['warmup_strategy'] = args.warmup_strategy
+        optimizer_params['warmup_epochs'] = args.warmup_epochs
+        optimizer_params['num_epochs'] = args.num_epochs
+
+    if args.initializer == 'default':
+        if args.network == 'alexnet':
+            # AlexNet will not converge using Xavier
+            initializer = mx.init.Normal()
+        else:
+            initializer = mx.init.Xavier(
+                rnd_type='gaussian', factor_type="in", magnitude=2)
     # initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),
+    elif args.initializer == 'xavier':
+        initializer = mx.init.Xavier()
+    elif args.initializer == 'msra':
+        initializer = mx.init.MSRAPrelu()
+    elif args.initializer == 'orthogonal':
+        initializer = mx.init.Orthogonal()
+    elif args.initializer == 'normal':
+        initializer = mx.init.Normal()
+    elif args.initializer == 'uniform':
+        initializer = mx.init.Uniform()
+    elif args.initializer == 'one':
+        initializer = mx.init.One()
+    elif args.initializer == 'zero':
+        initializer = mx.init.Zero()
 
     # evaluation metrices
     eval_metrics = ['accuracy']
     if args.top_k > 0:
-        eval_metrics.append(mx.metric.create('top_k_accuracy', 
top_k=args.top_k))
+        eval_metrics.append(mx.metric.create(
+            'top_k_accuracy', top_k=args.top_k))
 
     # callbacks that run after each batch
-    batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, 
args.disp_batches)]
+    batch_end_callbacks = [mx.callback.Speedometer(
+        args.batch_size, args.disp_batches)]
     if 'batch_end_callback' in kwargs:
         cbs = kwargs['batch_end_callback']
         batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
 
     # run
     model.fit(train,
-              begin_epoch        = args.load_epoch if args.load_epoch else 0,
-              num_epoch          = args.num_epochs,
-              eval_data          = val,
-              eval_metric        = eval_metrics,
-              kvstore            = kv,
-              optimizer          = args.optimizer,
-              optimizer_params   = optimizer_params,
-              initializer        = initializer,
-              arg_params         = arg_params,
-              aux_params         = aux_params,
-              batch_end_callback = batch_end_callbacks,
-              epoch_end_callback = checkpoint,
-              allow_missing      = True,
-              monitor            = monitor)
+              begin_epoch=args.load_epoch if args.load_epoch else 0,
+              num_epoch=args.num_epochs,
+              eval_data=val,
+              eval_metric=eval_metrics,
+              kvstore=kv,
+              optimizer=args.optimizer,
+              optimizer_params=optimizer_params,
+              initializer=initializer,
+              arg_params=arg_params,
+              aux_params=aux_params,
+              batch_end_callback=batch_end_callbacks,
+              epoch_end_callback=checkpoint,
+              allow_missing=True,
+              monitor=monitor)
diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py
index e4af77a..963560d 100644
--- a/python/mxnet/lr_scheduler.py
+++ b/python/mxnet/lr_scheduler.py
@@ -136,3 +136,35 @@ class MultiFactorScheduler(LRScheduler):
             else:
                 return self.base_lr
         return self.base_lr
+
+class PolyScheduler(LRScheduler):
+    """ Reduce the learning rate by given a list of steps.
+
+    Calculate the new learning rate by::
+
+       base_lr * (1-nup/max_nup)^pwr
+       if nup < max_nup, 0 otherwise.
+
+    Parameters
+    ----------
+       max_update: maximum number of updates before the decay reaches 0.
+       base_lr:    base learning rate
+       pwr:   power of the decay term as a funtion of the current number of 
updates.
+
+    """
+
+    def __init__(self, max_update, base_lr=0.01, pwr=2):
+        super(PolyScheduler, self).__init__(base_lr)
+        assert isinstance(max_update, int)
+        if max_update < 1:
+            raise ValueError("maximum number of updates must be strictly 
positive")
+        self.base_lr_orig = self.base_lr
+        self.max_update = max_update
+        self.power = pwr
+        self.base_lr = self.base_lr_orig
+
+    def __call__(self, num_update):
+        if num_update <= self.max_update:
+            self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / 
float(self.max_update),
+                                                   self.power)
+        return self.base_lr
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 57340be..c3338f4 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -645,6 +645,195 @@ class FTML(Optimizer):
         ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
                     lr=lr, wd=wd, **kwargs)
 
+@register
+class LBSGD(Optimizer):
+    """The Large Batch SGD optimizer with momentum and weight decay.
+
+    The optimizer updates the weight by::
+
+        state = momentum * state + lr * rescale_grad * clip(grad, 
clip_gradient) + wd * weight
+        weight = weight - state
+
+    For details of the update algorithm see 
:class:`~mxnet.ndarray.lbsgd_update` and
+    :class:`~mxnet.ndarray.lbsgd_mom_update`.
+
+    This optimizer accepts the following parameters in addition to those 
accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    multi_precision: bool, optional
+       Flag to control the internal precision of the optimizer.
+       ``False`` results in using the same precision as the weights (default),
+       ``True`` makes internal 32-bit copy of the weights and applies gradients
+                in 32-bit precision even if actual weights used in the model 
have lower precision.`<
+                Turning this on can improve convergence and accuracy when 
training with float16.
+    warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars'   default : 
'linear')
+    warmup_epochs: unsigned, default: 5
+    batch_scale:   unsigned, default: 1 (same as batch size*numworkers)
+    updates_per_epoch: updates_per_epoch (default: 32, Default might not 
reflect true number batches per epoch. Used for warmup.)
+    begin_epoch: unsigned, default 0, starting epoch.
+    """
+
+    def __init__(self, momentum=0.0, multi_precision=False, 
warmup_strategy='linear',
+                 warmup_epochs=5, batch_scale=1, updates_per_epoch=32, 
begin_epoch=0, num_epochs=60,
+                 **kwargs):
+        super(LBSGD, self).__init__(**kwargs)
+        logging.info('Running Large-Batch SGD Algorithm')
+        logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, 
updates_per_epoch=%d)',
+                     batch_scale, warmup_epochs, warmup_strategy, 
updates_per_epoch)
+        self.momentum = momentum
+        self.multi_precision = multi_precision
+        # new user parameters for large batch
+        self.warmup_strategy = warmup_strategy
+        self.warmup_epochs = warmup_epochs
+        self.batch_scale = batch_scale
+        self.updates_per_epoch = updates_per_epoch
+        self.init_updates = begin_epoch * updates_per_epoch
+        self.num_epochs = num_epochs
+        # addl internal usage parameters and storage
+        self.lbmult = 1
+        self.cumgrads = {}
+        # for adaptive lr
+        self.adaptive = False
+        self.admult = 1  # adaptation constant
+
+    def create_state(self, index, weight):
+        momentum = None
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = array(weight, ctx=weight.context, 
dtype=numpy.float32)
+            if self.momentum != 0.0:
+                momentum = zeros(weight.shape, weight.context, 
dtype=numpy.float32,
+                                 stype=weight.stype)
+            return (momentum, weight_master_copy)
+        if weight.dtype == numpy.float16 and not self.multi_precision:
+            warnings.warn("Accumulating with float16 in optimizer can lead to "
+                          "poor accuracy or slow convergence. "
+                          "Consider using multi_precision=True option of the "
+                          "SGD optimizer")
+        if self.momentum != 0.0:
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=weight.stype)
+        return momentum
+
+    def _get_lbmult(self, nup):
+        """Returns lr scaling factor for large batch according to warmup 
schedule
+        (to be implemented)
+        """
+        nwup = self.warmup_epochs * self.updates_per_epoch
+        strategy = self.warmup_strategy
+        maxmult = float(self.batch_scale)
+        if nup >= nwup:
+            mult = maxmult
+        elif nwup <= 1:
+            mult = 1.0
+        else:
+            if (strategy == 'linear'):
+                mult = 1.0 + (maxmult - 1) * nup / nwup
+            elif (strategy == 'power2'):
+                mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
+            elif (strategy == 'sqrt'):
+                mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup)
+            else:
+                mult = 1.0
+        return mult
+
+    def _get_lars(self, weight, g, wd):
+        """Returns a scaling factor for the learning rate for this layer
+        default is 1
+        """
+        weight2 = self._l2norm(weight)
+        grad2 = self._l2norm(g)
+        lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18))
+        if lars < 0.01:
+            lars = 0.01
+        elif lars > 100:
+            lars = 100
+        return lars
+
+    def _l2norm(self, v):
+        "inner product implementation"
+        norm = multiply(v, v).asnumpy().sum()
+        return norm
+
+    def _reset_cum_gradient(self, index):
+        "called every macro-batch to reset cumulated gradients to 0 for a 
given index"
+        self.cumgrads[index]['cum_grad'] = 0
+
+    def _get_cum_gradient(self, index):
+        "get the cumulated gradient for index"
+        if index in self.cumgrads:
+            return self.cumgrads[index]
+        else:
+            return {}
+
+    def _put_cum_gradient(self, index, cgrad):
+        "store cumulated gradient for index"
+        self.cumgrads[index] = cgrad
+
+    def _cumulate_gradient(self, grad, index):
+        "Cumulate gradients for large-batch emulation. Cumulated by index 
(layer)"
+        cgrad = self._get_cum_gradient(index)
+        if cgrad:
+            num_cums = cgrad['num_cums']
+            if num_cums > 0:
+                cum_grad = cgrad['cum_grad'] + grad
+                num_cums += 1
+            else:
+                cum_grad = grad
+                num_cums = self.init_updates + 1
+        else:
+            cum_grad = grad
+            num_cums = self.init_updates + 1
+        cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums}
+        self._put_cum_gradient(index, cgrad)
+        return cgrad
+
+    def update(self, index, weight, grad, state):
+        assert (isinstance(weight, NDArray))
+        assert (isinstance(grad, NDArray))
+
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        self._update_count(index)
+
+        # new stuff for large batch
+        cgrad = self._cumulate_gradient(grad, index)
+        if (cgrad['num_cums'] % self.batch_scale) == 0:
+            grad = cgrad['cum_grad'] / self.batch_scale
+            if self.warmup_strategy == 'lars':
+                lbmult = self._get_lars(weight, grad, wd)
+            else:
+                lbmult = self._get_lbmult(cgrad['num_cums'])
+            lr = lr * lbmult
+            # do the regular sgd update flow
+            kwargs = {'rescale_grad': self.rescale_grad}
+            if self.momentum > 0:
+                kwargs['momentum'] = self.momentum
+            if self.clip_gradient:
+                kwargs['clip_gradient'] = self.clip_gradient
+            use_multi_precision = isinstance(state, (list, tuple))
+
+            if not use_multi_precision:
+                if state is not None:
+                    sgd_mom_update(weight, grad, state, out=weight, lr=lr, 
wd=wd, **kwargs)
+                else:
+                    sgd_update(weight, grad, out=weight, lr=lr, wd=wd, 
**kwargs)
+            else:
+                if state[0] is not None:
+                    mp_sgd_mom_update(weight, grad, state[0], state[1], 
out=weight, lr=lr, wd=wd,
+                                      **kwargs)
+                else:
+                    mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, 
wd=wd, **kwargs)
+            # reset update count and cumulated gradient per large batch
+            self._reset_cum_gradient(index)
+        else:
+            lr = 0.0
+            kwargs = {}
+            sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
+
 # pylint: enable=line-too-long
 @register
 class DCASGD(Optimizer):
@@ -1282,6 +1471,7 @@ class Updater(object):
         self.optimizer.update_multi_precision(index, weight, grad, 
self.states[index])
 
     def sync_state_context(self, state, context):
+        """sync state context."""
         if isinstance(state, NDArray):
             return state.as_in_context(context)
         elif isinstance(state, (tuple, list)):

-- 
To stop receiving notification emails like this one, please contact
zhresh...@apache.org.

Reply via email to