zhreshold commented on a change in pull request #8005: add warmup lr_scheduler
URL: https://github.com/apache/incubator-mxnet/pull/8005#discussion_r141697319
 
 

 ##########
 File path: tests/python/unittest/test_lr_scheduler.py
 ##########
 @@ -0,0 +1,57 @@
+import logging
+import mxnet as mx 
+import mxnet.optimizer as opt              
+
+def test_lr_sceduler(lr, steps, lr_factor, warmup_step, warmup_lr):
+    logging.basicConfig(level=logging.DEBUG) 
+
+    lr_scheduler = None
+    if warmup_step > 0 and warmup_lr > lr:
+        lr_scheduler =  mx.lr_scheduler.MultiFactorScheduler(step=steps, 
factor=lr_factor, 
+                    warmup_step = warmup_step, begin_lr=lr, stop_lr=warmup_lr)
+    else:  
+        lr_scheduler =  mx.lr_scheduler.MultiFactorScheduler(step=steps, 
factor=lr_factor) 
+
+    optimizer_params = {
+            'learning_rate': lr,
+            'lr_scheduler': lr_scheduler}
+
+    optimizer = opt.create('sgd', **optimizer_params)  
+    updater = opt.get_updater(optimizer)     
+
+    x = [[[[i*10+j for j in range(10)] for i in range(10)]]]
+    x = mx.nd.array(x, dtype='float32')
+    y = mx.nd.ones(shape = x.shape, dtype='float32') 
+
+    res_lr = []
+    for i in range(1,steps[-1] + 5):
+        updater(0, y, x)
+        cur_lr = optimizer._get_lr(0)
+        res_lr.append(cur_lr)
+        logging.info("step %d lr = %f", i, cur_lr)
+
+    if warmup_step > 1:
+        assert mx.test_utils.almost_equal(res_lr[warmup_step], warmup_lr, 
1e-10) 
+        lr = warmup_lr
+    for i in range(len(steps)):
+        assert mx.test_utils.almost_equal(res_lr[steps[i]], lr * 
pow(lr_factor, i + 1), 1e-10)    
+
+if __name__ == "__main__":
+    #Legal input
+    test_lr_sceduler(lr = 0.2, steps = [8,12], lr_factor = 0.1, warmup_step = 
0, warmup_lr = 0.1)
+    test_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.1, warmup_step = 
1, warmup_lr = 0.1)
+    test_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.3, warmup_step = 
5, warmup_lr = 0.1)
+    test_lr_sceduler(lr = 0.002, steps = [8,12], lr_factor = 0.1, warmup_step 
= 7, warmup_lr = 0.1)
 
 Review comment:
   add a default test like: test_lr_sceduler(steps=[100, 200]) to test the 
default parameters if user don't specify anything.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to