hetong007 closed pull request #11374: [MXNET-563] Refactor R optimizers to fix 
memory leak
URL: https://github.com/apache/incubator-mxnet/pull/11374
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/R-package/R/model.R b/R-package/R/model.R
index b461f7973f6..a2c441968bc 100644
--- a/R-package/R/model.R
+++ b/R-package/R/model.R
@@ -147,7 +147,7 @@ mx.model.train <- function(symbol, ctx, input.shape, 
output.shape,
     kvstore$set.optimizer(optimizer)
   } else {
     updaters <- lapply(seq_len(ndevice), function(i) {
-      mx.opt.get.updater(optimizer, train.execs[[i]]$ref.arg.arrays)
+      mx.opt.get.updater(optimizer, train.execs[[i]]$ref.arg.arrays, ctx = 
ctx[[i]])
     })
   }
   if (!is.null(kvstore)) {
diff --git a/R-package/R/model.rnn.R b/R-package/R/model.rnn.R
index f328d1ba6b7..580c82a0a65 100644
--- a/R-package/R/model.rnn.R
+++ b/R-package/R/model.rnn.R
@@ -1,51 +1,50 @@
 # Internal function to do multiple device training on RNN
-mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data, 
-                                   dlist, arg.params, aux.params, 
-                                   grad.req, arg.update.idx, 
+mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
+                                   dlist, arg.params, aux.params,
+                                   grad.req, arg.update.idx,
                                    begin.round, end.round, optimizer, metric, 
metric_cpu,
-                                   epoch.end.callback, batch.end.callback, 
kvstore, verbose,
-                                   gc_freq) {
-  
+                                   epoch.end.callback, batch.end.callback, 
kvstore, verbose) {
+
   ndevice <- length(ctx)
-  if (verbose) 
+  if (verbose)
     message("Start training with ", ndevice, " devices")
-  
+
   input.names <- names(dlist)
   arg.params.names <- names(arg.params)
-  
+
   if (is.list(symbol)) sym_ini <- symbol[[names(train.data$bucketID)]] else 
sym_ini <- symbol
-  
+
   slices <- lapply(seq_len(ndevice), function(i) {
     sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], num_outputs 
= ndevice, axis = 0, squeeze_axis = FALSE))
   })
-  
+
   train.execs <- lapply(seq_len(ndevice), function(i) {
     s <- slices[[i]]
-    mx.symbol.bind(symbol = sym_ini, arg.arrays = c(s, 
arg.params)[arg.update.idx], 
+    mx.symbol.bind(symbol = sym_ini, arg.arrays = c(s, 
arg.params)[arg.update.idx],
                    aux.arrays = aux.params, ctx = ctx[[i]], grad.req = 
grad.req)
   })
-  
+
   # KVStore related stuffs
   params.index <- as.integer(
     mx.util.filter.null(
       lapply(seq_along(train.execs[[1]]$ref.grad.arrays), function(k) {
         if (!is.null(train.execs[[1]]$ref.grad.arrays[[k]])) k else NULL}
       )))
-  
+
   update.on.kvstore <- FALSE
   if (!is.null(kvstore) && kvstore$update.on.kvstore) {
     update.on.kvstore <- TRUE
     kvstore$set.optimizer(optimizer)
   } else {
     updaters <- lapply(seq_len(ndevice), function(i) {
-      mx.opt.get.updater(optimizer, train.execs[[i]]$ref.arg.arrays)
+      mx.opt.get.updater(optimizer, train.execs[[i]]$ref.arg.arrays, ctx = 
ctx[[i]])
     })
   }
-  
+
   if (!is.null(kvstore)) {
     kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
   }
-  
+
   # train over specified number of epochs
   for (iteration in begin.round:end.round) {
     nbatch <- 0
@@ -55,20 +54,20 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, 
eval.data,
     }
     train.data$reset()
     while (train.data$iter.next()) {
-      
+
       # Get iterator data
       dlist <- train.data$value()[input.names]
-      
+
       # Slice inputs for multi-devices
       slices <- lapply(seq_len(ndevice), function(i) {
         sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], 
num_outputs = ndevice, axis = 0, squeeze_axis = F))
       })
-      
+
       # Assign input to each executor - bug on inference if using BatchNorm
       if (is.list(symbol)) {
         train.execs <- lapply(seq_len(ndevice), function(i) {
           s <- slices[[i]]
-          mx.symbol.bind(symbol = symbol[[names(train.data$bucketID)]], 
+          mx.symbol.bind(symbol = symbol[[names(train.data$bucketID)]],
                          arg.arrays = c(s, 
train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
                          aux.arrays = train.execs[[i]]$aux.arrays, ctx = 
ctx[[i]], grad.req = grad.req)
         })
@@ -78,12 +77,12 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, 
eval.data,
           mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
         }
       }
-      
+
       # forward pass
       for (texec in train.execs) {
         mx.exec.forward(texec, is.train = TRUE)
       }
-      
+
       # copy of preds and labels for metric
       if (!is.null(metric)) {
         preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
@@ -93,12 +92,12 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, 
eval.data,
           labels <- lapply(seq_along(train.execs), function(i) 
{mx.nd.copyto(labels[[i]], mx.cpu())})
         }
       }
-      
+
       # backward pass
       for (texec in train.execs) {
         mx.exec.backward(texec)
       }
