This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new b4ef84b  [SYSTEMML-540] Added an initial CP operator for lstm_backward 
builtin function
b4ef84b is described below

commit b4ef84ba2568dc96fe20d1a45faeb4c1bd2d47b4
Author: Niketan Pansare <[email protected]>
AuthorDate: Tue Mar 5 12:52:11 2019 -0800

    [SYSTEMML-540] Added an initial CP operator for lstm_backward builtin 
function
---
 scripts/nn/layers/lstm_staging.dml                 |   6 -
 .../java/org/apache/sysml/hops/FunctionOp.java     |   9 +-
 .../runtime/instructions/cp/DnnCPInstruction.java  | 104 +++++++
 .../instructions/gpu/DnnGPUInstruction.java        |  42 +--
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    | 341 ++++++++++++++++++++-
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     | 189 ++++++++----
 6 files changed, 575 insertions(+), 116 deletions(-)

diff --git a/scripts/nn/layers/lstm_staging.dml 
b/scripts/nn/layers/lstm_staging.dml
index 886b88c..2f71f22 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -92,12 +92,6 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *      Note: This is *optional* and could just be an empty matrix.
    *  - c0: Initial cell state, of shape (N, M).
    *      Note: This is *optional* and could just be an empty matrix.
-   *  - cache_out: Cache of outputs, of shape (T, N*M).
-   *      Note: This is used for performance during training.
-   *  - cache_c: Cache of cell state, of shape (T, N*M).
-   *      Note: This is used for performance during training.
-   *  - cache_ifog: Cache of intermediate values, of shape (T, N*4*M).
-   *      Note: This is used for performance during training.
    *
    * Outputs:
    *  - dX: Gradient wrt `X`, of shape (N, T*D).
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 5fdc8e7..dedbad6 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -282,13 +282,6 @@ public class FunctionOp extends MultiThreadedHop
                
                if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && 
isBuiltinFunction() &&
                        (getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
-                       
-                       if(getFunctionName().equalsIgnoreCase("lstm_backward")) 
{
-                               if(!ConfigurationManager.isGPU())
-                                       throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
-                               _etype = ExecType.GPU;
-                       }
-                       
                        ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() 
? ExecType.SPARK : ExecType.MR;
                        
                        if( _etypeForced != null ) {
@@ -306,7 +299,7 @@ public class FunctionOp extends MultiThreadedHop
                                checkAndSetInvalidCPDimsAndSize();
                        }
                        
-                       // Since lstm builtin functions are not supported on 
Spark
+                       // Since lstm builtin functions are not supported on 
Spark or MR.
                        _etype = _etype == REMOTE ?  ExecType.CP : _etype;
                        
                        //mark for recompile (forever)
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 93ffd4f..50a11de 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -274,6 +274,24 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        int numThreads = Integer.parseInt(parts[9]);
                        return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
                }
+               else if (opcode.equalsIgnoreCase("lstm_backward")) {
+                       InstructionUtils.checkNumFields(parts, 14);
+                       CPOperand in1 = new CPOperand(parts[1]); // X
+                       CPOperand in2 = new CPOperand(parts[2]); // W
+                       CPOperand in3 = new CPOperand(parts[3]); // b
+                       CPOperand in4 = new CPOperand(parts[4]); // out0
+                       CPOperand in5 = new CPOperand(parts[5]); // c0
+                       CPOperand in6 = new CPOperand(parts[6]); // return_seq
+                       CPOperand in7 = new CPOperand(parts[7]); // dout
+                       CPOperand in8 = new CPOperand(parts[8]); // dc
+                       CPOperand out = new CPOperand(parts[9]);  // dX
+                       CPOperand out2 = new CPOperand(parts[10]); // dW
+                       CPOperand out3 = new CPOperand(parts[11]); // db
+                       CPOperand out4 = new CPOperand(parts[12]); // dout0
+                       CPOperand out5 = new CPOperand(parts[13]); // dc0
+                       int numThreads = Integer.parseInt(parts[14]);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnCPInstruction: " + str);
                }
@@ -329,6 +347,88 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode());
        }
        
