Repository: systemml
Updated Branches:
  refs/heads/master 1e5984cca -> e4220e3bc


[SYSTEMML-445] Added builtin functions for efficient computation of 
lstm_backward function

- The current implementation treats lstm and lstm_backward as stateless
  function for simplicity. We can revisit this after performance testing.
- Removed reserve parameter from lstm builtin function.
- Updated the language reference and lstm_staging.dml file.
- Added necessary kernels for transforming input to the format required by
  lstm_backward function.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e4220e3b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e4220e3b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e4220e3b

Branch: refs/heads/master
Commit: e4220e3bc2559056f94ef08cd1c9006315cf229a
Parents: 1e5984c
Author: Niketan Pansare <[email protected]>
Authored: Fri Jun 15 13:07:59 2018 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jun 15 13:07:59 2018 -0700

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |   2 +-
 scripts/nn/layers/lstm_staging.dml              |  85 +--
 src/main/cpp/kernels/SystemML.cu                | 121 ++-
 src/main/cpp/kernels/SystemML.ptx               | 760 ++++++++++++++-----
 .../java/org/apache/sysml/hops/FunctionOp.java  |   9 +-
 .../sysml/parser/BuiltinFunctionExpression.java |  60 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |   1 +
 .../org/apache/sysml/parser/Expression.java     |   2 +-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../gpu/ConvolutionGPUInstruction.java          |  79 +-
 .../instructions/gpu/context/GPUObject.java     |   4 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |  98 ++-
 .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java |  48 +-
 13 files changed, 962 insertions(+), 308 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 3212806..5bf9099 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1520,7 +1520,7 @@ Hence, the images are internally represented as a matrix 
with dimension (N, C *
 | max_pool_backward, avg_pool_backward        | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_out* width_out]        | [batch_size X num_channels* 
height_image* width_image]                                      | 
stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, 
num_channels, height_image, width_image], pool_size=[height_pool, width_pool]   
                                | Computes the gradients wrt input of 2D max 
pooling, average pooling                                                        
                       |
 | bias_add                                    | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Adds the bias (row vector of size num_channels) to input with 
the given num_channels                                                          
    |
 | bias_multiply                               | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Multiplies the bias (row vector of size num_channels) to 
input with the given num_channels                                               
         |
-| lstm                                        | X,  W, bias, out0, c0    | 
[batch_size X seq_length*num_features]                    | 
[num_features+hidden_size X 4*hidden_size]                | [batch_size X 
seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | 
return_sequences                                                                
                                                                                
                              | Perform computation for single-layer 
unidirectional LSTM (outputs: out, carryOut, reserveSpace)                      
                             |
+| lstm                                        | X,  W, bias, out0, c0    | 
[batch_size X seq_length*num_features]                    | 
[num_features+hidden_size X 4*hidden_size]                | [batch_size X 
seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | 
return_sequences                                                                
                                                                                
                              | Perform computation for single-layer 
unidirectional LSTM (outputs: out, carryOut)                                    
                             |
 | batch_norm2d                                | input                    | 
[batch_size X num_channels* height_image* width_image]    |                     
                                      | [batch_size X num_channels* 
height_image* width_image]                                      | scale, shift, 
exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, 
epsilon, momentum                                                               
                        | Performs batch normalization operation  (outputs: 
updated exponential moving average mean and variance, cache of the batch mean 
and variance)     |
 | batch_norm2d_backward                       | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_image* width_image]    | [batch_size X num_channels* 
height_image* width_image]                                      | scale, 
epsilon, cache_mean (from forward), cache_inv_var (from forward)                
                                                                                
                       | Computed backpropagation error for batch normalization 
operation                                                                       
           |
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/scripts/nn/layers/lstm_staging.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/lstm_staging.dml 
b/scripts/nn/layers/lstm_staging.dml
index d0949d9..886b88c 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -57,17 +57,15 @@ forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b,
    *  - out: If `return_sequences` is True, outputs for all timesteps,
    *      of shape (N, T*M).  Else, outputs for the final timestep, of
    *      shape (N, M).
-   *  - c: Cell state for final timestep, of shape (N, M).
-   *  - reserveSpace: reserveSpace to be passed to output (row-vector whose 
size is determined at runtime). 
+   *  - c: Cell state for final timestep, of shape (N, M). 
    */
-  [out, c, reserveSpace] = lstm(X, W, b, out0, c0, return_sequences)
+  out = 0; c = c0;
+  [out, c] = lstm(X, W, b, out0, c0, return_sequences)
 }
 
-# TODO:
 backward = function(matrix[double] dout, matrix[double] dc,
-                    matrix[double] X, matrix[double] W, matrix[double] b, int 
T, int D,
-                    boolean given_sequences, matrix[double] out0, 
matrix[double] c0,
-                    matrix[double] cache_out, matrix[double] cache_c, 
matrix[double] cache_ifog)
+                    matrix[double] X, matrix[double] W, matrix[double] b,
+                    boolean given_sequences, matrix[double] out0, 
matrix[double] c0)
     return (matrix[double] dX, matrix[double] dW, matrix[double] db,
             matrix[double] dout0, matrix[double] dc0) {
   /*
@@ -87,8 +85,6 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *  - X: Inputs, of shape (N, T*D).
    *  - W: Weights, of shape (D+M, 4M).
    *  - b: Biases, of shape (1, 4M).
-   *  - T: Length of example sequences (number of timesteps).
-   *  - D: Dimensionality of the input features.
    *  - given_sequences: Whether `dout` is for all timesteps,
    *      or just for the final timestep.  This is based on whether
    *      `return_sequences` was true in the forward pass.
@@ -110,75 +106,8 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *  - dout0: Gradient wrt `out0`, of shape (N, M).
    *  - dc0: Gradient wrt `c0`, of shape (N, M).
    */
-  N = nrow(X)
-  M = as.integer(ncol(W)/4)
-  N1 = nrow(out0)
-  if(N != N1) {
-    # Allow for smaller out0 for last batch 
-    # out0 = out0[1:N,]
-    # c0 = c0[1:N,]
-    stop("Unsupported operation: The batch size of previous iteration " + N1 + 
" is different than the batch size of current iteration " + N)
-  }
-  dX = matrix(0, rows=N, cols=T*D)
-  dW = matrix(0, rows=D+M, cols=4*M)
-  db = matrix(0, rows=1, cols=4*M)
-  dout0 = matrix(0, rows=N, cols=M)
-  dc0 = matrix(0, rows=N, cols=M)
-  dct = dc
-  if (!given_sequences) {
-    # only given dout for output at final timestep, so prepend empty douts for 
all other timesteps
-    dout = cbind(matrix(0, rows=N, cols=(T-1)*M), dout)  # shape (N, T*M)
-  }
-
-  t = T
-  for (iter in 1:T) {  # each timestep in reverse order
-    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
-    dout_t = dout[,(t-1)*M+1:t*M]  # shape (N, M)
-    out_t = matrix(cache_out[t,], rows=N, cols=M)  # shape (N, M)
-    ct = matrix(cache_c[t,], rows=N, cols=M)  # shape (N, M)
-    if (t == 1) {
-      out_prev = out0  # shape (N, M)
-      c_prev = c0  # shape (N, M)
-    }
-    else {
-      out_prev = matrix(cache_out[t-1,], rows=N, cols=M)  # shape (N, M)
-      c_prev = matrix(cache_c[t-1,], rows=N, cols=M)  # shape (N, M)
-    }
-    input = cbind(X_t, out_prev)  # shape (N, D+M)
-    ifog = matrix(cache_ifog[t,], rows=N, cols=4*M)
-    i = ifog[,1:M]  # input gate, shape (N, M)
-    f = ifog[,M+1:2*M]  # forget gate, shape (N, M)
-    o = ifog[,2*M+1:3*M]  # output gate, shape (N, M)
-    g = ifog[,3*M+1:4*M]  # g gate, shape (N, M)
-
-    dct = dct + o*tanh::backward(dout_t, ct)  # shape (N, M)
-    do = tanh::forward(ct) * dout_t  # output gate, shape (N, M)
-    df = c_prev * dct  # forget gate, shape (N, M)
-    dc_prev = f * dct  # shape (N, M)
-    di = g * dct  # input gate, shape (N, M)
-    dg = i * dct  # g gate, shape (N, M)
-
-    di_raw = i * (1-i) * di
-    df_raw = f * (1-f) * df
-    do_raw = o * (1-o) * do
-    dg_raw = (1-g^2) * dg
-    difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw)  # shape (N, 4M)
-
-    dW = dW + t(input) %*% difog_raw  # shape (D+M, 4M)
-    db = db + colSums(difog_raw)  # shape (1, 4M)
-    dinput = difog_raw %*% t(W)  # shape (N, D+M)
-    dX[,(t-1)*D+1:t*D] = dinput[,1:D]
-    dout_prev = dinput[,D+1:D+M]  # shape (N, M)
-    if (t == 1) {
-      dout0 = dout_prev  # shape (N, M)
-      dc0 = dc_prev  # shape (N, M)
-    }
-    else {
-      dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev  # shape 
(N, M)
-      dct = dc_prev  # shape (N, M)
-    }
-    t = t - 1
-  }
+  dX = X; dW = W; db = b; dout0 = out0; dc0 = c0
+  [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, 
dout, dc)
 }
 
 init = function(int N, int D, int M)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index cc2d531..082daac 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -1987,11 +1987,7 @@ __device__ int swap_co(int offset) {
   return (offset < 2) ? offset : (offset == 2 ? 3 : 2);
 }
 
