Repository: systemml
Updated Branches:
  refs/heads/master f14255f46 -> a671883fc


[SYSTEMML-2129] Performance sparse relu_backward operations

This patch improves the performance of sparse relu_backward operations
(including the common case of dense-sparse) by (1) allocating the output
in sparse representation if at least one input is sparse, (2) dedicated
implementations for all dense-sparse combinations, and (3) thread-local
nnz maintenance to avoid unnecessary passes over the output.

On 1000 iterations of lenet over mnist, this patch improved the
relu_backward runtime (for 2000 exec calls), from 24.9s to 5.6s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a671883f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a671883f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a671883f

Branch: refs/heads/master
Commit: a671883fc5cac2d81e0dcbd33c57abe1d45b593a
Parents: f14255f
Author: Matthias Boehm <[email protected]>
Authored: Sat Feb 3 20:08:22 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sat Feb 3 20:45:51 2018 -0800

----------------------------------------------------------------------
 .../apache/sysml/api/ScriptExecutorUtils.java   |  15 +-
 .../cp/ConvolutionCPInstruction.java            |  10 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |   4 +-
 .../runtime/matrix/data/LibMatrixDNNRelu.java   | 146 ++++++++-----------
 4 files changed, 70 insertions(+), 105 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java 
b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
index 253a317..a6c276f 100644
--- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
@@ -113,18 +113,9 @@ public class ScriptExecutorUtils {
                        
                        // display statistics (incl caching stats if enabled)
                        Statistics.stopRunTimer();
-
-                       if (!exceptionThrown) {
-                               if (statisticsMaxHeavyHitters > 0)
-                                       
System.out.println(Statistics.display(statisticsMaxHeavyHitters));
-                               else
-                                       
System.out.println(Statistics.display());
-                       } else {
-                               if (statisticsMaxHeavyHitters > 0)
-                                       
System.err.println(Statistics.display(statisticsMaxHeavyHitters));
-                               else
-                                       
System.err.println(Statistics.display());
-                       }
+                       (exceptionThrown ? System.err : System.out)
+                               
.println(Statistics.display(statisticsMaxHeavyHitters > 0 ?
+                                       statisticsMaxHeavyHitters : 
DMLScript.STATISTICS_COUNT));
                }
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index 08d220e..e32c3cf 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -227,15 +227,13 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                // (X > 0) * dout
                MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
                MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
-               MatrixBlock outputBlock;
+               MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(),
+                       input.isInSparseFormat() || dout.isInSparseFormat() );
                
-               if( !input.isEmpty() && !dout.isEmpty() ) {
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), false);
-                       outputBlock.allocateDenseBlock();
+               if( !input.isEmpty() && !dout.isEmpty() ) { //sparse-safe
+                       outputBlock.allocateBlock();
                        LibMatrixDNN.reluBackward(input, dout, outputBlock, 
_numThreads);
                }
-               else
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true);
                
                // release inputs/outputs
                ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());

http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index e8a88d8..1ad56b2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -273,10 +273,10 @@ public class LibMatrixDNN {
                                input.getNumRows() + " != " + dout.getNumRows() 
+ " || " + input.getNumColumns() + " != " + dout.getNumColumns());
                }
                
-               execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), 
params);
+               long nnz = 
execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), params);
                
                // post-processing: maintain nnz
