kmarlen1988 commented on issue #1546: predict returning the same value for all 
observations
URL: 
https://github.com/apache/incubator-mxnet/issues/1546#issuecomment-356256041
 
 
   I'm using googlenet in mxnetR but the predicted probabilities are the same 
for all images. I tried to change the activation function, the batch size and i 
normalized the input to [-1,1]
   The code is the following:
   num_classes = 2
          
        ConvFactory <- function(data, num_filter, kernel, stride = c(1, 1), pad 
= c(0, 0),
                           name = '', suffix = '') {
                conv <- mx.symbol.Convolution(data = data, num_filter = 
num_filter, kernel = kernel, stride = stride,
                                                                          pad = 
pad, name = paste('conv_', name, suffix, sep = ""))
                act <- mx.symbol.Activation(data = conv, act_type = 'relu', 
name = paste('relu_', name, suffix, sep = ''))
                return(act)
        }
   
        InceptionFactory <- function(data, num_1x1, num_3x3red, num_3x3,
                                num_d5x5red, num_d5x5, pool, proj, name) {
                # 1x1
                c1x1 <- ConvFactory(data = data, num_filter = num_1x1, kernel = 
c(1, 1),
                                                        name = paste(name, 
'_1x1', sep = ''))
                # 3x3 reduce + 3x3
                c3x3r = ConvFactory(data = data, num_filter = num_3x3red, 
kernel = c(1, 1),
                                                        name = paste(name, 
'_3x3', sep = ''), suffix = '_reduce')
                c3x3 = ConvFactory(data = c3x3r, num_filter = num_3x3, kernel = 
c(3, 3),
                                                   pad = c(1, 1), name = 
paste(name, '_3x3', sep = ''))
                # double 3x3 reduce + double 3x3
                cd5x5r = ConvFactory(data = data, num_filter = num_d5x5red, 
kernel = c(1, 1),
                                                         name = paste(name, 
'_5x5', sep = ''), suffix = '_reduce')
                cd5x5 = ConvFactory(data = cd5x5r, num_filter = num_d5x5, 
kernel = c(5, 5), pad = c(2, 2),
                                                        name = paste(name, 
'_5x5', sep = ''))
                # pool + proj
                pooling = mx.symbol.Pooling(data = data, kernel = c(3, 3), 
stride = c(1, 1), 
                                                                        pad = 
c(1, 1), pool_type = pool,
                                                                        name = 
paste(pool, '_pool_', name, '_pool', sep = ''))
   
                cproj = ConvFactory(data = pooling, num_filter = proj, kernel = 
c(1, 1), 
                                                        name = paste(name, 
'_proj', sep = ''))
                # concat
                concat_lst <- list()
                concat_lst <- c(c1x1, c3x3, cd5x5, cproj)
                concat_lst$num.args = 4
                concat_lst$name = paste('ch_concat_', name, '_chconcat', sep = 
'')
                concat = mxnet:::mx.varg.symbol.Concat(concat_lst)
                return(concat)
        }
   
        
        data <- mx.symbol.Variable("data")
        conv1 <- ConvFactory(data, 64, kernel = c(7, 7), stride = c(2, 2), pad 
= c(3, 3), name = "conv1")
        pool1 <- mx.symbol.Pooling(conv1, kernel = c(3, 3), stride = c(2, 2), 
pool_type = "max", pooling_convention='full')
        conv2 <- ConvFactory(pool1, 64, kernel = c(1, 1), stride = c(1, 1), 
name = "conv2")
        conv3 <- ConvFactory(conv2, 192, kernel = c(3, 3), stride = c(1, 1), 
pad = c(1, 1), name = "conv3")
        pool3 <- mx.symbol.Pooling(conv3, kernel = c(3, 3), stride = c(2, 2), 
pool_type = "max", pooling_convention='full')
          
        in3a <- InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name = 
"in3a")
        in3b <- InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name = 
"in3b")
        pool4 <- mx.symbol.Pooling(in3b, kernel = c(3, 3), stride = c(2, 2), 
pool_type = "max", pooling_convention='full')
        in4a <- InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name = 
"in4a")
        in4b <- InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name = 
"in4b")
        in4c <- InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name = 
"in4c")
        in4d <- InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name = 
"in4d")
        in4e <- InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name 
= "in4e")
        pool5 <- mx.symbol.Pooling(in4e, kernel = c(3, 3), stride = c(2, 2), 
pool_type = "max", pooling_convention='full')
        in5a <- InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, 
name = "in5a")
        in5b <- InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name 
= "in5b")
        pool6 <- mx.symbol.Pooling(in5b, kernel = c(7, 7), stride = c(1, 1), 
pool_type = "avg" )
        flatten <- mx.symbol.Flatten(data = pool6)
        fc1 <- mx.symbol.FullyConnected(data = flatten, num_hidden = 
num_classes)
        googlenet <- mx.symbol.SoftmaxOutput(data = fc1, name = 'softmax')
   
   model <- mx.model.FeedForward.create(googlenet, X = train_array, y = train_y,
                                        ctx = device.cpu,
                                        num.round = 20,
                                        array.batch.size = 32,
                                        learning.rate = 0.05,
                                        momentum = 0.9,
                                        wd = 0.00001,
                                        eval.metric = mx.metric.accuracy, 
initializer=mx.init.uniform(0.07),
                                        epoch.end.callback = 
mx.callback.save.checkpoint("checkpoint"))
   
        predict_probs <- predict(model, test_array)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to