[SYSTEMML-552] Cache-conscious sparse-dense wsloss (post_nz pattern) Similar to cache-conscious sparse-dense wdivmm, wsloss also showed room for performance improvements with regard to large factors. This patch introduces cache-conscious sparse-dense operations for the pattern post_nz, e.g., sum (ppred(X,0, "!=") * (U %*% t(V) - X) ^ 2). On a scenario with 100k x 100k, sp=0.01 this change led to the following improvements: for rank=50, 3.6s -> 1.4s; rank=10, 650ms -> 520ms.
Furthermore, this patch also makes various cleanups with regard to multi-threaded operations: (1) error checking via futures for wsloss and wcemm, and (2) type safe result/nnz aggregation for wsloss, wcemm, and wdivmm. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cfc561ee Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cfc561ee Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cfc561ee Branch: refs/heads/master Commit: cfc561eecc8dbc78a943532e19574b02600637c0 Parents: 63a56b0 Author: Matthias Boehm <[email protected]> Authored: Tue Mar 8 20:27:17 2016 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Mar 8 20:27:17 2016 -0800 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 102 ++++++++++--------- 1 file changed, 53 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cfc561ee/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index d18f13d..d56b018 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -584,16 +585,16 @@ public class LibMatrixMult try { ExecutorService pool = Executors.newFixedThreadPool(k); - ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>(); + ArrayList<MatrixMultWSLossTask> tasks = new ArrayList<MatrixMultWSLossTask>(); int blklen = (int)(Math.ceil((double)mX.rlen/k)); for( int i=0; i<k & i*blklen<mX.rlen; i++ ) tasks.add(new MatrixMultWSLossTask(mX, mU, mV, mW, wt, i*blklen, Math.min((i+1)*blklen, mX.rlen))); - pool.invokeAll(tasks); + List<Future<Double>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results - sumScalarResults(tasks, ret); + sumScalarResults(taskret, ret); } - catch (InterruptedException e) { + catch( Exception e ) { throw new DMLRuntimeException(e); } @@ -795,12 +796,12 @@ public class LibMatrixMult tasks.add(new MatrixMultWDivTask(mW, mU, mV, mX, ret, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen), 0, mW.clen)); } //execute tasks - List<Future<Object>> taskret = pool.invokeAll(tasks); + List<Future<Long>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial nnz and check for errors ret.nonZeros = 0; //reset after execute - for( Future<Object> task : taskret ) - ret.nonZeros += (Long)task.get(); + for( Future<Long> task : taskret ) + ret.nonZeros += task.get(); } catch (Exception e) { throw new DMLRuntimeException(e); @@ -877,16 +878,16 @@ public class LibMatrixMult try { ExecutorService pool = Executors.newFixedThreadPool(k); - ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>(); + ArrayList<MatrixMultWCeTask> tasks = new ArrayList<MatrixMultWCeTask>(); int blklen = (int)(Math.ceil((double)mW.rlen/k)); for( int i=0; i<k & i*blklen<mW.rlen; i++ ) tasks.add(new MatrixMultWCeTask(mW, mU, mV, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen))); - pool.invokeAll(tasks); + List<Future<Double>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results - sumScalarResults(tasks, ret); + sumScalarResults(taskret, ret); } - catch (InterruptedException e) { + catch( Exception e ) { throw new DMLRuntimeException(e); } @@ -2230,17 +2231,34 @@ public class LibMatrixMult else if( wt==WeightsType.POST_NZ ) { // approach: iterate over W, point-wise in order to exploit sparsity - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - if( !x.isEmpty(i) ) { - int xpos = x.pos(i); - int xlen = x.size(i); - int[] xix = x.indexes(i); - double[] xval = x.values(i); - for( int k=xpos; k<xpos+xlen; k++ ) { - double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); - wsloss += (xval[k]-uvij)*(xval[k]-uvij); + // blocked over ij, while maintaining front of column indexes, where the + // blocksize is chosen such that we reuse each vector on average 8 times. + final int blocksizeIJ = (int) (8L*mX.rlen*mX.clen/mX.nonZeros); + int[] curk = new int[blocksizeIJ]; + + for( int bi=rl; bi<ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + //prepare starting indexes for block row + Arrays.fill(curk, 0); + //blocked execution over column blocks + for( int bj=0; bj<n; bj+=blocksizeIJ ) { + int bjmin = Math.min(n, bj+blocksizeIJ); + for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { + if( !x.isEmpty(i) ) { + int xpos = x.pos(i); + int xlen = x.size(i); + int[] xix = x.indexes(i); + double[] xval = x.values(i); + int k = xpos + curk[i-bi]; + for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) { + double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); + wsloss += (xval[k]-uvij)*(xval[k]-uvij); + } + curk[i-bi] = k - xpos; + } } - } + } + } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) else if( wt==WeightsType.PRE ) @@ -2650,7 +2668,7 @@ public class LibMatrixMult SparseBlock x = (mX==null) ? null : mX.sparseBlock; //approach: iterate over non-zeros of w, selective mm computation - //blocked over ij, while maintaining font of column indexes, where the + //blocked over ij, while maintaining front of column indexes, where the //blocksize is chosen such that we reuse each vector on average 8 times. final int blocksizeIJ = (int) (8L*mW.rlen*mW.clen/mW.nonZeros); int[] curk = new int[blocksizeIJ]; @@ -3985,13 +4003,16 @@ public class LibMatrixMult * * @param tasks * @param ret + * @throws ExecutionException + * @throws InterruptedException */ - private static void sumScalarResults(ArrayList<ScalarResultTask> tasks, MatrixBlock ret) + private static void sumScalarResults(List<Future<Double>> tasks, MatrixBlock ret) + throws InterruptedException, ExecutionException { - //aggregate partial results + //aggregate partial results and check for errors double val = 0; - for(ScalarResultTask task : tasks) - val += task.getScalarResult(); + for(Future<Double> task : tasks) + val += task.get(); ret.quickSetValue(0, 0, val); } @@ -4207,19 +4228,12 @@ public class LibMatrixMult return null; } } - - /** - * - */ - private static interface ScalarResultTask extends Callable<Object>{ - public double getScalarResult(); - } /** * * */ - private static class MatrixMultWSLossTask implements ScalarResultTask + private static class MatrixMultWSLossTask implements Callable<Double> { private MatrixBlock _mX = null; private MatrixBlock _mU = null; @@ -4247,7 +4261,7 @@ public class LibMatrixMult } @Override - public Object call() throws DMLRuntimeException + public Double call() throws DMLRuntimeException { if( !_mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || !_mW.sparse) && !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() @@ -4260,11 +4274,6 @@ public class LibMatrixMult else matrixMultWSLossGeneric(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru); - return null; - } - - @Override - public double getScalarResult() { return _ret.quickGetValue(0, 0); } } @@ -4322,7 +4331,7 @@ public class LibMatrixMult * * */ - private static class MatrixMultWDivTask implements Callable<Object> + private static class MatrixMultWDivTask implements Callable<Long> { private MatrixBlock _mW = null; private MatrixBlock _mU = null; @@ -4351,7 +4360,7 @@ public class LibMatrixMult } @Override - public Object call() throws DMLRuntimeException + public Long call() throws DMLRuntimeException { //core weighted div mm computation boolean scalarX = _wt.hasScalar(); @@ -4369,7 +4378,7 @@ public class LibMatrixMult } } - private static class MatrixMultWCeTask implements ScalarResultTask + private static class MatrixMultWCeTask implements Callable<Double> { private MatrixBlock _mW = null; private MatrixBlock _mU = null; @@ -4395,7 +4404,7 @@ public class LibMatrixMult } @Override - public Object call() throws DMLRuntimeException + public Double call() throws DMLRuntimeException { //core weighted div mm computation if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() ) @@ -4406,11 +4415,6 @@ public class LibMatrixMult matrixMultWCeMMGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru); - return null; - } - - @Override - public double getScalarResult() { return _ret.quickGetValue(0, 0); } }
