Repository: systemml Updated Branches: refs/heads/master 4d7606574 -> 3c7a2fc96
[SYSTEMML-2034] Performance sparse-dense maxpooling_backward The existing sparse-dense maxpooling backward implementation scanned each sparse row for all channels, and p/q parameters, which caused significant overhead on lenet over mnist. This patch addresses this performance issue by introducing a "real" sparse-dense implementation that uses small auxiliary data structures and a single, sequential scan over the sparse input data. Additionally, the nnzs are now maintained in a thread-local manner to increase utilization in multi-threaded environments. There is additional potential to exploit overlapping zeros for gaps but since this is no major bottleneck, we can postpone this into the next release. On lenet over mnist60k w/ C=1, epochs=1, Hin=28, Win=28, Hf=5, Wf=5, stride=1, pad=2, F1=32, F2=64, N3=512, this patch improved end-to-end performance from 1,355s (1,098s maxpooling backward) to 259s (24s maxpooling backward). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3c7a2fc9 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3c7a2fc9 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3c7a2fc9 Branch: refs/heads/master Commit: 3c7a2fc96069e99332a21ecfa00907c99486e5a5 Parents: 4d76065 Author: Matthias Boehm <[email protected]> Authored: Sat Dec 2 00:15:43 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sat Dec 2 01:45:57 2017 -0800 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 13 +- .../runtime/matrix/data/LibMatrixDNNHelper.java | 25 ++- .../data/LibMatrixDNNPoolingBackwardHelper.java | 167 ++++++++++++++----- 3 files changed, 142 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 67d4a1a..0e4a468 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -313,23 +313,22 @@ public class LibMatrixDNN { } if(DMLScript.FINEGRAINED_STATISTICS) { - if(input.isInSparseFormat() || dout.isInSparseFormat()) { + if(input.isInSparseFormat() || dout.isInSparseFormat()) maxPoolBwdSparseCount.addAndGet(1); - } - else { + else maxPoolBwdDenseCount.addAndGet(1); - } } if (params.output.isInSparseFormat()) throw new DMLRuntimeException("Sparse maxpooling_backward is not supported"); - fillIndexesArray(params); + if( !(params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) ) + fillIndexesArray(params); //not needed for sparse-dense - execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, performReluBackward), params); + long nnz = execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, performReluBackward), params); //post-processing: maintain nnz - outputBlock.recomputeNonZeros(); + outputBlock.setNonZeros(nnz); outputBlock.examSparsity(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java index 7b749dd..ee21ce3 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java @@ -75,19 +75,17 @@ public class LibMatrixDNNHelper { ArrayList<Callable<Long>> ret = new ArrayList<>(); int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads); int taskSize = (int)(Math.ceil((double)params.N / k)); + boolean sparse1 = params.input1.isInSparseFormat(); + boolean sparse2 = params.input2.isInSparseFormat(); for(int i = 0; i*taskSize < params.N; i++) { - if(!params.input1.isInSparseFormat()) { - if(!params.input2.isInSparseFormat()) - ret.add(new PoolingBackwardDenseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); - else - ret.add(new PoolingBackwardDenseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); - } - else { - if(!params.input2.isInSparseFormat()) - ret.add(new PoolingBackwardSparseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); - else - ret.add(new PoolingBackwardSparseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); - } + if( !sparse1 && !sparse2 ) + ret.add(new PoolingBackwardDenseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); + else if( !sparse1 && sparse2 ) + ret.add(new PoolingBackwardDenseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); + else if( sparse1 && !sparse2 ) + ret.add(new PoolingBackwardSparseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); + else if( sparse1 && sparse2 ) + ret.add(new PoolingBackwardSparseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), params, performReluBackward)); } return ret; } @@ -417,9 +415,6 @@ public class LibMatrixDNNHelper { * @throws DMLRuntimeException if error occurs */ static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException { - if(!input.isInSparseFormat()) - throw new DMLRuntimeException("Incorrect usage: Only sparse format supported"); - int [] tensorIndexes = new int[3]; int start_h = params.start_indexes_h[p]; http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java index 5b04e59..4d8319b 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java @@ -18,6 +18,7 @@ */ package org.apache.sysml.runtime.matrix.data; +import java.util.Arrays; import java.util.concurrent.Callable; /** @@ -31,8 +32,9 @@ public class LibMatrixDNNPoolingBackwardHelper { { public int _rl; public int _ru; private final ConvolutionParameters _params; - double [] outputArray; boolean performReluBackward; - double [] inputArray; double [] doutArray; + boolean performReluBackward; + double [] inputArray, doutArray; + MatrixBlock output; int C; int CHW; int P; int Q; int HW; int CPQ; int PQ; public PoolingBackwardDenseDense(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) { _rl = rl; _ru = ru; @@ -40,16 +42,17 @@ public class LibMatrixDNNPoolingBackwardHelper { this.performReluBackward = performReluBackward; inputArray = params.input1.getDenseBlock(); doutArray = params.input2.getDenseBlock(); - outputArray = params.output.getDenseBlock(); + output = params.output; C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W; P = params.P; Q = params.Q; CPQ = params.C*params.P*params.Q; PQ = params.P*params.Q; - if (inputArray == null || doutArray == null || outputArray == null ) + if (inputArray == null || doutArray == null || output.getDenseBlock() == null ) throw new RuntimeException("Incorrect usage: empty inputs"); } @Override public Long call() throws Exception { + double[] out = output.getDenseBlock(); for(int n = _rl; n < _ru; n++) { for (int c = 0; c < C; c++) { final int inputOffset = n*CHW + c*HW; @@ -58,12 +61,13 @@ public class LibMatrixDNNPoolingBackwardHelper { for (int q = 0; q < Q; q++) { int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward); if(maxIndex != -1) - outputArray[maxIndex] += doutArray[outputOffset + p * Q + q]; + out[maxIndex] += doutArray[outputOffset + p * Q + q]; } } } } - return 0L; + //thread-local nnz maintenance + return output.recomputeNonZeros(_rl, _ru-1); } } @@ -74,7 +78,8 @@ public class LibMatrixDNNPoolingBackwardHelper { { public int _rl; public int _ru; private final ConvolutionParameters _params; - double [] outputArray; boolean performReluBackward; + MatrixBlock output; + boolean performReluBackward; double [] inputArray; MatrixBlock dout; int C; int CHW; int P; int Q; int HW; public PoolingBackwardDenseSparse(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) { @@ -83,10 +88,10 @@ public class LibMatrixDNNPoolingBackwardHelper { this.performReluBackward = performReluBackward; inputArray = params.input1.getDenseBlock(); dout = params.input2; - outputArray = params.output.getDenseBlock(); + output = params.output; C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W; P = params.P; Q = params.Q; - if (inputArray == null || outputArray == null ) + if (inputArray == null || output.getDenseBlock() == null ) throw new RuntimeException("Incorrect usage: empty inputs"); if (!params.input2.isInSparseFormat()) throw new RuntimeException("Incorrect usage: Call optimized versions"); @@ -94,6 +99,7 @@ public class LibMatrixDNNPoolingBackwardHelper { @Override public Long call() throws Exception { + double[] out = output.getDenseBlock(); for(int n = _rl; n < _ru; n++) { if( !dout.sparseBlock.isEmpty(n) ) { int [] tensorIndexes = new int[3]; @@ -109,11 +115,12 @@ public class LibMatrixDNNPoolingBackwardHelper { final int inputOffset = n*CHW + c*HW; int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward); if(maxIndex != -1) - outputArray[maxIndex] += avals[j]; + out[maxIndex] += avals[j]; } } } - return 0L; + //thread-local nnz maintenance + return output.recomputeNonZeros(_rl, _ru-1); } } @@ -122,45 +129,120 @@ public class LibMatrixDNNPoolingBackwardHelper { */ public static class PoolingBackwardSparseDense implements Callable<Long> { - public int _rl; public int _ru; + private final int _rl, _ru; private final ConvolutionParameters _params; - double [] outputArray; boolean performReluBackward; - double [] doutArray; - int C; int CHW; int P; int Q; int HW; int CPQ; int PQ; - public PoolingBackwardSparseDense(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) { - _rl = rl; _ru = ru; - _params = params; - this.performReluBackward = performReluBackward; - doutArray = params.input2.getDenseBlock(); - outputArray = params.output.getDenseBlock(); - C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W; - P = params.P; Q = params.Q; CPQ = params.C*params.P*params.Q; - PQ = params.P*params.Q; - if (doutArray == null || outputArray == null ) + private final double[] dout; + private final MatrixBlock output; + private final boolean reluBack; + + public PoolingBackwardSparseDense(int rl, int ru, ConvolutionParameters params, boolean relu) { + _rl = rl; _ru = ru; _params = params; + reluBack = relu; + dout = params.input2.getDenseBlock(); + output = params.output; + if (dout == null || output.getDenseBlock() == null ) throw new RuntimeException("Incorrect usage: empty inputs"); if (!params.input1.isInSparseFormat()) - throw new RuntimeException("Incorrect usage: Call optimized versions"); + throw new RuntimeException("Incorrect usage: sparse input1 expected"); } @Override - public Long call() throws Exception { + public Long call() throws Exception + { + SparseBlock sblock = _params.input1.getSparseBlock(); + double[] out = output.getDenseBlock(); + final int P = _params.P, Q = _params.Q, W = _params.W; + final int C = _params.C, R = _params.R, S = _params.S; + final int padh = _params.pad_h, padw = _params.pad_w; + final int strideh = _params.stride_h, stridew = _params.stride_w; + final int PQ = _params.P * _params.Q; + final int CPQ = _params.C * _params.P * _params.Q; + final int HW = _params.H * _params.W; + final int CHW = _params.C * _params.H * _params.W; + + //allocate auxiliary data structures + double[] maxVal = new double[PQ]; + int[] maxIx = new int[PQ]; + for(int n = _rl; n < _ru; n++) { for (int c = 0; c < C; c++) { + //step 0: basic initializations final int doutOffset = n*CPQ + c*PQ; - final int inputOffset = n*CHW + c*HW; - for (int p = 0; p < P; p++) { - for (int q = 0; q < Q; q++) { - double inVal = doutArray[doutOffset + p * Q + q]; - if(inVal != 0) { - int maxIndex = LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, _params, performReluBackward); - if(maxIndex != -1) - outputArray[maxIndex] += inVal; + final int outOffset = n*CHW + c*HW; + + //step 1: perform maxpooling w/ index maintenance in a + //single, sequential pass over the sparse input matrix + if( !sblock.isEmpty(n) ) { + Arrays.fill(maxVal, -Double.MAX_VALUE); + int apos = sblock.pos(n); + int alen = sblock.size(n); + int[] aix = sblock.indexes(n); + double[] avals = sblock.values(n); + //find channel start and end, w/ robustness for non-existing entries + int cpos = (c==0) ? 0 : sblock.posFIndexGTE(n, c*HW); + int cpos2 = (c+1==C) ? alen : sblock.posFIndexGTE(n, (c+1)*HW); + cpos = (cpos>=0) ? cpos : alen; + cpos2 = (cpos2>=0) ? cpos2 : alen; + int lastix = c*HW-1; + for(int j=apos+cpos; j<apos+cpos2; j++) { + //handle skipped zero values + update0(lastix+1, aix[j], maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W); + //handle current non-zero value + int h = (aix[j] % HW) / W; + int w = aix[j] % W; + double val = reluBack && avals[j] < 0 ? 0 : avals[j]; + update(val, maxVal, maxIx, h, w, padh, padw, strideh, stridew, P, Q, R, S, W); + //memoize last seen index + lastix = aix[j]; + } + //handle skipped zero values at end of row + update0(lastix+1, (c+1)*HW, maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W); + } + else { + //handle empty row + Arrays.fill(maxVal, 0); + for(int p = 0, ix=0; p < P; p++) { + int h = Math.max(-padh+p*strideh, 0); + for(int q = 0; q < Q; q++, ix++) { + int w = Math.max(-padw+q*stridew, 0); + maxIx[ix] = h * W + w; } } } + + //step 2: perform maxpooling backward + for (int pq = 0; pq < PQ; pq++) + out[ outOffset + maxIx[pq] ] += dout[ doutOffset + pq ]; } } - return 0L; + //thread-local nnz maintenance + return output.recomputeNonZeros(_rl, _ru-1); + } + + private static void update0(int lix, int uix, double[] maxVal, int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, int S, int HW, int W) { + //TODO exploit constant value and overlap for potential early abort + for(int i = lix; i<uix; i++) + update(0, maxVal, maxIx, (i%HW)/W, i%W, padh, padw, strideh, stridew, P, Q, R, S, W); + } + + private static void update(double val, double[] maxVal, int[] maxIx, int h, int w, int padh, int padw, int strideh, int stridew, int P, int Q, int R, int S, int W) { + //determine lower and upper bounds for p and q + //(see fillIndexesArray, solved for p and q, reversed) + int lp = Math.max((h+padh-R+strideh)/strideh, 0); + int up = Math.min((h+padh+strideh)/strideh, P); + int lq = Math.max((w+padw-S+stridew)/stridew, 0); + int uq = Math.min((w+padw+stridew)/stridew, Q); + + //maintain max index for all relevant p and q + int maxIndex = h * W + w; + for(int p = lp; p < up; p++) + for(int q = lq; q < uq; q++) { + int ix = p * Q + q; + if( maxVal[ix] < val ) { + maxVal[ix] = val; + maxIx[ix] = maxIndex; + } + } } } @@ -171,16 +253,17 @@ public class LibMatrixDNNPoolingBackwardHelper { { public int _rl; public int _ru; private final ConvolutionParameters _params; - double [] outputArray; boolean performReluBackward; + MatrixBlock output; + boolean performReluBackward; int C; int CHW; int P; int Q; int HW; public PoolingBackwardSparseSparse(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) { _rl = rl; _ru = ru; _params = params; this.performReluBackward = performReluBackward; - outputArray = params.output.getDenseBlock(); + output = params.output; C = params.C; CHW = params.C*params.H*params.W; HW = params.H*params.W; P = params.P; Q = params.Q; - if (outputArray == null ) + if (output.getDenseBlock() == null ) throw new RuntimeException("Incorrect usage: empty outputs"); if (!params.input1.isInSparseFormat() || !params.input2.isInSparseFormat()) throw new RuntimeException("Incorrect usage: Call optimized versions"); @@ -188,6 +271,7 @@ public class LibMatrixDNNPoolingBackwardHelper { @Override public Long call() throws Exception { + double[] out = output.getDenseBlock(); for(int n = _rl; n < _ru; n++) { if( !_params.input2.sparseBlock.isEmpty(n) ) { int [] tensorIndexes = new int[3]; @@ -203,11 +287,12 @@ public class LibMatrixDNNPoolingBackwardHelper { final int inputOffset = n*CHW + c*HW; int maxIndex = LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, _params, performReluBackward); if(maxIndex != -1) - outputArray[maxIndex] += avals[j]; + out[maxIndex] += avals[j]; } } } - return 0L; + //thread-local nnz maintenance + return output.recomputeNonZeros(_rl, _ru-1); } } }
