This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git

commit 792da5d0aa2abd6e650a3a17f243795d0f9a4b35
Author: Niketan Pansare <npan...@us.ibm.com>
AuthorDate: Mon Mar 4 13:32:35 2019 -0800

    [SYSTEMML-540] Improved the performance of lstm builtin function for sparse 
inputs
    
    This commits allows matrix multiplication operator to exploit sparsity by 
separating lstm into three cases:
    1. If W is sparse, perform cbind(X_t, out_prev) %*% W
    2. If X_t is sparse, perform X_t %*% W1 + out_prev %*% W2
    3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W to 
maximize parallelism within matrix multiplication operator
---
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    | 114 ++++++++++-----------
 1 file changed, 53 insertions(+), 61 deletions(-)

diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index e2742d8..365d7a2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -284,26 +284,34 @@ public class LibMatrixDNN {
        
        private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock 
matBlock2, boolean inplace) {
                BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject());
-//             if(inplace) {
-//                     matBlock1.binaryOperationsInPlace(bop, matBlock2);
-//                     return matBlock1;
-//             }
-//             else {
+               if(inplace && matBlock1.isInSparseFormat() == 
matBlock2.isInSparseFormat() &&
+                       matBlock1.getNumRows() == matBlock2.getNumRows() && 
matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+                       matBlock1.binaryOperationsInPlace(bop, matBlock2);
+                       return matBlock1;
+               }
+               else {
                        return (MatrixBlock) matBlock1.binaryOperations(bop, 
matBlock2, new MatrixBlock());
-//             }
+               }
+       }
+       private static MatrixBlock plusMultiply(MatrixBlock matBlock1, 
MatrixBlock matBlock2, MatrixBlock matBlock3) {
+               return matBlock1.ternaryOperations(new 
TernaryOperator(PlusMultiply.getFnObject()), 
+                               matBlock2, matBlock3, new MatrixBlock());
        }
        
+               
        private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock 
matBlock2, boolean inplace) {
                BinaryOperator bop = new 
BinaryOperator(Multiply.getMultiplyFnObject());
-//             if(inplace) {
-//                     matBlock1.binaryOperationsInPlace(bop, matBlock2);
-//                     return matBlock1;
-//             }
-//             else {
+               if(inplace && matBlock1.isInSparseFormat() == 
matBlock2.isInSparseFormat() &&
+                       matBlock1.getNumRows() == matBlock2.getNumRows() && 
matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+                       matBlock1.binaryOperationsInPlace(bop, matBlock2);
+                       return matBlock1;
+               }
+               else {
                        return (MatrixBlock) matBlock1.binaryOperations(bop, 
matBlock2, new MatrixBlock());
-//             }
+               }
        }
        
+       
        // sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
        
        private static Builtin sigmoidOp = 
Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID);
@@ -311,16 +319,10 @@ public class LibMatrixDNN {
        private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, 
boolean inPlace) {
                return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
        }
-       
        private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean 
inPlace) {
                return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
        }
        
-       private static MatrixBlock plusMultiply(MatrixBlock matBlock1, 
MatrixBlock matBlock2, MatrixBlock matBlock3) {
-               return matBlock1.ternaryOperations(new 
TernaryOperator(PlusMultiply.getFnObject()), 
-                               matBlock2, matBlock3, new MatrixBlock());
-       }
-       
        public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, 
MatrixBlock out0, MatrixBlock c0, 
                        boolean return_seq, int N, int T, int D, int M,
                        MatrixBlock out, MatrixBlock c, // output 