-template <typename T>
-__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, 
int D, int M) {
-  int DM = D*M; int MM = M*M; int DM4 = DM*4; 
-  int M4 = M*4;
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
+__device__ void compute_lstm_weight_indexes(int index, int D, int M, int* ret) 
{
   // input: cbind(X_t, out_prev) => [N, D+M], weight: [D+M, 4M]
   // 
https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnGetRNNLinLayerMatrixParams
 states that 
   // Elements in each weight matrix are arranged in the row-major order, but 
the column-major format works !!
@@ -1999,20 +1995,16 @@ __device__ void prepare_lstm_weight(T* smlWeight, T* 
smlBias, T* cudnnWeight, in
   // CuDNN weight order: w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
   // SystemML weight order: i, f, o, c; TF weight order: i, c, f, o
   // SystemML performs (X_t %*% W + out_prev %*% R) => [N, 4*M]
-  
-  // bias layout: bi bf bc bo 0 0 0 0
-  // where W: [DxM], R: [MxM] and b: [1x1]
-  
-  // Maximum (D+M+2)*M4 threads
-  int srcIndex = -1; int destIndex;
+  int DM = D*M; int MM = M*M; int DM4 = DM*4; 
+  int M4 = M*4;
   if(index < DM4) {
     // Fill w_i, w_f, w_c and w_o
     int localIndex = index%DM;
     int smlRowIndex = localIndex/M; 
     int smlColIndex = swap_co(index/(DM))*M + localIndex%M;
     // Convert index to column-major where index = (index/(DM))*DM + 
(localIndex/M)*M + localIndex%M
-    destIndex = (index/(DM))*DM + (localIndex%M)*D + localIndex/M;
-    srcIndex = smlRowIndex*M4+smlColIndex;
+    ret[1] = (index/(DM))*DM + (localIndex%M)*D + localIndex/M;
+    ret[0] = smlRowIndex*M4+smlColIndex;
   }
   else if(index < (D+M)*M4) {
     // Fill r_i, r_f, r_c and r_o
@@ -2021,18 +2013,29 @@ __device__ void prepare_lstm_weight(T* smlWeight, T* 
smlBias, T* cudnnWeight, in
     int smlRowIndex = D + (localIndex / M);
     int smlColIndex = swap_co(tmpIndex/(MM))*M + localIndex%M;
     // Convert index to column-major where index = DM4 + (tmpIndex/(MM))*MM + 
(localIndex/M)*M + localIndex%M
-    destIndex = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M;
-    srcIndex = smlRowIndex*M4+smlColIndex;
+    ret[1] = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M;
+    ret[0] = smlRowIndex*M4+smlColIndex;
+  }
+}
+
+template <typename T>
+__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, 
int D, int M) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  // Maximum (D+M+2)*M4 threads
+  int M4 = M*4;
+  if(index < (D+M)*M4) {
+    int indexes[2];
+    compute_lstm_weight_indexes(index, D, M, indexes);
+    cudnnWeight[indexes[1]] = smlWeight[indexes[0]];
   }
   else if(index < (D+M+1)*M4) {
        // Fill bias
+       // bias layout: bi bf bc bo 0 0 0 0
+    // where W: [DxM], R: [MxM] and b: [1x1]
        int tmpIndex = index - (D+M)*M4;
        int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M;
        cudnnWeight[index] = smlBias[smlColIndex];
   }
-  // __syncthreads();
-  if(srcIndex != -1)
-       cudnnWeight[destIndex] = smlWeight[srcIndex];
 }
 
 extern "C" __global__ void prepare_lstm_weight_d(double* smlWeight, double* 
smlBias, double* cudnnWeight, int D, int M) {
@@ -2162,3 +2165,85 @@ extern "C" __global__ void prepare_lstm_output_d(double* 
smlInput, double* cudnn
 extern "C" __global__ void prepare_lstm_output_f(float* smlInput, float* 
cudnnInput, int N, int T, int M, int size) {
   prepare_lstm_output(smlInput, cudnnInput, N, T, M, size);
 }
+
+template <typename T>
+__device__ void prepare_lstm_backward_gradients(T* smlDout, T* cudnnDy, int N, 
int T1, int M, int size, int return_sequences) {
+       int index = blockIdx.x * blockDim.x + threadIdx.x;
+       if(index < size && return_sequences != 0) {
+               // smlDout = [N, T, M]
+               int TM = T1*M;
+               int NT = T1*N;
+               int n = index / TM;
+               int tm = index % TM;
+               int t = tm / M;
+               int m = tm % M;
+               T val = smlDout[index];
+               cudnnDy[t*N*M + n*M + m] = val;
+       }
+       else if(index < size) {
+               // smlDout = [N, T, M]
+               int n = index / M;
+               int m = index % M;
+               T val = smlDout[index];
+               cudnnDy[(T1-1)*N*M + n*M + m] = val;
+       }
+}
+
+
+extern "C" __global__ void prepare_lstm_backward_gradients_d(double* smlInput, 
double* cudnnDy, int N, int T, int M, int size, int return_sequences) {
+  prepare_lstm_backward_gradients(smlInput, cudnnDy, N, T, M, size, 
return_sequences);
+}
+
+extern "C" __global__ void prepare_lstm_backward_gradients_f(float* smlInput, 
float* cudnnDy, int N, int T, int M, int size, int return_sequences) {
+  prepare_lstm_backward_gradients(smlInput, cudnnDy, N, T, M, size, 
return_sequences);
+}
+
+template <typename T>
+__device__ void prepare_lstm_dweight(T* smldWeight, T* smldBias, T* 
cudnndWeight, int D, int M) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  // Maximum (D+M+2)*M4 threads
+  int M4 = M*4;
+  if(index < (D+M)*M4) {
+    int indexes[2];
+    compute_lstm_weight_indexes(index, D, M, indexes);
+    smldWeight[indexes[0]] = cudnndWeight[indexes[1]];
+  }
+  else if(index < (D+M+1)*M4) {
+       // Fill bias
+       // bias layout: bi bf bc bo 0 0 0 0
+    // where W: [DxM], R: [MxM] and b: [1x1]
+       int tmpIndex = index - (D+M)*M4;
+       int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M;
+       smldBias[smlColIndex] = cudnndWeight[index];
+  }
+}
+
+extern "C" __global__ void prepare_lstm_dweight_d(double* smldWeight, double* 
smldBias, double* cudnndWeight, int D, int M) {
+  prepare_lstm_dweight(smldWeight, smldBias, cudnndWeight, D, M);
+}
+
+extern "C" __global__ void prepare_lstm_dweight_f(float* smldWeight, float* 
smldBias, float* cudnndWeight, int D, int M) {
+  prepare_lstm_dweight(smldWeight, smldBias, cudnndWeight, D, M);
+}
+
+template <typename T>
+__device__ void prepare_lstm_dinput(T* smlInput, T* cudnnInput, int N, int D, 
int TD, int size) {
+       int index = blockIdx.x * blockDim.x + threadIdx.x;
+       if(index < size) {
+               int n = index / TD;
+               int td = index % TD;
+               int t = td / D;
+               int d = td % D;
+               smlInput[index] = cudnnInput[t*N*D + n*D + d];
+       }
+}
+
+
+extern "C" __global__ void prepare_lstm_dinput_d(double* smlInput, double* 
cudnnInput, int N, int D, int TD, int size) {
+  prepare_lstm_dinput(smlInput, cudnnInput, N, D, TD, size);
+}
+
+extern "C" __global__ void prepare_lstm_dinput_f(float* smlInput, float* 
cudnnInput, int N, int D, int TD, int size) {
+  prepare_lstm_dinput(smlInput, cudnnInput, N, D, TD, size);
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index ed1a100..9d8b178 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -11794,8 +11794,8 @@ BB99_2:
        .param .u32 prepare_lstm_weight_d_param_4
 )
 {
-       .reg .pred      %p<11>;
-       .reg .b32       %r<53>;
+       .reg .pred      %p<8>;
+       .reg .b32       %r<48>;
        .reg .f64       %fd<3>;
        .reg .b64       %rd<15>;
 
@@ -11803,98 +11803,87 @@ BB99_2:
        ld.param.u64    %rd2, [prepare_lstm_weight_d_param_0];
        ld.param.u64    %rd3, [prepare_lstm_weight_d_param_1];
        ld.param.u64    %rd4, [prepare_lstm_weight_d_param_2];
-       ld.param.u32    %r13, [prepare_lstm_weight_d_param_3];
-       ld.param.u32    %r14, [prepare_lstm_weight_d_param_4];
+       ld.param.u32    %r45, [prepare_lstm_weight_d_param_3];
+       ld.param.u32    %r21, [prepare_lstm_weight_d_param_4];
        cvta.to.global.u64      %rd1, %rd4;
-       mul.lo.s32      %r1, %r14, %r13;
-       shl.b32         %r2, %r1, 2;
-       shl.b32         %r3, %r14, 2;
-       mov.u32         %r15, %ntid.x;
-       mov.u32         %r16, %ctaid.x;
-       mov.u32         %r17, %tid.x;
-       mad.lo.s32      %r4, %r15, %r16, %r17;
-       setp.lt.s32     %p1, %r4, %r2;
-       @%p1 bra        BB100_5;
+       mov.u32         %r22, %ntid.x;
+       mov.u32         %r23, %ctaid.x;
+       mov.u32         %r24, %tid.x;
+       mad.lo.s32      %r1, %r22, %r23, %r24;
+       add.s32         %r2, %r21, %r45;
+       shl.b32         %r3, %r21, 2;
+       mul.lo.s32      %r4, %r2, %r3;
+       setp.lt.s32     %p1, %r1, %r4;
+       @%p1 bra        BB100_3;
        bra.uni         BB100_1;
 
+BB100_3:
+       mul.lo.s32      %r5, %r21, %r45;
+       mul.lo.s32      %r47, %r21, %r21;
+       shl.b32         %r7, %r5, 2;
+       setp.lt.s32     %p5, %r1, %r7;
+       @%p5 bra        BB100_5;
+       bra.uni         BB100_4;
+
 BB100_5:
-       rem.s32         %r42, %r4, %r1;
-       div.s32         %r43, %r42, %r14;
-       div.s32         %r44, %r4, %r1;
-       setp.lt.s32     %p8, %r44, 2;
-       setp.eq.s32     %p9, %r44, 2;
-       selp.b32        %r45, 3, 2, %p9;
-       selp.b32        %r46, %r44, %r45, %p8;
-       rem.s32         %r47, %r42, %r14;
-       sub.s32         %r48, %r4, %r42;
-       add.s32         %r49, %r48, %r43;
-       mad.lo.s32      %r52, %r47, %r13, %r49;
-       mad.lo.s32      %r50, %r43, %r3, %r47;
-       mad.lo.s32      %r51, %r46, %r14, %r50;
+       rem.s32         %r44, %r1, %r5;
+       div.s32         %r42, %r44, %r21;
+       mov.u32         %r43, %r42;
+       mov.u32         %r46, %r1;
+       mov.u32         %r47, %r5;
        bra.uni         BB100_6;
 
 BB100_1:
-       add.s32         %r5, %r14, %r13;
-       mul.lo.s32      %r6, %r5, %r3;
-       setp.lt.s32     %p2, %r4, %r6;
-       @%p2 bra        BB100_4;
-       bra.uni         BB100_2;
-
-BB100_4:
-       mul.lo.s32      %r30, %r14, %r14;
-       sub.s32         %r31, %r4, %r2;
-       rem.s32         %r32, %r31, %r30;
-       div.s32         %r33, %r32, %r14;
-       add.s32         %r34, %r33, %r13;
-       div.s32         %r35, %r31, %r30;
-       setp.lt.s32     %p6, %r35, 2;
-       setp.eq.s32     %p7, %r35, 2;
-       selp.b32        %r36, 3, 2, %p7;
-       selp.b32        %r37, %r35, %r36, %p6;
-       rem.s32         %r38, %r32, %r14;
-       sub.s32         %r39, %r4, %r32;
-       add.s32         %r40, %r39, %r33;
-       mad.lo.s32      %r52, %r38, %r14, %r40;
-       mad.lo.s32      %r41, %r34, %r3, %r38;
-       mad.lo.s32      %r51, %r37, %r14, %r41;
-       bra.uni         BB100_6;
-
-BB100_2:
-       add.s32         %r20, %r5, 1;
-       mul.lo.s32      %r21, %r20, %r3;
-       mov.u32         %r51, -1;
-       setp.ge.s32     %p3, %r4, %r21;
-       @%p3 bra        BB100_6;
+       add.s32         %r25, %r2, 1;
+       mul.lo.s32      %r26, %r25, %r3;
+       setp.ge.s32     %p2, %r1, %r26;
+       @%p2 bra        BB100_7;
 
        cvta.to.global.u64      %rd5, %rd3;
-       sub.s32         %r24, %r4, %r6;
-       div.s32         %r25, %r24, %r14;
-       setp.lt.s32     %p4, %r25, 2;
-       setp.eq.s32     %p5, %r25, 2;
-       selp.b32        %r26, 3, 2, %p5;
-       selp.b32        %r27, %r25, %r26, %p4;
-       rem.s32         %r28, %r24, %r14;
-       mad.lo.s32      %r29, %r27, %r14, %r28;
-       mul.wide.s32    %rd6, %r29, 8;
+       sub.s32         %r27, %r1, %r4;
+       div.s32         %r28, %r27, %r21;
+       setp.lt.s32     %p3, %r28, 2;
+       setp.eq.s32     %p4, %r28, 2;
+       selp.b32        %r29, 3, 2, %p4;
+       selp.b32        %r30, %r28, %r29, %p3;
+       rem.s32         %r31, %r27, %r21;
+       mad.lo.s32      %r32, %r30, %r21, %r31;
+       mul.wide.s32    %rd6, %r32, 8;
        add.s64         %rd7, %rd5, %rd6;
        ld.global.f64   %fd1, [%rd7];
-       mul.wide.s32    %rd8, %r4, 8;
+       mul.wide.s32    %rd8, %r1, 8;
        add.s64         %rd9, %rd1, %rd8;
        st.global.f64   [%rd9], %fd1;
+       bra.uni         BB100_7;
 
-BB100_6:
-       setp.eq.s32     %p10, %r51, -1;
-       @%p10 bra       BB100_8;
+BB100_4:
+       sub.s32         %r46, %r1, %r7;
+       rem.s32         %r44, %r46, %r47;
+       div.s32         %r43, %r44, %r21;
+       add.s32         %r42, %r43, %r45;
+       mov.u32         %r45, %r21;
 
+BB100_6:
        cvta.to.global.u64      %rd10, %rd2;
-       mul.wide.s32    %rd11, %r51, 8;
+       div.s32         %r33, %r46, %r47;
+       setp.eq.s32     %p6, %r33, 2;
+       selp.b32        %r34, 3, 2, %p6;
+       setp.lt.s32     %p7, %r33, 2;
+       selp.b32        %r35, %r33, %r34, %p7;
+       rem.s32         %r36, %r44, %r21;
+       sub.s32         %r37, %r1, %r44;
+       add.s32         %r38, %r37, %r43;
+       mad.lo.s32      %r39, %r36, %r45, %r38;
+       mad.lo.s32      %r40, %r42, %r3, %r36;
+       mad.lo.s32      %r41, %r35, %r21, %r40;
+       mul.wide.s32    %rd11, %r41, 8;
        add.s64         %rd12, %rd10, %rd11;
        ld.global.f64   %fd2, [%rd12];
-       mul.wide.s32    %rd13, %r52, 8;
+       mul.wide.s32    %rd13, %r39, 8;
        add.s64         %rd14, %rd1, %rd13;
        st.global.f64   [%rd14], %fd2;
 
-BB100_8:
+BB100_7:
        ret;
 }
 
@@ -11907,107 +11896,96 @@ BB100_8:
        .param .u32 prepare_lstm_weight_f_param_4
 )
 {
-       .reg .pred      %p<11>;
+       .reg .pred      %p<8>;
        .reg .f32       %f<3>;
-       .reg .b32       %r<53>;
+       .reg .b32       %r<48>;
        .reg .b64       %rd<15>;
 
 
        ld.param.u64    %rd2, [prepare_lstm_weight_f_param_0];
        ld.param.u64    %rd3, [prepare_lstm_weight_f_param_1];
        ld.param.u64    %rd4, [prepare_lstm_weight_f_param_2];
-       ld.param.u32    %r13, [prepare_lstm_weight_f_param_3];
-       ld.param.u32    %r14, [prepare_lstm_weight_f_param_4];
+       ld.param.u32    %r45, [prepare_lstm_weight_f_param_3];
+       ld.param.u32    %r21, [prepare_lstm_weight_f_param_4];
        cvta.to.global.u64      %rd1, %rd4;
-       mul.lo.s32      %r1, %r14, %r13;
-       shl.b32         %r2, %r1, 2;
-       shl.b32         %r3, %r14, 2;
-       mov.u32         %r15, %ntid.x;
-       mov.u32         %r16, %ctaid.x;
-       mov.u32         %r17, %tid.x;
-       mad.lo.s32      %r4, %r15, %r16, %r17;
-       setp.lt.s32     %p1, %r4, %r2;
-       @%p1 bra        BB101_5;
+       mov.u32         %r22, %ntid.x;
+       mov.u32         %r23, %ctaid.x;
+       mov.u32         %r24, %tid.x;
+       mad.lo.s32      %r1, %r22, %r23, %r24;
+       add.s32         %r2, %r21, %r45;
+       shl.b32         %r3, %r21, 2;
+       mul.lo.s32      %r4, %r2, %r3;
+       setp.lt.s32     %p1, %r1, %r4;
+       @%p1 bra        BB101_3;
        bra.uni         BB101_1;
 
+BB101_3:
+       mul.lo.s32      %r5, %r21, %r45;
+       mul.lo.s32      %r47, %r21, %r21;
+       shl.b32         %r7, %r5, 2;
+       setp.lt.s32     %p5, %r1, %r7;
+       @%p5 bra        BB101_5;
+       bra.uni         BB101_4;
+
 BB101_5:
-       rem.s32         %r42, %r4, %r1;
-       div.s32         %r43, %r42, %r14;
-       div.s32         %r44, %r4, %r1;
-       setp.lt.s32     %p8, %r44, 2;
-       setp.eq.s32     %p9, %r44, 2;
-       selp.b32        %r45, 3, 2, %p9;
-       selp.b32        %r46, %r44, %r45, %p8;
-       rem.s32         %r47, %r42, %r14;
-       sub.s32         %r48, %r4, %r42;
-       add.s32         %r49, %r48, %r43;
-       mad.lo.s32      %r52, %r47, %r13, %r49;
-       mad.lo.s32      %r50, %r43, %r3, %r47;
-       mad.lo.s32      %r51, %r46, %r14, %r50;
+       rem.s32         %r44, %r1, %r5;
+       div.s32         %r42, %r44, %r21;
+       mov.u32         %r43, %r42;
+       mov.u32         %r46, %r1;
+       mov.u32         %r47, %r5;
        bra.uni         BB101_6;
 
 BB101_1:
-       add.s32         %r5, %r14, %r13;
-       mul.lo.s32      %r6, %r5, %r3;
-       setp.lt.s32     %p2, %r4, %r6;
-       @%p2 bra        BB101_4;
-       bra.uni         BB101_2;
-
-BB101_4:
-       mul.lo.s32      %r30, %r14, %r14;
-       sub.s32         %r31, %r4, %r2;
-       rem.s32         %r32, %r31, %r30;
-       div.s32         %r33, %r32, %r14;
-       add.s32         %r34, %r33, %r13;
-       div.s32         %r35, %r31, %r30;
-       setp.lt.s32     %p6, %r35, 2;
-       setp.eq.s32     %p7, %r35, 2;
-       selp.b32        %r36, 3, 2, %p7;
-       selp.b32        %r37, %r35, %r36, %p6;
-       rem.s32         %r38, %r32, %r14;
-       sub.s32         %r39, %r4, %r32;
-       add.s32         %r40, %r39, %r33;
-       mad.lo.s32      %r52, %r38, %r14, %r40;
-       mad.lo.s32      %r41, %r34, %r3, %r38;
-       mad.lo.s32      %r51, %r37, %r14, %r41;
-       bra.uni         BB101_6;
-
-BB101_2:
-       add.s32         %r20, %r5, 1;
-       mul.lo.s32      %r21, %r20, %r3;
-       mov.u32         %r51, -1;
-       setp.ge.s32     %p3, %r4, %r21;
-       @%p3 bra        BB101_6;
+       add.s32         %r25, %r2, 1;
+       mul.lo.s32      %r26, %r25, %r3;
+       setp.ge.s32     %p2, %r1, %r26;
+       @%p2 bra        BB101_7;
 
        cvta.to.global.u64      %rd5, %rd3;
-       sub.s32         %r24, %r4, %r6;
-       div.s32         %r25, %r24, %r14;
-       setp.lt.s32     %p4, %r25, 2;
-       setp.eq.s32     %p5, %r25, 2;
-       selp.b32        %r26, 3, 2, %p5;
-       selp.b32        %r27, %r25, %r26, %p4;
-       rem.s32         %r28, %r24, %r14;
-       mad.lo.s32      %r29, %r27, %r14, %r28;
-       mul.wide.s32    %rd6, %r29, 4;
+       sub.s32         %r27, %r1, %r4;
+       div.s32         %r28, %r27, %r21;
+       setp.lt.s32     %p3, %r28, 2;
+       setp.eq.s32     %p4, %r28, 2;
+       selp.b32        %r29, 3, 2, %p4;
+       selp.b32        %r30, %r28, %r29, %p3;
+       rem.s32         %r31, %r27, %r21;
+       mad.lo.s32      %r32, %r30, %r21, %r31;
+       mul.wide.s32    %rd6, %r32, 4;
        add.s64         %rd7, %rd5, %rd6;
        ld.global.f32   %f1, [%rd7];
-       mul.wide.s32    %rd8, %r4, 4;
+       mul.wide.s32    %rd8, %r1, 4;
        add.s64         %rd9, %rd1, %rd8;
        st.global.f32   [%rd9], %f1;
+       bra.uni         BB101_7;
 
-BB101_6:
-       setp.eq.s32     %p10, %r51, -1;
-       @%p10 bra       BB101_8;
+BB101_4:
+       sub.s32         %r46, %r1, %r7;
+       rem.s32         %r44, %r46, %r47;
+       div.s32         %r43, %r44, %r21;
+       add.s32         %r42, %r43, %r45;
+       mov.u32         %r45, %r21;
 
+BB101_6:
        cvta.to.global.u64      %rd10, %rd2;
-       mul.wide.s32    %rd11, %r51, 4;
+       div.s32         %r33, %r46, %r47;
+       setp.eq.s32     %p6, %r33, 2;
+       selp.b32        %r34, 3, 2, %p6;
+       setp.lt.s32     %p7, %r33, 2;
+       selp.b32        %r35, %r33, %r34, %p7;
+       rem.s32         %r36, %r44, %r21;
+       sub.s32         %r37, %r1, %r44;
+       add.s32         %r38, %r37, %r43;
+       mad.lo.s32      %r39, %r36, %r45, %r38;
+       mad.lo.s32      %r40, %r42, %r3, %r36;
+       mad.lo.s32      %r41, %r35, %r21, %r40;
+       mul.wide.s32    %rd11, %r41, 4;
        add.s64         %rd12, %rd10, %rd11;
        ld.global.f32   %f2, [%rd12];
-       mul.wide.s32    %rd13, %r52, 4;
+       mul.wide.s32    %rd13, %r39, 4;
        add.s64         %rd14, %rd1, %rd13;
        st.global.f32   [%rd14], %f2;
 
-BB101_8:
+BB101_7:
        ret;
 }
 
@@ -12463,12 +12441,448 @@ BB105_2:
        ret;
 }
 
