lichen11 commented on issue #7968: [R] Transfer Learning using VGG-16
URL: 
https://github.com/apache/incubator-mxnet/issues/7968#issuecomment-332622479
 
 
   Hi, I believe I have set up the transfer learning correctly, retraining the 
last fully connected layer. I also made sure the names match the pretrained vgg 
model. However, R would always crash. When I use on inception-bn or 
inception-v3, it works fine.
   
   Is there another source to download mxnet vgg weights? 
   
   Below is my code for vgg transfer learning. The data is using the data from 
the cat/dog classification problem 
https://statist-bhfz.github.io/cats_dogs_finetune.
   
       vgg <- mx.model.load("vgg19", iteration = 0)
       symbol <- vgg$symbol
       internals <- symbol$get.internals()
       outputs <- internals$outputs
   
       drop7 <- internals$get.output(which(outputs == "drop7_output"))
       fc_final <- mx.symbol.FullyConnected(data = drop7, num.hidden = 2, name 
= 'fc8')
       new_soft <- mx.symbol.SoftmaxOutput(data = fc_final, name = 'prob')
   
       arg_params_new <- mxnet:::mx.model.init.params(
         symbol = new_soft, 
         input.shape = c(224, 224, 3, 8), 
         output.shape = (8),
         initializer = mxnet:::mx.init.uniform(0.1), 
         ctx = mx.gpu(0)
       )$arg.params
   
       fc8_weights_new <- arg_params_new[["fc8_weight"]]
       fc8_bias_new <- arg_params_new[["fc8_bias"]]
   
       arg_params_new <- vgg$arg.params
       arg_params_new[["fc8_weight"]] <- fc8_weights_new 
       arg_params_new[["fc8_bias"]] <- fc8_bias_new 
   
   
       model <- mx.model.FeedForward.create(
         symbol             = new_soft,
         X                  = train,
         eval.data          = val,
         ctx                = mx.gpu(0),
         eval.metric        = mx.metric.accuracy,
         num.round          = 1,
         learning.rate      = 0.05,
         momentum           = 0.9,
         wd                 = 0.00001,
         kvstore            = "local",
         array.batch.size   = 128,
         epoch.end.callback = mx.callback.save.checkpoint("vgg"), 
         batch.end.callback = mx.callback.log.train.metric(150),
         initializer        = mx.init.Xavier(factor_type = "in", magnitude = 
2.34),
         optimizer          = "sgd",
         arg.params         = arg_params_new,
         aux.params         = vgg$aux.params
       )
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to