This is an automated email from the ASF dual-hosted git repository. niketanpansare pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 592f1b0 [SYSTEMML-540] Added an initial CP operator for lstm builtin function 592f1b0 is described below commit 592f1b0e9a5566195d0e73fede318e2a269bb4a0 Author: Niketan Pansare <npan...@us.ibm.com> AuthorDate: Sat Feb 23 09:05:11 2019 -0800 [SYSTEMML-540] Added an initial CP operator for lstm builtin function - This operator relies on existing slice, left indexing, binary and matrix multiplication operators. - We can later create a fused implementation especially for lstm activation to avoid unnecessary copies in the inner loop. - To exploit sparsity in the input data and avoid unnecessary sparse-to-dense conversion (as out_prev is often dense), we perform two matrix multiplication followed by binary addition. --- .../java/org/apache/sysml/hops/FunctionOp.java | 42 +++++- .../runtime/instructions/CPInstructionParser.java | 2 + .../runtime/instructions/cp/DnnCPInstruction.java | 61 ++++++++ .../sysml/runtime/matrix/data/LibMatrixDNN.java | 88 +++++++++++ .../org/apache/sysml/test/gpu/LstmCPUTest.java | 162 +++++++++++++++++++++ 5 files changed, 347 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index 5f177bd..66ce478 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -274,19 +274,45 @@ public class FunctionOp extends Hop { checkAndSetForcedPlatform(); - if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) { - boolean isBuiltinFunction = isBuiltinFunction(); + if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && isBuiltinFunction() && + (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) { + + if(getFunctionName().equalsIgnoreCase("lstm_backward")) { + if(!ConfigurationManager.isGPU()) + throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); + _etype = ExecType.GPU; + } + + ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; + + if( _etypeForced != null ) { + _etype = _etypeForced; + } + else { + if ( OptimizerUtils.isMemoryBasedOptLevel() ) { + _etype = findExecTypeByMemEstimate(); + } + else { + _etype = REMOTE; + } + + //check for valid CP dimensions and matrix size + checkAndSetInvalidCPDimsAndSize(); + } + + // Since lstm builtin functions are not supported on Spark + _etype = _etype == REMOTE ? ExecType.CP : _etype; + + //mark for recompile (forever) + setRequiresRecompileIfNecessary(); + } + else if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) { // check if there is sufficient memory to execute this function - if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("transformencode") ) { + if(isBuiltinFunction() && getFunctionName().equalsIgnoreCase("transformencode") ) { _etype = ((_etypeForced==ExecType.SPARK || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); } - else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) { - if(!ConfigurationManager.isGPU()) - throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); - _etype = ExecType.GPU; - } else { // Since the memory estimate is only conservative, do not throw // exception if the estimated memory is larger than the budget diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index ea01dc2..6e253d8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -247,6 +247,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "conv2d_bias_add" , CPType.Dnn); String2CPInstructionType.put( "conv2d_backward_filter" , CPType.Dnn); String2CPInstructionType.put( "conv2d_backward_data" , CPType.Dnn); + String2CPInstructionType.put( "lstm", CPType.Dnn); + String2CPInstructionType.put( "lstm_backward", CPType.Dnn); String2CPInstructionType.put( "bias_add" , CPType.Dnn); String2CPInstructionType.put( "bias_multiply" , CPType.Dnn); String2CPInstructionType.put( "batch_norm2d", CPType.Dnn); diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java index 7bed33f..4043908 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java @@ -261,6 +261,18 @@ public class DnnCPInstruction extends UnaryCPInstruction { CPOperand out3 = new CPOperand(parts[9]); // dBias return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); } + else if (opcode.equalsIgnoreCase("lstm")) { + InstructionUtils.checkNumFields(parts, 8); + CPOperand in1 = new CPOperand(parts[1]); // X + CPOperand in2 = new CPOperand(parts[2]); // W + CPOperand in3 = new CPOperand(parts[3]); // b + CPOperand in4 = new CPOperand(parts[4]); // out0 + CPOperand in5 = new CPOperand(parts[5]); // c0 + CPOperand in6 = new CPOperand(parts[6]); // return_seq + CPOperand out = new CPOperand(parts[7]); // out + CPOperand out2 = new CPOperand(parts[8]); // c + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0); + } else { throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str); } @@ -271,6 +283,51 @@ public class DnnCPInstruction extends UnaryCPInstruction { aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); } + public void processLstmInstruction(ExecutionContext ec) { + MatrixBlock X = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock W = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock b = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + MatrixBlock out0 = ec.getMatrixInput(_in4.getName(), getExtendedOpcode()); + MatrixBlock c0 = ec.getMatrixInput(_in5.getName(), getExtendedOpcode()); + boolean return_seq = ec.getScalarInput(_in6.getName(), _in6.getValueType(), _in6.isLiteral()).getBooleanValue(); + + int N = X.getNumRows(); + int TD = X.getNumColumns(); + int DPlusM = W.getNumRows(); + int M = W.getNumColumns() / 4; + if(b.getNumRows() != 1 || b.getNumColumns() != M*4) { + throw new DMLRuntimeException("Incorrect dimensions of bias in lstm instruction. Expected [1, " + (M*4) + "], " + + "but found [" + b.getNumRows() + "," + b.getNumColumns() + "]"); + } + if(out0.getNumRows() != N || out0.getNumColumns() != M) { + throw new DMLRuntimeException("Incorrect dimensions of out0 in lstm instruction. Expected [" + N + ", " + M + "], " + + "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]"); + } + if(c0.getNumRows() != N || c0.getNumColumns() != M) { + throw new DMLRuntimeException("Incorrect dimensions of c0 in lstm instruction. Expected [" + N + ", " + M + "], " + + "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]"); + } + int D = DPlusM - M; + int T = TD / D; + + MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M, false); + MatrixBlock c = new MatrixBlock(N, M, false); + + LibMatrixDNN.lstm(X, W, b, out0, c0, + return_seq, N, T, D, M, + out, c, null, null, null, + _numThreads); + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); + ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode()); + } + public void processReluBackwardInstruction(ExecutionContext ec) { // (X > 0) * dout MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); @@ -436,6 +493,10 @@ public class DnnCPInstruction extends UnaryCPInstruction { processBatchNorm2dBackwardInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("lstm")) { + processLstmInstruction(ec); + return; + } // acquire inputs MatrixBlock outputBlock = null; diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 4569dbe..179b5d3 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -31,10 +31,22 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.compress.CompressedMatrixBlock; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.KahanPlus; +import org.apache.sysml.runtime.functionobjects.Multiply; +import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.instructions.cp.KahanObject; +import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysml.runtime.matrix.operators.AggregateOperator; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.matrix.operators.UnaryOperator; import org.apache.sysml.runtime.util.CommonThreadPool; import org.apache.sysml.runtime.util.DnnUtils; +import org.apache.sysml.runtime.util.IndexRange; /* * This class allows users to invoke deep learning related operations @@ -262,6 +274,79 @@ public class LibMatrixDNN { outputBlock.examSparsity(); } + private static MatrixBlock matmult(MatrixBlock matBlock1, MatrixBlock matBlock2, int numThreads) { + AggregateBinaryOperator ab_op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), + new AggregateOperator(0, Plus.getPlusFnObject()), numThreads); + MatrixBlock main = (matBlock2 instanceof CompressedMatrixBlock) ? matBlock2 : matBlock1; + MatrixBlock ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op); + return ret; + } + + private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2) { + return (MatrixBlock) matBlock1.binaryOperations(new BinaryOperator(Plus.getPlusFnObject()), matBlock2, new MatrixBlock()); + } + private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2) { + return (MatrixBlock) matBlock1.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), matBlock2, new MatrixBlock()); + } + + // sigmoid(0)*c_prev + sigmoid(0)*tanh(0); + + private static Builtin sigmoidOp = Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID); + private static Builtin tanhOp = Builtin.getBuiltinFnObject(BuiltinCode.TANH); + private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, boolean inPlace) { + return (MatrixBlock) in.unaryOperations(new UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock()); + } + private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) { + return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock()); + } + + public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, + boolean return_seq, int N, int T, int D, int M, + MatrixBlock out, MatrixBlock c, // output + MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog, // if null, the cache values are not computed + int numThreads) { + MatrixBlock out_prev = out0; + MatrixBlock c_prev = c0; + + MatrixBlock W1 = W.slice(0, D-1); + MatrixBlock W2 = W.slice(D, D+M-1); + MatrixBlock c_t = null; + MatrixBlock out_t = null; + for(int t = 1; t <= T; t++) { + MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock()); + MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads)), b); + MatrixBlock i = ifog_raw.slice(0, N-1, 0, M-1, new MatrixBlock()); + MatrixBlock f = ifog_raw.slice(0, N-1, M, 2*M-1, new MatrixBlock()); + MatrixBlock o = ifog_raw.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock()); + MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()); + i = sigmoid(i, numThreads, true); + f = sigmoid(f, numThreads, true); + o = sigmoid(o, numThreads, true); + g = tanh(g, numThreads, true); + // c_t = f*c_prev + i*g + c_t = add(multiply(f, c_prev) , multiply(i, g)); + // out_t = o*tanh(c) + out_t = multiply(o, tanh(c_t, numThreads, false)); + if(return_seq) { + out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE); + } + out_prev = out_t; + c_prev = c_t; + + // TODO: Add this when implementing lstm_backward +// cache_out[t,] = matrix(out_t, rows=1, cols=N*M) # reshape +// cache_c[t,] = matrix(c, rows=1, cols=N*M) # reshape +// cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M) # reshape + } + if(out_t != null && !return_seq) + out.copy(out_t); + if(c_t != null) + c.copy(c_t); + else + c.copy(c0); + + } + /** * This method computes the backpropagation errors for previous layer of relu operation * @@ -574,6 +659,9 @@ public class LibMatrixDNN { if(inputArr != null) { System.arraycopy(inputArr, 0, output, 0, inputArr.length); } + else { + Arrays.fill(output, 0); + } } } diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java new file mode 100644 index 0000000..3aa37ad --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.gpu; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction; +import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction.LstmOperator; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; + +/** + * Tests lstm builtin function + */ +public class LstmCPUTest extends GPUTests { + + private final static String TEST_NAME = "LstmTests"; + private final int seed = 42; + + private final static String builtinDML = "\"nn/layers/lstm_staging.dml\""; + private final static String nnDML = "\"nn/layers/lstm.dml\""; + + @Override + public void setUp() { + super.setUp(); + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testLstmForward9() { + testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9); + } + + @Test + public void testLstmForward10() { + testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9); + } + + @Test + public void testLstmForward11() { + testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9); + } + + @Test + public void testLstmForward12() { + testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9); + } + + public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity) { + String scriptStr1 = "source(" + builtinDML + ") as lstm;\n " + + "[output, c] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)"; + String scriptStr2 = "source(" + nnDML + ") as lstm;\n " + + "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " + + T + ", " + D + ", " + returnSequences + ", out0, c0)"; + + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed)); + inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, sparsity, seed)); + inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed)); + inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed)); + inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed)); + List<String> outputs = Arrays.asList("output", "c"); + List<Object> outGPUWithCuDNN = null; + List<Object> outCPUWithNN = null; + synchronized (DnnGPUInstruction.FORCED_LSTM_OP) { + try { + DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN; + outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs); + outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs); + } + finally { + DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE; + } + } + assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0)); + assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1)); + } + + + +// @Test +// public void testLstmBackward7() { +// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9); +// } +// +// @Test +// public void testLstmBackward8() { +// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9); +// } +// +// @Test +// public void testLstmBackward9() { +// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 0.9); +// } +// +// @Test +// public void testLstmBackward10() { +// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 0.9); +// } +// +// +// public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity, +// double weightSparsity) { +// boolean returnSequences1 = returnSequences.equals("TRUE"); +// +// String scriptStr1 = "source(" + builtinDML + ") as lstm;\n " +// + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);"; +// String scriptStr2 = "source(" + nnDML + ") as lstm;\n " +// + "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " +// + T + ", " + D + ", " + returnSequences + ", out0, c0); \n" +// + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " +// + T + ", " + D + ", " + returnSequences + ", out0, c0, cache_out, cache_c, cache_ifog);"; +// +// HashMap<String, Object> inputs = new HashMap<>(); +// inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed)); +// inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed)); +// inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed)); +// inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, weightSparsity, seed)); +// inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed)); +// inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed)); +// inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed)); +// List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", "dc0"); +// List<Object> outGPUWithCuDNN = null; +// List<Object> outCPUWithNN = null; +// synchronized (DnnGPUInstruction.FORCED_LSTM_OP) { +// try { +// DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN; +// outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs); +// } +// finally { +// DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE; +// } +// outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs); +// } +// assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0)); +// assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1)); +// assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2)); +// assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3)); +// assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4)); +// } +}