This is an automated email from the ASF dual-hosted git repository.
niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 592f1b0 [SYSTEMML-540] Added an initial CP operator for lstm builtin
function
592f1b0 is described below
commit 592f1b0e9a5566195d0e73fede318e2a269bb4a0
Author: Niketan Pansare <[email protected]>
AuthorDate: Sat Feb 23 09:05:11 2019 -0800
[SYSTEMML-540] Added an initial CP operator for lstm builtin function
- This operator relies on existing slice, left indexing, binary and matrix
multiplication operators.
- We can later create a fused implementation especially for lstm activation
to avoid unnecessary copies in the inner loop.
- To exploit sparsity in the input data and avoid unnecessary
sparse-to-dense conversion (as out_prev is often dense), we perform two matrix
multiplication followed by binary addition.
---
.../java/org/apache/sysml/hops/FunctionOp.java | 42 +++++-
.../runtime/instructions/CPInstructionParser.java | 2 +
.../runtime/instructions/cp/DnnCPInstruction.java | 61 ++++++++
.../sysml/runtime/matrix/data/LibMatrixDNN.java | 88 +++++++++++
.../org/apache/sysml/test/gpu/LstmCPUTest.java | 162 +++++++++++++++++++++
5 files changed, 347 insertions(+), 8 deletions(-)
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 5f177bd..66ce478 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -274,19 +274,45 @@ public class FunctionOp extends Hop
{
checkAndSetForcedPlatform();
- if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
- boolean isBuiltinFunction = isBuiltinFunction();
+ if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN &&
isBuiltinFunction() &&
+ (getFunctionName().equalsIgnoreCase("lstm") ||
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
+
+ if(getFunctionName().equalsIgnoreCase("lstm_backward"))
{
+ if(!ConfigurationManager.isGPU())
+ throw new RuntimeException("The
function " + getFunctionName() + " is only supported on GPU.");
+ _etype = ExecType.GPU;
+ }
+
+ ExecType REMOTE = OptimizerUtils.isSparkExecutionMode()
? ExecType.SPARK : ExecType.MR;
+
+ if( _etypeForced != null ) {
+ _etype = _etypeForced;
+ }
+ else {
+ if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+ _etype = findExecTypeByMemEstimate();
+ }
+ else {
+ _etype = REMOTE;
+ }
+
+ //check for valid CP dimensions and matrix size
+ checkAndSetInvalidCPDimsAndSize();
+ }
+
+ // Since lstm builtin functions are not supported on
Spark
+ _etype = _etype == REMOTE ? ExecType.CP : _etype;
+
+ //mark for recompile (forever)
+ setRequiresRecompileIfNecessary();
+ }
+ else if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN
) {
// check if there is sufficient memory to execute this
function
- if(isBuiltinFunction &&
getFunctionName().equalsIgnoreCase("transformencode") ) {
+ if(isBuiltinFunction() &&
getFunctionName().equalsIgnoreCase("transformencode") ) {
_etype = ((_etypeForced==ExecType.SPARK
|| (getMemEstimate() >=
OptimizerUtils.getLocalMemBudget()
&&
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
}
- else if(isBuiltinFunction &&
(getFunctionName().equalsIgnoreCase("lstm") ||
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
- if(!ConfigurationManager.isGPU())
- throw new RuntimeException("The
function " + getFunctionName() + " is only supported on GPU.");
- _etype = ExecType.GPU;
- }
else {
// Since the memory estimate is only
conservative, do not throw
// exception if the estimated memory is larger
than the budget
diff --git
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index ea01dc2..6e253d8 100644
---
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -247,6 +247,8 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "conv2d_bias_add" ,
CPType.Dnn);
String2CPInstructionType.put( "conv2d_backward_filter" ,
CPType.Dnn);
String2CPInstructionType.put( "conv2d_backward_data" ,
CPType.Dnn);
+ String2CPInstructionType.put( "lstm",
CPType.Dnn);
+ String2CPInstructionType.put( "lstm_backward",
CPType.Dnn);
String2CPInstructionType.put( "bias_add" , CPType.Dnn);
String2CPInstructionType.put( "bias_multiply" ,
CPType.Dnn);
String2CPInstructionType.put( "batch_norm2d",
CPType.Dnn);
diff --git
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 7bed33f..4043908 100644
---
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -261,6 +261,18 @@ public class DnnCPInstruction extends UnaryCPInstruction {
CPOperand out3 = new CPOperand(parts[9]); // dBias
return new DnnCPInstruction(in1, in2, in3, in4, in5,
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
+ else if (opcode.equalsIgnoreCase("lstm")) {
+ InstructionUtils.checkNumFields(parts, 8);
+ CPOperand in1 = new CPOperand(parts[1]); // X
+ CPOperand in2 = new CPOperand(parts[2]); // W
+ CPOperand in3 = new CPOperand(parts[3]); // b
+ CPOperand in4 = new CPOperand(parts[4]); // out0
+ CPOperand in5 = new CPOperand(parts[5]); // c0
+ CPOperand in6 = new CPOperand(parts[6]); // return_seq
+ CPOperand out = new CPOperand(parts[7]); // out
+ CPOperand out2 = new CPOperand(parts[8]); // c
+ return new DnnCPInstruction(in1, in2, in3, in4, in5,
in6, null, null, out, out2, null, null, null, opcode, str, 0);
+ }
else {
throw new DMLRuntimeException("Unknown opcode while
parsing a DnnCPInstruction: " + str);
}
@@ -271,6 +283,51 @@ public class DnnCPInstruction extends UnaryCPInstruction {
aL.get(index).getValueType(),
aL.get(index).isLiteral()).getLongValue();
}
+ public void processLstmInstruction(ExecutionContext ec) {
+ MatrixBlock X = ec.getMatrixInput(input1.getName(),
getExtendedOpcode());
+ MatrixBlock W = ec.getMatrixInput(_in2.getName(),
getExtendedOpcode());
+ MatrixBlock b = ec.getMatrixInput(_in3.getName(),
getExtendedOpcode());
+ MatrixBlock out0 = ec.getMatrixInput(_in4.getName(),
getExtendedOpcode());
+ MatrixBlock c0 = ec.getMatrixInput(_in5.getName(),
getExtendedOpcode());
+ boolean return_seq = ec.getScalarInput(_in6.getName(),
_in6.getValueType(), _in6.isLiteral()).getBooleanValue();
+
+ int N = X.getNumRows();
+ int TD = X.getNumColumns();
+ int DPlusM = W.getNumRows();
+ int M = W.getNumColumns() / 4;
+ if(b.getNumRows() != 1 || b.getNumColumns() != M*4) {
+ throw new DMLRuntimeException("Incorrect dimensions of
bias in lstm instruction. Expected [1, " + (M*4) + "], "
+ + "but found [" + b.getNumRows() + ","
+ b.getNumColumns() + "]");
+ }
+ if(out0.getNumRows() != N || out0.getNumColumns() != M) {
+ throw new DMLRuntimeException("Incorrect dimensions of
out0 in lstm instruction. Expected [" + N + ", " + M + "], "
+ + "but found [" + out0.getNumRows() +
"," + out0.getNumColumns() + "]");
+ }
+ if(c0.getNumRows() != N || c0.getNumColumns() != M) {
+ throw new DMLRuntimeException("Incorrect dimensions of
c0 in lstm instruction. Expected [" + N + ", " + M + "], "
+ + "but found [" + out0.getNumRows() +
"," + out0.getNumColumns() + "]");
+ }
+ int D = DPlusM - M;
+ int T = TD / D;
+
+ MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M,
false);
+ MatrixBlock c = new MatrixBlock(N, M, false);
+
+ LibMatrixDNN.lstm(X, W, b, out0, c0,
+ return_seq, N, T, D, M,
+ out, c, null, null, null,
+ _numThreads);
+
+ // release inputs/outputs
+ ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+ ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
+ ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode());
+ }
+
public void processReluBackwardInstruction(ExecutionContext ec) {
// (X > 0) * dout
MatrixBlock input = ec.getMatrixInput(input1.getName(),
getExtendedOpcode());
@@ -436,6 +493,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
processBatchNorm2dBackwardInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("lstm")) {
+ processLstmInstruction(ec);
+ return;
+ }
// acquire inputs
MatrixBlock outputBlock = null;
diff --git
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 4569dbe..179b5d3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -31,10 +31,22 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
+import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.functionobjects.Multiply;
+import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
+import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.CommonThreadPool;
import org.apache.sysml.runtime.util.DnnUtils;
+import org.apache.sysml.runtime.util.IndexRange;
/*
* This class allows users to invoke deep learning related operations
@@ -262,6 +274,79 @@ public class LibMatrixDNN {
outputBlock.examSparsity();
}
+ private static MatrixBlock matmult(MatrixBlock matBlock1, MatrixBlock
matBlock2, int numThreads) {
+ AggregateBinaryOperator ab_op = new
AggregateBinaryOperator(Multiply.getMultiplyFnObject(),
+ new AggregateOperator(0,
Plus.getPlusFnObject()), numThreads);
+ MatrixBlock main = (matBlock2 instanceof CompressedMatrixBlock)
? matBlock2 : matBlock1;
+ MatrixBlock ret = main.aggregateBinaryOperations(matBlock1,
matBlock2, new MatrixBlock(), ab_op);
+ return ret;
+ }
+
+ private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock
matBlock2) {
+ return (MatrixBlock) matBlock1.binaryOperations(new
BinaryOperator(Plus.getPlusFnObject()), matBlock2, new MatrixBlock());
+ }
+ private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock
matBlock2) {
+ return (MatrixBlock) matBlock1.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), matBlock2, new MatrixBlock());
+ }
+
+ // sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
+
+ private static Builtin sigmoidOp =
Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID);
+ private static Builtin tanhOp =
Builtin.getBuiltinFnObject(BuiltinCode.TANH);
+ private static MatrixBlock sigmoid(MatrixBlock in, int numThreads,
boolean inPlace) {
+ return (MatrixBlock) in.unaryOperations(new
UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
+ }
+ private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean
inPlace) {
+ return (MatrixBlock) in.unaryOperations(new
UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
+ }
+
+ public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b,
MatrixBlock out0, MatrixBlock c0,
+ boolean return_seq, int N, int T, int D, int M,
+ MatrixBlock out, MatrixBlock c, // output
+ MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock
cache_ifog, // if null, the cache values are not computed
+ int numThreads) {
+ MatrixBlock out_prev = out0;
+ MatrixBlock c_prev = c0;
+
+ MatrixBlock W1 = W.slice(0, D-1);
+ MatrixBlock W2 = W.slice(D, D+M-1);
+ MatrixBlock c_t = null;
+ MatrixBlock out_t = null;
+ for(int t = 1; t <= T; t++) {
+ MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new
MatrixBlock());
+ MatrixBlock ifog_raw = add(add(matmult(X_t, W1,
numThreads), matmult(out_prev, W2, numThreads)), b);
+ MatrixBlock i = ifog_raw.slice(0, N-1, 0, M-1, new
MatrixBlock());
+ MatrixBlock f = ifog_raw.slice(0, N-1, M, 2*M-1, new
MatrixBlock());
+ MatrixBlock o = ifog_raw.slice(0, N-1, 2*M, 3*M-1, new
MatrixBlock());
+ MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new
MatrixBlock());
+ i = sigmoid(i, numThreads, true);
+ f = sigmoid(f, numThreads, true);
+ o = sigmoid(o, numThreads, true);
+ g = tanh(g, numThreads, true);
+ // c_t = f*c_prev + i*g
+ c_t = add(multiply(f, c_prev) , multiply(i, g));
+ // out_t = o*tanh(c)
+ out_t = multiply(o, tanh(c_t, numThreads, false));
+ if(return_seq) {
+ out = out.leftIndexingOperations(out_t, 0, N-1,
(t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
+ }
+ out_prev = out_t;
+ c_prev = c_t;
+
+ // TODO: Add this when implementing lstm_backward
+// cache_out[t,] = matrix(out_t, rows=1, cols=N*M) #
reshape
+// cache_c[t,] = matrix(c, rows=1, cols=N*M) # reshape
+// cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M)
# reshape
+ }
+ if(out_t != null && !return_seq)
+ out.copy(out_t);
+ if(c_t != null)
+ c.copy(c_t);
+ else
+ c.copy(c0);
+
+ }
+
/**
* This method computes the backpropagation errors for previous layer
of relu operation
*
@@ -574,6 +659,9 @@ public class LibMatrixDNN {
if(inputArr != null) {
System.arraycopy(inputArr, 0, output, 0,
inputArr.length);
}
+ else {
+ Arrays.fill(output, 0);
+ }
}
}
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
new file mode 100644
index 0000000..3aa37ad
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
+import
org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction.LstmOperator;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests lstm builtin function
+ */
+public class LstmCPUTest extends GPUTests {
+
+ private final static String TEST_NAME = "LstmTests";
+ private final int seed = 42;
+
+ private final static String builtinDML =
"\"nn/layers/lstm_staging.dml\"";
+ private final static String nnDML = "\"nn/layers/lstm.dml\"";
+
+ @Override
+ public void setUp() {
+ super.setUp();
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_DIR, TEST_NAME);
+ getAndLoadTestConfiguration(TEST_NAME);
+ }
+
+ @Test
+ public void testLstmForward9() {
+ testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9);
+ }
+
+ @Test
+ public void testLstmForward10() {
+ testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9);
+ }
+
+ @Test
+ public void testLstmForward11() {
+ testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9);
+ }
+
+ @Test
+ public void testLstmForward12() {
+ testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9);
+ }
+
+ public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String
returnSequences, double sparsity) {
+ String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+ + "[output, c] = lstm::forward(x, w, b, " +
returnSequences + ", out0, c0)";
+ String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+ + "[output, c, cache_out, cache_c, cache_ifog]
= lstm::forward(x, w, b, "
+ + T + ", " + D + ", " + returnSequences + ",
out0, c0)";
+
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10,
sparsity, seed));
+ inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10,
sparsity, seed));
+ inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10,
sparsity, seed));
+ inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10,
sparsity, seed));
+ inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10,
sparsity, seed));
+ List<String> outputs = Arrays.asList("output", "c");
+ List<Object> outGPUWithCuDNN = null;
+ List<Object> outCPUWithNN = null;
+ synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+ try {
+ DnnGPUInstruction.FORCED_LSTM_OP =
LstmOperator.CUDNN;
+ outGPUWithCuDNN = runOnCPU(spark, scriptStr1,
inputs, outputs);
+ outCPUWithNN = runOnCPU(spark, scriptStr2,
inputs, outputs);
+ }
+ finally {
+ DnnGPUInstruction.FORCED_LSTM_OP =
LstmOperator.NONE;
+ }
+ }
+ assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+ assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+ }
+
+
+
+// @Test
+// public void testLstmBackward7() {
+// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+// }
+//
+// @Test
+// public void testLstmBackward8() {
+// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+// }
+//
+// @Test
+// public void testLstmBackward9() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9,
0.9);
+// }
+//
+// @Test
+// public void testLstmBackward10() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9,
0.9);
+// }
+//
+//
+// public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int
M, String returnSequences, double sparsity,
+// double weightSparsity) {
+// boolean returnSequences1 = returnSequences.equals("TRUE");
+//
+// String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+// + "[dX, dW, db, dout0, dc0] =
lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+// String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+// + "[output, c, cache_out, cache_c, cache_ifog]
= lstm::forward(x, w, b, "
+// + T + ", " + D + ", " + returnSequences + ",
out0, c0); \n"
+// + "[dX, dW, db, dout0, dc0] =
lstm::backward(dout, dc, x, w, b, "
+// + T + ", " + D + ", " + returnSequences + ",
out0, c0, cache_out, cache_c, cache_ifog);";
+//
+// HashMap<String, Object> inputs = new HashMap<>();
+// inputs.put("dout", generateInputMatrix(spark, N,
returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+// inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10,
sparsity, seed));
+// inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10,
sparsity, seed));
+// inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10,
weightSparsity, seed));
+// inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10,
sparsity, seed));
+// inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10,
sparsity, seed));
+// inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10,
sparsity, seed));
+// List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0",
"dc0");
+// List<Object> outGPUWithCuDNN = null;
+// List<Object> outCPUWithNN = null;
+// synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
+// try {
+// DnnGPUInstruction.FORCED_LSTM_OP =
LstmOperator.CUDNN;
+// outGPUWithCuDNN = runOnCPU(spark, scriptStr1,
inputs, outputs);
+// }
+// finally {
+// DnnGPUInstruction.FORCED_LSTM_OP =
LstmOperator.NONE;
+// }
+// outCPUWithNN = runOnCPU(spark, scriptStr2, inputs,
outputs);
+// }
+// assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
+// assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+// assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
+// assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
+// assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+// }
+}