+       // .globl       prepare_lstm_backward_gradients_d
+.visible .entry prepare_lstm_backward_gradients_d(
+       .param .u64 prepare_lstm_backward_gradients_d_param_0,
+       .param .u64 prepare_lstm_backward_gradients_d_param_1,
+       .param .u32 prepare_lstm_backward_gradients_d_param_2,
+       .param .u32 prepare_lstm_backward_gradients_d_param_3,
+       .param .u32 prepare_lstm_backward_gradients_d_param_4,
+       .param .u32 prepare_lstm_backward_gradients_d_param_5,
+       .param .u32 prepare_lstm_backward_gradients_d_param_6
+)
+{
+       .reg .pred      %p<5>;
+       .reg .b32       %r<20>;
+       .reg .f64       %fd<3>;
+       .reg .b64       %rd<11>;
+
+
+       ld.param.u64    %rd3, [prepare_lstm_backward_gradients_d_param_0];
+       ld.param.u64    %rd4, [prepare_lstm_backward_gradients_d_param_1];
+       ld.param.u32    %r2, [prepare_lstm_backward_gradients_d_param_2];
+       ld.param.u32    %r3, [prepare_lstm_backward_gradients_d_param_3];
+       ld.param.u32    %r4, [prepare_lstm_backward_gradients_d_param_4];
+       ld.param.u32    %r5, [prepare_lstm_backward_gradients_d_param_5];
+       ld.param.u32    %r6, [prepare_lstm_backward_gradients_d_param_6];
+       cvta.to.global.u64      %rd1, %rd4;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %ctaid.x;
+       mov.u32         %r9, %tid.x;
+       mad.lo.s32      %r1, %r7, %r8, %r9;
+       setp.lt.s32     %p1, %r1, %r5;
+       setp.ne.s32     %p2, %r6, 0;
+       and.pred        %p3, %p1, %p2;
+       cvta.to.global.u64      %rd5, %rd3;
+       mul.wide.s32    %rd6, %r1, 8;
+       add.s64         %rd2, %rd5, %rd6;
+       @%p3 bra        BB106_3;
+       bra.uni         BB106_1;
+
+BB106_3:
+       mul.lo.s32      %r13, %r4, %r3;
+       div.s32         %r14, %r1, %r13;
+       rem.s32         %r15, %r1, %r13;
+       div.s32         %r16, %r15, %r4;
+       rem.s32         %r17, %r15, %r4;
+       ld.global.f64   %fd2, [%rd2];
+       mad.lo.s32      %r18, %r16, %r2, %r14;
+       mad.lo.s32      %r19, %r18, %r4, %r17;
+       mul.wide.s32    %rd9, %r19, 8;
+       add.s64         %rd10, %rd1, %rd9;
+       st.global.f64   [%rd10], %fd2;
+       bra.uni         BB106_4;
+
+BB106_1:
+       setp.ge.s32     %p4, %r1, %r5;
+       @%p4 bra        BB106_4;
+
+       ld.global.f64   %fd1, [%rd2];
+       add.s32         %r10, %r3, -1;
+       mul.lo.s32      %r11, %r10, %r2;
+       mad.lo.s32      %r12, %r11, %r4, %r1;
+       mul.wide.s32    %rd7, %r12, 8;
+       add.s64         %rd8, %rd1, %rd7;
+       st.global.f64   [%rd8], %fd1;
+
+BB106_4:
+       ret;
+}
+
+       // .globl       prepare_lstm_backward_gradients_f
+.visible .entry prepare_lstm_backward_gradients_f(
+       .param .u64 prepare_lstm_backward_gradients_f_param_0,
+       .param .u64 prepare_lstm_backward_gradients_f_param_1,
+       .param .u32 prepare_lstm_backward_gradients_f_param_2,
+       .param .u32 prepare_lstm_backward_gradients_f_param_3,
+       .param .u32 prepare_lstm_backward_gradients_f_param_4,
+       .param .u32 prepare_lstm_backward_gradients_f_param_5,
+       .param .u32 prepare_lstm_backward_gradients_f_param_6
+)
+{
+       .reg .pred      %p<5>;
+       .reg .f32       %f<3>;
+       .reg .b32       %r<20>;
+       .reg .b64       %rd<11>;
+
+
+       ld.param.u64    %rd3, [prepare_lstm_backward_gradients_f_param_0];
+       ld.param.u64    %rd4, [prepare_lstm_backward_gradients_f_param_1];
+       ld.param.u32    %r2, [prepare_lstm_backward_gradients_f_param_2];
+       ld.param.u32    %r3, [prepare_lstm_backward_gradients_f_param_3];
+       ld.param.u32    %r4, [prepare_lstm_backward_gradients_f_param_4];
+       ld.param.u32    %r5, [prepare_lstm_backward_gradients_f_param_5];
+       ld.param.u32    %r6, [prepare_lstm_backward_gradients_f_param_6];
+       cvta.to.global.u64      %rd1, %rd4;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %ctaid.x;
+       mov.u32         %r9, %tid.x;
+       mad.lo.s32      %r1, %r7, %r8, %r9;
+       setp.lt.s32     %p1, %r1, %r5;
+       setp.ne.s32     %p2, %r6, 0;
+       and.pred        %p3, %p1, %p2;
+       cvta.to.global.u64      %rd5, %rd3;
+       mul.wide.s32    %rd6, %r1, 4;
+       add.s64         %rd2, %rd5, %rd6;
+       @%p3 bra        BB107_3;
+       bra.uni         BB107_1;
+
+BB107_3:
+       mul.lo.s32      %r13, %r4, %r3;
+       div.s32         %r14, %r1, %r13;
+       rem.s32         %r15, %r1, %r13;
+       div.s32         %r16, %r15, %r4;
+       rem.s32         %r17, %r15, %r4;
+       ld.global.f32   %f2, [%rd2];
+       mad.lo.s32      %r18, %r16, %r2, %r14;
+       mad.lo.s32      %r19, %r18, %r4, %r17;
+       mul.wide.s32    %rd9, %r19, 4;
+       add.s64         %rd10, %rd1, %rd9;
+       st.global.f32   [%rd10], %f2;
+       bra.uni         BB107_4;
+
+BB107_1:
+       setp.ge.s32     %p4, %r1, %r5;
+       @%p4 bra        BB107_4;
+
+       ld.global.f32   %f1, [%rd2];
+       add.s32         %r10, %r3, -1;
+       mul.lo.s32      %r11, %r10, %r2;
+       mad.lo.s32      %r12, %r11, %r4, %r1;
+       mul.wide.s32    %rd7, %r12, 4;
+       add.s64         %rd8, %rd1, %rd7;
+       st.global.f32   [%rd8], %f1;
+
+BB107_4:
+       ret;
+}
+
+       // .globl       prepare_lstm_dweight_d
+.visible .entry prepare_lstm_dweight_d(
+       .param .u64 prepare_lstm_dweight_d_param_0,
+       .param .u64 prepare_lstm_dweight_d_param_1,
+       .param .u64 prepare_lstm_dweight_d_param_2,
+       .param .u32 prepare_lstm_dweight_d_param_3,
+       .param .u32 prepare_lstm_dweight_d_param_4
+)
+{
+       .reg .pred      %p<8>;
+       .reg .b32       %r<48>;
+       .reg .f64       %fd<3>;
+       .reg .b64       %rd<15>;
+
+
+       ld.param.u64    %rd2, [prepare_lstm_dweight_d_param_0];
+       ld.param.u64    %rd3, [prepare_lstm_dweight_d_param_1];
+       ld.param.u64    %rd4, [prepare_lstm_dweight_d_param_2];
+       ld.param.u32    %r45, [prepare_lstm_dweight_d_param_3];
+       ld.param.u32    %r21, [prepare_lstm_dweight_d_param_4];
+       cvta.to.global.u64      %rd1, %rd4;
+       mov.u32         %r22, %ntid.x;
+       mov.u32         %r23, %ctaid.x;
+       mov.u32         %r24, %tid.x;
+       mad.lo.s32      %r1, %r22, %r23, %r24;
+       add.s32         %r2, %r21, %r45;
+       shl.b32         %r3, %r21, 2;
+       mul.lo.s32      %r4, %r2, %r3;
+       setp.lt.s32     %p1, %r1, %r4;
+       @%p1 bra        BB108_3;
+       bra.uni         BB108_1;
+
+BB108_3:
+       mul.lo.s32      %r5, %r21, %r45;
+       mul.lo.s32      %r47, %r21, %r21;
+       shl.b32         %r7, %r5, 2;
+       setp.lt.s32     %p5, %r1, %r7;
+       @%p5 bra        BB108_5;
+       bra.uni         BB108_4;
+
+BB108_5:
+       rem.s32         %r44, %r1, %r5;
+       div.s32         %r42, %r44, %r21;
+       mov.u32         %r43, %r42;
+       mov.u32         %r46, %r1;
+       mov.u32         %r47, %r5;
+       bra.uni         BB108_6;
+
+BB108_1:
+       add.s32         %r25, %r2, 1;
+       mul.lo.s32      %r26, %r25, %r3;
+       setp.ge.s32     %p2, %r1, %r26;
+       @%p2 bra        BB108_7;
+
+       cvta.to.global.u64      %rd5, %rd3;
+       sub.s32         %r27, %r1, %r4;
+       div.s32         %r28, %r27, %r21;
+       setp.lt.s32     %p3, %r28, 2;
+       setp.eq.s32     %p4, %r28, 2;
+       selp.b32        %r29, 3, 2, %p4;
+       selp.b32        %r30, %r28, %r29, %p3;
+       rem.s32         %r31, %r27, %r21;
+       mad.lo.s32      %r32, %r30, %r21, %r31;
+       mul.wide.s32    %rd6, %r1, 8;
+       add.s64         %rd7, %rd1, %rd6;
+       ld.global.f64   %fd1, [%rd7];
+       mul.wide.s32    %rd8, %r32, 8;
+       add.s64         %rd9, %rd5, %rd8;
+       st.global.f64   [%rd9], %fd1;
+       bra.uni         BB108_7;
+
+BB108_4:
+       sub.s32         %r46, %r1, %r7;
+       rem.s32         %r44, %r46, %r47;
+       div.s32         %r43, %r44, %r21;
+       add.s32         %r42, %r43, %r45;
+       mov.u32         %r45, %r21;
+
+BB108_6:
+       cvta.to.global.u64      %rd10, %rd2;
+       div.s32         %r33, %r46, %r47;
+       setp.eq.s32     %p6, %r33, 2;
+       selp.b32        %r34, 3, 2, %p6;
+       setp.lt.s32     %p7, %r33, 2;
+       selp.b32        %r35, %r33, %r34, %p7;
+       rem.s32         %r36, %r44, %r21;
+       sub.s32         %r37, %r1, %r44;
+       add.s32         %r38, %r37, %r43;
+       mad.lo.s32      %r39, %r36, %r45, %r38;
+       mad.lo.s32      %r40, %r42, %r3, %r36;
+       mad.lo.s32      %r41, %r35, %r21, %r40;
+       mul.wide.s32    %rd11, %r39, 8;
+       add.s64         %rd12, %rd1, %rd11;
+       ld.global.f64   %fd2, [%rd12];
+       mul.wide.s32    %rd13, %r41, 8;
+       add.s64         %rd14, %rd10, %rd13;
+       st.global.f64   [%rd14], %fd2;
+
+BB108_7:
+       ret;
+}
+
+       // .globl       prepare_lstm_dweight_f
+.visible .entry prepare_lstm_dweight_f(
+       .param .u64 prepare_lstm_dweight_f_param_0,
+       .param .u64 prepare_lstm_dweight_f_param_1,
+       .param .u64 prepare_lstm_dweight_f_param_2,
+       .param .u32 prepare_lstm_dweight_f_param_3,
+       .param .u32 prepare_lstm_dweight_f_param_4
+)
+{
+       .reg .pred      %p<8>;
+       .reg .f32       %f<3>;
+       .reg .b32       %r<48>;
+       .reg .b64       %rd<15>;
+
+
+       ld.param.u64    %rd2, [prepare_lstm_dweight_f_param_0];
+       ld.param.u64    %rd3, [prepare_lstm_dweight_f_param_1];
+       ld.param.u64    %rd4, [prepare_lstm_dweight_f_param_2];
+       ld.param.u32    %r45, [prepare_lstm_dweight_f_param_3];
+       ld.param.u32    %r21, [prepare_lstm_dweight_f_param_4];
+       cvta.to.global.u64      %rd1, %rd4;
+       mov.u32         %r22, %ntid.x;
+       mov.u32         %r23, %ctaid.x;
+       mov.u32         %r24, %tid.x;
+       mad.lo.s32      %r1, %r22, %r23, %r24;
+       add.s32         %r2, %r21, %r45;
+       shl.b32         %r3, %r21, 2;
+       mul.lo.s32      %r4, %r2, %r3;
+       setp.lt.s32     %p1, %r1, %r4;
+       @%p1 bra        BB109_3;
+       bra.uni         BB109_1;
+
+BB109_3:
+       mul.lo.s32      %r5, %r21, %r45;
+       mul.lo.s32      %r47, %r21, %r21;
+       shl.b32         %r7, %r5, 2;
+       setp.lt.s32     %p5, %r1, %r7;
+       @%p5 bra        BB109_5;
+       bra.uni         BB109_4;
+
+BB109_5:
+       rem.s32         %r44, %r1, %r5;
+       div.s32         %r42, %r44, %r21;
+       mov.u32         %r43, %r42;
+       mov.u32         %r46, %r1;
+       mov.u32         %r47, %r5;
+       bra.uni         BB109_6;
+
+BB109_1:
+       add.s32         %r25, %r2, 1;
+       mul.lo.s32      %r26, %r25, %r3;
+       setp.ge.s32     %p2, %r1, %r26;
+       @%p2 bra        BB109_7;
+
+       cvta.to.global.u64      %rd5, %rd3;
+       sub.s32         %r27, %r1, %r4;
+       div.s32         %r28, %r27, %r21;
+       setp.lt.s32     %p3, %r28, 2;
+       setp.eq.s32     %p4, %r28, 2;
+       selp.b32        %r29, 3, 2, %p4;
+       selp.b32        %r30, %r28, %r29, %p3;
+       rem.s32         %r31, %r27, %r21;
+       mad.lo.s32      %r32, %r30, %r21, %r31;
+       mul.wide.s32    %rd6, %r1, 4;
+       add.s64         %rd7, %rd1, %rd6;
+       ld.global.f32   %f1, [%rd7];
+       mul.wide.s32    %rd8, %r32, 4;
+       add.s64         %rd9, %rd5, %rd8;
+       st.global.f32   [%rd9], %f1;
+       bra.uni         BB109_7;
+
+BB109_4:
+       sub.s32         %r46, %r1, %r7;
+       rem.s32         %r44, %r46, %r47;
+       div.s32         %r43, %r44, %r21;
+       add.s32         %r42, %r43, %r45;
+       mov.u32         %r45, %r21;
+
+BB109_6:
+       cvta.to.global.u64      %rd10, %rd2;
+       div.s32         %r33, %r46, %r47;
+       setp.eq.s32     %p6, %r33, 2;
+       selp.b32        %r34, 3, 2, %p6;
+       setp.lt.s32     %p7, %r33, 2;
+       selp.b32        %r35, %r33, %r34, %p7;
+       rem.s32         %r36, %r44, %r21;
+       sub.s32         %r37, %r1, %r44;
+       add.s32         %r38, %r37, %r43;
+       mad.lo.s32      %r39, %r36, %r45, %r38;
+       mad.lo.s32      %r40, %r42, %r3, %r36;
+       mad.lo.s32      %r41, %r35, %r21, %r40;
+       mul.wide.s32    %rd11, %r39, 4;
+       add.s64         %rd12, %rd1, %rd11;
+       ld.global.f32   %f2, [%rd12];
+       mul.wide.s32    %rd13, %r41, 4;
+       add.s64         %rd14, %rd10, %rd13;
+       st.global.f32   [%rd14], %f2;
+
+BB109_7:
+       ret;
+}
+
+       // .globl       prepare_lstm_dinput_d
+.visible .entry prepare_lstm_dinput_d(
+       .param .u64 prepare_lstm_dinput_d_param_0,
+       .param .u64 prepare_lstm_dinput_d_param_1,
+       .param .u32 prepare_lstm_dinput_d_param_2,
+       .param .u32 prepare_lstm_dinput_d_param_3,
+       .param .u32 prepare_lstm_dinput_d_param_4,
+       .param .u32 prepare_lstm_dinput_d_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<15>;
+       .reg .f64       %fd<2>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_dinput_d_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_dinput_d_param_1];
+       ld.param.u32    %r2, [prepare_lstm_dinput_d_param_2];
+       ld.param.u32    %r3, [prepare_lstm_dinput_d_param_3];
+       ld.param.u32    %r4, [prepare_lstm_dinput_d_param_4];
+       ld.param.u32    %r5, [prepare_lstm_dinput_d_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB110_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       rem.s32         %r9, %r1, %r4;
+       div.s32         %r10, %r9, %r3;
+       rem.s32         %r11, %r9, %r3;
+       div.s32         %r12, %r1, %r4;
+       mad.lo.s32      %r13, %r10, %r2, %r12;
+       mad.lo.s32      %r14, %r13, %r3, %r11;
+       mul.wide.s32    %rd4, %r14, 8;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f64   %fd1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd1;
+       mul.wide.s32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f64   [%rd8], %fd1;
+
+BB110_2:
+       ret;
+}
+
+       // .globl       prepare_lstm_dinput_f
+.visible .entry prepare_lstm_dinput_f(
+       .param .u64 prepare_lstm_dinput_f_param_0,
+       .param .u64 prepare_lstm_dinput_f_param_1,
+       .param .u32 prepare_lstm_dinput_f_param_2,
+       .param .u32 prepare_lstm_dinput_f_param_3,
+       .param .u32 prepare_lstm_dinput_f_param_4,
+       .param .u32 prepare_lstm_dinput_f_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<2>;
+       .reg .b32       %r<15>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd1, [prepare_lstm_dinput_f_param_0];
+       ld.param.u64    %rd2, [prepare_lstm_dinput_f_param_1];
+       ld.param.u32    %r2, [prepare_lstm_dinput_f_param_2];
+       ld.param.u32    %r3, [prepare_lstm_dinput_f_param_3];
+       ld.param.u32    %r4, [prepare_lstm_dinput_f_param_4];
+       ld.param.u32    %r5, [prepare_lstm_dinput_f_param_5];
+       mov.u32         %r6, %ctaid.x;
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %tid.x;
+       mad.lo.s32      %r1, %r7, %r6, %r8;
+       setp.ge.s32     %p1, %r1, %r5;
+       @%p1 bra        BB111_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       rem.s32         %r9, %r1, %r4;
+       div.s32         %r10, %r9, %r3;
+       rem.s32         %r11, %r9, %r3;
+       div.s32         %r12, %r1, %r4;
+       mad.lo.s32      %r13, %r10, %r2, %r12;
+       mad.lo.s32      %r14, %r13, %r3, %r11;
+       mul.wide.s32    %rd4, %r14, 4;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f32   %f1, [%rd5];
+       cvta.to.global.u64      %rd6, %rd1;
+       mul.wide.s32    %rd7, %r1, 4;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f32   [%rd8], %f1;
+
+BB111_2:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
        .param .b64 __internal_trig_reduction_slowpathd_param_0,
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot106[40];
+       .local .align 8 .b8     __local_depot112[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -12477,7 +12891,7 @@ BB105_2:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot106;
+       mov.u64         %rd101, __local_depot112;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -12491,7 +12905,7 @@ BB105_2:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB106_13;
+       @%p1 bra        BB112_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -12504,7 +12918,7 @@ BB105_2:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB106_4;
+       @%p2 bra        BB112_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -12521,7 +12935,7 @@ BB105_2:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB106_3:
+BB112_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -12551,15 +12965,15 @@ BB106_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB106_3;
+       @%p3 bra        BB112_3;
 
-BB106_4:
+BB112_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB106_6;
+       @%p4 bra        BB112_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -12571,7 +12985,7 @@ BB106_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB106_6:
+BB112_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -12588,7 +13002,7 @@ BB106_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB106_8;
+       @%p6 bra        BB112_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -12608,10 +13022,10 @@ BB106_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB106_8:
+BB112_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB106_10;
+       @%p7 bra        BB112_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -12619,7 +13033,7 @@ BB106_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB106_10:
+BB112_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -12640,7 +13054,7 @@ BB106_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB106_12;
+       @%p8 bra        BB112_12;
 
        // inline asm
        {
@@ -12659,7 +13073,7 @@ BB106_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB106_12:
+BB112_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -12674,7 +13088,7 @@ BB106_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB106_13:
+BB112_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -12702,7 +13116,7 @@ BB106_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB107_2;
+       @%p1 bra        BB113_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -12716,13 +13130,13 @@ BB106_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB107_2:
+BB113_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB107_4;
+       @%p2 bra        BB113_4;
 
        {
        .reg .b32 %temp; 
@@ -12736,7 +13150,7 @@ BB107_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB107_4:
+BB113_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -12899,13 +13313,13 @@ BB107_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB107_7;
+       @%p4 bra        BB113_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB107_7;
+       @%p6 bra        BB113_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -12927,26 +13341,26 @@ BB107_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB107_7:
+BB113_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB107_9;
+       @%p7 bra        BB113_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB107_10;
+       @%p8 bra        BB113_10;
 
-BB107_9:
+BB113_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB107_10:
+BB113_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 428f357..64963d9 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -169,7 +169,7 @@ public class FunctionOp extends Hop
                                long outputValues = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
                                return outputVectors+outputValues; 
                        }
-                       else if ( getFunctionName().equalsIgnoreCase("lstm") ) {
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward")  ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
@@ -218,7 +218,7 @@ public class FunctionOp extends Hop
                        else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
                                return 0; 
                        }
-                       else if ( getFunctionName().equalsIgnoreCase("lstm") ) {
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
 getFunctionName().equalsIgnoreCase("lstm_backward") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
@@ -239,7 +239,8 @@ public class FunctionOp extends Hop
        
        @Override
        public boolean isGPUEnabled() {
-               if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) 
+               if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward") ||  
+                       getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) 
                        return true;
                else
                        return false;
@@ -288,7 +289,7 @@ public class FunctionOp extends Hop
                                        || (getMemEstimate() >= 
OptimizerUtils.getLocalMemBudget()
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
                        }
-                       else if( getFunctionName().equalsIgnoreCase("lstm")) {
+                       else if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward")) {
                                if(!DMLScript.USE_ACCELERATOR)
                                        throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
                                _etype = ExecType.GPU;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 3514e6b..3ca8e1d 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -102,6 +102,15 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
        public Expression getSixthExpr() {
                return (_args.length >= 6 ? _args[5] : null);
        }
+       
+       public Expression getSeventhExpr() {
+               return (_args.length >= 7 ? _args[6] : null);
+       }
+
+       public Expression getEighthExpr() {
+               return (_args.length >= 8 ? _args[7] : null);
+       }
+
 
        public Expression[] getAllExpr(){
                return _args;
@@ -210,13 +219,12 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        checkMatrixParam(getFifthExpr());
                        
                        // setup output properties
-                       if(getOutputs() == null || getOutputs().length != 3) {
+                       if(getOutputs() == null || getOutputs().length != 2) {
                                int numOutputs = getOutputs() == null ? 0 : 
getOutputs().length;
-                               raiseValidateError("The builtin function lstm 
has three outputs, but instead found: " + numOutputs, conditional);
+                               raiseValidateError("The builtin function lstm 
has two outputs, but instead found: " + numOutputs, conditional);
                        }
                        DataIdentifier out = (DataIdentifier) getOutputs()[0];
                        DataIdentifier cy = (DataIdentifier) getOutputs()[1];
-                       DataIdentifier reserveSpace = (DataIdentifier) 
getOutputs()[2];
                        
                        // Output1 - out: If `return_sequences` is True, 
outputs for all timesteps, else outputs for the final timestep.
                        out.setDataType(DataType.MATRIX);
@@ -230,12 +238,36 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        cy.setDimensions(getExpr(4).getOutput().getDim1(), 
getExpr(4).getOutput().getDim2());
                        
cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), 
getExpr(4).getOutput().getColumnsInBlock());
                        
-                       // Output3 - reserve space.
-                       reserveSpace.setDataType(DataType.MATRIX);
-                       reserveSpace.setValueType(ValueType.DOUBLE);
-                       reserveSpace.setDimensions(-1, -1);
-                       
reserveSpace.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), 
getFirstExpr().getOutput().getColumnsInBlock());
+                       break;
+               }
+               case LSTM_BACKWARD:
+               {
+                       // Input: X, W, b, out0, c0, return_sequences, dout, cy
+                       checkNumParameters(8);
+                       checkMatrixParam(getFirstExpr());
+                       checkMatrixParam(getSecondExpr());
+                       checkMatrixParam(getThirdExpr());
+                       checkMatrixParam(getFourthExpr());
+                       checkMatrixParam(getFifthExpr());
+                       checkMatrixParam(getSeventhExpr());
+                       checkMatrixParam(getEighthExpr());
                        
+                       // Output: dx, dw, db, dout0, dc0
+                       // setup output properties
+                       if(getOutputs().length != 5)
+                               raiseValidateError("lstm_backward has 5 
outputs", false);
+                        
+                       DataIdentifier dx = (DataIdentifier) getOutputs()[0];
+                       DataIdentifier dw = (DataIdentifier) getOutputs()[1];
+                       DataIdentifier db = (DataIdentifier) getOutputs()[2];
+                       DataIdentifier dout0 = (DataIdentifier) getOutputs()[3];
+                       DataIdentifier dc0 = (DataIdentifier) getOutputs()[4];
+                       
+                       setDimensions(dx, getFirstExpr());
+                       setDimensions(dw, getSecondExpr());
+                       setDimensions(db, getThirdExpr());
+                       setDimensions(dout0, getFourthExpr());
+                       setDimensions(dc0, getFifthExpr());
                        break;
                }
                case BATCH_NORM2D:
@@ -1414,7 +1446,8 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                                // always unconditional (because unsupported 
operation)
                                BuiltinFunctionOp op = getOpCode();
                                if( op==BuiltinFunctionOp.EIGEN || 
op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || 
op==BuiltinFunctionOp.SVD 
-                                               || op==BuiltinFunctionOp.LSTM 
|| op==BuiltinFunctionOp.BATCH_NORM2D || 
op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
+                                               || op==BuiltinFunctionOp.LSTM 
|| op==BuiltinFunctionOp.LSTM_BACKWARD
+                                               || 
op==BuiltinFunctionOp.BATCH_NORM2D || 
op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD)
                                        raiseValidateError("Function "+op+" 
needs to be called with multi-return assignment.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                else
                                        raiseValidateError("Unsupported 
function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
@@ -1496,6 +1529,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case LU:
                case EIGEN:
                case LSTM:
+               case LSTM_BACKWARD:
                case BATCH_NORM2D:
                case BATCH_NORM2D_BACKWARD:
                case SVD:
@@ -1624,7 +1658,8 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        raiseValidateError("Missing argument for function " + 
this.getOpCode(), false,
                                        LanguageErrorCodes.INVALID_PARAMETERS);
                }
-
+               
+               // Not sure the rationale for the first two if loops, but will 
keep them for backward compatibility
                if (((count == 1) && (getSecondExpr() != null || getThirdExpr() 
!= null))
                                || ((count == 2) && (getThirdExpr() != null))) {
                        raiseValidateError("Invalid number of arguments for 
function " + this.getOpCode().toString().toLowerCase()
@@ -1633,6 +1668,9 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                                || ((count == 3) && (getSecondExpr() == null || 
getThirdExpr() == null))) {
                        raiseValidateError("Missing argument for function " + 
this.getOpCode(), false,
                                        LanguageErrorCodes.INVALID_PARAMETERS);
+               } else if(count > 0 && (_args == null || _args.length < count)) 
{
+                       raiseValidateError("Missing argument for function " + 
this.getOpCode(), false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                }
        }
 
@@ -1909,6 +1947,8 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        bifop = Expression.BuiltinFunctionOp.EIGEN;
                else if (functionName.equals("lstm"))
                        bifop = Expression.BuiltinFunctionOp.LSTM;
+               else if (functionName.equals("lstm_backward"))
+                       bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD;
                else if (functionName.equals("batch_norm2d"))
                        bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D;
                else if (functionName.equals("batch_norm2d_backward"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 102ca54..d29a8f4 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2225,6 +2225,7 @@ public class DMLTranslator
                case LU:
                case EIGEN:
                case LSTM:
+               case LSTM_BACKWARD:
                case BATCH_NORM2D:
                case BATCH_NORM2D_BACKWARD:
                case SVD:

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index 2808609..1708ed3 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -92,7 +92,7 @@ public abstract class Expression implements ParseInfo
                EXISTS,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, 
BIAS_MULTIPLY,
                MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
-               LSTM, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
+               LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD,
                EXP,
                FLOOR,
                IFELSE,

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 049577b..8e9bb47 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -57,6 +57,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "bias_multiply",          
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "channel_sums",          
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "lstm",                  
GPUINSTRUCTION_TYPE.Convolution);
+               String2GPUInstructionType.put( "lstm_backward",         
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "batch_norm2d",           
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "batch_norm2d_backward",  
GPUINSTRUCTION_TYPE.Convolution);
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index e523a45..7d8a0fc 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -73,7 +73,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction 
{
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, 
-                       CPOperand out, CPOperand out2, CPOperand out3, String 
opcode, String istr, 
+                       CPOperand out, CPOperand out2, String opcode, String 
istr, 
                        double intermediateMemoryBudget) throws 
DMLRuntimeException {
                super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
                _input1 = in1;
@@ -85,7 +85,6 @@ public class ConvolutionGPUInstruction extends GPUInstruction 
{
                _gputype = GPUINSTRUCTION_TYPE.Convolution;
                _output = out;
                _output2 = out2;
-               _output3 = out3;
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
@@ -283,7 +282,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        return new ConvolutionGPUInstruction(in, in2, in3, out, 
opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm")) {
-                       InstructionUtils.checkNumFields(parts, 9);
+                       InstructionUtils.checkNumFields(parts, 8);
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand in3 = new CPOperand(parts[3]);
@@ -292,10 +291,9 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        CPOperand in6 = new CPOperand(parts[6]);
                        CPOperand out = new CPOperand(parts[7]);
                        CPOperand out2 = new CPOperand(parts[8]);
-                       CPOperand out3 = new CPOperand(parts[9]);
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
in4, in5, in6, out, out2, out3, opcode, str, 0);
+                       return new ConvolutionGPUInstruction(in1, in2, in3, 
in4, in5, in6, out, out2, opcode, str, 0);
                }
-               else if (opcode.equalsIgnoreCase("batch_norm2d")) {
+               else if (opcode.equalsIgnoreCase("batch_norm2d") || 
opcode.equalsIgnoreCase("lstm_backward")) {
                        InstructionUtils.checkNumFields(parts, 13);
                        CPOperand in1 = new CPOperand(parts[1]); // image
                        CPOperand in2 = new CPOperand(parts[2]); // scale
@@ -471,6 +469,66 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
 //             return tX;
 //     }
        
+       private void processLstmBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               GPUContext gCtx = ec.getGPUContext(0);
+               String instructionName = getExtendedOpcode();
+               
+               MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
+               int M = toInt(out0.getNumColumns()); // hiddenSize .. since 
out0: (N, M)
+               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instructionName);
+               
+               MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
+               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
+               long numRowsW = W.getNumRows();
+               int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... 
numFeatures 
+               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
+               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
+               Pointer cudnnWPointer = gCtx.allocate(instructionName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+                               
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
+                               sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
+               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+               
+               
+               MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
+               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instructionName); 
+               int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
+               long numColsX = X.getNumColumns();
+               int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
+               Pointer cudnnInput = gCtx.allocate(instructionName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
+                               xPointer, cudnnInput, N, D, T*D, N*T*D);
+               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               
+               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
+               boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
+               
+               // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, 
+                               // cudnnInput, cudnnWPointer, out0Pointer, 
c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
+                               // String xName, Pointer hx, Pointer cx, 
Pointer wPointer, String doutName, String dcyName,  // input
+                               // String dxName, String dwName, String dbName, 
String dhxName, String dcxName,         // output
+               String dxName = _output.getName();
+               String dwName = _output2.getName();
+               String dbName = _output3.getName();
+               String dhxName = _output4.getName();
+               String dcxName = _output5.getName();
+               String doutName = _input7.getName();
+               String dcyName = _input8.getName();
+               LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, 
+                               cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
+                               dxName, dwName, dbName, dhxName, dcxName, // 
output 
+                               return_sequences, N, M, D, T);
+               gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
DMLScript.EAGER_CUDA_FREE);
+               gCtx.cudaFreeHelper(instructionName, cudnnInput, 
DMLScript.EAGER_CUDA_FREE);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
+       }
+       
        private void processLstmInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
                // input  X:(N, T*D),   ==> (T, D, N)
@@ -496,6 +554,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                                
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
                                sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
                ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
                
                boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
                
@@ -513,17 +572,15 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                
                Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); 
                
-               LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), _output3.getName(), N, M, D, T);
+               LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), N, M, D, T);
                gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
