This is an automated email from the ASF dual-hosted git repository. zhreshold pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 785690c Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… (#8918) 785690c is described below commit 785690c0569f265b52c88ff3041849fd7c338d70 Author: Ashok Emani <ashok.em...@intel.com> AuthorDate: Mon Jan 29 09:16:11 2018 -0800 Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… (#8918) * Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable. * Fix pylint issues * pylint fixes * remove duplicate num_update * remove unused count --- example/image-classification/common/fit.py | 138 +++++++++++++++------ python/mxnet/lr_scheduler.py | 32 +++++ python/mxnet/optimizer.py | 190 +++++++++++++++++++++++++++++ 3 files changed, 324 insertions(+), 36 deletions(-) diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 2b002c7..d9f96d0 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. -import mxnet as mx +""" example train fit utility """ import logging import os import time +import re +import math +import mxnet as mx + def _get_lr_scheduler(args, kv): if 'lr_factor' not in args or args.lr_factor >= 1: @@ -27,17 +31,26 @@ def _get_lr_scheduler(args, kv): if 'dist' in args.kv_store: epoch_size /= kv.num_workers begin_epoch = args.load_epoch if args.load_epoch else 0 + if 'pow' in args.lr_step_epochs: + lr = args.lr + max_up = args.num_epochs * epoch_size + pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs)) + poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr) + return (lr, poly_sched) step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] lr = args.lr for s in step_epochs: if begin_epoch >= s: lr *= args.lr_factor if lr != args.lr: - logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch)) + logging.info('Adjust learning rate to %e for epoch %d', + lr, begin_epoch) - steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0] + steps = [epoch_size * (x - begin_epoch) + for x in step_epochs if x - begin_epoch > 0] return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor)) + def _load_model(args, rank=0): if 'load_epoch' not in args or args.load_epoch is None: return (None, None, None) @@ -50,6 +63,7 @@ def _load_model(args, rank=0): logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) return (sym, arg_params, aux_params) + def _save_model(args, rank=0): if args.model_prefix is None: return None @@ -59,6 +73,7 @@ def _save_model(args, rank=0): return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( args.model_prefix, rank)) + def add_fit_args(parser): """ parser : argparse.ArgumentParser @@ -68,7 +83,8 @@ def add_fit_args(parser): train.add_argument('--network', type=str, help='the neural network to use') train.add_argument('--num-layers', type=int, - help='number of layers in the neural network, required by some networks such as resnet') + help='number of layers in the neural network, \ + required by some networks such as resnet') train.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') train.add_argument('--kv-store', type=str, default='device', @@ -81,6 +97,8 @@ def add_fit_args(parser): help='the ratio to reduce lr on each step') train.add_argument('--lr-step-epochs', type=str, help='the epochs to reduce the lr, e.g. 30,60') + train.add_argument('--initializer', type=str, default='default', + help='the initializer type') train.add_argument('--optimizer', type=str, default='sgd', help='the optimizer type') train.add_argument('--mom', type=float, default=0.9, @@ -108,8 +126,16 @@ def add_fit_args(parser): takes `2bit` or `none` for now') train.add_argument('--gc-threshold', type=float, default=0.5, help='threshold for 2bit gradient compression') + # additional parameters for large batch sgd + train.add_argument('--macrobatch-size', type=int, default=0, + help='distributed effective batch size') + train.add_argument('--warmup-epochs', type=int, default=5, + help='the epochs to ramp-up lr to scaled large-batch value') + train.add_argument('--warmup-strategy', type=str, default='linear', + help='the ramping-up strategy for large batch sgd') return train + def fit(args, network, data_loader, **kwargs): """ train a model @@ -135,14 +161,13 @@ def fit(args, network, data_loader, **kwargs): for i, batch in enumerate(train): for j in batch.data: j.wait_to_read() - if (i+1) % args.disp_batches == 0: - logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % ( - i, args.disp_batches*args.batch_size/(time.time()-tic))) + if (i + 1) % args.disp_batches == 0: + logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i, + args.disp_batches * args.batch_size / (time.time() - tic)) tic = time.time() return - # load model if 'arg_params' in kwargs and 'aux_params' in kwargs: arg_params = kwargs['arg_params'] @@ -156,7 +181,7 @@ def fit(args, network, data_loader, **kwargs): checkpoint = _save_model(args, kv.rank) # devices for training - devs = mx.cpu() if args.gpus is None or args.gpus is '' else [ + devs = mx.cpu() if args.gpus is None or args.gpus == "" else [ mx.gpu(int(i)) for i in args.gpus.split(',')] # learning rate @@ -164,14 +189,14 @@ def fit(args, network, data_loader, **kwargs): # create model model = mx.mod.Module( - context = devs, - symbol = network + context=devs, + symbol=network ) - lr_scheduler = lr_scheduler + lr_scheduler = lr_scheduler optimizer_params = { 'learning_rate': lr, - 'wd' : args.wd, + 'wd': args.wd, 'lr_scheduler': lr_scheduler, 'multi_precision': True} @@ -180,40 +205,81 @@ def fit(args, network, data_loader, **kwargs): if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom - monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None + monitor = mx.mon.Monitor( + args.monitor, pattern=".*") if args.monitor > 0 else None - if args.network == 'alexnet': - # AlexNet will not converge using Xavier - initializer = mx.init.Normal() - else: - initializer = mx.init.Xavier( - rnd_type='gaussian', factor_type="in", magnitude=2) + # A limited number of optimizers have a warmup period + has_warmup = {'lbsgd', 'lbnag'} + if args.optimizer in has_warmup: + if 'dist' in args.kv_store: + nworkers = kv.num_workers + else: + nworkers = 1 + epoch_size = args.num_examples / args.batch_size / nworkers + if epoch_size < 1: + epoch_size = 1 + macrobatch_size = args.macrobatch_size + if macrobatch_size < args.batch_size * nworkers: + macrobatch_size = args.batch_size * nworkers + #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) + batch_scale = math.ceil( + float(macrobatch_size) / args.batch_size / nworkers) + optimizer_params['updates_per_epoch'] = epoch_size + optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0 + optimizer_params['batch_scale'] = batch_scale + optimizer_params['warmup_strategy'] = args.warmup_strategy + optimizer_params['warmup_epochs'] = args.warmup_epochs + optimizer_params['num_epochs'] = args.num_epochs + + if args.initializer == 'default': + if args.network == 'alexnet': + # AlexNet will not converge using Xavier + initializer = mx.init.Normal() + else: + initializer = mx.init.Xavier( + rnd_type='gaussian', factor_type="in", magnitude=2) # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), + elif args.initializer == 'xavier': + initializer = mx.init.Xavier() + elif args.initializer == 'msra': + initializer = mx.init.MSRAPrelu() + elif args.initializer == 'orthogonal': + initializer = mx.init.Orthogonal() + elif args.initializer == 'normal': + initializer = mx.init.Normal() + elif args.initializer == 'uniform': + initializer = mx.init.Uniform() + elif args.initializer == 'one': + initializer = mx.init.One() + elif args.initializer == 'zero': + initializer = mx.init.Zero() # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: - eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k)) + eval_metrics.append(mx.metric.create( + 'top_k_accuracy', top_k=args.top_k)) # callbacks that run after each batch - batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)] + batch_end_callbacks = [mx.callback.Speedometer( + args.batch_size, args.disp_batches)] if 'batch_end_callback' in kwargs: cbs = kwargs['batch_end_callback'] batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] # run model.fit(train, - begin_epoch = args.load_epoch if args.load_epoch else 0, - num_epoch = args.num_epochs, - eval_data = val, - eval_metric = eval_metrics, - kvstore = kv, - optimizer = args.optimizer, - optimizer_params = optimizer_params, - initializer = initializer, - arg_params = arg_params, - aux_params = aux_params, - batch_end_callback = batch_end_callbacks, - epoch_end_callback = checkpoint, - allow_missing = True, - monitor = monitor) + begin_epoch=args.load_epoch if args.load_epoch else 0, + num_epoch=args.num_epochs, + eval_data=val, + eval_metric=eval_metrics, + kvstore=kv, + optimizer=args.optimizer, + optimizer_params=optimizer_params, + initializer=initializer, + arg_params=arg_params, + aux_params=aux_params, + batch_end_callback=batch_end_callbacks, + epoch_end_callback=checkpoint, + allow_missing=True, + monitor=monitor) diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index e4af77a..963560d 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -136,3 +136,35 @@ class MultiFactorScheduler(LRScheduler): else: return self.base_lr return self.base_lr + +class PolyScheduler(LRScheduler): + """ Reduce the learning rate by given a list of steps. + + Calculate the new learning rate by:: + + base_lr * (1-nup/max_nup)^pwr + if nup < max_nup, 0 otherwise. + + Parameters + ---------- + max_update: maximum number of updates before the decay reaches 0. + base_lr: base learning rate + pwr: power of the decay term as a funtion of the current number of updates. + + """ + + def __init__(self, max_update, base_lr=0.01, pwr=2): + super(PolyScheduler, self).__init__(base_lr) + assert isinstance(max_update, int) + if max_update < 1: + raise ValueError("maximum number of updates must be strictly positive") + self.base_lr_orig = self.base_lr + self.max_update = max_update + self.power = pwr + self.base_lr = self.base_lr_orig + + def __call__(self, num_update): + if num_update <= self.max_update: + self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update), + self.power) + return self.base_lr diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 57340be..c3338f4 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -645,6 +645,195 @@ class FTML(Optimizer): ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight, lr=lr, wd=wd, **kwargs) +@register +class LBSGD(Optimizer): + """The Large Batch SGD optimizer with momentum and weight decay. + + The optimizer updates the weight by:: + + state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight + weight = weight - state + + For details of the update algorithm see :class:`~mxnet.ndarray.lbsgd_update` and + :class:`~mxnet.ndarray.lbsgd_mom_update`. + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + momentum : float, optional + The momentum value. + multi_precision: bool, optional + Flag to control the internal precision of the optimizer. + ``False`` results in using the same precision as the weights (default), + ``True`` makes internal 32-bit copy of the weights and applies gradients + in 32-bit precision even if actual weights used in the model have lower precision.`< + Turning this on can improve convergence and accuracy when training with float16. + warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear') + warmup_epochs: unsigned, default: 5 + batch_scale: unsigned, default: 1 (same as batch size*numworkers) + updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.) + begin_epoch: unsigned, default 0, starting epoch. + """ + + def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear', + warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60, + **kwargs): + super(LBSGD, self).__init__(**kwargs) + logging.info('Running Large-Batch SGD Algorithm') + logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)', + batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch) + self.momentum = momentum + self.multi_precision = multi_precision + # new user parameters for large batch + self.warmup_strategy = warmup_strategy + self.warmup_epochs = warmup_epochs + self.batch_scale = batch_scale + self.updates_per_epoch = updates_per_epoch + self.init_updates = begin_epoch * updates_per_epoch + self.num_epochs = num_epochs + # addl internal usage parameters and storage + self.lbmult = 1 + self.cumgrads = {} + # for adaptive lr + self.adaptive = False + self.admult = 1 # adaptation constant + + def create_state(self, index, weight): + momentum = None + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32) + if self.momentum != 0.0: + momentum = zeros(weight.shape, weight.context, dtype=numpy.float32, + stype=weight.stype) + return (momentum, weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "SGD optimizer") + if self.momentum != 0.0: + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + return momentum + + def _get_lbmult(self, nup): + """Returns lr scaling factor for large batch according to warmup schedule + (to be implemented) + """ + nwup = self.warmup_epochs * self.updates_per_epoch + strategy = self.warmup_strategy + maxmult = float(self.batch_scale) + if nup >= nwup: + mult = maxmult + elif nwup <= 1: + mult = 1.0 + else: + if (strategy == 'linear'): + mult = 1.0 + (maxmult - 1) * nup / nwup + elif (strategy == 'power2'): + mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup) + elif (strategy == 'sqrt'): + mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup) + else: + mult = 1.0 + return mult + + def _get_lars(self, weight, g, wd): + """Returns a scaling factor for the learning rate for this layer + default is 1 + """ + weight2 = self._l2norm(weight) + grad2 = self._l2norm(g) + lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18)) + if lars < 0.01: + lars = 0.01 + elif lars > 100: + lars = 100 + return lars + + def _l2norm(self, v): + "inner product implementation" + norm = multiply(v, v).asnumpy().sum() + return norm + + def _reset_cum_gradient(self, index): + "called every macro-batch to reset cumulated gradients to 0 for a given index" + self.cumgrads[index]['cum_grad'] = 0 + + def _get_cum_gradient(self, index): + "get the cumulated gradient for index" + if index in self.cumgrads: + return self.cumgrads[index] + else: + return {} + + def _put_cum_gradient(self, index, cgrad): + "store cumulated gradient for index" + self.cumgrads[index] = cgrad + + def _cumulate_gradient(self, grad, index): + "Cumulate gradients for large-batch emulation. Cumulated by index (layer)" + cgrad = self._get_cum_gradient(index) + if cgrad: + num_cums = cgrad['num_cums'] + if num_cums > 0: + cum_grad = cgrad['cum_grad'] + grad + num_cums += 1 + else: + cum_grad = grad + num_cums = self.init_updates + 1 + else: + cum_grad = grad + num_cums = self.init_updates + 1 + cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums} + self._put_cum_gradient(index, cgrad) + return cgrad + + def update(self, index, weight, grad, state): + assert (isinstance(weight, NDArray)) + assert (isinstance(grad, NDArray)) + + lr = self._get_lr(index) + wd = self._get_wd(index) + self._update_count(index) + + # new stuff for large batch + cgrad = self._cumulate_gradient(grad, index) + if (cgrad['num_cums'] % self.batch_scale) == 0: + grad = cgrad['cum_grad'] / self.batch_scale + if self.warmup_strategy == 'lars': + lbmult = self._get_lars(weight, grad, wd) + else: + lbmult = self._get_lbmult(cgrad['num_cums']) + lr = lr * lbmult + # do the regular sgd update flow + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + use_multi_precision = isinstance(state, (list, tuple)) + + if not use_multi_precision: + if state is not None: + sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) + else: + sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd, + **kwargs) + else: + mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs) + # reset update count and cumulated gradient per large batch + self._reset_cum_gradient(index) + else: + lr = 0.0 + kwargs = {} + sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + # pylint: enable=line-too-long @register class DCASGD(Optimizer): @@ -1282,6 +1471,7 @@ class Updater(object): self.optimizer.update_multi_precision(index, weight, grad, self.states[index]) def sync_state_context(self, state, context): + """sync state context.""" if isinstance(state, NDArray): return state.as_in_context(context) elif isinstance(state, (tuple, list)): -- To stop receiving notification emails like this one, please contact zhresh...@apache.org.