This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 0cabde0  [SYSTEMML-540] Improve the performance of lstm builtin 
function
0cabde0 is described below

commit 0cabde0ca26c99a55c62f7e7ffac67b450dea850
Author: Niketan Pansare <npan...@us.ibm.com>
AuthorDate: Wed Feb 27 21:03:15 2019 -0800

    [SYSTEMML-540] Improve the performance of lstm builtin function
    
    - Allow FunctionOp to be multi-threaded.
    - Currently, only lstm builtin function will have number of threads > 1.
    - Added more tests.
---
 .../java/org/apache/sysml/hops/FunctionOp.java     | 10 +++-
 .../java/org/apache/sysml/lops/FunctionCallCP.java | 14 +++++-
 .../runtime/instructions/cp/DnnCPInstruction.java  | 13 ++---
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    | 55 ++++++++++++++++------
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     | 50 ++++++++++++++++++++
 5 files changed, 118 insertions(+), 24 deletions(-)

diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 66ce478..5fdc8e7 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -39,7 +39,7 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimatorHops;
  * Note: Currently, we support expressions in function arguments along with 
function calls
  * in expressions with single outputs, leaving multiple outputs handling as it 
is.
  */
-public class FunctionOp extends Hop
+public class FunctionOp extends MultiThreadedHop
 {
        public enum FunctionType{
                DML,
@@ -253,8 +253,14 @@ public class FunctionOp extends Hop
                        tmp.add( in.constructLops() );
                
                //construct function call
+               int numThreads = 0;
+               if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && 
isBuiltinFunction() && et == ExecType.CP &&
+                               (getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
+                       numThreads = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+               }
+               
                Lop fcall = _singleOutFun ? new FunctionCallCPSingle( tmp, 
_fnamespace, _fname, et ) :
-                       new FunctionCallCP(tmp, _fnamespace, _fname, 
_inputNames, _outputNames, _outputHops, et);
+                       new FunctionCallCP(tmp, _fnamespace, _fname, 
_inputNames, _outputNames, _outputHops, et, numThreads);
                setLineNumbers(fcall);
                setLops(fcall);
                
diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java 
b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
index 50d43de..237b806 100644
--- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
+++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java
@@ -38,10 +38,12 @@ public class FunctionCallCP extends Lop
        private String[] _inputNames;
        private String[] _outputNames;
        private ArrayList<Lop> _outputLops = null;
+       private int _numThreads;
 
        public FunctionCallCP(ArrayList<Lop> inputs, String fnamespace, String 
fname, 
-               String[] inputNames, String[] outputNames, ArrayList<Hop> 
outputHops, ExecType et) {
+               String[] inputNames, String[] outputNames, ArrayList<Hop> 
outputHops, ExecType et, int numThreads) {
                this(inputs, fnamespace, fname, inputNames, outputNames, et);
+               _numThreads = numThreads;
                if(outputHops != null) {
                        _outputLops = new ArrayList<>();
                        setLevel();
@@ -104,6 +106,11 @@ public class FunctionCallCP extends Lop
                        sb.append(_outputNames[i]);
                }
                
+               if(_numThreads > 0) {
+                       sb.append(Lop.OPERAND_DELIMITOR);
+                       sb.append(_numThreads);
+               }
+               
                return sb.toString();
        }
        
@@ -145,6 +152,11 @@ public class FunctionCallCP extends Lop
                        inst.append(Lop.OPERAND_DELIMITOR);
                        inst.append(out);
                }
+               
+               if(_numThreads > 0) {
+                       inst.append(Lop.OPERAND_DELIMITOR);
+                       inst.append(_numThreads);
+               }
 
                return inst.toString();
        }
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 4043908..93ffd4f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -103,7 +103,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
        public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5,
                        CPOperand in6, CPOperand in7, CPOperand in8,
                        CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
-                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+                       double intermediateMemoryBudget, int numThreads) throws 
DMLRuntimeException {
                super(CPType.Dnn, null, in1, out, opcode, istr);
                _in2 = in2;
                _in3 = in3;
@@ -120,7 +120,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                _padding = null;
                _input_shape = null;
                _filter_shape = null;
-               _numThreads = 0;
+               _numThreads = numThreads;
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
 
@@ -246,7 +246,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand out3 = new CPOperand(parts[11]); // 
retRunningVar
                        CPOperand out4 = new CPOperand(parts[12]); // 
resultSaveMean
                        CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
-                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, 0);
                }
                else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
                        InstructionUtils.checkNumFields(parts, 9);
@@ -259,10 +259,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand out = new CPOperand(parts[7]);  // dX
                        CPOperand out2 = new CPOperand(parts[8]); // dScale
                        CPOperand out3 = new CPOperand(parts[9]); // dBias