DMLScript.EAGER_CUDA_FREE);
                gCtx.cudaFreeHelper(instructionName, cudnnInput, 
DMLScript.EAGER_CUDA_FREE);
                
                // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
                ec.releaseMatrixInputForGPUInstruction(_input4.getName());
                ec.releaseMatrixInputForGPUInstruction(_input5.getName());
                ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
        }
        
        @Override
@@ -544,6 +601,10 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        processLstmInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
+                       processLstmBackwardInstruction(ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
                        processBatchNorm2dInstruction(ec);
                        return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 576584b..328d1d4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -739,9 +739,9 @@ public class GPUObject {
                long rows = mat.getNumRows();
                long cols = mat.getNumColumns();
                if(rows <= 0)
-                       throw new DMLRuntimeException("Internal error - invalid 
number of rows when allocating dense matrix");
+                       throw new DMLRuntimeException("Internal error - invalid 
number of rows when allocating dense matrix:" + rows);
                if(cols <= 0)
-                       throw new DMLRuntimeException("Internal error - invalid 
number of columns when allocating dense matrix;");
+                       throw new DMLRuntimeException("Internal error - invalid 
number of columns when allocating dense matrix:" + cols);
                long size = getDatatypeSizeOf(rows * cols);
                Pointer tmp = allocate(size);
                setDensePointer(tmp);

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index e84dce7..a692739 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -846,6 +846,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                }
        }
        
+       static Pointer getDenseInputPointer(ExecutionContext ec, GPUContext 
gCtx, String instName, String inputName,
+                       long numRows, long numCols) throws DMLRuntimeException {
+               MatrixObject output = 
ec.getMatrixInputForGPUInstruction(inputName, instName);
+               return LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, 
instName, toInt(numRows), toInt(numCols));
+       }
+       
        static Pointer getDenseOutputPointer(ExecutionContext ec, GPUContext 
gCtx, String instName, String outputName,
                        long numRows, long numCols) throws DMLRuntimeException {
                MatrixObject output = ec.getMatrixObject(outputName);
@@ -867,7 +873,6 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param return_sequences Whether to return `out` at all timesteps, or 
just for the final timestep.
         * @param outputName name of the out variable. If `return_sequences` is 
True, outputs for all timesteps.
         * @param cyName name of the output cell state. Cell state for final 
timestep.
-        * @param reserveSpaceName name of reserve space.
         * @param N minibatch size
         * @param M hidden size
         * @param D number of features
@@ -876,13 +881,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         */
        public static void lstm(ExecutionContext ec, GPUContext gCtx, String 
instName,
                        Pointer X,  Pointer wPointer, Pointer out0, Pointer c0, 
boolean return_sequences,
-                       String outputName, String cyName, String 
reserveSpaceName, int N, int M, int D, int T) throws DMLRuntimeException {
-               singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, reserveSpaceName, "lstm", 
return_sequences, N, M, D, T);
+                       String outputName, String cyName, int N, int M, int D, 
int T) throws DMLRuntimeException {
+               singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, 
out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
        }
        
        private static void 
singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, 
String instName,
                        Pointer x, Pointer hx, Pointer cx, Pointer wPointer,  
// input
-                       String outputName, String cyName, String 
reserveSpaceName,  // output
+                       String outputName, String cyName,                       
                 // output
                        String rnnMode, boolean return_sequences, int N, int M, 
int D, int T) throws DMLRuntimeException {
                boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
                // Get output pointers
@@ -891,8 +896,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, 
instName, cyName, N, M) : new Pointer();
                // Pointer wPointer = getDensePointerForCuDNN(gCtx, w, 
instName, D+M+2, 4*M);
                
-               try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, 
wPointer, reserveSpaceName)) {
-                       jcuda.runtime.JCuda.cudaDeviceSynchronize();
+               try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, 
wPointer)) {
                        JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), 