+       public void processLstmBackwardInstruction(ExecutionContext ec) {
+               MatrixBlock X = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock W = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock b = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               MatrixBlock out0 = ec.getMatrixInput(_in4.getName(), 
getExtendedOpcode());
+               MatrixBlock c0 = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               boolean return_seq = ec.getScalarInput(_in6.getName(), 
_in6.getValueType(), _in6.isLiteral()).getBooleanValue();
+               MatrixBlock dout = ec.getMatrixInput(_in7.getName(), 
getExtendedOpcode());
+               MatrixBlock dc = ec.getMatrixInput(_in8.getName(), 
getExtendedOpcode());
+               
+               int N = X.getNumRows();
+               int TD = X.getNumColumns();
+               int DPlusM = W.getNumRows();
+               int M = W.getNumColumns() / 4;
+               int D = DPlusM - M;
+               int T = TD / D;
+               if(b.getNumRows() != 1 || b.getNumColumns() != M*4) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
bias in lstm_backward instruction. Expected [1, " + (M*4) + "], "
+                                       + "but found [" + b.getNumRows() + "," 
+ b.getNumColumns() + "]");
+               }
+               if(out0.getNumRows() != N) {
+                       throw new DMLRuntimeException("Unsupported operation: 
The batch size of previous iteration " + out0.getNumRows() + 
+                                       " is different than the batch size of 
current iteration " + N);
+               }
+               if(out0.getNumColumns() != M) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
out0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+                                       + "but found [" + out0.getNumRows() + 
"," + out0.getNumColumns() + "]");
+               }
+               if(c0.getNumRows() != N || c0.getNumColumns() != M) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
c0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+                                       + "but found [" + out0.getNumRows() + 
"," + out0.getNumColumns() + "]");
+               }
+               if(dout.getNumRows() != N || dout.getNumColumns() != 
(return_seq ? (T*M) : M)) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
dout in lstm_backward instruction. Expected [" + N + ", " + (return_seq ? (T*M) 
: M) + "], "
+                                       + "but found [" + dout.getNumRows() + 
"," + dout.getNumColumns() + "]");
+               }
+               if(dc.getNumRows() != N || dc.getNumColumns() != M) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
dc in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+                                       + "but found [" + dc.getNumRows() + "," 
+ dc.getNumColumns() + "]");
+               }
+               
+               MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M, 
false);
+               MatrixBlock c = new MatrixBlock(N, M, false);
+               MatrixBlock cache_out = new MatrixBlock(T, N*M, false);
+               MatrixBlock cache_c = new MatrixBlock(T, N*M, false);
+               MatrixBlock cache_ifog = new MatrixBlock(T, N*4*M, false);
+               
+               // In the initial implementation, invoke lstm redundantly.
+               // TODO: Optimize this later.
+               cache_out.allocateDenseBlock();
+               cache_c.allocateDenseBlock();
+               cache_ifog.allocateDenseBlock();
+               LibMatrixDNN.lstm(X, W, b, out0, c0, 
+                               return_seq, N, T, D, M,
+                               out,  c, cache_out, cache_c, cache_ifog,
+                               _numThreads);
+               
+               MatrixBlock dX = new MatrixBlock(N, T*D, false);
+               MatrixBlock dW = new MatrixBlock(D+M, 4*M, false);
+               MatrixBlock db = new MatrixBlock(1, 4*M, false);
+               MatrixBlock dout0 = new MatrixBlock(N, M, false);
+               MatrixBlock dc0 = new MatrixBlock(N, M, false);
+               LibMatrixDNN.lstm_backward(dout, dc, X, W, b, out0, c0, 
return_seq, N, T, D, M,
+                               cache_out, cache_c, cache_ifog, // from forward 
invocation
+                               dX, dW, db, dout0, dc0, // output
+                               _numThreads);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in7.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in8.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), dW, getExtendedOpcode());
+               ec.setMatrixOutput(_out3.getName(), db, getExtendedOpcode());
+               ec.setMatrixOutput(_out4.getName(), dout0, getExtendedOpcode());
+               ec.setMatrixOutput(_out5.getName(), dc0, getExtendedOpcode());
+       }
+       
        public void processReluBackwardInstruction(ExecutionContext ec) {
                // (X > 0) * dout
                MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
@@ -498,6 +598,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        processLstmInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
+                       processLstmBackwardInstruction(ec);
+                       return;
+               }
                
                // acquire inputs
                MatrixBlock outputBlock = null;
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 35c9591..fbe7c9d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -361,31 +361,31 @@ public class DnnGPUInstruction extends GPUInstruction {
                }
                else if (opcode.equalsIgnoreCase("lstm")) {
                        InstructionUtils.checkNumFields(parts, 8);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = new CPOperand(parts[3]);
-                       CPOperand in4 = new CPOperand(parts[4]);
-                       CPOperand in5 = new CPOperand(parts[5]);
-                       CPOperand in6 = new CPOperand(parts[6]);
-                       CPOperand out = new CPOperand(parts[7]);
-                       CPOperand out2 = new CPOperand(parts[8]);
+                       CPOperand in1 = new CPOperand(parts[1]); // X
+                       CPOperand in2 = new CPOperand(parts[2]); // W
+                       CPOperand in3 = new CPOperand(parts[3]); // b
+                       CPOperand in4 = new CPOperand(parts[4]); // out0
+                       CPOperand in5 = new CPOperand(parts[5]); // c0
+                       CPOperand in6 = new CPOperand(parts[6]); // return_seq
+                       CPOperand out = new CPOperand(parts[7]); // out
+                       CPOperand out2 = new CPOperand(parts[8]); // c
                        return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, out, out2, opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm_backward")) {
                        InstructionUtils.checkNumFields(parts, 13);
-                       CPOperand in1 = new CPOperand(parts[1]); // image
-                       CPOperand in2 = new CPOperand(parts[2]); // scale
-                       CPOperand in3 = new CPOperand(parts[3]); // bias
-                       CPOperand in4 = new CPOperand(parts[4]); // runningMean
-                       CPOperand in5 = new CPOperand(parts[5]); // runningVar
-                       CPOperand in6 = new CPOperand(parts[6]); // mode
-                       CPOperand in7 = new CPOperand(parts[7]); // epsilon
-                       CPOperand in8 = new CPOperand(parts[8]); // 
exponentialAverageFactor
-                       CPOperand out = new CPOperand(parts[9]);  // ret
-                       CPOperand out2 = new CPOperand(parts[10]); // 
retRunningMean
-                       CPOperand out3 = new CPOperand(parts[11]); // 
retRunningVar
-                       CPOperand out4 = new CPOperand(parts[12]); // 
resultSaveMean
-                       CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
+                       CPOperand in1 = new CPOperand(parts[1]); // X
+                       CPOperand in2 = new CPOperand(parts[2]); // W
+                       CPOperand in3 = new CPOperand(parts[3]); // b
+                       CPOperand in4 = new CPOperand(parts[4]); // out0
+                       CPOperand in5 = new CPOperand(parts[5]); // c0
+                       CPOperand in6 = new CPOperand(parts[6]); // return_seq
+                       CPOperand in7 = new CPOperand(parts[7]); // dout
+                       CPOperand in8 = new CPOperand(parts[8]); // dc
+                       CPOperand out = new CPOperand(parts[9]);  // dX
+                       CPOperand out2 = new CPOperand(parts[10]); // dW
+                       CPOperand out3 = new CPOperand(parts[11]); // db
+                       CPOperand out4 = new CPOperand(parts[12]); // dout0
+                       CPOperand out5 = new CPOperand(parts[13]); // dc0
                        return new DnnGPUInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
                }
                else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 365d7a2..0005932 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -35,19 +36,29 @@ import 