-                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0, 0);
                }
                else if (opcode.equalsIgnoreCase("lstm")) {
-                       InstructionUtils.checkNumFields(parts, 8);
+                       InstructionUtils.checkNumFields(parts, 9);
                        CPOperand in1 = new CPOperand(parts[1]); // X
                        CPOperand in2 = new CPOperand(parts[2]); // W
                        CPOperand in3 = new CPOperand(parts[3]); // b
@@ -271,7 +271,8 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        CPOperand in6 = new CPOperand(parts[6]); // return_seq
                        CPOperand out = new CPOperand(parts[7]);  // out
                        CPOperand out2 = new CPOperand(parts[8]); // c
-                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, null, null, null, opcode, str, 0);
+                       int numThreads = Integer.parseInt(parts[9]);
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
                }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnCPInstruction: " + str);
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 179b5d3..0f932ba 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -37,16 +37,16 @@ import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
-import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.DnnUtils;
-import org.apache.sysml.runtime.util.IndexRange;
 
 /*
  * This class allows users to invoke deep learning related operations 
@@ -282,11 +282,26 @@ public class LibMatrixDNN {
                return ret;
        }
        
-       private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock 
matBlock2) {
-               return (MatrixBlock) matBlock1.binaryOperations(new 
BinaryOperator(Plus.getPlusFnObject()), matBlock2, new MatrixBlock());
+       private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock 
matBlock2, boolean inplace) {
+               BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject());
+//             if(inplace) {
+//                     matBlock1.binaryOperationsInPlace(bop, matBlock2);
+//                     return matBlock1;
+//             }
+//             else {
+                       return (MatrixBlock) matBlock1.binaryOperations(bop, 
matBlock2, new MatrixBlock());
+//             }
        }
-       private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock 
matBlock2) {
-               return (MatrixBlock) matBlock1.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), matBlock2, new MatrixBlock());
+       
+       private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock 
matBlock2, boolean inplace) {
+               BinaryOperator bop = new 
BinaryOperator(Multiply.getMultiplyFnObject());
+//             if(inplace) {
+//                     matBlock1.binaryOperationsInPlace(bop, matBlock2);
+//                     return matBlock1;
+//             }
+//             else {
+                       return (MatrixBlock) matBlock1.binaryOperations(bop, 
matBlock2, new MatrixBlock());
+//             }
        }
        
        // sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
@@ -296,10 +311,16 @@ public class LibMatrixDNN {
        private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, 
boolean inPlace) {
                return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
        }
+       
        private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean 
inPlace) {
                return (MatrixBlock) in.unaryOperations(new 
UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
        }
        
+       private static MatrixBlock plusMultiply(MatrixBlock matBlock1, 
MatrixBlock matBlock2, MatrixBlock matBlock3) {
+               return matBlock1.ternaryOperations(new 
TernaryOperator(PlusMultiply.getFnObject()), 
+                               matBlock2, matBlock3, new MatrixBlock());
+       }
+       
        public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, 
MatrixBlock out0, MatrixBlock c0, 
                        boolean return_seq, int N, int T, int D, int M,
                        MatrixBlock out, MatrixBlock c, // output 
@@ -314,19 +335,23 @@ public class LibMatrixDNN {
                MatrixBlock out_t = null;
                for(int t = 1; t <= T; t++) {
                        MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new 
MatrixBlock());
-                       MatrixBlock ifog_raw = add(add(matmult(X_t, W1, 
numThreads), matmult(out_prev, W2, numThreads)), b);
-                       MatrixBlock i = ifog_raw.slice(0, N-1, 0, M-1, new 
MatrixBlock());
-                       MatrixBlock f = ifog_raw.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
-                       MatrixBlock o = ifog_raw.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
+                       MatrixBlock ifog_raw = add(add(matmult(X_t, W1, 
numThreads), matmult(out_prev, W2, numThreads), true), b, true);
+                       
+                       MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new 
MatrixBlock());
+                       ifo = sigmoid(ifo, numThreads, true);
+                       MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new 
MatrixBlock());
+                       MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new 
MatrixBlock());
+                       MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new 
MatrixBlock());
+                       
                        MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new 
MatrixBlock());
-                       i = sigmoid(i, numThreads, true);
-                       f = sigmoid(f, numThreads, true);
-                       o = sigmoid(o, numThreads, true);
                        g = tanh(g, numThreads, true);
+                       
                        // c_t = f*c_prev + i*g
-                       c_t = add(multiply(f, c_prev) , multiply(i, g));
+                       c_t = plusMultiply(multiply(f, c_prev, true), i, g);
+                       
                        // out_t = o*tanh(c)
-                       out_t = multiply(o, tanh(c_t, numThreads, false));
+                       out_t = multiply(o, tanh(c_t, numThreads, false), true);
+                       
                        if(return_seq) {
                                out = out.leftIndexingOperations(out_t, 0, N-1, 
(t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
                        }
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java 
b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 3aa37ad..785c890 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -48,6 +48,46 @@ public class LstmCPUTest extends GPUTests {
        }
        
        @Test
+       public void testLstmForward1() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.2);
+       }
+       
+       @Test
+       public void testLstmForward2() {
+               testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.1);
+       }
+       
+       @Test
+       public void testLstmForward3() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.15);
+       }
+       
+       @Test
+       public void testLstmForward4() {
+               testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.1);
+       }
+       
+       @Test
+       public void testLstmForward5() {
+               testLstmCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.5);
+       }
+       
+       @Test
+       public void testLstmForward6() {
+               testLstmCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.3);
+       }
+       
+       @Test
+       public void testLstmForward7() {
+               testLstmCuDNNWithNNLayer(20, 13, 4, 1, "TRUE", 0.8);
+       }
+       
+       @Test
+       public void testLstmForward8() {
+               testLstmCuDNNWithNNLayer(20, 13, 4, 1, "FALSE", 0.9);
+       }
+       
+       @Test
        public void testLstmForward9() {
                testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9);
        }
@@ -67,6 +107,16 @@ public class LstmCPUTest extends GPUTests {
                testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9);
        }
        
+       @Test
+       public void testLstmForward13() {
+               testLstmCuDNNWithNNLayer(20, 1, 4, 10, "TRUE", 0.8);
+       }
+       
+       @Test
+       public void testLstmForward14() {
+               testLstmCuDNNWithNNLayer(20, 1, 4, 10, "FALSE", 0.9);
+       }
+       
        public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String 
returnSequences, double sparsity) {
                String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
                                + "[output, c] = lstm::forward(x, w, b, " + 
returnSequences + ", out0, c0)";

Reply via email to