algo.rnnDesc, T, 
                                        algo.xDesc, x, 
                                        algo.hxDesc, hx, 
@@ -915,6 +919,88 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                gCtx.cudaFreeHelper(instName, cudnnYPointer, 
DMLScript.EAGER_CUDA_FREE);
        }
        
+       public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, 
String instName,
+                       Pointer x, Pointer hx, Pointer cx, Pointer wPointer, 
String doutName, String dcyName,  // input
+                       String dxName, String dwName, String dbName, String 
dhxName, String dcxName,    // output
+                       boolean return_sequences, int N, int M, int D, int T) 
throws DMLRuntimeException {
+               // Transform the input dout and prepare them for 
cudnnRNNBackwardData
+               Pointer dy = gCtx.allocate(instName, N*T*M*sizeOfDataType);
+               int size = return_sequences ? N*T*M : N*M;
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_backward_gradients",
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(size),
+                               getDenseInputPointer(ec, gCtx, instName, 
doutName, N, return_sequences ? T*M : M),
+                               dy, N, T, M, size, return_sequences ? 1 : 0);
+               ec.releaseMatrixInputForGPUInstruction(doutName);
+                               
+               // Allocate intermediate pointers computed by forward
+               Pointer yPointer = gCtx.allocate(instName, 
N*T*M*sizeOfDataType);
+               try(LibMatrixCuDNNRnnAlgorithm algo = new 
LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", N, T, M, D, true, 
wPointer)) {
+                       JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), 
algo.rnnDesc, T, 
+                                       algo.xDesc, x, 
+                                       algo.hxDesc, hx, 
+                                       algo.cxDesc, cx, 
+                                       algo.wDesc, wPointer, 
+                                       algo.yDesc, yPointer, 
+                                       algo.hyDesc, new Pointer(), 
+                                       algo.cyDesc, new Pointer(), 
+                                       algo.workSpace, algo.sizeInBytes, 
+                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
+                       
+                       Pointer cudnnDx = gCtx.allocate(instName, 
N*T*D*LibMatrixCUDA.sizeOfDataType);
+                       JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), 
algo.rnnDesc, T, 
+                                       algo.yDesc, yPointer,
+                                       // ----------------------
+                                       // Additional inputs:
+                                       algo.dyDesc, dy, 
+                                       algo.dhyDesc, new Pointer(), 
+                                       algo.dcyDesc, getDenseInputPointer(ec, 
gCtx, instName, dcyName, N, M),
+                                       // ----------------------
+                                       algo.wDesc, wPointer, 
+                                       algo.hxDesc, hx,
+                                       algo.cxDesc, cx,
+                                       // ----------------------
+                                       // Output:
+                                       algo.dxDesc, cudnnDx, 
+                                       algo.dhxDesc, getDenseOutputPointer(ec, 
gCtx, instName, dhxName, N, M), 
+                                       algo.dcxDesc, getDenseOutputPointer(ec, 
gCtx, instName, dcxName, N, M),
+                                       // ----------------------
+                                       algo.workSpace, algo.sizeInBytes, 
+                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
+                       gCtx.cudaFreeHelper(instName, dy, 
DMLScript.EAGER_CUDA_FREE);
+                       ec.releaseMatrixInputForGPUInstruction(dcyName);
+                       ec.releaseMatrixOutputForGPUInstruction(dhxName);
+                       ec.releaseMatrixOutputForGPUInstruction(dcxName);
+                       
+                       Pointer smlDx = getDenseOutputPointer(ec, gCtx, 
instName, dxName, N, T*D);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
+                                       smlDx, cudnnDx, N, D, T*D, N*T*D);
+                       ec.releaseMatrixOutputForGPUInstruction(dxName);
+                       
+                       // 
-------------------------------------------------------------------------------------------
+                       Pointer cudnnDwPointer = gCtx.allocate(instName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+                       JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), 
algo.rnnDesc, T, 
+                                       algo.xDesc, x, 
+                                       algo.hxDesc, hx, 
+                                       algo.yDesc, yPointer, 
+                                       algo.workSpace, algo.sizeInBytes, 
+                                       algo.dwDesc, cudnnDwPointer, 
+                                       algo.reserveSpace, 
algo.reserveSpaceSizeInBytes);
+                       
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
+                                       
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
+                                       getDenseOutputPointer(ec, gCtx, 
instName, dwName, D+M, 4*M), 
+                                       getDenseOutputPointer(ec, gCtx, 
instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
+                       gCtx.cudaFreeHelper(instName, cudnnDwPointer, 
DMLScript.EAGER_CUDA_FREE);
+                       ec.releaseMatrixOutputForGPUInstruction(dwName);
+                       ec.releaseMatrixOutputForGPUInstruction(dbName);
+                       // 
-------------------------------------------------------------------------------------------
+                       
+                       gCtx.cudaFreeHelper(instName, yPointer, 
DMLScript.EAGER_CUDA_FREE);
+               }
+       }
+       
+       
+       
        /**
         * Performs the forward BatchNormalization layer computation for 
training
         * @param gCtx   a valid {@link GPUContext}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
index d772a55..68d308e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
@@ -49,27 +49,36 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
        String instName;
        cudnnDropoutDescriptor dropoutDesc;
        cudnnRNNDescriptor rnnDesc;
-       cudnnTensorDescriptor[] xDesc, yDesc; // of length T
-       cudnnTensorDescriptor hxDesc, cxDesc, hyDesc, cyDesc; 
+       cudnnTensorDescriptor[] xDesc, dxDesc, yDesc, dyDesc; // of length T
+       cudnnTensorDescriptor hxDesc, cxDesc, hyDesc, cyDesc, dhxDesc, dcxDesc, 
dhyDesc, dcyDesc; 
        cudnnFilterDescriptor wDesc;
+       cudnnFilterDescriptor dwDesc;
        long sizeInBytes; Pointer workSpace;
        long reserveSpaceSizeInBytes; Pointer reserveSpace;
        public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, 
String instName, 
-                       String rnnMode, int N, int T, int M, int D, boolean 
isTraining, Pointer w, String reserveSpaceName) throws DMLRuntimeException {
+                       String rnnMode, int N, int T, int M, int D, boolean 
isTraining, Pointer w) throws DMLRuntimeException {
                this.gCtx = gCtx;
                this.instName = instName;
                
                // Allocate input/output descriptors
                xDesc = new cudnnTensorDescriptor[T];
+               dxDesc = new cudnnTensorDescriptor[T];
                yDesc = new cudnnTensorDescriptor[T];
+               dyDesc = new cudnnTensorDescriptor[T];
                for(int t = 0; t < T; t++) {
                        xDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
+                       dxDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
                        yDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
+                       dyDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
                }
                hxDesc = allocateTensorDescriptorWithStride(1, N, M); 
+               dhxDesc = allocateTensorDescriptorWithStride(1, N, M);
                cxDesc = allocateTensorDescriptorWithStride(1, N, M);
+               dcxDesc = allocateTensorDescriptorWithStride(1, N, M);
                hyDesc = allocateTensorDescriptorWithStride(1, N, M);
+               dhyDesc = allocateTensorDescriptorWithStride(1, N, M);
                cyDesc = allocateTensorDescriptorWithStride(1, N, M);
+               dcyDesc = allocateTensorDescriptorWithStride(1, N, M);
                
                // Initial dropout descriptor
                dropoutDesc = new cudnnDropoutDescriptor();
@@ -94,6 +103,7 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                        throw new DMLRuntimeException("Incorrect number of RNN 
parameters " +  (D+M+2)*4*M + " != " +  expectedNumWeights + ", where 
numFeatures=" + D + ", hiddenSize=" + M);
                }
                wDesc = allocateFilterDescriptor(expectedNumWeights);
+               dwDesc = allocateFilterDescriptor(expectedNumWeights);
                
                // Setup workspace
                workSpace = new Pointer(); reserveSpace = new Pointer();
@@ -104,12 +114,11 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                if(isTraining) {
                        reserveSpaceSizeInBytes = getReservespaceSize(T);
                        if (reserveSpaceSizeInBytes != 0) {
-                               int numCols =  (int) 
Math.ceil(((double)reserveSpaceSizeInBytes) / LibMatrixCUDA.sizeOfDataType);
-                               reserveSpace = 
LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, reserveSpaceName, 1, 
numCols);
+                               reserveSpace = 
gCtx.allocate(reserveSpaceSizeInBytes);
                        }
                }
                if (reserveSpaceSizeInBytes == 0) {
-                       reserveSpace = LibMatrixCuDNN.getDenseOutputPointer(ec, 
gCtx, instName, reserveSpaceName, 1, 1);
+                       reserveSpace = gCtx.allocate(reserveSpaceSizeInBytes);
                }
                
                /*
@@ -241,30 +250,57 @@ public class LibMatrixCuDNNRnnAlgorithm implements 
java.lang.AutoCloseable {
                if(hxDesc != null)
                        cudnnDestroyTensorDescriptor(hxDesc);
                hxDesc = null;
+               if(dhxDesc != null)
+                       cudnnDestroyTensorDescriptor(dhxDesc);
+               dhxDesc = null;
                if(hyDesc != null)
                        cudnnDestroyTensorDescriptor(hyDesc);
                hyDesc = null;
+               if(dhyDesc != null)
+                       cudnnDestroyTensorDescriptor(dhyDesc);
+               dhyDesc = null;
                if(cxDesc != null)
                        cudnnDestroyTensorDescriptor(cxDesc);
                cxDesc = null;
+               if(dcxDesc != null)
+                       cudnnDestroyTensorDescriptor(dcxDesc);
+               dcxDesc = null;
                if(cyDesc != null)
                        cudnnDestroyTensorDescriptor(cyDesc);
                cyDesc = null;
+               if(dcyDesc != null)
+                       cudnnDestroyTensorDescriptor(dcyDesc);
+               dcyDesc = null;
                if(wDesc != null)
                        cudnnDestroyFilterDescriptor(wDesc);
                wDesc = null;
+               if(dwDesc != null)
+                       cudnnDestroyFilterDescriptor(dwDesc);
+               dwDesc = null;
                if(xDesc != null) {
                        for(cudnnTensorDescriptor dsc : xDesc) {
                                cudnnDestroyTensorDescriptor(dsc);
                        }
                        xDesc = null;
                }
+               if(dxDesc != null) {
+                       for(cudnnTensorDescriptor dsc : dxDesc) {
+                               cudnnDestroyTensorDescriptor(dsc);
+                       }
+                       dxDesc = null;
+               }
                if(yDesc != null) {
                        for(cudnnTensorDescriptor dsc : yDesc) {
                                cudnnDestroyTensorDescriptor(dsc);
                        }
                        yDesc = null;
                }
+               if(dyDesc != null) {
+                       for(cudnnTensorDescriptor dsc : dyDesc) {
+                               cudnnDestroyTensorDescriptor(dsc);
+                       }
+                       dyDesc = null;
+               }
                if(sizeInBytes != 0) {
                        try {
                                gCtx.cudaFreeHelper(instName, workSpace, 
DMLScript.EAGER_CUDA_FREE);

Reply via email to