-               outputBlock.recomputeNonZeros(); 
+               outputBlock.setNonZeros(nnz);
                outputBlock.examSparsity();
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
index c8f2f41..c8f85b1 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
@@ -23,17 +23,12 @@ import java.util.concurrent.Callable;
 
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.instructions.InstructionUtils;
-import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
 
 /**
  * This class contains the different implementation of rotate180 operation
  */
 public class LibMatrixDNNRelu
 {
-       private static ScalarOperator GT0 = 
InstructionUtils.parseScalarBinaryOperator(">", false, 0);
-
-       
        /**
         * Factory method that returns list of callable tasks for performing 
relu backward operation
         * 
@@ -64,94 +59,75 @@ public class LibMatrixDNNRelu
                
                @Override
                public Long call() throws Exception {
-                       //note: X (m x n), dout (m x n) -> out (m x n)
-                       DenseBlock out = _params.output.getDenseBlock();
-                       final int n = _params.input1.getNumColumns();
-                       if(!_params.input1.isInSparseFormat() && 
!_params.input2.isInSparseFormat()) {
-                               DenseBlock x = _params.input1.getDenseBlock();
-                               DenseBlock dout = 
_params.input2.getDenseBlock();
-                               for(int i = _rl; i < _ru; i++) {
-                                       double[] xvals = x.values(i), doutvals 
= dout.values(i), cvals = out.values(i);
-                                       int xpos = x.pos(i), doutpos = 
dout.pos(i), cpos = out.pos(i);
-                                       for(int j=0; j<n; j++)
-                                               cvals[cpos+j] = xvals[xpos+j] > 
0 ? doutvals[doutpos +j] : 0;
-                               }
-                       }
-                       else {
-                               scalarOperations(_params.input1, out, n, _rl, 
_ru, GT0);      // (X > 0)
-                               binaryOperationInPlaceMult(_params.input2, out, 
n, _rl, _ru); // (X > 0) * dout
-                       }
-                       return 0L;
+                       MatrixBlock m1 = _params.input1;
+                       MatrixBlock m2 = _params.input2;
+                       MatrixBlock out = _params.output;
+                       final int n = m1.getNumColumns();
+                       if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) )
+                               return 0L; //nothing to do
+                       
+                       //compute c = (a > 0) * b
+                       //(if there is at least one sparse input, the output is 
allocated in sparse)
+                       if(!m1.isInSparseFormat() && !m2.isInSparseFormat())
+                               reluBackwardDenseDense(m1.getDenseBlock(), 
m2.getDenseBlock(), out.getDenseBlock(), n, _rl, _ru);
+                       else if(!m1.isInSparseFormat() && m2.isInSparseFormat())
+                               reluBackwardDenseSparse(m1.getDenseBlock(), 
m2.getSparseBlock(), out.getSparseBlock(), _rl, _ru);
+                       else if(m1.isInSparseFormat() && !m2.isInSparseFormat())
+                               reluBackwardSparseDense(m1.getSparseBlock(), 
m2.getDenseBlock(), out.getSparseBlock(), _rl, _ru);
+                       else //sparse-sparse
+                               reluBackwardSparseDense(m1.getSparseBlock(), 
m2.getSparseBlock(), out.getSparseBlock(), _rl, _ru);
+                       
+                       //thread-local nnz maintenance
+                       return out.recomputeNonZeros(_rl, _ru-1);
                }
        }
        
-       private static void scalarOperations(MatrixBlock src, DenseBlock c,
-                       int destNumCols, int src_rl, int src_ru, ScalarOperator 
op)
-               throws DMLRuntimeException
-       {
-               if(src.isInSparseFormat()) {
-                       for(int i = src_rl; i < src_ru; i++) {
-                               if( src.getSparseBlock().isEmpty(i) ) continue;
-                               int apos = src.getSparseBlock().pos(i);
-                               int alen = src.getSparseBlock().size(i);
-                               int[] aix = src.getSparseBlock().indexes(i);
-                               double[] avals = src.getSparseBlock().values(i);
-                               double[] cvals = c.values(i);
-                               int cix = c.pos(i);
-                               for(int j = apos; j < apos+alen; j++)
-                                       cvals[ cix+aix[j] ] = 
op.executeScalar(avals[j]);
-                       }
-               }
-               else {
-                       DenseBlock a = src.getDenseBlock();
-                       for(int i = src_rl; i < src_ru; i++) {
-                               double[] avals = a.values(i), cvals = 
c.values(i);
-                               int aix = a.pos(i), cix = c.pos(i);
-                               for(int j=0; j<destNumCols; j++)
-                                       cvals[cix+j] = 
op.executeScalar(avals[aix+j]);
-                       }
+       private static void reluBackwardDenseDense(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int rl, int ru) {
+               for(int i = rl; i < ru; i++) {
+                       double[] avals = a.values(i), bvals = b.values(i);
+                       double[] cvals = c.values(i);
+                       int ix = a.pos(i);
+                       for(int j=0; j<n; j++)
+                               cvals[ix+j] = (avals[ix+j] > 0) ? bvals[ix +j] 
: 0;
                }
        }
        
-       private static void binaryOperationInPlaceMult(MatrixBlock src,
-               DenseBlock c, int destNumCols, int src_rl, int src_ru)
-               throws DMLRuntimeException 
-       {
-               if( src.isEmptyBlock(false) ) {
-                       c.set(src_rl, src_rl, 0, destNumCols, 0);
-                       return;
+       private static void reluBackwardDenseSparse(DenseBlock a, SparseBlock 
b, SparseBlock c, int rl, int ru) {
+               for(int i = rl; i < ru; i++) {
+                       if( b.isEmpty(i) ) continue;
+                       int bpos = b.pos(i), blen = b.size(i);
+                       int[] bix = b.indexes(i);
+                       double[] bvals = b.values(i), avals = a.values(i);
+                       int aix = a.pos(i);
+                       c.allocate(i, blen);
+                       for(int k=bpos; k<bpos+blen; k++)
+                               c.append(i, bix[k], (avals[aix+bix[k]] > 0) ? 
bvals[k] : 0);
                }
-               
-               if(src.isInSparseFormat()) {
-                       for(int i = src_rl; i < src_ru; i++) {
-                               if( !src.getSparseBlock().isEmpty(i) ) {
-                                       int apos = src.getSparseBlock().pos(i);
-                                       int alen = src.getSparseBlock().size(i);
-                                       int[] aix = 
src.getSparseBlock().indexes(i);
-                                       double[] avals = 
src.getSparseBlock().values(i);
-                                       double[] cvals = c.values(i);
-                                       int cix = c.pos(i);
-                                       int prevDestIndex = 0;
-                                       for(int j = apos; j < apos+alen; j++) {
-                                               c.set(i, i+1, prevDestIndex, 
aix[j], 0);
-                                               prevDestIndex = aix[j]+1;
-                                               cvals[ cix+aix[j] ] *= avals[j];
-                                       }
-                                       c.set(i, i+1, prevDestIndex, 
destNumCols, 0);
-                               }
-                               else {
-                                       c.set(i, i+1, 0, destNumCols, 0);
-                               }
-                       }
+       }
+       
+       private static void reluBackwardSparseDense(SparseBlock a, DenseBlock 
b, SparseBlock c, int rl, int ru) {
+               for(int i = rl; i < ru; i++) {
+                       if( a.isEmpty(i) ) continue;
+                       int apos = a.pos(i), alen = a.size(i);
+                       int[] aix = a.indexes(i);
+                       double[] avals = a.values(i), bvals = b.values(i);
+                       int bix = b.pos(i);
+                       c.allocate(i, alen);
+                       for(int k=apos; k<apos+alen; k++)
+                               c.append(i, aix[k], (avals[k] > 0) ? 
bvals[bix+aix[k]] : 0);
                }
-               else { //DENSE
-                       DenseBlock a = src.getDenseBlock();
-                       for(int i = src_rl; i < src_ru; i++) {
-                               double[] avals = a.values(i), cvals = 
c.values(i);
-                               int aix = a.pos(i), cix = c.pos(i);
-                               for(int j=0; j<destNumCols; j++)
-                                       cvals[cix+j] *= avals[aix+j];
-                       }
+       }
+       
+       private static void reluBackwardSparseDense(SparseBlock a, SparseBlock 
b, SparseBlock c, int rl, int ru) {
+               //b is the driver as it has likely less non-zeros
+               for(int i = rl; i < ru; i++) {
+                       if( a.isEmpty(i) || b.isEmpty(i) ) continue;
+                       int bpos = b.pos(i), blen = b.size(i);
+                       int[] bix = b.indexes(i);
+                       double[] bvals = b.values(i);
+                       c.allocate(i, blen);
+                       for(int k=bpos; k<bpos+blen; k++)
+                               c.append(i, bix[k], (a.get(i, bix[k]) > 0) ? 
bvals[k] : 0);
                }
        }
 }

Reply via email to