hetong007 commented on a change in pull request #14023: [R] Add NAG optimizer
URL: https://github.com/apache/incubator-mxnet/pull/14023#discussion_r252482701
 
 

 ##########
 File path: R-package/tests/testthat/test_optimizer.R
 ##########
 @@ -164,22 +170,82 @@ test_that("adadelta", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = 
list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), 
arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", 
"write", 
     "null"))
   
   optimizer <- mx.opt.create("adadelta", rho = 0.9, epsilon = 1e-05, wd = 0, 
rescale.grad = 1, 
     clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
   
   arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
   mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
-  
+
   expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2, 1)), 
     tolerance = 0.1)
   
 })
+
+
+test_that("nag_no_momentum", {
+  data <- mx.symbol.Variable("data")
+  label <- mx.symbol.Variable("label")
+  fc_weight <- mx.symbol.Variable("fc_weight")
+  fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
+       name = "fc1", num_hidden = 1)
+  loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
"loss")
+
+  x <- mx.nd.array(array(1:6, dim = 2:3))
+       y <- mx.nd.array(c(5, 11, 16))
+       w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
+
+       exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), 
arg.arrays = list(data = x,
+    fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", 
"write", "null"))
+
+  optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0, wd = 0, 
rescale.grad = 1,
+         clip_gradient = -1)
+
+       updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.ctx.default())
+       
+  mx.exec.forward(exec, is.train = T)
+       mx.exec.backward(exec)
+               
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+       mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+               
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2, 1)), 
tolerance = 0.05)
+})
+
+
+test_that("nag_momentum", {
+  data <- mx.symbol.Variable("data")
+  label <- mx.symbol.Variable("label")
+  fc_weight <- mx.symbol.Variable("fc_weight")
+  fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
+                                 name = "fc1", num_hidden = 1)
+  loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
"loss")
+  
+  x <- mx.nd.array(array(1:6, dim = 2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
+  
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), 
arg.arrays = list(data = x,
+                                                                               
           fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = 
c("null", "write", "null"))
+  
+  optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0.1, wd = 0, 
rescale.grad = 1,
+                             clip_gradient = 5)
+  
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.ctx.default())
+  
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+  
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+  
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.45, 2.65), dim = c(2, 1)), 
tolerance = 0.1)
 
 Review comment:
   How do we come up with the coefficients `c(1.45, 2.65)`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to