org.apache.sysml.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.functionobjects.PlusMultiply;
+import org.apache.sysml.runtime.functionobjects.Power;
+import org.apache.sysml.runtime.functionobjects.Power2;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysml.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysml.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.DnnUtils;
 
+import com.sun.org.apache.xpath.internal.operations.Minus;
+
 /*
  * This class allows users to invoke deep learning related operations 
  * (such as conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, 
maxpooling_backward, bias_add)
@@ -297,6 +308,10 @@ public class LibMatrixDNN {
                return matBlock1.ternaryOperations(new 
TernaryOperator(PlusMultiply.getFnObject()), 
                                matBlock2, matBlock3, new MatrixBlock());
        }
+       private static MatrixBlock minusMultiply(MatrixBlock matBlock1, 
MatrixBlock matBlock2, MatrixBlock matBlock3) {
+               return matBlock1.ternaryOperations(new 
TernaryOperator(MinusMultiply.getFnObject()), 
+                               matBlock2, matBlock3, new MatrixBlock());
+       }
        
                
        private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock 
matBlock2, boolean inplace) {
@@ -311,6 +326,11 @@ public class LibMatrixDNN {
                }
        }
        
+       private static MatrixBlock multiply(MatrixBlock matBlock1, double 
scalar, boolean inplace) {
+               ScalarOperator sc_op = new 
LeftScalarOperator(Multiply.getMultiplyFnObject(), scalar);
+               return (MatrixBlock) matBlock1.scalarOperations(sc_op, new 
MatrixBlock());
+       }
+       
        
        // sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
        
@@ -322,6 +342,175 @@ public class LibMatrixDNN {
        private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean 
inPlace) {
                return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
        }
+       private static MatrixBlock power(MatrixBlock in, double exponent) {
+               return (MatrixBlock) in.scalarOperations(new 
RightScalarOperator(Power.getPowerFnObject(), exponent), new MatrixBlock());
+       }
+       private static MatrixBlock minus(double scalar, MatrixBlock in) {
+               return (MatrixBlock) in.scalarOperations(new 
LeftScalarOperator(org.apache.sysml.runtime.functionobjects.Minus.getMinusFnObject(),
 scalar), new MatrixBlock());
+       }
+       private static MatrixBlock tanh_backward(MatrixBlock dout, MatrixBlock 
X, int numThreads) {
+               MatrixBlock out = tanh(X, numThreads, false);
+               return minusMultiply(dout, power(out, 2), dout);
+       }
+       
+       public static void lstm_backward(MatrixBlock dout, MatrixBlock dc,
+                       MatrixBlock X, MatrixBlock W, MatrixBlock b, 
MatrixBlock out0, MatrixBlock c0, 
+                       boolean given_sequences, int N, int T, int D, int M,
+                       MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock 
cache_ifog, // from forward invocation
+                       MatrixBlock dX, MatrixBlock dW, MatrixBlock db, 
MatrixBlock dout0, MatrixBlock dc0,
+                       int numThreads) {
+               MatrixBlock dct = dc;
+               if (!given_sequences) {
+                       // only given dout for output at final timestep, so 
prepend empty douts for all other timesteps
+                       dout = new MatrixBlock(N, (T-1)*M, true).append(dout, 
new MatrixBlock());
+               }
+               MatrixBlock dW_ret = dW;
+               MatrixBlock db_ret = db;
+               MatrixBlock dout_t = dout.slice(0, N-1, (T-1)*M, T*M-1, new 
MatrixBlock());
+               for(int t = T; t > 0; t--) {
+                       MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, 
(t-1)*D, t*D-1, new MatrixBlock());
+                       MatrixBlock ct = sliceAndReshape(cache_c, new 
MatrixBlock(), t-1, N, M);
+                       MatrixBlock out_prev = (t == 1) ? out0 : 
sliceAndReshape(cache_out, new MatrixBlock(), t-2, N, M);
+                       MatrixBlock c_prev = (t == 1) ? c0 : 
sliceAndReshape(cache_c, new MatrixBlock(), t-2, N, M);
+                       MatrixBlock input = X_t.append(out_prev, new 
MatrixBlock());
+                       MatrixBlock ifog = sliceAndReshape(cache_ifog, new 
MatrixBlock(), t-1, N, 4*M);
+                       MatrixBlock i = ifog.slice(0, N-1, 0, M-1, new 
MatrixBlock());
+                       MatrixBlock f = ifog.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
+                       MatrixBlock o = ifog.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
+                       MatrixBlock g = ifog.slice(0, N-1, 3*M, 4*M-1, new 
MatrixBlock());
+                       dct = plusMultiply(dct, o, tanh_backward(dout_t, ct, 
numThreads));
+                       MatrixBlock dc_prev = multiply(f, dct, false);
+                       
+                       MatrixBlock di_raw = multiply(new MatrixBlock[] {i, 
minus(1, i), g, dct}); 
+                       MatrixBlock df_raw = multiply(new MatrixBlock[] {f, 
minus(1, f), c_prev, dct});
+                       MatrixBlock do_raw = multiply(new MatrixBlock[] {o, 
minus(1, o), tanh(ct, numThreads, false), dout_t});
+                       MatrixBlock dg_raw = multiply(new MatrixBlock[] 
{minus(1, power(g, 2)), i, dct});
+                       MatrixBlock difog_raw = di_raw.append(new MatrixBlock[] 
{ df_raw, do_raw, dg_raw}, new MatrixBlock(), true);
+                       
+                       // dW = dW + t(input) %*% difog_raw
+                       dW = add(matmult(transpose(input, numThreads), 
difog_raw, numThreads), dW, true);
+                       // db = db + colSums(difog_raw)
+                       db = add(colSums(difog_raw), db, true);
+                       // dinput = difog_raw %*% t(W)
+                       MatrixBlock dinput = matmult(difog_raw, transpose(W, 
numThreads), numThreads);
+                       // dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+                       dX.leftIndexingOperations(dinput.slice(0, N-1, 0, D-1, 
new MatrixBlock()), 0, N-1, (t-1)*D, t*D-1, dX, UpdateType.INPLACE);
+                       // dout_prev = dinput[,D+1:D+M]
+                       MatrixBlock dout_prev = dinput.slice(0, N-1, D, D+M-1, 
new MatrixBlock());
+                       
+                       if(t == 1) {
+                               dout0.copy(dout_prev);
+                               dc0.copy(dc_prev);
+                       }
+                       else {
+                               dout_t = add(dout.slice(0, N-1, (t-2)*M, 
(t-1)*M-1, new MatrixBlock()), dout_prev, true);
+                               dct = dc_prev;
+                       }
+               }
+               dW_ret.copy(dW);
+               db_ret.copy(db);
+       }
+       
+       
+       private static MatrixBlock colSums(MatrixBlock in) {
+               MatrixBlock ret = new MatrixBlock(1, in.getNumColumns(), false);
+               if(in.isEmpty()) {
+                       // Do nothing
+                       ret.setNonZeros(0);
+               }
+               else if(in.isInSparseFormat()) {
+                       ret.allocateDenseBlock();
+                       double [] retArr = ret.getDenseBlockValues();
+                       SparseBlock sblock = in.getSparseBlock();
+                       for(int n = 0; n < in.getNumRows(); n++) {
+                               if( sblock.isEmpty(n) )
+                                       continue;
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       retArr[aix[j]] += avals[j];
+                               }
+                       }
+                       ret.recomputeNonZeros();
+               }
+               else {
+                       double [] inArr = in.getDenseBlockValues();
+                       if(inArr != null) {
+                               int index = 0;
+                               ret.allocateDenseBlock();
+                               double [] retArr = ret.getDenseBlockValues();
+                               for(int r = 0; r < in.getNumRows(); r++) {
+                                       for(int c = 0; c < in.getNumColumns(); 
c++, index++) {
+                                               retArr[c] += inArr[index];
+                                       }
+                               }
+                               ret.recomputeNonZeros();
+                       }
+                       else {
+                               ret.setNonZeros(0);
+                       }
+               }
+               return ret;
+       }
+       
+       private static MatrixBlock transpose(MatrixBlock in, int numThreads) {
+               ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+               return (MatrixBlock) (in.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0));
+       }
+       
+       private static MatrixBlock multiply(MatrixBlock [] in) {
+               boolean allDense = true;
+               int rows = 0; int cols = 0;
+               for(MatrixBlock mb : in) {
+                       rows = Math.max(rows, mb.getNumRows());
+                       cols = Math.max(cols, mb.getNumColumns());
+               }
+               for(MatrixBlock mb : in) {
+                       if(mb.isEmpty() || (!mb.isInSparseFormat() && 
mb.getDenseBlockValues() == null)) {
+                               MatrixBlock ret = new MatrixBlock(rows, cols, 
true);
+                               ret.setNonZeros(0);
+                               return ret;
+                       }
+                       allDense = allDense && !mb.isInSparseFormat();
+               }
+               if(allDense) {
+                       MatrixBlock ret = new MatrixBlock(rows, cols, false);
+                       ret.allocateDenseBlock();
+                       double [] retArr = null;
+                       // Avoids (in.length-1) recomputeNonZeros calls
+                       for(MatrixBlock mb : in) {
+                               if(retArr == null) {
+                                       retArr = ret.getDenseBlockValues();
+                                       
System.arraycopy(mb.getDenseBlockValues(), 0, retArr, 0, retArr.length);
+                               }
+                               else {
+                                       double [] inArr = 
mb.getDenseBlockValues();
+                                       for(int index = 0; index < 
retArr.length; index++) {
+                                               retArr[index] *= inArr[index];
+                                       }
+                               }
+                       }
+                       ret.recomputeNonZeros();
+                       return ret;
+               }
+               else {
+                       Arrays.sort(in, (mb1, mb2) -> 
Long.compare(mb1.getNonZeros(), mb2.getNonZeros()));
+                       MatrixBlock ret = new MatrixBlock(rows, cols, 
in[0].isInSparseFormat());
+                       for(MatrixBlock mb : in) {
+                               ret = multiply(ret, mb, true);
+                       }
+                       return ret;
+               }
+       }
+       
+       // Performs the following operation: ret = matrix(in[rowIndex+1,], 
rows=numRows, cols=numCols)
+       public static MatrixBlock sliceAndReshape(MatrixBlock in, MatrixBlock 
ret, int rowIndex, int numRows, int numCols) {
+               return LibMatrixReorg.reshape(in.slice(rowIndex, rowIndex), 
ret, numRows, numCols, true);
+       }
        
        public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, 
MatrixBlock out0, MatrixBlock c0, 
                        boolean return_seq, int N, int T, int D, int M,
@@ -367,27 +556,84 @@ public class LibMatrixDNN {
                                ifog_raw = add(matmult(input, W, numThreads), 
b, true);
                        }
                        
-                       MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new 
MatrixBlock());
-                       ifo = sigmoid(ifo, numThreads, true);
-                       MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new 
MatrixBlock());
-                       MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
-                       MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
-                       MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, 
new MatrixBlock()), numThreads, true);
-                                       
-                       // c_t = f*c_prev + i*g
-                       c_t = plusMultiply(multiply(f, c_prev, true), i, g);
-                       // out_t = o*tanh(c)
-                       out_t = multiply(o, tanh(c_t, numThreads, false), true);
+                       if(!ifog_raw.isInSparseFormat() && 
!c_prev.isInSparseFormat()) {
+                               double [] ifog_rawArr = 
ifog_raw.getDenseBlockValues();
+                               double [] c_prevArr = 
c_prev.getDenseBlockValues();
+                               double [] cache_ifogArr = null;
+                               if(cache_ifog != null) {
+                                       cache_ifogArr = 
cache_ifog.getDenseBlockValues();
+                                       if(cache_ifogArr == null)
+                                               throw new 
DMLRuntimeException("Expected cache_ifog to be allocated in the dense format");
+                               }
+                               if(ifog_rawArr == null && c_prevArr == null) {
+                                       // Both ifog_raw and c_prev are empty 
matrix
+                                       c_t = new MatrixBlock(N, M, 0);
+                                       out_t = new MatrixBlock(N, M, 0);
+                                       c_t.setNonZeros(0);
+                                       out_t.setNonZeros(0);
+                                       updateIfogCache(cache_ifogArr, t, N, M);
+                               }
+                               else if(ifog_rawArr == null) {
+                                       // ifog_raw is an empty matrix
+                                       // c_t = f*c_prev + i*g 
+                                       //     = 0.5*c_prev
+                                       c_t = multiply(c_prev, 0.5, false);
+                                       // out_t = o*tanh(c)
+                                       //       = 0.5*tanh(c)
+                                       out_t = multiply(tanh(c_t, numThreads, 
false), 0.5, false);
+                                       updateIfogCache(cache_ifogArr, t, N, M);
+                               }
+                               else {
+                                       // ifog_raw is not an empty matrix
+                                       c_t = new MatrixBlock(N, M, false); 
c_t.allocateDenseBlock();
+                                       double [] c_tArr = 
c_t.getDenseBlockValues();
+                                       out_t = new MatrixBlock(N, M, false); 
out_t.allocateDenseBlock();
+                                       double [] out_tArr = 
out_t.getDenseBlockValues();
+                                       int index = 0;
+                                       int offset = (t-1)*N*4*M;
+                                       for(int n = 0; n < N; n++) {
+                                               for(int m = 0; m < M; m++, 
index++) {
+                                                       double c_prevVal = 
(c_prevArr == null) ? 0 : c_prevArr[index];
+                                                       // c_t = f*c_prev + i*g
+                                                       double i = 
sigmoidOp.execute(ifog_rawArr[n*4*M + m]);
+                                                       double f = 
sigmoidOp.execute(ifog_rawArr[n*4*M + M + m]);
+                                                       double o = 
sigmoidOp.execute(ifog_rawArr[n*4*M + 2*M + m]);
+                                                       double g = 
tanhOp.execute(ifog_rawArr[n*4*M + 3*M + m]);
+                                                       c_tArr[index] = 
f*c_prevVal + i*g;
+                                                       // out_t = o*tanh(c)
+                                                       out_tArr[index] = 
o*tanhOp.execute(c_tArr[index]);
+                                                       
updateIfogCache(cache_ifogArr, i, f, o, g, offset, n, m, N, M);
+                                               }
+                                       }
+                                       c_t.recomputeNonZeros();
+                                       out_t.recomputeNonZeros();
+                               }
+                       }
+                       else {
+                               MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 
3*M-1, new MatrixBlock());
+                               ifo = sigmoid(ifo, numThreads, true);
+                               MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new 
MatrixBlock());
+                               MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
+                               MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, 
new MatrixBlock());
+                               MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 
3*M, 4*M-1, new MatrixBlock()), numThreads, true);
+                                               
+                               // c_t = f*c_prev + i*g
+                               c_t = plusMultiply(multiply(f, c_prev, true), 
i, g);
+                               // out_t = o*tanh(c)
+                               out_t = multiply(o, tanh(c_t, numThreads, 
false), true);
+                               updateIfogCache(cache_ifog, ifo, g, t, N, M);
+                       }
+                       
                        if(return_seq) {
-                               out = out.leftIndexingOperations(out_t, 0, N-1, 
(t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
+                               out = out.leftIndexingOperations(out_t, 0, N-1, 
(t-1)*M, t*M-1, out, UpdateType.INPLACE);
                        }
                        out_prev = out_t;
                        c_prev = c_t;
                        
-                       // TODO: Add this when implementing lstm_backward
-//                     cache_out[t,] = matrix(out_t, rows=1, cols=N*M)  # 
reshape
-//                 cache_c[t,] = matrix(c, rows=1, cols=N*M)  # reshape
-//                 cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M)  
# reshape
+                       if(cache_out != null) {
+                               reshapeAsRowMatrixAndLeftIndex(cache_out, 
out_t, t-1, N*M);
+                               reshapeAsRowMatrixAndLeftIndex(cache_c, c_t, 
t-1, N*M);
+                       }
                }
                if(out_t != null && !return_seq)
                        out.copy(out_t);
@@ -395,6 +641,69 @@ public class LibMatrixDNN {
                        c.copy(c_t);
                else
                        c.copy(c0);
+               if(cache_out != null) {
+                       cache_out.recomputeNonZeros();
+                       cache_c.recomputeNonZeros();
+                       cache_ifog.recomputeNonZeros();
+               }
+       }
+       
+       private static void updateIfogCache(MatrixBlock cache_ifog, MatrixBlock 
ifo, MatrixBlock g, int t, int N, int M) {
+               if(cache_ifog != null) {
+                       reshapeAsRowMatrixAndLeftIndex(cache_ifog, 
ifo.append(g, new MatrixBlock()), t-1, N*M);
+               }
+       }
+       
+       // ifog_raw is an empty matrix
+       private static void updateIfogCache(double[] cache_ifogArr, int t, int 
N, int M) {
+               if(cache_ifogArr != null) {
+                       int offset = (t-1)*N*4*M;
+                       for(int n = 0 ; n < N; n++) {
+                               int srcIndex = offset + n*4*M;
+                               Arrays.fill(cache_ifogArr, srcIndex, srcIndex + 
3*M, 0.5);
+                       }
+               }
+       }
+       
+       private static void updateIfogCache(double[] cache_ifogArr, double i, 
double f, double o, double g, int offset, int n, int m, int N, int M) {
+               if(cache_ifogArr != null) {
+                       cache_ifogArr[offset + n*4*M + m] = i;
+                       cache_ifogArr[offset + n*4*M + M + m] = f;
+                       cache_ifogArr[offset + n*4*M + 2*M + m] = o;
+                       cache_ifogArr[offset + n*4*M + 3*M + m] = g;
+               }
+       }
+       
+       // Performs operation: lhsMatrix[rowIndex+1, ] =  matrix(rhsMatrix, 
rows=1, cols=numCols)
+       private static void reshapeAsRowMatrixAndLeftIndex(MatrixBlock 
lhsMatrix, MatrixBlock rhsMatrix, int rowIndex, int numCols) {
+               double [] lhsArr = lhsMatrix.getDenseBlockValues();
+               if(lhsArr == null)
+                       throw new DMLRuntimeException("Incorrect usage: 
lhsMatrix needs to be allocated in dense format before invocation of this 
method.");
+               if(rhsMatrix.isInSparseFormat()) {
+                       SparseBlock sblock = rhsMatrix.getSparseBlock();
+                       for(int n = 0; n < rhsMatrix.getNumRows(); n++) {
+                               if( sblock.isEmpty(n) )
+                                       continue;
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       lhsArr[n*numCols + aix[j]] = avals[j];
+                               }
+                       }
+               }
+               else if(!rhsMatrix.isInSparseFormat()) {
+                       double [] rhsArr = rhsMatrix.getDenseBlockValues();
+                       if(rhsArr != null) {
+                               System.arraycopy(rhsArr, 0, lhsArr, 
rowIndex*numCols, numCols);
+                       }
+                       else {
+                               // Do nothing: assumption => lhsMatrix is 
initialized to 0 before invocation.
+                       }
+               }
        }
        
        /**
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 785c890..5c93bca 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -131,82 +131,141 @@ public class LstmCPUTest extends GPUTests {
                inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
                inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
                List<String> outputs = Arrays.asList("output", "c");
-               List<Object> outGPUWithCuDNN = null;
-               List<Object> outCPUWithNN = null;
-               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
-                       try {
-                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
-                               outGPUWithCuDNN = runOnCPU(spark, scriptStr1, 
inputs, outputs);
-                               outCPUWithNN = runOnCPU(spark, scriptStr2, 
inputs, outputs);
-                       }
-                       finally {
-                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
-                       }
-               }
-               assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
-               assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+               List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, 
outputs);
+               List<Object> outNNLayer = runOnCPU(spark, scriptStr2, inputs, 
outputs);
+               assertEqualObjects(outBuiltin.get(0), outNNLayer.get(0));
+               assertEqualObjects(outBuiltin.get(1), outNNLayer.get(1));
        }
        
        
+       @Test
+       public void testLstmBackward1() {
+               testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward2() {
+               testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward3() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 
0.9);
+       }
        
 //     @Test
-//     public void testLstmBackward7() {
-//             testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+//     public void testLstmBackward4() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.9, 
0.9);
 //     }
-//     
+       
+       @Test
+       public void testLstmBackward5() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward6() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward7() {
+               testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward8() {
+               testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+       }
+       
+       @Test
+       public void testLstmBackward9() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward10() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 
0.9);
+       }
+       
 //     @Test
-//     public void testLstmBackward8() {
-//             testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+//     public void testLstmBackward11() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.2, 
0.3);
 //     }
-//     
+       
+       @Test
+       public void testLstmBackward12() {
+               testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.2, 
0.9);
+       }
+       
 //     @Test
-//     public void testLstmBackward9() {
-//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 
0.9);
+//     public void testLstmBackward13() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 
0.1);
 //     }
-//     
+       
 //     @Test
-//     public void testLstmBackward10() {
-//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 
0.9);
+//     public void testLstmBackward14() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.3, 
0.6);
 //     }
-//     
-//     
-//     public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int 
M, String returnSequences, double sparsity,
-//                     double weightSparsity) {
-//             boolean returnSequences1 = returnSequences.equals("TRUE");
-//             
-//             String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-//                             + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
-//             String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
-//                             + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
-//                             + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"
-//                             + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " 
-//                             + T + ", " + D + ", " + returnSequences + ", 
out0, c0, cache_out, cache_c, cache_ifog);";
-//             
-//             HashMap<String, Object> inputs = new HashMap<>();
-//             inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
-//             inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
-//             inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
-//             inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
weightSparsity, seed));
-//             inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
-//             inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
-//             inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
-//             List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", 
"dc0");
-//             List<Object> outGPUWithCuDNN = null;
-//             List<Object> outCPUWithNN = null;
-//             synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
-//                     try {
-//                             DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
-//                             outGPUWithCuDNN = runOnCPU(spark, scriptStr1, 
inputs, outputs);
-//                     }
-//                     finally {
-//                             DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
-//                     }
-//                     outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, 
outputs);
-//             }
-//             assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
-//             assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
-//             assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
-//             assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
-//             assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+       
+       @Test
+       public void testLstmBackward15() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.2, 
0.9);
+       }
+       
+//     @Test
+//     public void testLstmBackward16() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.3, 
0.1);
 //     }
+       
+       @Test
+       public void testLstmBackward17() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward18() {
+               testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "FALSE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward19() {
+               testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "TRUE", 0.9, 
0.9);
+       }
+       
+       @Test
+       public void testLstmBackward20() {
+               testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "FALSE", 0.9, 
0.9);
+       }
+       
+       
+       public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int 
M, String returnSequences, double sparsity,
+                       double weightSparsity) {
+               boolean returnSequences1 = returnSequences.equals("TRUE");
+               
+               String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+               String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+                               + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"
+                               + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0, cache_out, cache_c, cache_ifog);";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+               inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
weightSparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", 
"dc0");
+               List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, 
outputs);
+               List<Object> outNN = runOnCPU(spark, scriptStr2, inputs, 
outputs);
+               assertEqualObjects(outBuiltin.get(0), outNN.get(0));
+               assertEqualObjects(outBuiltin.get(1), outNN.get(1));
+               assertEqualObjects(outBuiltin.get(2), outNN.get(2));
+               assertEqualObjects(outBuiltin.get(3), outNN.get(3));
+               assertEqualObjects(outBuiltin.get(4), outNN.get(4));
+       }
 }

Reply via email to