rahul003 commented on a change in pull request #11234: [MXNET-535] Fix bugs in
LR Schedulers and add warmup
URL: https://github.com/apache/incubator-mxnet/pull/11234#discussion_r211042880
##########
File path: python/mxnet/lr_scheduler.py
##########
@@ -138,33 +171,73 @@ def __call__(self, num_update):
return self.base_lr
class PolyScheduler(LRScheduler):
+ """ Reduce the learning rate according to a polynomial of given power.
+
+ Calculate the new learning rate by::
+
+ final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr
+ if nup < max_nup, 0 otherwise.
+
+ Parameters
+ ----------
+ max_update: maximum number of updates before the decay reaches final
learning rate.
+ base_lr: base learning rate to start from
+ pwr: power of the decay term as a function of the current number of
updates.
+ final_lr: final learning rate after all steps
+ warmup_steps: number of warmup steps used before this scheduler starts
decay
+ """
+
+ def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0,
+ warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
+ super(PolyScheduler, self).__init__(base_lr, warmup_steps,
warmup_begin_lr, warmup_mode)
+ assert isinstance(max_update, int)
+ if max_update < 1:
+ raise ValueError("maximum number of updates must be strictly
positive")
+ self.power = pwr
+ self.base_lr_orig = self.base_lr
+ self.max_update = max_update
+ self.final_lr = final_lr
+ self.max_steps = self.max_update - self.warmup_steps
+
+ def __call__(self, num_update):
+ if num_update < self.warmup_steps:
+ return self.get_warmup_lr(num_update)
+ if num_update <= self.max_update:
+ self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr)
* \
+ pow(1 - float(num_update - self.warmup_steps) /
float(self.max_steps), self.power)
+ return self.base_lr
+
+class CosineScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.
Calculate the new learning rate by::
- base_lr * (1-nup/max_nup)^pwr
+ final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2
if nup < max_nup, 0 otherwise.
Parameters
----------
- max_update: maximum number of updates before the decay reaches 0.
+ max_update: maximum number of updates before the decay reaches 0
base_lr: base learning rate
- pwr: power of the decay term as a funtion of the current number of
updates.
-
+ final_lr: final learning rate after all steps
+ warmup_steps: number of warmup steps used before this scheduler starts
decay
"""
- def __init__(self, max_update, base_lr=0.01, pwr=2):
- super(PolyScheduler, self).__init__(base_lr)
+ def __init__(self, max_update, base_lr=0.01, final_lr=0,
Review comment:
Please refer
https://github.com/apache/incubator-mxnet/pull/11234#issuecomment-413998969
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services