diyang commented on issue #10805: SKIP RNN is incorrect in LSTnet
URL: 
https://github.com/apache/incubator-mxnet/issues/10805#issuecomment-386563096
 
 
   I have used MxNet R to implement SKIP RNN
   You may find it in this function.
   
https://github.com/diyang/deeplearning.mxnet/blob/master/LSTnet/src/lstnet_model.R
   
   ```R
   rnn.skip.unroll<-function(data, 
                        num.rnn.layer=1,
                        seq.len,
                        num.hidden,
                        seasonal.period,
                        dropout=0,
                        config="gru")
   {
     param.cells <- list()
     last.states <- list()
     for( i in 1:num.rnn.layer){
       if(config == "gru"){
         param.cells[[i]] <- list(gates.i2h.weight = 
mx.symbol.Variable(paste0("l", i, ".gates.i2h.weight")),
                                  gates.i2h.bias = 
mx.symbol.Variable(paste0("l", i, ".gates.i2h.bias")),
                                  gates.h2h.weight = 
mx.symbol.Variable(paste0("l", i, ".gates.h2h.weight")),
                                  gates.h2h.bias = 
mx.symbol.Variable(paste0("l", i, ".gates.h2h.bias")),
                                  
                                  trans.i2h.weight = 
mx.symbol.Variable(paste0("l", i, ".trans.i2h.weight")),
                                  trans.i2h.bias = 
mx.symbol.Variable(paste0("l", i, ".trans.i2h.bias")),
                                  trans.h2h.weight = 
mx.symbol.Variable(paste0("l", i, ".trans.h2h.weight")),
                                  trans.h2h.bias = 
mx.symbol.Variable(paste0("l", i, ".trans.h2h.bias")))
         state <- list(h=mx.symbol.Variable(paste0("l", i, ".gru.init.h")))
       }else{
         param.cells[[i]] <- list(i2h.weight = mx.symbol.Variable(paste0("l", 
i, ".i2h.weight")),
                                  i2h.bias = mx.symbol.Variable(paste0("l", i, 
".i2h.bias")),
                                  h2h.weight = mx.symbol.Variable(paste0("l", 
i, ".h2h.weight")),
                                  h2h.bias = mx.symbol.Variable(paste0("l", i, 
".h2h.bias")))
         state <- list(c=mx.symbol.Variable(paste0("l", i, ".lstm.init.c")),
                       h=mx.symbol.Variable(paste0("l", i, ".lstm.init.h")))
       }
       last.states[[i]] <- state
     }
     
     data_seq_slice = mx.symbol.SliceChannel(data=data, num_outputs=seq.len, 
axis=2, squeeze_axis=1)
     
     last.hidden <- list()
     #it's a queue
     seasonal.states <- list()
     for (seqidx in 1:seq.len){
       hidden <- data_seq_slice[[seqidx]]
       # stack lstm
       if(seqidx <= seasonal.period){
         for (i in 1:num.rnn.layer){
           dropout <- ifelse(i==1, 0, dropout)
           prev.state <- last.states[[i]]
           
           if(config == "gru"){
             next.state <- gru.cell(num.hidden,
                                    indata = hidden,
                                    prev.state = prev.state,
                                    param = param.cells[[i]],
                                    seqidx = seqidx,
                                    layeridx = i,
                                    dropout = dropout)
           }else{
             next.state <- lstm.cell(num.hidden,
                                     indata = hidden,
                                     prev.state = prev.state,
                                     param = param.cells[[i]],
                                     seqidx = seqidx,
                                     layeridx = i,
                                     dropout = dropout)
           }
           hidden <- next.state$h
           last.states[[i]] <- next.state
         }
         seasonal.states <- c(seasonal.states, last.states)
       }else{
         for (i in 1:num.rnn.layer){
           dropout <- ifelse(i==1, 0, dropout)
           prev.state <- seasonal.states[[1]]
           seasonal.states <- seasonal.states[-1]
           if(config == "gru"){
             next.state <- gru.cell(num.hidden,
                                    indata = hidden,
                                    prev.state = prev.state,
                                    param = param.cells[[i]],
                                    seqidx = seqidx,
                                    layeridx = i,
                                    dropout = dropout)
           }else{
             next.state <- lstm.cell(num.hidden,
                                     indata = hidden,
                                     prev.state = prev.state,
                                     param = param.cells[[i]],
                                     seqidx = seqidx,
                                     layeridx = i,
                                     dropout = dropout)
           }
           hidden <- next.state$h
           last.states[[i]] <- next.state
         }
         seasonal.states <- c(seasonal.states, last.states)
       }
       
       # Aggeregate outputs from each timestep
       last.hidden <- c(last.hidden, hidden)
     }
     list.all <- list(outputs = last.hidden, last.states = last.states)
     
     return(list.all)
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to