This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 592f1b0  [SYSTEMML-540] Added an initial CP operator for lstm builtin 
function
592f1b0 is described below

commit 592f1b0e9a5566195d0e73fede318e2a269bb4a0
Author: Niketan Pansare <npan...@us.ibm.com>
AuthorDate: Sat Feb 23 09:05:11 2019 -0800

    [SYSTEMML-540] Added an initial CP operator for lstm builtin function
    
    - This operator relies on existing slice, left indexing, binary and matrix 
multiplication operators.
    - We can later create a fused implementation especially for lstm activation 
to avoid unnecessary copies in the inner loop.
    - To exploit sparsity in the input data and avoid unnecessary 
sparse-to-dense conversion (as out_prev is often dense), we perform two matrix 
multiplication followed by binary addition.
---
 .../java/org/apache/sysml/hops/FunctionOp.java     |  42 +++++-
 .../runtime/instructions/CPInstructionParser.java  |   2 +
 .../runtime/instructions/cp/DnnCPInstruction.java  |  61 ++++++++
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    |  88 +++++++++++
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     | 162 +++++++++++++++++++++
 5 files changed, 347 insertions(+), 8 deletions(-)

diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 5f177bd..66ce478 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -274,19 +274,45 @@ public class FunctionOp extends Hop
        {
                checkAndSetForcedPlatform();
                
-               if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
-                       boolean isBuiltinFunction = isBuiltinFunction();
+               if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && 
isBuiltinFunction() &&
+                       (getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
+                       
+                       if(getFunctionName().equalsIgnoreCase("lstm_backward")) 
{
+                               if(!ConfigurationManager.isGPU())
+                                       throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
+                               _etype = ExecType.GPU;
+                       }
+                       
+                       ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() 
? ExecType.SPARK : ExecType.MR;
+                       
+                       if( _etypeForced != null ) {
+                               _etype = _etypeForced;
+                       }
+                       else {  
+                               if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+                                       _etype = findExecTypeByMemEstimate();
+                               }
+                               else {
+                                       _etype = REMOTE;
+                               }
+                               
+                               //check for valid CP dimensions and matrix size
+                               checkAndSetInvalidCPDimsAndSize();
+                       }
+                       
+                       // Since lstm builtin functions are not supported on 
Spark
+                       _etype = _etype == REMOTE ?  ExecType.CP : _etype;
+                       
+                       //mark for recompile (forever)
+                       setRequiresRecompileIfNecessary();
+               }
+               else if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN 
) {
                        // check if there is sufficient memory to execute this 
function
-                       if(isBuiltinFunction && 
getFunctionName().equalsIgnoreCase("transformencode") ) {
+                       if(isBuiltinFunction() && 
getFunctionName().equalsIgnoreCase("transformencode") ) {
                                _etype = ((_etypeForced==ExecType.SPARK 
                                        || (getMemEstimate() >= 
OptimizerUtils.getLocalMemBudget()
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
                        }
-                       else if(isBuiltinFunction && 
(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
-                               if(!ConfigurationManager.isGPU())
-                                       throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
-                               _etype = ExecType.GPU;
-                       }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw
                                // exception if the estimated memory is larger 
than the budget
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index ea01dc2..6e253d8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -247,6 +247,8 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "conv2d_bias_add"      , 
CPType.Dnn);
                String2CPInstructionType.put( "conv2d_backward_filter"      , 
CPType.Dnn);
                String2CPInstructionType.put( "conv2d_backward_data"      , 
CPType.Dnn);
+               String2CPInstructionType.put( "lstm",                   
CPType.Dnn);
+               String2CPInstructionType.put( "lstm_backward",          
CPType.Dnn);
                String2CPInstructionType.put( "bias_add"      , CPType.Dnn);
                String2CPInstructionType.put( "bias_multiply"      , 
CPType.Dnn);
                String2CPInstructionType.put( "batch_norm2d",           
CPType.Dnn);
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 7bed33f..4043908 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -261,6 +261,18 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand out3 = new CPOperand(parts[9]); // dBias
                        return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
                }
+               else if (opcode.equalsIgnoreCase("lstm")) {
+                       InstructionUtils.checkNumFields(parts, 8);
+                       CPOperand in1 = new CPOperand(parts[1]); // X
+                       CPOperand in2 = new CPOperand(parts[2]); // W
+                       CPOperand in3 = new CPOperand(parts[3]); // b
+                       CPOperand in4 = new CPOperand(parts[4]); // out0
+                       CPOperand in5 = new CPOperand(parts[5]); // c0
+                       CPOperand in6 = new CPOperand(parts[6]); // return_seq
+                       CPOperand out = new CPOperand(parts[7]);  // out
+                       CPOperand out2 = new CPOperand(parts[8]); // c
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, null, null, null, opcode, str, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnCPInstruction: " + str);
                }
@@ -271,6 +283,51 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        aL.get(index).getValueType(), 
aL.get(index).isLiteral()).getLongValue();
        }
        
+       public void processLstmInstruction(ExecutionContext ec) {
+               MatrixBlock X = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock W = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock b = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               MatrixBlock out0 = ec.getMatrixInput(_in4.getName(), 
getExtendedOpcode());
+               MatrixBlock c0 = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               boolean return_seq = ec.getScalarInput(_in6.getName(), 
_in6.getValueType(), _in6.isLiteral()).getBooleanValue();
+               
+               int N = X.getNumRows();
+               int TD = X.getNumColumns();
+               int DPlusM = W.getNumRows();
+               int M = W.getNumColumns() / 4;
+               if(b.getNumRows() != 1 || b.getNumColumns() != M*4) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
bias in lstm instruction. Expected [1, " + (M*4) + "], "
+                                       + "but found [" + b.getNumRows() + "," 
+ b.getNumColumns() + "]");
+               }
+               if(out0.getNumRows() != N || out0.getNumColumns() != M) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
out0 in lstm instruction. Expected [" + N + ", " + M + "], "
+                                       + "but found [" + out0.getNumRows() + 
"," + out0.getNumColumns() + "]");
+               }
+               if(c0.getNumRows() != N || c0.getNumColumns() != M) {
+                       throw new DMLRuntimeException("Incorrect dimensions of 
c0 in lstm instruction. Expected [" + N + ", " + M + "], "
+                                       + "but found [" + out0.getNumRows() + 
"," + out0.getNumColumns() + "]");
+               }
+               int D = DPlusM - M;
+               int T = TD / D;
+               
+               MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M, 
false);
+               MatrixBlock c = new MatrixBlock(N, M, false);
+               
+               LibMatrixDNN.lstm(X, W, b, out0, c0, 
+                               return_seq, N, T, D, M,
+                               out,  c, null, null, null,
+                               _numThreads);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode());
+       }
+       
        public void processReluBackwardInstruction(ExecutionContext ec) {
                // (X > 0) * dout
                MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
@@ -436,6 +493,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        processBatchNorm2dBackwardInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("lstm")) {
+                       processLstmInstruction(ec);
+                       return;
+               }
                
                // acquire inputs
                MatrixBlock outputBlock = null;
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 4569dbe..179b5d3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -31,10 +31,22 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
+import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.functionobjects.Multiply;
+import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
+import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.DnnUtils;
+import org.apache.sysml.runtime.util.IndexRange;
 
 /*
  * This class allows users to invoke deep learning related operations 
@@ -262,6 +274,79 @@ public class LibMatrixDNN {
                outputBlock.examSparsity();
        }
        
+       private static MatrixBlock matmult(MatrixBlock matBlock1, MatrixBlock 
matBlock2, int numThreads) {
+               AggregateBinaryOperator ab_op = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), 
+                               new AggregateOperator(0, 
Plus.getPlusFnObject()), numThreads);
+               MatrixBlock main = (matBlock2 instanceof CompressedMatrixBlock) 
? matBlock2 : matBlock1;
+               MatrixBlock ret = main.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op);
+               return ret;
+       }
+       
+       private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock 
matBlock2) {
+               return (MatrixBlock) matBlock1.binaryOperations(new 
BinaryOperator(Plus.getPlusFnObject()), matBlock2, new MatrixBlock());
+       }
+       private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock 
matBlock2) {
+               return (MatrixBlock) matBlock1.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), matBlock2, new MatrixBlock());
+       }
+       
+       // sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
+       
+       private static Builtin sigmoidOp = 
Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID);
+       private static Builtin tanhOp = 
Builtin.getBuiltinFnObject(BuiltinCode.TANH);
+       private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, 
boolean inPlace) {
+               return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
+       }
+       private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean 
inPlace) {
+               return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
+       }
+       
+       public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, 
MatrixBlock out0, MatrixBlock c0, 
+                       boolean return_seq, int N, int T, int D, int M,
+                       MatrixBlock out, MatrixBlock c, // output 
+                       MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock 
cache_ifog, // if null, the cache values are not computed
+                       int numThreads) {
+               MatrixBlock out_prev = out0;
+               MatrixBlock c_prev = c0;
+               
+               MatrixBlock W1 = W.slice(0, D-1);
+               MatrixBlock W2 = W.slice(D, D+M-1);
+               MatrixBlock c_t = null;
+               MatrixBlock out_t = null;
+               for(int t = 1; t <= T; t++) {
+                       MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new 
MatrixBlock());
+                       MatrixBlock ifog_raw = add(add(matmult(X_t, W1, 
numThreads), matmult(out_prev, W2, numThreads)), b);
+                       MatrixBlock i = ifog_raw.slice(0, N-1, 0, M-1, new 
MatrixBlock());
+                       MatrixBlock f = ifog_raw.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
+                       MatrixBlock o = ifog_raw.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
+                       MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new 
MatrixBlock());
+                       i = sigmoid(i, numThreads, true);
+                       f = sigmoid(f, numThreads, true);
+                       o = sigmoid(o, numThreads, true);
+                       g = tanh(g, numThreads, true);
+                       // c_t = f*c_prev + i*g
+                       c_t = add(multiply(f, c_prev) , multiply(i, g));
+                       // out_t = o*tanh(c)
+                       out_t = multiply(o, tanh(c_t, numThreads, false));
+                       if(return_seq) {
+                               out = out.leftIndexingOperations(out_t, 0, N-1, 
(t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
+                       }
+                       out_prev = out_t;
+                       c_prev = c_t;
+                       
+                       // TODO: Add this when implementing lstm_backward
+//                     cache_out[t,] = matrix(out_t, rows=1, cols=N*M)  # 
reshape
+//                 cache_c[t,] = matrix(c, rows=1, cols=N*M)  # reshape
+//                 cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M)  
# reshape
+               }
+               if(out_t != null && !return_seq)
+                       out.copy(out_t);
+               if(c_t != null)
+                       c.copy(c_t);
+               else
+                       c.copy(c0);
+               
+       }
+       
        /**
         * This method computes the backpropagation errors for previous layer 
of relu operation
         * 
@@ -574,6 +659,9 @@ public class LibMatrixDNN {
                        if(inputArr != null) {
                                System.arraycopy(inputArr, 0, output, 0, 
inputArr.length);
                        }
+                       else {
+                               Arrays.fill(output, 0);
+                       }
                }
        }
        
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
new file mode 100644
index 0000000..3aa37ad
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
+import 
org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction.LstmOperator;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests lstm builtin function
+ */
+public class LstmCPUTest extends GPUTests {
+
+       private final static String TEST_NAME = "LstmTests";
+       private final int seed = 42;
+       
+       private final static String builtinDML = 
"\"nn/layers/lstm_staging.dml\"";
+       private final static String nnDML = "\"nn/layers/lstm.dml\"";
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+       
+       @Test
+       public void testLstmForward9() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward10() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward11() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9);
+       }
+       
+       @Test
+       public void testLstmForward12() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9);
+       }
+       
+       public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String 
returnSequences, double sparsity) {
+               String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+                               + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";
+               String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+                               + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
+                               + T + ", " + D + ", " + returnSequences + ", 
out0, c0)";
+               
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+               inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+               inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+               List<String> outputs = Arrays.asList("output", "c");
+               List<Object> outGPUWithCuDNN = null;
+               List<Object> outCPUWithNN = null;
+               synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+                       try {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+                               outGPUWithCuDNN = runOnCPU(spark, scriptStr1, 
inputs, outputs);
+                               outCPUWithNN = runOnCPU(spark, scriptStr2, 
inputs, outputs);
+                       }
+                       finally {
+                               DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+                       }
+               }
+               assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+               assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+       }
+       
+       
+       
+//     @Test
+//     public void testLstmBackward7() {
+//             testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+//     }
+//     
+//     @Test
+//     public void testLstmBackward8() {
+//             testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+//     }
+//     
+//     @Test
+//     public void testLstmBackward9() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 
0.9);
+//     }
+//     
+//     @Test
+//     public void testLstmBackward10() {
+//             testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 
0.9);
+//     }
+//     
+//     
+//     public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int 
M, String returnSequences, double sparsity,
+//                     double weightSparsity) {
+//             boolean returnSequences1 = returnSequences.equals("TRUE");
+//             
+//             String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+//                             + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+//             String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+//                             + "[output, c, cache_out, cache_c, cache_ifog] 
= lstm::forward(x, w, b, " 
+//                             + T + ", " + D + ", " + returnSequences + ", 
out0, c0); \n"
+//                             + "[dX, dW, db, dout0, dc0] = 
lstm::backward(dout, dc, x, w, b, " 
+//                             + T + ", " + D + ", " + returnSequences + ", 
out0, c0, cache_out, cache_c, cache_ifog);";
+//             
+//             HashMap<String, Object> inputs = new HashMap<>();
+//             inputs.put("dout", generateInputMatrix(spark, N, 
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+//             inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+//             inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, 
sparsity, seed));
+//             inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, 
weightSparsity, seed));
+//             inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, 
sparsity, seed));
+//             inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+//             inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, 
sparsity, seed));
+//             List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", 
"dc0");
+//             List<Object> outGPUWithCuDNN = null;
+//             List<Object> outCPUWithNN = null;
+//             synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+//                     try {
+//                             DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.CUDNN;
+//                             outGPUWithCuDNN = runOnCPU(spark, scriptStr1, 
inputs, outputs);
+//                     }
+//                     finally {
+//                             DnnGPUInstruction.FORCED_LSTM_OP = 
LstmOperator.NONE;
+//                     }
+//                     outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, 
outputs);
+//             }
+//             assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+//             assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+//             assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
+//             assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
+//             assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+//     }
+}

Reply via email to