diyang commented on issue #10805: SKIP RNN is incorrect in LSTnet
URL: 
https://github.com/apache/incubator-mxnet/issues/10805#issuecomment-386563096
 
 
   I have used MxNet R to implement SKIP RNN
   You may find it in this function.
   
https://github.com/diyang/deeplearning.mxnet/blob/master/LSTnet/src/lstnet_model.R
   
   I used queue to contain the hidden states of 24 hours, then I will pop the 
queue head, and then push the newly yielded hidden state of current hour into 
the queue tail. 
   
   ```R
   rnn.skip.unroll<-function(data, 
                        num.rnn.layer=1,
                        seq.len,
                        num.hidden,
                        seasonal.period,
                        dropout=0,
                        config="gru")
   {
     param.cells <- list()
     last.states <- list()
     for( i in 1:num.rnn.layer){
       if(config == "gru"){
         param.cells[[i]] <- list(gates.i2h.weight = 
mx.symbol.Variable(paste0("l", i, ".gates.i2h.weight")),
                                  gates.i2h.bias = 
mx.symbol.Variable(paste0("l", i, ".gates.i2h.bias")),
                                  gates.h2h.weight = 
mx.symbol.Variable(paste0("l", i, ".gates.h2h.weight")),
                                  gates.h2h.bias = 
mx.symbol.Variable(paste0("l", i, ".gates.h2h.bias")),
                                  
                                  trans.i2h.weight = 
mx.symbol.Variable(paste0("l", i, ".trans.i2h.weight")),
                                  trans.i2h.bias = 
mx.symbol.Variable(paste0("l", i, ".trans.i2h.bias")),
                                  trans.h2h.weight = 
mx.symbol.Variable(paste0("l", i, ".trans.h2h.weight")),
                                  trans.h2h.bias = 
mx.symbol.Variable(paste0("l", i, ".trans.h2h.bias")))
         state <- list(h=mx.symbol.Variable(paste0("l", i, ".gru.init.h")))
       }else{
         param.cells[[i]] <- list(i2h.weight = mx.symbol.Variable(paste0("l", 
i, ".i2h.weight")),
                                  i2h.bias = mx.symbol.Variable(paste0("l", i, 
".i2h.bias")),
                                  h2h.weight = mx.symbol.Variable(paste0("l", 
i, ".h2h.weight")),
                                  h2h.bias = mx.symbol.Variable(paste0("l", i, 
".h2h.bias")))
         state <- list(c=mx.symbol.Variable(paste0("l", i, ".lstm.init.c")),
                       h=mx.symbol.Variable(paste0("l", i, ".lstm.init.h")))
       }
       last.states[[i]] <- state
     }
     
     data_seq_slice = mx.symbol.SliceChannel(data=data, num_outputs=seq.len, 
axis=2, squeeze_axis=1)
     
     last.hidden <- list()
     #it's a queue
     seasonal.states <- list()
     for (seqidx in 1:seq.len){
       hidden <- data_seq_slice[[seqidx]]
       # stack lstm
       if(seqidx <= seasonal.period){
         for (i in 1:num.rnn.layer){
           dropout <- ifelse(i==1, 0, dropout)
           prev.state <- last.states[[i]]
           
           if(config == "gru"){
             next.state <- gru.cell(num.hidden,
                                    indata = hidden,
                                    prev.state = prev.state,
                                    param = param.cells[[i]],
                                    seqidx = seqidx,
                                    layeridx = i,
                                    dropout = dropout)
           }else{
             next.state <- lstm.cell(num.hidden,
                                     indata = hidden,
                                     prev.state = prev.state,
                                     param = param.cells[[i]],
                                     seqidx = seqidx,
                                     layeridx = i,
                                     dropout = dropout)
           }
           hidden <- next.state$h
           last.states[[i]] <- next.state
         }
         seasonal.states <- c(seasonal.states, last.states)
       }else{
         for (i in 1:num.rnn.layer){
           dropout <- ifelse(i==1, 0, dropout)
           prev.state <- seasonal.states[[1]]
           seasonal.states <- seasonal.states[-1]
           if(config == "gru"){
             next.state <- gru.cell(num.hidden,
                                    indata = hidden,
                                    prev.state = prev.state,
                                    param = param.cells[[i]],
                                    seqidx = seqidx,
                                    layeridx = i,
                                    dropout = dropout)
           }else{
             next.state <- lstm.cell(num.hidden,
                                     indata = hidden,
                                     prev.state = prev.state,
                                     param = param.cells[[i]],
                                     seqidx = seqidx,
                                     layeridx = i,
                                     dropout = dropout)
           }
           hidden <- next.state$h
           last.states[[i]] <- next.state
         }
         seasonal.states <- c(seasonal.states, last.states)
       }
       
       # Aggeregate outputs from each timestep
       last.hidden <- c(last.hidden, hidden)
     }
     list.all <- list(outputs = last.hidden, last.states = last.states)
     
     return(list.all)
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to