-      
+
       if (!is.null(kvstore)) {
         # push the gradient
         kvstore$push(params.index, lapply(train.execs, function(texec) {
@@ -124,7 +123,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, 
eval.data,
           mx.exec.update.arg.arrays(train.execs[[i]], arg.blocks[[i]], 
skip.null = TRUE)
         }
       }
-      
+
       # Update the evaluation metrics
       if (!is.null(metric)) {
         for (i in seq_len(ndevice)) {
@@ -133,43 +132,40 @@ mx.model.train.buckets <- function(symbol, ctx, 
train.data, eval.data,
                                         state = train.metric)
         }
       }
-      
+
       nbatch <- nbatch + 1
-      if (!is.null(gc_freq)) {
-        if (nbatch %% gc_freq == 0) gc()
-      }
-      
+
       if (!is.null(batch.end.callback)) {
         batch.end.callback(iteration, nbatch, environment())
       }
     }
-    
+
     if (!is.null(metric)) {
       result <- metric$get(train.metric)
-      if (verbose) 
+      if (verbose)
         message("[", iteration, "] Train-", result$name, "=", result$value)
     }
-    
+
     if (!is.null(eval.data)) {
       if (!is.null(metric)) {
         eval.metric <- metric$init()
       }
       eval.data$reset()
       while (eval.data$iter.next()) {
-        
+
         # Get iterator data
         dlist <- eval.data$value()[input.names]
-        
+
         # Slice input to multiple devices
         slices <- lapply(seq_len(ndevice), function(i) {
           sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], 
num_outputs = ndevice, axis = 0, squeeze_axis = FALSE))
         })
-        
+
         # Assign input to each executor - bug on inference if using BatchNorm
         if (is.list(symbol)) {
           train.execs <- lapply(seq_len(ndevice), function(i) {
             s <- slices[[i]]
-            mx.symbol.bind(symbol = symbol[[names(eval.data$bucketID)]], 
+            mx.symbol.bind(symbol = symbol[[names(eval.data$bucketID)]],
                            arg.arrays = c(s, 
train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
                            aux.arrays = train.execs[[i]]$aux.arrays, ctx = 
ctx[[i]], grad.req = grad.req)
           })
@@ -179,12 +175,12 @@ mx.model.train.buckets <- function(symbol, ctx, 
train.data, eval.data,
             mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
           }
         }
-        
+
         # forward pass
         for (texec in train.execs) {
           mx.exec.forward(texec, is.train = FALSE)
         }
-        
+
         # copy of preds and labels for metric and update metric
         if (!is.null(metric)) {
           preds <- lapply(train.execs, function(texec) 
{texec$ref.outputs[[1]]})
@@ -194,17 +190,17 @@ mx.model.train.buckets <- function(symbol, ctx, 
train.data, eval.data,
             labels <- lapply(seq_along(train.execs), function(i) 
{mx.nd.copyto(labels[[i]], mx.cpu())})
           }
           for (i in seq_len(ndevice)) {
-            eval.metric <- metric$update(label = labels[[i]], 
-                                         pred = preds[[i]], 
+            eval.metric <- metric$update(label = labels[[i]],
+                                         pred = preds[[i]],
                                          state = eval.metric)
           }
         }
       }
-      
+
       if (!is.null(metric)) {
         result <- metric$get(eval.metric)
         if (verbose) {
-          message("[", iteration, "] Validation-", result$name, "=", 
+          message("[", iteration, "] Validation-", result$name, "=",
                   result$value)
         }
       }
@@ -213,12 +209,12 @@ mx.model.train.buckets <- function(symbol, ctx, 
train.data, eval.data,
     }
     # get the model out
     model <- mx.model.extract.model(sym_ini, train.execs)
-    
+
     epoch_continue <- TRUE
     if (!is.null(epoch.end.callback)) {
       epoch_continue <- epoch.end.callback(iteration, 0, environment(), 
verbose = verbose)
     }
-    
+
     if (!epoch_continue) {
       break
     }
@@ -227,7 +223,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, 
eval.data,
 }
 
 
-# 
+#
 #' Train RNN with bucket support
 #'
 #' @param symbol Symbol or list of Symbols representing the model
@@ -245,33 +241,33 @@ mx.model.train.buckets <- function(symbol, ctx, 
train.data, eval.data,
 #' @param verbose
 #'
 #' @export
-mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = 
NULL, 
-                             arg.params = NULL, aux.params = NULL, 
fixed.params = NULL, 
-                             num.round = 1, begin.round = 1, 
-                             initializer = mx.init.uniform(0.01), optimizer = 
"sgd", ctx = NULL, 
-                             batch.end.callback = NULL, epoch.end.callback = 
NULL, 
-                             kvstore = "local", verbose = TRUE, metric_cpu = 
TRUE, gc_freq = NULL) {
-  
+mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = 
NULL,
+                             arg.params = NULL, aux.params = NULL, 
fixed.params = NULL,
+                             num.round = 1, begin.round = 1,
+                             initializer = mx.init.uniform(0.01), optimizer = 
"sgd", ctx = NULL,
+                             batch.end.callback = NULL, epoch.end.callback = 
NULL,
+                             kvstore = "local", verbose = TRUE, metric_cpu = 
TRUE) {
+
   if (!train.data$iter.next()) {
     train.data$reset()
-    if (!train.data$iter.next()) 
+    if (!train.data$iter.next())
       stop("Empty train.data")
   }
-  
+
   if (!is.null(eval.data)) {
     if (!eval.data$iter.next()) {
       eval.data$reset()
-      if (!eval.data$iter.next()) 
+      if (!eval.data$iter.next())
         stop("Empty eval.data")
     }
   }
-  
-  if (is.null(ctx)) 
+
+  if (is.null(ctx))
     ctx <- mx.ctx.default()
   if (is.mx.context(ctx)) {
     ctx <- list(ctx)
   }
-  if (!is.list(ctx)) 
+  if (!is.list(ctx))
     stop("ctx must be mx.context or list of mx.context")
   if (is.character(optimizer)) {
     if (is.numeric(input.shape)) {
@@ -283,75 +279,75 @@ mx.model.buckets <- function(symbol, train.data, 
eval.data = NULL, metric = NULL
     }
     optimizer <- mx.opt.create(optimizer, rescale.grad = (1/batchsize), ...)
   }
-  
+
   sym_ini <- if (is.list(symbol)) symbol[[names(train.data$bucketID)]] else 
symbol
-  
+
   arguments <- sym_ini$arguments
   input.names <- intersect(names(train.data$value()), arguments)
-  
+
   input.shape <- sapply(input.names, function(n) {
     dim(train.data$value()[[n]])
   }, simplify = FALSE)
-  
+
   shapes <- sym_ini$infer.shape(input.shape)
-  
+
   # assign arg.params and aux.params arguments to arg.params.input and 
aux.params.input
   arg.params.input <- arg.params
   aux.params.input <- aux.params
-  
+
   # initialize all arguments with zeros
   arg.params <- lapply(shapes$arg.shapes, function(shape) {
     mx.nd.zeros(shape = shape, ctx = mx.cpu())
   })
-  
+
   # initialize input parameters
   dlist <- arg.params[input.names]
-  
+
   # initialize parameters - only argument ending with _weight and _bias are 
initialized
   arg.params.ini <- mx.init.create(initializer = initializer, shape.array = 
shapes$arg.shapes, ctx = mx.cpu(), skip.unknown = TRUE)
-  
+
   # assign initilized parameters to arg.params
   arg.params[names(arg.params.ini)] <- arg.params.ini
-  
+
   # assign input params to arg.params
   arg.params[names(arg.params.input)] <- arg.params.input
-  
+
   # remove input params from arg.params
   arg.params[input.names] <- NULL
-  
+
   # Grad request
   grad.req <- rep("null", length(arguments))
   grad.req.write <- arguments %in% setdiff(names(arg.params.ini), fixed.params)
   grad.req[grad.req.write] <- "write"
-  
+
   # Arg array order
   update_names <- c(input.names, names(arg.params))
   arg.update.idx <- match(arguments, update_names)
-  
+
   # aux parameters setup
   aux.params <- lapply(shapes$aux.shapes, function(shape) {
     mx.nd.zeros(shape = shape, ctx = mx.cpu())
   })
-  
+
   aux.params.ini <- mx.init.create(initializer, shapes$aux.shapes, ctx = 
mx.cpu(), skip.unknown = FALSE)
   if (length(aux.params) > 0) {
     aux.params[names(aux.params.ini)] <- aux.params.ini
   } else aux.params <- NULL
-  
+
   aux.params[names(aux.params.input)] <- aux.params.input
-  
+
   # kvstore initialization
-  kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), 
+  kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx),
                                      verbose = verbose)
-  
+
   ### Execute training
-  model <- mx.model.train.buckets(symbol = symbol, ctx = ctx,  train.data = 
train.data, eval.data = eval.data, 
-                                  dlist = dlist,  arg.params = arg.params, 
aux.params = aux.params, 
-                                  grad.req = grad.req, arg.update.idx = 
arg.update.idx, 
-                                  optimizer = optimizer, metric = metric, 
-                                  begin.round = begin.round, end.round = 
num.round, 
-                                  batch.end.callback = batch.end.callback, 
epoch.end.callback = epoch.end.callback, 
-                                  kvstore = kvstore, verbose = verbose, 
metric_cpu = metric_cpu, gc_freq = gc_freq)
-  
+  model <- mx.model.train.buckets(symbol = symbol, ctx = ctx,  train.data = 
train.data, eval.data = eval.data,
+                                  dlist = dlist,  arg.params = arg.params, 
aux.params = aux.params,
+                                  grad.req = grad.req, arg.update.idx = 
arg.update.idx,
+                                  optimizer = optimizer, metric = metric,
+                                  begin.round = begin.round, end.round = 
num.round,
+                                  batch.end.callback = batch.end.callback, 
epoch.end.callback = epoch.end.callback,
+                                  kvstore = kvstore, verbose = verbose, 
metric_cpu = metric_cpu)
+
   return(model)
 }