@@ -329,61 +331,56 @@ public class LibMatrixDNN {
                MatrixBlock out_prev = out0;
                MatrixBlock c_prev = c0;
                
-               MatrixBlock W1 = W.slice(0, D-1);
-               MatrixBlock W2 = W.slice(D, D+M-1);
+               MatrixBlock W1 = null;
+               MatrixBlock W2 = null;
                MatrixBlock c_t = null;
                MatrixBlock out_t = null;
                
-               boolean profile = true;
-               long t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0;
+               MatrixBlock input = null;
                for(int t = 1; t <= T; t++) {
-                       long s =  profile ? System.nanoTime() : 0;
-                       MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new 
MatrixBlock());
-                       if(profile) {
-                               long e = System.nanoTime();
-                               t1 += e - s;
-                       }
-                       
-                       s =  profile ? System.nanoTime() : 0;
-                       MatrixBlock ifog_raw = add(add(matmult(X_t, W1, 
numThreads), matmult(out_prev, W2, numThreads), true), b, true);
-                       if(profile) {
-                               long e = System.nanoTime();
-                               t2 += e - s;
+                       final MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, 
(t-1)*D, t*D-1, new MatrixBlock());
+                       MatrixBlock ifog_raw = null;
+                       // Logic: Exploit sparse matrix multiplication whenever 
possible:
+                       // 1. If W is sparse, perform cbind(X_t, out_prev) %*% W
+                       // 2. Else if X_t is sparse, perform X_t %*% W1 + 
out_prev %*% W2
+                       // 3. If none of the case is applicable, perform 
cbind(X_t, out_prev) %*% W
+                       boolean isCase1 = W.isInSparseFormat();
+                       boolean isCase2 = !isCase1 && X_t.isInSparseFormat();
+                       if(isCase2) {
+                               // Perform X_t %*% W1 + out_prev %*% W2
+                               if(W1 == null) {
+                                       // Lazy slicing: applicable only when 
atleast one X_t is sparse.
+                                       W1 = W.slice(0, D-1);
+                                       W2 = W.slice(D, D+M-1);
+                               }
+                               ifog_raw = add(matmult(X_t, W1, numThreads), 
matmult(out_prev, W2, numThreads), true);
+                               ifog_raw = add(ifog_raw, b, true);
+                       } 
+                       else {
+                               // Case 1 and 3:
+                               // Perform input %*% W, where input = 
cbind(X_t, out_prev)
+                               if(input == null) {
+                                       input = new MatrixBlock(N, D+M, false);
+                                       input.allocateDenseBlock();
+                               }
+                               input = X_t.append(out_prev, input);
+                               ifog_raw = add(matmult(input, W, numThreads), 
b, true);
                        }
                        
-                       s =  profile ? System.nanoTime() : 0;
                        MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new 
MatrixBlock());
                        ifo = sigmoid(ifo, numThreads, true);
                        MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new 
MatrixBlock());
                        MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
                        MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
-                       
-                       MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new 
MatrixBlock());
-                       g = tanh(g, numThreads, true);
-                       if(profile) {
-                               long e = System.nanoTime();
-                               t3 += e - s;
-                       }
-                       
-                       s =  profile ? System.nanoTime() : 0;
+                       MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, 
new MatrixBlock()), numThreads, true);
+                                       
                        // c_t = f*c_prev + i*g
                        c_t = plusMultiply(multiply(f, c_prev, true), i, g);
                        // out_t = o*tanh(c)
                        out_t = multiply(o, tanh(c_t, numThreads, false), true);
-                       if(profile) {
-                               long e = System.nanoTime();
-                               t4 += e - s;
-                       }
-                       
-                       s =  profile ? System.nanoTime() : 0;
                        if(return_seq) {
                                out = out.leftIndexingOperations(out_t, 0, N-1, 
(t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
                        }
-                       if(profile) {
-                               long e = System.nanoTime();
-                               t5 += e - s;
-                       }
-                       
                        out_prev = out_t;
                        c_prev = c_t;
                        
@@ -398,11 +395,6 @@ public class LibMatrixDNN {
                        c.copy(c_t);
                else
                        c.copy(c0);
-               System.out.println("Time taken in lstm forward call: [X_t 
indexing:" + String.format("%.3f", t1*1e-9) + 
-                               ", ifog_raw computation:" + 
String.format("%.3f", t2*1e-9) + 
-                               ", lstm_squash computation:" + 
String.format("%.3f", t3*1e-9) +  
-                               ", c_t/out_t computation:" + 
String.format("%.3f", t4*1e-9) + 
-                               ", out leftIndexing computation:" + 
String.format("%.3f", t5*1e-9));
        }
        
        /**
@@ -1009,4 +1001,4 @@ public class LibMatrixDNN {
                        params.end_indexes_w[q] = Math.min(ix+params.S, 
params.W);
                }
        }
-}
+}
\ No newline at end of file

Reply via email to