Repository: systemml Updated Branches: refs/heads/master b67f18641 -> 083cc77f8
[SYSTEMML-1761] Performance wsloss w/o weights (sparsity exploitation) So far the wsloss operator w/o weights for sum((X-U%*%t(V))^2) was not sparsity-exploiting due to the missing sparse driver (sparse matrix with sparse-safe operation such as multiply or divide). However, this expression can be rewritten into a sparsity-exploiting form with sum((X-U%*%t(V))^2) -> sum(X^2) - sum(2*X*(U%*%t(V)) + sum((t(U)%*%U)*(t(V)%*%V)). This patch leverages leverages this rewrite for a much more efficient, sparsity-exploiting and cache-conscious block-level implementation. The performance improvements of the entire wsloss operation were as follows: 100K x 100K, sparsity=0.1, rank=100: 92.5s -> 8.5s 100K x 100K, sparsity=0.01, rank=100: 92.2s -> 1.3s 100K x 100K, sparsity=0.001, rank=100: 92.1s -> 0.4s Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/083cc77f Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/083cc77f Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/083cc77f Branch: refs/heads/master Commit: 083cc77f82b70748b70f8cc311fde040034bcc7c Parents: b67f186 Author: Matthias Boehm <mboe...@gmail.com> Authored: Tue Jul 11 21:41:20 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Tue Jul 11 22:28:52 2017 -0700 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 126 +++++++++++++------ 1 file changed, 88 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/083cc77f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 5996a51..da3b12b 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -390,7 +390,7 @@ public class LibMatrixMult //check no parallelization benefit (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) - if( ret.rlen == 1 + if( ret.rlen == 1 || k <= 1 || leftTranspose && 1L * m1.rlen * m1.clen * m1.clen < PAR_MINFLOP_THRESHOLD || !leftTranspose && 1L * m1.clen * m1.rlen * m1.rlen < PAR_MINFLOP_THRESHOLD) { @@ -533,6 +533,10 @@ public class LibMatrixMult else matrixMultWSLossGeneric(mX, mU, mV, mW, ret, wt, 0, mX.rlen); + //add correction for sparse wsloss w/o weight + if( mX.sparse && wt==WeightsType.NONE ) + addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, 1); + //System.out.println("MMWSLoss " +wt.toString()+ " ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } @@ -572,6 +576,10 @@ public class LibMatrixMult throw new DMLRuntimeException(e); } + //add correction for sparse wsloss w/o weight + if( mX.sparse && wt==WeightsType.NONE ) + addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, k); + //System.out.println("MMWSLoss "+wt.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } @@ -2163,36 +2171,34 @@ public class LibMatrixMult // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) else if( wt==WeightsType.NONE ) { - // approach: iterate over all cells of X and - for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) - { - if( x.isEmpty(i) ) { //empty row - for( int j=0, vix=0; j<n; j++, vix+=cd) { - double uvij = dotProduct(u, v, uix, vix, cd); - wsloss += (-uvij)*(-uvij); - } - } - else { //non-empty row - int xpos = x.pos(i); - int xlen = x.size(i); - int[] xix = x.indexes(i); - double[] xval = x.values(i); - int last = -1; - for( int k=xpos; k<xpos+xlen; k++ ) { - //process last nnz til current nnz - for( int k2=last+1; k2<xix[k]; k2++ ){ - double uvij = dotProduct(u, v, uix, k2*cd, cd); - wsloss += (-uvij)*(-uvij); + //approach: use sparsity-exploiting pattern rewrite sum((X-(U%*%t(V)))^2) + //-> sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each + //parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) and the last term + //sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two tsmm operations. + + final int blocksizeIJ = (int) (8L*mX.rlen*mX.clen/mX.nonZeros); + int[] curk = new int[blocksizeIJ]; + + for( int bi=rl; bi<ru; bi+=blocksizeIJ ) { + int bimin = Math.min(ru, bi+blocksizeIJ); + //prepare starting indexes for block row + Arrays.fill(curk, 0); + //blocked execution over column blocks + for( int bj=0; bj<n; bj+=blocksizeIJ ) { + int bjmin = Math.min(n, bj+blocksizeIJ); + for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) { + if( x.isEmpty(i) ) continue; + int xpos = x.pos(i); + int xlen = x.size(i); + int[] xix = x.indexes(i); + double[] xval = x.values(i); + int k = xpos + curk[i-bi]; + for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) { + double xij = xval[k]; + double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); + wsloss += xij * xij - 2 * xij * uvij; } - //process current nnz - double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); - wsloss += (xval[k]-uvij)*(xval[k]-uvij); - last = xix[k]; - } - //process last nnz til end of row - for( int k2=last+1; k2<n; k2++ ) { - double uvij = dotProduct(u, v, uix, k2*cd, cd); - wsloss += (-uvij)*(-uvij); + curk[i-bi] = k - xpos; } } } @@ -2291,18 +2297,52 @@ public class LibMatrixMult // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) else if( wt==WeightsType.NONE ) { - // approach: iterate over all cells of X and - for( int i=rl; i<ru; i++ ) - for( int j=0; j<n; j++) - { - double xij = mX.quickGetValue(i, j); - double uvij = dotProductGeneric(mU, mV, i, j, cd); - wsloss += (xij-uvij)*(xij-uvij); - } + //approach: use sparsity-exploiting pattern rewrite sum((X-(U%*%t(V)))^2) + //-> sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each + //parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) and the last term + //sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two tsmm operations. + + if( mW.sparse ) { //SPARSE + SparseBlock x = mX.sparseBlock; + for( int i=rl; i<ru; i++ ) { + if( x.isEmpty(i) ) continue; + int xpos = x.pos(i); + int xlen = x.size(i); + int[] xix = x.indexes(i); + double[] xval = x.values(i); + for( int k=xpos; k<xpos+xlen; k++ ) { + double xij = xval[k]; + double uvij = dotProductGeneric(mU, mV, i, xix[k], cd); + wsloss += xij * xij - 2 * xij * uvij; + } + } + } + else { //DENSE + double[] x = mX.denseBlock; + for( int i=rl, xix=rl*n; i<ru; i++, xix+=n ) + for( int j=0; j<n; j++) + if( x[xix+j] != 0 ) { + double xij = x[xix+j]; + double uvij = dotProductGeneric(mU, mV, i, j, cd); + wsloss += xij * xij - 2 * xij * uvij; + } + } } ret.quickSetValue(0, 0, wsloss); } + + private static void addMatrixMultWSLossNoWeightCorrection(MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, int k) + throws DMLRuntimeException + { + MatrixBlock tmp1 = new MatrixBlock(mU.clen, mU.clen, false); + MatrixBlock tmp2 = new MatrixBlock(mU.clen, mU.clen, false); + matrixMultTransposeSelf(mU, tmp1, true, k); + matrixMultTransposeSelf(mV, tmp2, true, k); + ret.quickSetValue(0, 0, ret.quickGetValue(0, 0) + + ((tmp1.sparse || tmp2.sparse) ? dotProductGeneric(tmp1, tmp2) : + dotProduct(tmp1.denseBlock, tmp2.denseBlock, mU.clen*mU.clen))); + } private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException @@ -3405,6 +3445,16 @@ public class LibMatrixMult return val; } + private static double dotProductGeneric(MatrixBlock a, MatrixBlock b) + { + double val = 0; + for( int i=0; i<a.getNumRows(); i++ ) + for( int j=0; j<a.getNumColumns(); j++ ) + val += a.quickGetValue(i, j) * b.quickGetValue(i, j); + + return val; + } + /** * Used for all version of TSMM where the result is known to be symmetric. * Hence, we compute only the upper triangular matrix and copy this partial