Repository: systemml Updated Branches: refs/heads/master f14255f46 -> a671883fc
[SYSTEMML-2129] Performance sparse relu_backward operations This patch improves the performance of sparse relu_backward operations (including the common case of dense-sparse) by (1) allocating the output in sparse representation if at least one input is sparse, (2) dedicated implementations for all dense-sparse combinations, and (3) thread-local nnz maintenance to avoid unnecessary passes over the output. On 1000 iterations of lenet over mnist, this patch improved the relu_backward runtime (for 2000 exec calls), from 24.9s to 5.6s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a671883f Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a671883f Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a671883f Branch: refs/heads/master Commit: a671883fc5cac2d81e0dcbd33c57abe1d45b593a Parents: f14255f Author: Matthias Boehm <[email protected]> Authored: Sat Feb 3 20:08:22 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sat Feb 3 20:45:51 2018 -0800 ---------------------------------------------------------------------- .../apache/sysml/api/ScriptExecutorUtils.java | 15 +- .../cp/ConvolutionCPInstruction.java | 10 +- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 4 +- .../runtime/matrix/data/LibMatrixDNNRelu.java | 146 ++++++++----------- 4 files changed, 70 insertions(+), 105 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java index 253a317..a6c276f 100644 --- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java +++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java @@ -113,18 +113,9 @@ public class ScriptExecutorUtils { // display statistics (incl caching stats if enabled) Statistics.stopRunTimer(); - - if (!exceptionThrown) { - if (statisticsMaxHeavyHitters > 0) - System.out.println(Statistics.display(statisticsMaxHeavyHitters)); - else - System.out.println(Statistics.display()); - } else { - if (statisticsMaxHeavyHitters > 0) - System.err.println(Statistics.display(statisticsMaxHeavyHitters)); - else - System.err.println(Statistics.display()); - } + (exceptionThrown ? System.err : System.out) + .println(Statistics.display(statisticsMaxHeavyHitters > 0 ? + statisticsMaxHeavyHitters : DMLScript.STATISTICS_COUNT)); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java index 08d220e..e32c3cf 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java @@ -227,15 +227,13 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { // (X > 0) * dout MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); - MatrixBlock outputBlock; + MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), + input.isInSparseFormat() || dout.isInSparseFormat() ); - if( !input.isEmpty() && !dout.isEmpty() ) { - outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false); - outputBlock.allocateDenseBlock(); + if( !input.isEmpty() && !dout.isEmpty() ) { //sparse-safe + outputBlock.allocateBlock(); LibMatrixDNN.reluBackward(input, dout, outputBlock, _numThreads); } - else - outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true); // release inputs/outputs ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index e8a88d8..1ad56b2 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -273,10 +273,10 @@ public class LibMatrixDNN { input.getNumRows() + " != " + dout.getNumRows() + " || " + input.getNumColumns() + " != " + dout.getNumColumns()); } - execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), params); + long nnz = execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), params); // post-processing: maintain nnz - outputBlock.recomputeNonZeros(); + outputBlock.setNonZeros(nnz); outputBlock.examSparsity(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/a671883f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java index c8f2f41..c8f85b1 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java @@ -23,17 +23,12 @@ import java.util.concurrent.Callable; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.operators.ScalarOperator; /** * This class contains the different implementation of rotate180 operation */ public class LibMatrixDNNRelu { - private static ScalarOperator GT0 = InstructionUtils.parseScalarBinaryOperator(">", false, 0); - - /** * Factory method that returns list of callable tasks for performing relu backward operation * @@ -64,94 +59,75 @@ public class LibMatrixDNNRelu @Override public Long call() throws Exception { - //note: X (m x n), dout (m x n) -> out (m x n) - DenseBlock out = _params.output.getDenseBlock(); - final int n = _params.input1.getNumColumns(); - if(!_params.input1.isInSparseFormat() && !_params.input2.isInSparseFormat()) { - DenseBlock x = _params.input1.getDenseBlock(); - DenseBlock dout = _params.input2.getDenseBlock(); - for(int i = _rl; i < _ru; i++) { - double[] xvals = x.values(i), doutvals = dout.values(i), cvals = out.values(i); - int xpos = x.pos(i), doutpos = dout.pos(i), cpos = out.pos(i); - for(int j=0; j<n; j++) - cvals[cpos+j] = xvals[xpos+j] > 0 ? doutvals[doutpos +j] : 0; - } - } - else { - scalarOperations(_params.input1, out, n, _rl, _ru, GT0); // (X > 0) - binaryOperationInPlaceMult(_params.input2, out, n, _rl, _ru); // (X > 0) * dout - } - return 0L; + MatrixBlock m1 = _params.input1; + MatrixBlock m2 = _params.input2; + MatrixBlock out = _params.output; + final int n = m1.getNumColumns(); + if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) + return 0L; //nothing to do + + //compute c = (a > 0) * b + //(if there is at least one sparse input, the output is allocated in sparse) + if(!m1.isInSparseFormat() && !m2.isInSparseFormat()) + reluBackwardDenseDense(m1.getDenseBlock(), m2.getDenseBlock(), out.getDenseBlock(), n, _rl, _ru); + else if(!m1.isInSparseFormat() && m2.isInSparseFormat()) + reluBackwardDenseSparse(m1.getDenseBlock(), m2.getSparseBlock(), out.getSparseBlock(), _rl, _ru); + else if(m1.isInSparseFormat() && !m2.isInSparseFormat()) + reluBackwardSparseDense(m1.getSparseBlock(), m2.getDenseBlock(), out.getSparseBlock(), _rl, _ru); + else //sparse-sparse + reluBackwardSparseDense(m1.getSparseBlock(), m2.getSparseBlock(), out.getSparseBlock(), _rl, _ru); + + //thread-local nnz maintenance + return out.recomputeNonZeros(_rl, _ru-1); } } - private static void scalarOperations(MatrixBlock src, DenseBlock c, - int destNumCols, int src_rl, int src_ru, ScalarOperator op) - throws DMLRuntimeException - { - if(src.isInSparseFormat()) { - for(int i = src_rl; i < src_ru; i++) { - if( src.getSparseBlock().isEmpty(i) ) continue; - int apos = src.getSparseBlock().pos(i); - int alen = src.getSparseBlock().size(i); - int[] aix = src.getSparseBlock().indexes(i); - double[] avals = src.getSparseBlock().values(i); - double[] cvals = c.values(i); - int cix = c.pos(i); - for(int j = apos; j < apos+alen; j++) - cvals[ cix+aix[j] ] = op.executeScalar(avals[j]); - } - } - else { - DenseBlock a = src.getDenseBlock(); - for(int i = src_rl; i < src_ru; i++) { - double[] avals = a.values(i), cvals = c.values(i); - int aix = a.pos(i), cix = c.pos(i); - for(int j=0; j<destNumCols; j++) - cvals[cix+j] = op.executeScalar(avals[aix+j]); - } + private static void reluBackwardDenseDense(DenseBlock a, DenseBlock b, DenseBlock c, int n, int rl, int ru) { + for(int i = rl; i < ru; i++) { + double[] avals = a.values(i), bvals = b.values(i); + double[] cvals = c.values(i); + int ix = a.pos(i); + for(int j=0; j<n; j++) + cvals[ix+j] = (avals[ix+j] > 0) ? bvals[ix +j] : 0; } } - private static void binaryOperationInPlaceMult(MatrixBlock src, - DenseBlock c, int destNumCols, int src_rl, int src_ru) - throws DMLRuntimeException - { - if( src.isEmptyBlock(false) ) { - c.set(src_rl, src_rl, 0, destNumCols, 0); - return; + private static void reluBackwardDenseSparse(DenseBlock a, SparseBlock b, SparseBlock c, int rl, int ru) { + for(int i = rl; i < ru; i++) { + if( b.isEmpty(i) ) continue; + int bpos = b.pos(i), blen = b.size(i); + int[] bix = b.indexes(i); + double[] bvals = b.values(i), avals = a.values(i); + int aix = a.pos(i); + c.allocate(i, blen); + for(int k=bpos; k<bpos+blen; k++) + c.append(i, bix[k], (avals[aix+bix[k]] > 0) ? bvals[k] : 0); } - - if(src.isInSparseFormat()) { - for(int i = src_rl; i < src_ru; i++) { - if( !src.getSparseBlock().isEmpty(i) ) { - int apos = src.getSparseBlock().pos(i); - int alen = src.getSparseBlock().size(i); - int[] aix = src.getSparseBlock().indexes(i); - double[] avals = src.getSparseBlock().values(i); - double[] cvals = c.values(i); - int cix = c.pos(i); - int prevDestIndex = 0; - for(int j = apos; j < apos+alen; j++) { - c.set(i, i+1, prevDestIndex, aix[j], 0); - prevDestIndex = aix[j]+1; - cvals[ cix+aix[j] ] *= avals[j]; - } - c.set(i, i+1, prevDestIndex, destNumCols, 0); - } - else { - c.set(i, i+1, 0, destNumCols, 0); - } - } + } + + private static void reluBackwardSparseDense(SparseBlock a, DenseBlock b, SparseBlock c, int rl, int ru) { + for(int i = rl; i < ru; i++) { + if( a.isEmpty(i) ) continue; + int apos = a.pos(i), alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i), bvals = b.values(i); + int bix = b.pos(i); + c.allocate(i, alen); + for(int k=apos; k<apos+alen; k++) + c.append(i, aix[k], (avals[k] > 0) ? bvals[bix+aix[k]] : 0); } - else { //DENSE - DenseBlock a = src.getDenseBlock(); - for(int i = src_rl; i < src_ru; i++) { - double[] avals = a.values(i), cvals = c.values(i); - int aix = a.pos(i), cix = c.pos(i); - for(int j=0; j<destNumCols; j++) - cvals[cix+j] *= avals[aix+j]; - } + } + + private static void reluBackwardSparseDense(SparseBlock a, SparseBlock b, SparseBlock c, int rl, int ru) { + //b is the driver as it has likely less non-zeros + for(int i = rl; i < ru; i++) { + if( a.isEmpty(i) || b.isEmpty(i) ) continue; + int bpos = b.pos(i), blen = b.size(i); + int[] bix = b.indexes(i); + double[] bvals = b.values(i); + c.allocate(i, blen); + for(int k=bpos; k<bpos+blen; k++) + c.append(i, bix[k], (a.get(i, bix[k]) > 0) ? bvals[k] : 0); } } }