diff --git a/R-package/R/optimizer.R b/R-package/R/optimizer.R
index 3c503c2e855..7283f677fe4 100644
--- a/R-package/R/optimizer.R
+++ b/R-package/R/optimizer.R
@@ -1,31 +1,69 @@
 #' Create an SGD optimizer with respective parameters.
 #' Perform SGD with momentum update
 #'
-mx.opt.sgd <- function(learning.rate,
-                       momentum=0,
-                       wd=0,
-                       rescale.grad=1,
-                       clip_gradient = NULL, 
+#' @param learning.rate float, default=0.01
+#'      The initial learning rate.
+#' @param momentum float, default=0
+#'      The momentumvalue
+#' @param wd float, default=0.0
+#'      L2 regularization coefficient add to all the weights.
+#' @param rescale.grad float, default=1.0
+#'      rescaling factor of gradient.
+#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
+#'      clip gradient in range [-clip_gradient, clip_gradient].
+#' @param lr_scheduler function, optional
+#'      The learning rate scheduler.
+mx.opt.sgd <- function(learning.rate = 0.01,
+                       momentum = 0,
+                       wd = 0,
+                       rescale.grad = 1,
+                       clip_gradient = -1,
                        lr_scheduler = NULL) {
-  # use lr as short for learing rate.
+
   lr <- learning.rate
-  count       <- 0
-  num_update  <- 0
+  count <- 0
+  num_update <- 0
 
   sgd <- new.env()
   sgd$lr <- lr
   sgd$count <- 0
   sgd$num_update <- 0
 
-  create.state <- function(index, weight) {
+  create_exec <- function(index, weight_dim, ctx) {
+
     if (momentum == 0) {
-      return(NULL)
+
+      weight <- mx.symbol.Variable("weight")
+      grad <- mx.symbol.Variable("grad")
+
+      sym <- mx.symbol.sgd_update(weight,
+                                  grad,
+                                  lr = lr,
+                                  wd = wd,
+                                  rescale_grad = rescale.grad,
+                                  clip_gradient = clip_gradient,
+                                  name = "w")
     } else {
-      ret <- (mx.nd.zeros(dim(weight), ctx(weight)))
-      return(ret)
+
+      weight <- mx.symbol.Variable("weight")
+      grad <- mx.symbol.Variable("grad")
+      mom <- mx.symbol.Variable("mom")
+
+      sym <- mx.symbol.sgd_mom_update(weight,
+                                      grad,
+                                      mom,
+                                      lr = lr,
+                                      wd = wd,
+                                      momentum= momentum,
+                                      rescale_grad = rescale.grad,
+                                      clip_gradient = clip_gradient,
+                                      name = "w")
     }
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, 
grad.req = "null")
+    return(exec)
   }
-  update <- function(index, weight, grad, state) {
+
+  update <- function(index, exec_w, weight, grad) {
 
     if (!is.null(lr_scheduler)){
       lr_scheduler(sgd) ## changing lr
@@ -40,77 +78,104 @@ mx.opt.sgd <- function(learning.rate,
         sgd$num_update <- max(sgd$num_update, sgd[[indexKey]])
       }
     }
-    grad <- grad * rescale.grad
-    if (!is.null(clip_gradient)){
-      if(clip_gradient >= 0){
-        grad <- mx.nd.clip(grad, -clip_gradient, clip_gradient)
-      } else {
-        stop("Error: clip_gradient should be positive number.")
-      }
-    }
-    if (is.null(state)) {
-      weight <- weight - lr * (grad + wd * weight)
-    } else {
-      mom <- state
-      mom <- mom * momentum
-      mom <- mom - lr * (grad + wd * weight)
-      weight <- weight + mom
-      state <- mom
-    }
-    return(list(weight=weight, state=state))
+
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(weight = weight,grad = 
grad), match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
+    return(exec_w$ref.outputs$w_output)
   }
-  return(list(create.state=create.state, update=update))
+  return(list(create_exec = create_exec, update = update))
 }
 
 #' Create an RMSProp optimizer with respective parameters.
 #' Reference: Tieleman T, Hinton G. Lecture 6.5-rmsprop: Divide the gradient 
by a running average of its recent magnitude[J]. COURSERA: Neural Networks for 
Machine Learning, 2012, 4(2).
 #' The code follows: http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by 
Alex Graves, 2013.
-#' 
+#'
 #' @param learning.rate float, default=0.002
-#'      Step size.
+#'      The initial learning rate.
 #' @param gamma1 float, default=0.95
 #'      decay factor of moving average for gradient, gradient^2.
-#' @param gamm2 float, default=0.9
+#' @param gamma2 float, default=0.9
 #'      "momentum" factor.
+#' @param epsilon float, default=1e-4
 #' @param wd float, default=0.0
 #'      L2 regularization coefficient add to all the weights.
 #' @param rescale.grad float, default=1.0
 #'      rescaling factor of gradient.
-#' @param clip_gradient float, optional
+#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
 #'      clip gradient in range [-clip_gradient, clip_gradient].
 #' @param lr_scheduler function, optional
 #'      The learning rate scheduler.
 #'
-mx.opt.rmsprop <- function(learning.rate=0.002,
-                           gamma1=0.95,
-                           gamma2=0.9,
-                           wd=0,
-                           rescale.grad=1,
-                           clip_gradient = NULL, 
+mx.opt.rmsprop <- function(learning.rate = 0.002,
+                           centered = TRUE,
+                           gamma1 = 0.95,
+                           gamma2 = 0.9,
+                           epsilon = 1e-4,
+                           wd = 0,
+                           rescale.grad = 1,
+                           clip_gradient = -1,
                            lr_scheduler = NULL) {
-  # use lr as short for learing rate.
+
   lr <- learning.rate
-  count       <- 0
-  num_update  <- 0
+  count <- 0
+  num_update <- 0
 
   rmsprop <- new.env()
   rmsprop$lr <- lr
   rmsprop$count <- 0
   rmsprop$num_update <- 0
 
-  create.state <- function(index, weight) {
-      return (list(n=mx.nd.zeros(dim(weight), ctx(weight)),
-                   g=mx.nd.zeros(dim(weight), ctx(weight)),
-                   delta=mx.nd.zeros(dim(weight), ctx(weight))))
+  create_exec <- function(index, weight_dim, ctx) {
+
+    if (centered) {
+
+      weight <- mx.symbol.Variable("weight")
+      grad <- mx.symbol.Variable("grad")
+      n <- mx.symbol.Variable("n")
+      g <- mx.symbol.Variable("g")
+      delta <- mx.symbol.Variable("delta")
+
+      sym <- mx.symbol.rmspropalex_update(weight,
+                                          grad,
+                                          n,
+                                          g,
+                                          delta,
+                                          lr = lr,
+                                          gamma1 = gamma1,
+                                          gamma2 = gamma2,
+                                          epsilon = epsilon,
+                                          wd = wd,
+                                          rescale_grad = rescale.grad,
+                                          clip_gradient = clip_gradient,
+                                          name = "w")
+    } else {
+      weight <- mx.symbol.Variable("weight")
+      grad <- mx.symbol.Variable("grad")
+      n <- mx.symbol.Variable("n")
+
+      sym <- mx.symbol.rmsprop_update(weight,
+                                      grad,
+                                      n,
+                                      lr = lr,
+                                      gamma1 = gamma1,
+                                      epsilon = epsilon,
+                                      wd = wd,
+                                      rescale_grad = rescale.grad,
+                                      clip_gradient = clip_gradient,
+                                      name = "w")
+    }
+
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, 
grad.req = "null")
+    return(exec)
   }
 
-  update <- function(index, weight, grad, state) {
+  update <- function(index, exec_w, weight, grad) {
     if (!is.null(lr_scheduler)){
       lr_scheduler(rmsprop) ## changing lr
       lr <- rmsprop$lr
       ## update count
       indexKey <- paste0('ik', index)
-      if (!exists(envir = rmsprop, x = indexKey, inherits = FALSE)){
+      if (!exists(envir = rmsprop, x = indexKey, inherits = FALSE)) {
         rmsprop[[indexKey]] <- 0
       } else {
         indexValue <- rmsprop[[indexKey]]
@@ -118,27 +183,12 @@ mx.opt.rmsprop <- function(learning.rate=0.002,
         rmsprop$num_update <- max(rmsprop$num_update, rmsprop[[indexKey]])
       }
     }
-    grad <- grad * rescale.grad
-    if (!is.null(clip_gradient)){
-      if(clip_gradient >= 0){
-        grad <- mx.nd.clip(grad, -clip_gradient, clip_gradient)
-      } else {
-        stop("Error: clip_gradient should be positive number.")
-      }
-    }
 
-    n <- state$n
-    g <- state$g
-    delta <- state$delta
-    n <- gamma1 * n + (1 - gamma1) * (grad * grad)
-    g <- gamma1 * g + (1 - gamma1) * grad
-    delta <- gamma2 * delta - lr * (grad / mx.nd.sqrt(n - g*g + 1e-4) + wd * 
weight)
-    weight <- weight + delta
-    state <- list(n=n, g=g, delta=delta)
-
-    return(list(weight=weight, state=state))
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(weight = weight,grad = 
grad), match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
+    return(exec_w$ref.outputs$w_output)
   }
-  return(list(create.state=create.state, update=update))
+  return(list(create_exec = create_exec, update = update))
 }
 
 #' Create an Adam optimizer with respective parameters.
@@ -148,8 +198,8 @@ mx.opt.rmsprop <- function(learning.rate=0.002,
 #' Adam: A Method for Stochastic Optimization,
 #' http://arxiv.org/abs/1412.6980
 #'
-#' @param learning.rate float, default=0.001
-#'      Step size.
+#' @param learning.rate float, default=1e-3
+#'      The initial learning rate.
 #' @param beta1 float, default=0.9
 #'      Exponential decay rate for the first moment estimates.
 #' @param beta2 float, default=0.999
@@ -159,41 +209,60 @@ mx.opt.rmsprop <- function(learning.rate=0.002,
 #'      L2 regularization coefficient add to all the weights.
 #' @param rescale.grad float, default=1.0
 #'      rescaling factor of gradient.
-#' @param clip_gradient float, optional
+#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
 #'      clip gradient in range [-clip_gradient, clip_gradient].
 #' @param lr_scheduler function, optional
 #'      The learning rate scheduler.
 #'
-mx.opt.adam <- function(learning.rate=0.001,
-                        beta1=0.9,
-                        beta2=0.999,
-                        epsilon=1e-8,
-                        wd=0,
-                        rescale.grad=1,
-                        clip_gradient = NULL,
+mx.opt.adam <- function(learning.rate = 1e-3,
+                        beta1 = 0.9,
+                        beta2 = 0.999,
+                        epsilon = 1e-8,
+                        wd = 0,
+                        rescale.grad = 1,
+                        clip_gradient = -1,
                         lr_scheduler = NULL) {
-  # use lr as short for learing rate.
+
   lr <- learning.rate
-  count       <- 0
-  num_update  <- 0
+  count <- 0
+  num_update <- 0
 
   adam <- new.env()
   adam$lr <- lr
   adam$count <- 0
   adam$num_update <- 0
 
-  create.state <- function(index, weight) {
-      return (list(mean=mx.nd.zeros(dim(weight), ctx(weight)),
-                   variance=mx.nd.zeros(dim(weight), ctx(weight))))
+  create_exec <- function(index, weight_dim, ctx) {
+
+    weight <- mx.symbol.Variable("weight")
+    grad <- mx.symbol.Variable("grad")
+    mean <- mx.symbol.Variable("mean")
+    var <- mx.symbol.Variable("var")
+
+    sym <- mx.symbol.adam_update(weight,
+                                 grad,
+                                 mean,
+                                 var,
+                                 lr = lr,
+                                 beta1 = beta1,
+                                 beta2 = beta2,
+                                 epsilon = epsilon,
+                                 wd = wd,
+                                 rescale_grad = rescale.grad,
+                                 clip_gradient = clip_gradient,
+                                 name = "w")
+
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, 
grad.req = "null")
+    return(exec)
   }
 
-  update <- function(index, weight, grad, state) {
+  update <- function(index, exec_w, weight, grad) {
     if (!is.null(lr_scheduler)){
       lr_scheduler(adam) ## changing lr
       lr <- adam$lr
       ## update count
       indexKey <- paste0('ik', index)
-      if (!exists(envir = adam, x = indexKey, inherits = FALSE)){
+      if (!exists(envir = adam, x = indexKey, inherits = FALSE)) {
         adam[[indexKey]] <- 0
       } else {
         indexValue <- adam[[indexKey]]
@@ -202,44 +271,15 @@ mx.opt.adam <- function(learning.rate=0.001,
       }
     }
 
-    # increment time
-    time.key <- paste0('t', index)
-    if (!exists(envir = adam, x = time.key, inherits = FALSE)){
-      adam[[time.key]] <- 0
-    }
-    t <- adam[[time.key]]
-    t <- t + 1
-    adam[[time.key]] <- t
-
-    mean <- state$mean
-    variance <- state$variance
-
-    grad <- grad * rescale.grad
-    if (!is.null(clip_gradient)){
-      if(clip_gradient >= 0){
-        grad <- mx.nd.clip(grad, -clip_gradient, clip_gradient)
-      } else {
-        stop("Error: clip_gradient should be positive number.")
-      }
-    }
-
-    mean <- beta1 * mean + (1 - beta1) * grad
-    variance <- beta2 * variance + (1 - beta2) * (grad * grad)
-
-    coef1 <- 1 - beta1^t
-    coef2 <- 1 - beta2^t
-    lr <- lr * sqrt(coef2)/coef1
-
-    weight <- weight - lr * mean / (mx.nd.sqrt(variance) + epsilon)
-    weight <- weight - lr * wd * weight
-
-    state <- list(mean=mean, variance=variance)
-
-    return(list(weight=weight, state=state))
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(weight = weight,grad = 
grad), match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
+    return(exec_w$ref.outputs$w_output)
   }
-  return(list(create.state=create.state, update=update))
+  return(list(create_exec = create_exec, update = update))
 }
 
+
+
 #' Create an AdaGrad optimizer with respective parameters.
 #' AdaGrad optimizer of Duchi et al., 2011,
 #'
@@ -254,38 +294,58 @@ mx.opt.adam <- function(learning.rate=0.001,
 #'      L2 regularization coefficient add to all the weights.
 #' @param rescale.grad float, default=1.0
 #'      rescaling factor of gradient.
-#' @param clip_gradient float, optional
+#' @param clip_gradient float, default=-1.0 (no clipping if < 0)
 #'      clip gradient in range [-clip_gradient, clip_gradient].
 #' @param lr_scheduler function, optional
 #'      The learning rate scheduler.
 #'
-mx.opt.adagrad <- function(learning.rate=0.05,
-                           epsilon=1e-8,
-                           wd=0,
-                           rescale.grad=1,
-                           clip_gradient = NULL,
+mx.opt.adagrad <- function(learning.rate = 0.05,
+                           epsilon = 1e-8,
+                           wd = 0,
+                           rescale.grad = 1,
+                           clip_gradient = -1,
                            lr_scheduler = NULL) {
   # use lr as short for learing rate.
   lr <- learning.rate
-  count       <- 0
-  num_update  <- 0
+  count <- 0
+  num_update <- 0
 
   adagrad <- new.env()
   adagrad$lr <- lr
   adagrad$count <- 0
   adagrad$num_update <- 0
 
-  create.state <- function(index, weight) {
-      return (mx.nd.zeros(dim(weight), ctx(weight))) #history
+  create_exec <- function(index, weight_dim, ctx) {
+
+    weight <- mx.symbol.Variable("weight")
+    grad <- mx.symbol.Variable("grad")
+    history <- mx.symbol.Variable("history")
+
+    grad <- grad * rescale.grad
+    if (!is.null(clip_gradient)) {
+      if (clip_gradient >= 0) {
+        grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = 
clip_gradient)
+      }
+    }
+
+    history <- history + (grad * grad)
+    weight <- weight - lr * (grad / mx.symbol.sqrt(history + epsilon) + wd * 
weight)
+
+    w <- mx.symbol.identity(weight, name = "w")
+    h <- mx.symbol.identity(history, name = "h")
+    sym <- mx.symbol.Group(c(w, h))
+
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, 
grad.req = "null")
+    return(exec)
   }
 
-  update <- function(index, weight, grad, state) {
-    if (!is.null(lr_scheduler)){
+  update <- function(index, exec_w, weight, grad) {
+    if (!is.null(lr_scheduler)) {
       lr_scheduler(adagrad) ## changing lr
       lr <- adagrad$lr
       ## update count
       indexKey <- paste0('ik', index)
-      if (!exists(envir = adagrad, x = indexKey, inherits = FALSE)){
+      if (!exists(envir = adagrad, x = indexKey, inherits = FALSE)) {
         adagrad[[indexKey]] <- 0
       } else {
         indexValue <- adagrad[[indexKey]]
@@ -294,25 +354,18 @@ mx.opt.adagrad <- function(learning.rate=0.05,
       }
     }
 
-    grad <- grad * rescale.grad
-    if (!is.null(clip_gradient)){
-      if(clip_gradient >= 0){
-        grad <- mx.nd.clip(grad, -clip_gradient, clip_gradient)
-      } else {
-        stop("Error: clip_gradient should be positive number.")
-      }
-    }
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(weight = weight,grad = 
grad), match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
 
-    history <- state
-    history <- history + (grad * grad)
-    weight <- weight - lr * (grad / mx.nd.sqrt(history + epsilon) + wd * 
weight)
-    state <- history
+    # update state
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(history = 
exec_w$ref.outputs$h_output), match.name = T)
 
-    return(list(weight=weight, state=state))
+    return(exec_w$ref.outputs$w_output)
   }
-  return(list(create.state=create.state, update=update))
+  return(list(create_exec = create_exec, update = update))
 }
 
+
 #' Create an AdaDelta optimizer with respective parameters.
 #'
 #' AdaDelta optimizer as described in Zeiler, M. D. (2012).
@@ -325,50 +378,64 @@ mx.opt.adagrad <- function(learning.rate=0.05,
 #'      The constant as described in the thesis.
 #' @param wd float, default=0.0
 #'      L2 regularization coefficient add to all the weights.
-#' @param rescale.grad float, default=1.0
+#' @param rescale.grad float, default=1
 #'      rescaling factor of gradient.
-#' @param clip_gradient float, optional
+#' @param clip_gradient float, default=-1 (no clipping if < 0)
 #'      clip gradient in range [-clip_gradient, clip_gradient].
 #'
-mx.opt.adadelta <- function(rho=0.90,
-                            epsilon=1e-5,
-                            wd=0,
-                            rescale.grad=1,
-                            clip_gradient = NULL) {
+mx.opt.adadelta <- function(rho = 0.90,
+                            epsilon = 1e-5,
+                            wd = 0,
+                            rescale.grad = 1,
+                            clip_gradient = -1) {
   adadelta <- new.env()
 
-  create.state <- function(index, weight) {
-    return (list(acc.g=mx.nd.zeros(dim(weight), ctx(weight)),       # 
accumulated g
-                 acc.delta=mx.nd.zeros(dim(weight), ctx(weight))))  # 
accumulated delta
-  }
+  create_exec <- function(index, weight_dim, ctx) {
+    weight <- mx.symbol.Variable("weight")
+    grad <- mx.symbol.Variable("grad")
+    acc.g <- mx.symbol.Variable("acc.g")
+    acc.delta <- mx.symbol.Variable("acc.delta")
 
-  update <- function(index, weight, grad, state) {
-    # preprocess grad
     grad <- grad * rescale.grad
-    if (!is.null(clip_gradient)){
-      if(clip_gradient >= 0){
-        grad <- mx.nd.clip(grad, -clip_gradient, clip_gradient)
-      } else {
-        stop("Error: clip_gradient should be positive number.")
+    if (!is.null(clip_gradient)) {
+      if (clip_gradient >= 0) {
+        grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = 
clip_gradient)
       }
     }
 
-    # accumulated g and delta initlization
-    acc.g <- state$acc.g
-    acc.delta <- state$acc.delta
-
-    # update g, delta
+    # update state (acc.g, acc.delta)
     acc.g <- rho * acc.g + (1 - rho) * (grad * grad)
-    current.delta <- mx.nd.sqrt(acc.delta + epsilon) / mx.nd.sqrt(acc.g + 
epsilon) * grad
+    current.delta <- mx.symbol.sqrt(acc.delta + epsilon) / 
mx.symbol.sqrt(acc.g + epsilon) * grad
     acc.delta <- rho * acc.delta + (1 - rho) * (current.delta * current.delta)
     weight <- weight - current.delta - wd * weight
-    state <- list(acc.g=acc.g, acc.delta=acc.delta)
 
-    return(list(weight=weight, state=state))
+    w <- mx.symbol.identity(weight, name = "w")
+    g <- mx.symbol.identity(acc.g, name = "g")
+    delta <- mx.symbol.identity(acc.delta, name = "delta")
+    sym <- mx.symbol.Group(c(w, g, delta))
+
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, 
grad.req = "null")
+    return(exec)
   }
-  return(list(create.state=create.state, update=update))
+
+  update <- function(index, exec_w, weight, grad) {
+
+    mx.exec.update.arg.arrays(exec_w, arg.arrays = list(weight = weight,grad = 
grad), match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
+
+    # update state
+    mx.exec.update.arg.arrays(exec_w,
+                              arg.arrays = list(
+                                acc.g = exec_w$ref.outputs$g_output,
+                                acc.delta = exec_w$ref.outputs$delta_output),
+                              match.name = T)
+
+    return(exec_w$ref.outputs$w_output)
+  }
+  return(list(create_exec = create_exec, update = update))
 }
 
+
 #' Create an optimizer by name and parameters
 #'
 #' @param name The name of the optimizer
@@ -392,31 +459,28 @@ mx.opt.create <- function(name, ...) {
 #' @param weights The weights to be optimized
 #'
 #' @export
-mx.opt.get.updater <- function(optimizer, weights) {
-  # This is the list to keep track of internal states of optimzer
-  state.list <- lapply(seq_along(weights), function(i) {
-    if (is.null(weights[[i]])) return(NULL)
-    optimizer$create.state(i, weights[[i]])
+mx.opt.get.updater <- function(optimizer, weights, ctx) {
+
+  exec_list <- lapply(seq_along(weights), function(i) {
+    if (is.null(weights[[i]])) {
+      return(NULL)
+    } else {
+      optimizer$create_exec(index = i, weight_dim = dim(weights[[i]]), ctx = 
ctx)
+    }
   })
+
   update <- optimizer$update
 
   update.closure <- function(weight, grad) {
-    ulist <- lapply(seq_along(weight), function(i) {
+
+    weight_list <- lapply(seq_along(weight), function(i) {
       if (!is.null(grad[[i]])) {
-        update(i, weight[[i]], grad[[i]], state.list[[i]])
+        return(update(i, exec_list[[i]], weight[[i]], grad[[i]]))
       } else {
         return(NULL)
       }
     })
-    # update state list, use mutate assignment
-    state.list <<- lapply(ulist, function(x) {
-      x$state
-    })
-    # return updated weight list
-    weight.list <- lapply(ulist, function(x) {
-      x$weight
-    })
-    return(weight.list)
+    return(weight_list)
   }
   return(update.closure)
 }
diff --git a/R-package/tests/testthat/test_optimizer.R 
b/R-package/tests/testthat/test_optimizer.R
new file mode 100644
index 00000000000..c6dacaa728b
--- /dev/null
+++ b/R-package/tests/testthat/test_optimizer.R
@@ -0,0 +1,204 @@
+context("optimizer")
+
+test_that("sgd", {
+
+  data = mx.symbol.Variable('data')
+  label = mx.symbol.Variable('label')
+  fc_weight = mx.symbol.Variable('fc_weight')
+  fc = mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T, 
name = 'fc1', num_hidden = 1)
+  loss = mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
'loss')
+
+  x <- mx.nd.array(array(1:6, dim=2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2,1)))
+
+  exec <- mxnet:::mx.symbol.bind(symbol = loss,
+                                 ctx = mx.cpu(),
+                                 arg.arrays = list(data = x,
+                                                   fc1_weight = w1,
+                                                   label = y),
+                                 aux.arrays = NULL,
+                                 grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("sgd",
+                             learning.rate = 1,
+                             momentum = 0,
+                             wd = 0,
+                             rescale.grad = 1,
+                             clip_gradient = -1)
+
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2,1)), 
tolerance = 1e-1)
+
+})
+
+
+test_that("rmsprop", {
+
+  data = mx.symbol.Variable('data')
+  label = mx.symbol.Variable('label')
+  fc_weight = mx.symbol.Variable('fc_weight')
+  fc = mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T, 
name = 'fc1', num_hidden = 1)
+  loss = mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
'loss')
+
+  x <- mx.nd.array(array(1:6, dim=2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2,1)))
+
+  exec <- mxnet:::mx.symbol.bind(symbol = loss,
+                                 ctx = mx.cpu(),
+                                 arg.arrays = list(data = x,
+                                                   fc1_weight = w1,
+                                                   label = y),
+                                 aux.arrays = NULL,
+                                 grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("rmsprop", learning.rate = 1,
+                             centered = TRUE,
+                             gamma1 = 0.95,
+                             gamma2 = 0.9,
+                             epsilon = 1e-4,
+                             wd = 0,
+                             rescale.grad = 1,
+                             clip_gradient = -1)
+
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+
+  expect_equal(as.array(arg.blocks[[2]]), array(c(5.64, 6.38), dim = c(2,1)), 
tolerance = 1e-1)
+
+})
+
+
+test_that("adam", {
+
+  data = mx.symbol.Variable('data')
+  label = mx.symbol.Variable('label')
+  fc_weight = mx.symbol.Variable('fc_weight')
+  fc = mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T, 
name = 'fc1', num_hidden = 1)
+  loss = mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
'loss')
+
+  x <- mx.nd.array(array(1:6, dim=2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2,1)))
+
+  exec <- mxnet:::mx.symbol.bind(symbol = loss,
+                                 ctx = mx.cpu(),
+                                 arg.arrays = list(data = x,
+                                                   fc1_weight = w1,
+                                                   label = y),
+                                 aux.arrays = NULL,
+                                 grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("adam",
+                             learning.rate = 1,
+                             beta1 = 0.9,
+                             beta2 = 0.999,
+                             epsilon = 1e-8,
+                             wd = 0,
+                             rescale.grad = 1,
+                             clip_gradient = -1)
+
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+
+  expect_equal(as.array(arg.blocks[[2]]), array(c(4.26, 4.96), dim = c(2,1)), 
tolerance = 1e-1)
+
+})
+
+
+test_that("adagrad", {
+
+  data = mx.symbol.Variable('data')
+  label = mx.symbol.Variable('label')
+  fc_weight = mx.symbol.Variable('fc_weight')
+  fc = mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T, 
name = 'fc1', num_hidden = 1)
+  loss = mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
'loss')
+
+  x <- mx.nd.array(array(1:6, dim=2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2,1)))
+
+  exec <- mxnet:::mx.symbol.bind(symbol = loss,
+                                 ctx = mx.cpu(),
+                                 arg.arrays = list(data = x,
+                                                   fc1_weight = w1,
+                                                   label = y),
+                                 aux.arrays = NULL,
+                                 grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("adagrad",
+                             learning.rate = 1,
+                             epsilon = 1e-8,
+                             wd = 0,
+                             rescale.grad = 1,
+                             clip_gradient = -1)
+
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+
+  expect_equal(as.array(arg.blocks[[2]]), array(c(2.1, 2.8), dim = c(2,1)), 
tolerance = 1e-1)
+
+})
+
+
+test_that("adadelta", {
+
+  data = mx.symbol.Variable('data')
+  label = mx.symbol.Variable('label')
+  fc_weight = mx.symbol.Variable('fc_weight')
+  fc = mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T, 
name = 'fc1', num_hidden = 1)
+  loss = mx.symbol.LinearRegressionOutput(data = fc, label = label, name = 
'loss')
+
+  x <- mx.nd.array(array(1:6, dim=2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2,1)))
+
+  exec <- mxnet:::mx.symbol.bind(symbol = loss,
+                                 ctx = mx.cpu(),
+                                 arg.arrays = list(data = x,
+                                                   fc1_weight = w1,
+                                                   label = y),
+                                 aux.arrays = NULL,
+                                 grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("adadelta",
+                             rho = 0.90,
+                             epsilon = 1e-5,
+                             wd = 0,
+                             rescale.grad = 1,
+                             clip_gradient = -1)
+
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = 
mx.cpu())
+
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2,1)), 
tolerance = 1e-1)
+
+})


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to