This is an automated email from the ASF dual-hosted git repository. niketanpansare pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
commit 792da5d0aa2abd6e650a3a17f243795d0f9a4b35 Author: Niketan Pansare <npan...@us.ibm.com> AuthorDate: Mon Mar 4 13:32:35 2019 -0800 [SYSTEMML-540] Improved the performance of lstm builtin function for sparse inputs This commits allows matrix multiplication operator to exploit sparsity by separating lstm into three cases: 1. If W is sparse, perform cbind(X_t, out_prev) %*% W 2. If X_t is sparse, perform X_t %*% W1 + out_prev %*% W2 3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W to maximize parallelism within matrix multiplication operator --- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 114 ++++++++++----------- 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index e2742d8..365d7a2 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -284,26 +284,34 @@ public class LibMatrixDNN { private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) { BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject()); -// if(inplace) { -// matBlock1.binaryOperationsInPlace(bop, matBlock2); -// return matBlock1; -// } -// else { + if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() && + matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) { + matBlock1.binaryOperationsInPlace(bop, matBlock2); + return matBlock1; + } + else { return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock()); -// } + } + } + private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) { + return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), + matBlock2, matBlock3, new MatrixBlock()); } + private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) { BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject()); -// if(inplace) { -// matBlock1.binaryOperationsInPlace(bop, matBlock2); -// return matBlock1; -// } -// else { + if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() && + matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) { + matBlock1.binaryOperationsInPlace(bop, matBlock2); + return matBlock1; + } + else { return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock()); -// } + } } + // sigmoid(0)*c_prev + sigmoid(0)*tanh(0); private static Builtin sigmoidOp = Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID); @@ -311,16 +319,10 @@ public class LibMatrixDNN { private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, boolean inPlace) { return (MatrixBlock) in.unaryOperations(new UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock()); } - private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) { return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock()); } - private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) { - return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), - matBlock2, matBlock3, new MatrixBlock()); - } - public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, boolean return_seq, int N, int T, int D, int M, MatrixBlock out, MatrixBlock c, // output @@ -329,61 +331,56 @@ public class LibMatrixDNN { MatrixBlock out_prev = out0; MatrixBlock c_prev = c0; - MatrixBlock W1 = W.slice(0, D-1); - MatrixBlock W2 = W.slice(D, D+M-1); + MatrixBlock W1 = null; + MatrixBlock W2 = null; MatrixBlock c_t = null; MatrixBlock out_t = null; - boolean profile = true; - long t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0; + MatrixBlock input = null; for(int t = 1; t <= T; t++) { - long s = profile ? System.nanoTime() : 0; - MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock()); - if(profile) { - long e = System.nanoTime(); - t1 += e - s; - } - - s = profile ? System.nanoTime() : 0; - MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true), b, true); - if(profile) { - long e = System.nanoTime(); - t2 += e - s; + final MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock()); + MatrixBlock ifog_raw = null; + // Logic: Exploit sparse matrix multiplication whenever possible: + // 1. If W is sparse, perform cbind(X_t, out_prev) %*% W + // 2. Else if X_t is sparse, perform X_t %*% W1 + out_prev %*% W2 + // 3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W + boolean isCase1 = W.isInSparseFormat(); + boolean isCase2 = !isCase1 && X_t.isInSparseFormat(); + if(isCase2) { + // Perform X_t %*% W1 + out_prev %*% W2 + if(W1 == null) { + // Lazy slicing: applicable only when atleast one X_t is sparse. + W1 = W.slice(0, D-1); + W2 = W.slice(D, D+M-1); + } + ifog_raw = add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true); + ifog_raw = add(ifog_raw, b, true); + } + else { + // Case 1 and 3: + // Perform input %*% W, where input = cbind(X_t, out_prev) + if(input == null) { + input = new MatrixBlock(N, D+M, false); + input.allocateDenseBlock(); + } + input = X_t.append(out_prev, input); + ifog_raw = add(matmult(input, W, numThreads), b, true); } - s = profile ? System.nanoTime() : 0; MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock()); ifo = sigmoid(ifo, numThreads, true); MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock()); MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock()); MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock()); - - MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()); - g = tanh(g, numThreads, true); - if(profile) { - long e = System.nanoTime(); - t3 += e - s; - } - - s = profile ? System.nanoTime() : 0; + MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true); + // c_t = f*c_prev + i*g c_t = plusMultiply(multiply(f, c_prev, true), i, g); // out_t = o*tanh(c) out_t = multiply(o, tanh(c_t, numThreads, false), true); - if(profile) { - long e = System.nanoTime(); - t4 += e - s; - } - - s = profile ? System.nanoTime() : 0; if(return_seq) { out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE); } - if(profile) { - long e = System.nanoTime(); - t5 += e - s; - } - out_prev = out_t; c_prev = c_t; @@ -398,11 +395,6 @@ public class LibMatrixDNN { c.copy(c_t); else c.copy(c0); - System.out.println("Time taken in lstm forward call: [X_t indexing:" + String.format("%.3f", t1*1e-9) + - ", ifog_raw computation:" + String.format("%.3f", t2*1e-9) + - ", lstm_squash computation:" + String.format("%.3f", t3*1e-9) + - ", c_t/out_t computation:" + String.format("%.3f", t4*1e-9) + - ", out leftIndexing computation:" + String.format("%.3f", t5*1e-9)); } /** @@ -1009,4 +1001,4 @@ public class LibMatrixDNN { params.end_indexes_w[q] = Math.min(ix+params.S, params.W); } } -} +} \ No newline at end of file