Repository: systemml Updated Branches: refs/heads/master 426d7fa0d -> a97bc53f7
[SYSTEMML-1955] Memory-efficient ultra-sparse matrix multiplication So far the codepath for ultra-sparse matrix multiply was only used if any of the two inputs classified as ultra-sparse in terms of its sparsity and absolute number of non-zeros. For special cases of permutation matrix multiply with large ultra-sparse matrices this is not the case leading to sparse-dense matrix multiplication which allocates the output in dense format. This patch improves the code path selection for ultra-sparse matrix multiplication to include such permutation matrix multiplies, which allocates the output in sparse format and thus is much more memory efficient. Additionally, it also improves performance by avoiding repeated reallocations. For example, on a scenario of 58825360 x 2519371 (nnz=58825360) times 2519371 x 300 (nnz=755810998), this patch enables the successful execution in CP (before it crashed due to allocating a dense output larger than 16GB). By avoiding reallocations, this patch further improved the execution time from 112s to 24s due to significantly reduced GC overhead. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0b5480b7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0b5480b7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0b5480b7 Branch: refs/heads/master Commit: 0b5480b77de3391bec0ee1e60ab7d71de80e9605 Parents: 426d7fa Author: Matthias Boehm <[email protected]> Authored: Wed Oct 11 14:20:00 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Oct 12 01:13:05 2017 -0700 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 88 +++++++++++--------- .../sysml/runtime/matrix/data/MatrixBlock.java | 22 +++-- 2 files changed, 66 insertions(+), 44 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0b5480b7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 0885508..aedf975 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -29,6 +29,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.math3.util.FastMath; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.MapMultChain.ChainType; import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType; import org.apache.sysml.lops.WeightedDivMM.WDivMMType; @@ -111,19 +112,20 @@ public class LibMatrixMult public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean examSparsity) throws DMLRuntimeException - { + { //check inputs / outputs if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } - + //Timing time = new Timing(true); - + //pre-processing: output allocation + boolean ultraSparse = isUltraSparseMatrixMult(m1, m2); boolean tm2 = checkPrepMatrixMultRightInput(m1,m2); m2 = prepMatrixMultRightInput(m1, m2); - ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse()); + ret.sparse = ultraSparse; if( !ret.sparse ) ret.allocateDenseBlock(); @@ -133,7 +135,7 @@ public class LibMatrixMult int cu = m2.clen; //core matrix mult computation - if( m1.isUltraSparse() || m2.isUltraSparse() ) + if( ultraSparse ) matrixMultUltraSparse(m1, m2, ret, 0, ru2); else if(!m1.sparse && !m2.sparse) matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, cu); @@ -147,13 +149,11 @@ public class LibMatrixMult //post-processing: nnz/representation if( !ret.sparse ) ret.recomputeNonZeros(); - if(examSparsity) ret.examSparsity(); - //System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + - // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** @@ -168,13 +168,13 @@ public class LibMatrixMult */ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException - { + { //check inputs / outputs if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } - + //check too high additional vector-matrix memory requirements (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if( m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen==1 || m1.isUltraSparse() || m2.isUltraSparse()) @@ -188,15 +188,13 @@ public class LibMatrixMult //pre-processing: output allocation (in contrast to single-threaded, //we need to allocate sparse as well in order to prevent synchronization) + boolean ultraSparse = isUltraSparseMatrixMult(m1, m2); boolean tm2 = checkPrepMatrixMultRightInput(m1,m2); m2 = prepMatrixMultRightInput(m1, m2); - ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse()); - if( !ret.sparse ) - ret.allocateDenseBlock(); - else - ret.allocateSparseRowsBlock(); + ret.sparse = ultraSparse; + ret.allocateBlock(); - if (!ret.isThreadSafe()){ + if (!ret.isThreadSafe()) { matrixMult(m1, m2, ret); return; } @@ -216,7 +214,7 @@ public class LibMatrixMult for( int i=0, lb=0; i<blklens.size(); lb+=blklens.get(i), i++ ) tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2r, pm2c, lb, lb+blklens.get(i))); //execute tasks - List<Future<Object>> taskret = pool.invokeAll(tasks); + List<Future<Object>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results (nnz, ret for vector/matrix) ret.nonZeros = 0; //reset after execute @@ -233,12 +231,11 @@ public class LibMatrixMult throw new DMLRuntimeException(ex); } - //post-processing (nnz maintained in parallel) ret.examSparsity(); //System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + - // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** @@ -752,7 +749,7 @@ public class LibMatrixMult try { ExecutorService pool = Executors.newFixedThreadPool(k); - ArrayList<MatrixMultWDivTask> tasks = new ArrayList<MatrixMultWDivTask>(); + ArrayList<MatrixMultWDivTask> tasks = new ArrayList<MatrixMultWDivTask>(); //create tasks (for wdivmm-left, parallelization over columns; //for wdivmm-right, parallelization over rows; both ensure disjoint results) if( wt.isLeft() ) { @@ -827,7 +824,7 @@ public class LibMatrixMult ret.allocateDenseBlock(); try - { + { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultWCeTask> tasks = new ArrayList<MatrixMultWCeTask>(); int blklen = (int)(Math.ceil((double)mW.rlen/k)); @@ -1476,9 +1473,8 @@ public class LibMatrixMult private static void matrixMultUltraSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) throws DMLRuntimeException { - //TODO perf sparse block, consider iterators - - boolean leftUS = m1.isUltraSparse(); + final boolean leftUS = m1.isUltraSparse() + || (m1.isUltraSparse(false) && !m2.isUltraSparse()); final int m = m1.rlen; final int cd = m1.clen; final int n = m2.clen; @@ -1486,6 +1482,7 @@ public class LibMatrixMult if( leftUS ) //left is ultra-sparse (IKJ) { SparseBlock a = m1.sparseBlock; + SparseBlock c = ret.sparseBlock; boolean rightSparse = m2.sparse; for( int i=rl; i<ru; i++ ) @@ -1504,13 +1501,19 @@ public class LibMatrixMult if( !m2.sparseBlock.isEmpty(aix) ) { ret.rlen=m; ret.allocateSparseRowsBlock(false); //allocation on demand - ret.sparseBlock.set(i, m2.sparseBlock.get(aix), true); + boolean ldeep = (m2.sparseBlock instanceof SparseBlockMCSR); + ret.sparseBlock.set(i, m2.sparseBlock.get(aix), ldeep); ret.nonZeros += ret.sparseBlock.size(i); } } else { //dense right matrix (append all values) - for( int j=0; j<n; j++ ) - ret.appendValue(i, j, m2.quickGetValue(aix, j)); + int lnnz = (int)m2.recomputeNonZeros(aix, aix, 0, n-1); + if( lnnz > 0 ) { + c.allocate(i, lnnz); //allocate once + for( int j=0; j<n; j++ ) + c.append(i, j, m2.quickGetValue(aix, j)); + ret.nonZeros += lnnz; + } } } else //GENERAL CASE @@ -1536,13 +1539,13 @@ public class LibMatrixMult SparseBlock b = m2.sparseBlock; for(int k = 0; k < cd; k++ ) - { + { if( !b.isEmpty(k) ) { int bpos = b.pos(k); int blen = b.size(k); int[] bixs = b.indexes(k); - double[] bvals = b.values(k); + double[] bvals = b.values(k); for( int j=bpos; j<bpos+blen; j++ ) { double bval = bvals[j]; @@ -3552,6 +3555,14 @@ public class LibMatrixMult && m2.clen > k * 1024 && m1.rlen < k * 32 && !pm2r && 8*m1.rlen*m1.clen < 256*1024 ); //lhs fits in L2 cache } + + public static boolean isUltraSparseMatrixMult(MatrixBlock m1, MatrixBlock m2) { + //note: ultra-sparse matrix mult implies also sparse outputs, hence we need + //to be conservative an cannot use this for all ultra-sparse matrices. + return (m1.isUltraSparse() || m2.isUltraSparse()) //base case + || (m1.isUltraSparsePermutationMatrix() + && OptimizerUtils.getSparsity(m2.rlen, m2.clen, m2.nonZeros)<1.0); + } private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 ) throws DMLRuntimeException @@ -3643,17 +3654,16 @@ public class LibMatrixMult private static class MatrixMultTask implements Callable<Object> { - private MatrixBlock _m1 = null; - private MatrixBlock _m2 = null; + private final MatrixBlock _m1; + private final MatrixBlock _m2; private MatrixBlock _ret = null; - private boolean _tm2 = false; //transposed m2 - private boolean _pm2r = false; //par over m2 rows - private boolean _pm2c = false; //par over m2 rows - - private int _rl = -1; - private int _ru = -1; + private final boolean _tm2; //transposed m2 + private final boolean _pm2r; //par over m2 rows + private final boolean _pm2c; //par over m2 rows + private final int _rl; + private final int _ru; - protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2r, boolean pm2c, int rl, int ru ) { _m1 = m1; @@ -3687,7 +3697,7 @@ public class LibMatrixMult _ret.allocateDenseBlock(); //compute block matrix multiplication - if( _m1.isUltraSparse() || _m2.isUltraSparse() ) + if( _ret.sparse ) //ultra-sparse matrixMultUltraSparse(_m1, _m2, _ret, rl, ru); else if(!_m1.sparse && !_m2.sparse) matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2r, rl, ru, cl, cu); http://git-wip-us.apache.org/repos/asf/systemml/blob/0b5480b7/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index ff05fa0..8f9ae9e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -916,16 +916,28 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * * @return true if sparse */ - public boolean isInSparseFormat() - { + public boolean isInSparseFormat() { return sparse; } + + public boolean isUltraSparse() { + return isUltraSparse(true); + } - public boolean isUltraSparse() - { + public boolean isUltraSparse(boolean checkNnz) { double sp = ((double)nonZeros/rlen)/clen; //check for sparse representation in order to account for vectors in dense - return sparse && sp<ULTRA_SPARSITY_TURN_POINT && nonZeros<40; + return sparse && sp<ULTRA_SPARSITY_TURN_POINT && (!checkNnz || nonZeros<40); + } + + public boolean isUltraSparsePermutationMatrix() { + if( !isUltraSparse(false) ) + return false; + boolean isPM = true; + SparseBlock sblock = getSparseBlock(); + for( int i=0; i<rlen & isPM; i++ ) + isPM &= sblock.isEmpty(i) || sblock.size(i) == 1; + return isPM; } /**
