Repository: systemml
Updated Branches:
  refs/heads/master b67f18641 -> 083cc77f8


[SYSTEMML-1761] Performance wsloss w/o weights (sparsity exploitation)

So far the wsloss operator w/o weights for sum((X-U%*%t(V))^2) was not
sparsity-exploiting due to the missing sparse driver (sparse matrix with
sparse-safe operation such as multiply or divide). However, this
expression can be rewritten into a sparsity-exploiting form with 
sum((X-U%*%t(V))^2) -> sum(X^2) - sum(2*X*(U%*%t(V)) +
sum((t(U)%*%U)*(t(V)%*%V)). 

This patch leverages leverages this rewrite for a much more efficient,
sparsity-exploiting and cache-conscious block-level implementation. The
performance improvements of the entire wsloss operation were as follows:

100K x 100K, sparsity=0.1, rank=100: 92.5s -> 8.5s
100K x 100K, sparsity=0.01, rank=100: 92.2s -> 1.3s
100K x 100K, sparsity=0.001, rank=100: 92.1s -> 0.4s


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/083cc77f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/083cc77f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/083cc77f

Branch: refs/heads/master
Commit: 083cc77f82b70748b70f8cc311fde040034bcc7c
Parents: b67f186
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Tue Jul 11 21:41:20 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Tue Jul 11 22:28:52 2017 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 126 +++++++++++++------
 1 file changed, 88 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/083cc77f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 5996a51..da3b12b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -390,7 +390,7 @@ public class LibMatrixMult
                
                //check no parallelization benefit (fallback to sequential)
                //check too small workload in terms of flops (fallback to 
sequential too)
-               if( ret.rlen == 1 
+               if( ret.rlen == 1 || k <= 1
                        || leftTranspose && 1L * m1.rlen * m1.clen * m1.clen < 
PAR_MINFLOP_THRESHOLD
                        || !leftTranspose && 1L * m1.clen * m1.rlen * m1.rlen < 
PAR_MINFLOP_THRESHOLD) 
                { 
@@ -533,6 +533,10 @@ public class LibMatrixMult
                else
                        matrixMultWSLossGeneric(mX, mU, mV, mW, ret, wt, 0, 
mX.rlen);
                
+               //add correction for sparse wsloss w/o weight
+               if( mX.sparse && wt==WeightsType.NONE )
+                       addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, 1);
+               
                //System.out.println("MMWSLoss " +wt.toString()+ " 
("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x"
 +
                //                  
"("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+")
 in "+time.stop());
        }
@@ -572,6 +576,10 @@ public class LibMatrixMult
                        throw new DMLRuntimeException(e);
                }
 
+               //add correction for sparse wsloss w/o weight
+               if( mX.sparse && wt==WeightsType.NONE )
+                       addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, k);
+               
                //System.out.println("MMWSLoss "+wt.toString()+" k="+k+" 
("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x"
 +
                //                   
"("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+")
 in "+time.stop());
        }
@@ -2163,36 +2171,34 @@ public class LibMatrixMult
                // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
                else if( wt==WeightsType.NONE )
                {
-                       // approach: iterate over all cells of X and 
-                       for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) 
-                       {
-                               if( x.isEmpty(i) ) { //empty row
-                                       for( int j=0, vix=0; j<n; j++, vix+=cd) 
{
-                                               double uvij = dotProduct(u, v, 
uix, vix, cd);
-                                               wsloss += (-uvij)*(-uvij);
-                                       }
-                               }
-                               else { //non-empty row
-                                       int xpos = x.pos(i);
-                                       int xlen = x.size(i);
-                                       int[] xix = x.indexes(i);
-                                       double[] xval = x.values(i);
-                                       int last = -1;
-                                       for( int k=xpos; k<xpos+xlen; k++ ) {
-                                               //process last nnz til current 
nnz
-                                               for( int k2=last+1; k2<xix[k]; 
k2++ ){
-                                                       double uvij = 
dotProduct(u, v, uix, k2*cd, cd);
-                                                       wsloss += 
(-uvij)*(-uvij);                                                      
+                       //approach: use sparsity-exploiting pattern rewrite 
sum((X-(U%*%t(V)))^2) 
+                       //-> 
sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each
+                       //parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) 
and the last term
+                       //sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two 
tsmm operations.
+                       
+                       final int blocksizeIJ = (int) 
(8L*mX.rlen*mX.clen/mX.nonZeros); 
+                       int[] curk = new int[blocksizeIJ];                      
+                       
+                       for( int bi=rl; bi<ru; bi+=blocksizeIJ ) {
+                               int bimin = Math.min(ru, bi+blocksizeIJ);
+                               //prepare starting indexes for block row
+                               Arrays.fill(curk, 0); 
+                               //blocked execution over column blocks
+                               for( int bj=0; bj<n; bj+=blocksizeIJ ) {
+                                       int bjmin = Math.min(n, bj+blocksizeIJ);
+                                       for( int i=bi, uix=bi*cd; i<bimin; i++, 
uix+=cd ) {
+                                               if( x.isEmpty(i) ) continue; 
+                                               int xpos = x.pos(i);
+                                               int xlen = x.size(i);
+                                               int[] xix = x.indexes(i);
+                                               double[] xval = x.values(i);
+                                               int k = xpos + curk[i-bi];
+                                               for( ; k<xpos+xlen && 
xix[k]<bjmin; k++ ) {
+                                                       double xij = xval[k];
+                                                       double uvij = 
dotProduct(u, v, uix, xix[k]*cd, cd);
+                                                       wsloss += xij * xij - 2 
* xij * uvij;
                                                }
-                                               //process current nnz
-                                               double uvij = dotProduct(u, v, 
uix, xix[k]*cd, cd);
-                                               wsloss += 
(xval[k]-uvij)*(xval[k]-uvij);
-                                               last = xix[k];
-                                       }
-                                       //process last nnz til end of row
-                                       for( int k2=last+1; k2<n; k2++ ) { 
-                                               double uvij = dotProduct(u, v, 
uix, k2*cd, cd);
-                                               wsloss += (-uvij)*(-uvij);      
                                                
+                                               curk[i-bi] = k - xpos;
                                        }
                                }
                        }
@@ -2291,18 +2297,52 @@ public class LibMatrixMult
                // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
                else if( wt==WeightsType.NONE )
                {
-                       // approach: iterate over all cells of X and 
-                       for( int i=rl; i<ru; i++ )
-                               for( int j=0; j<n; j++)
-                               {
-                                       double xij = mX.quickGetValue(i, j);
-                                       double uvij = dotProductGeneric(mU, mV, 
i, j, cd);
-                                       wsloss += (xij-uvij)*(xij-uvij);
-                               }
+                       //approach: use sparsity-exploiting pattern rewrite 
sum((X-(U%*%t(V)))^2) 
+                       //-> 
sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each
+                       //parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) 
and the last term
+                       //sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two 
tsmm operations.
+                       
+                       if( mW.sparse ) { //SPARSE
+                               SparseBlock x = mX.sparseBlock;
+                               for( int i=rl; i<ru; i++ ) {
+                                       if( x.isEmpty(i) ) continue;
+                                       int xpos = x.pos(i);
+                                       int xlen = x.size(i);
+                                       int[] xix = x.indexes(i);
+                                       double[] xval = x.values(i);
+                                       for( int k=xpos; k<xpos+xlen; k++ ) {
+                                               double xij = xval[k];
+                                               double uvij = 
dotProductGeneric(mU, mV, i, xix[k], cd);
+                                               wsloss += xij * xij - 2 * xij * 
uvij;
+                                       }
+                               }       
+                       }
+                       else { //DENSE
+                               double[] x = mX.denseBlock;
+                               for( int i=rl, xix=rl*n; i<ru; i++, xix+=n )
+                                       for( int j=0; j<n; j++)
+                                               if( x[xix+j] != 0 ) {
+                                                       double xij = x[xix+j];
+                                                       double uvij = 
dotProductGeneric(mU, mV, i, j, cd);
+                                                       wsloss += xij * xij - 2 
* xij * uvij;
+                                               }
+                       }
                }
 
                ret.quickSetValue(0, 0, wsloss);
        }
+       
+       private static void addMatrixMultWSLossNoWeightCorrection(MatrixBlock 
mU, MatrixBlock mV, MatrixBlock ret, int k) 
+               throws DMLRuntimeException 
+       {
+               MatrixBlock tmp1 = new MatrixBlock(mU.clen, mU.clen, false);
+               MatrixBlock tmp2 = new MatrixBlock(mU.clen, mU.clen, false);
+               matrixMultTransposeSelf(mU, tmp1, true, k);
+               matrixMultTransposeSelf(mV, tmp2, true, k);
+               ret.quickSetValue(0, 0, ret.quickGetValue(0, 0) + 
+                       ((tmp1.sparse || tmp2.sparse) ? dotProductGeneric(tmp1, 
tmp2) :
+                       dotProduct(tmp1.denseBlock, tmp2.denseBlock, 
mU.clen*mU.clen)));
+       }
 
        private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock 
mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) 
                throws DMLRuntimeException 
@@ -3405,6 +3445,16 @@ public class LibMatrixMult
                return val;
        }
        
+       private static double dotProductGeneric(MatrixBlock a, MatrixBlock b)
+       {
+               double val = 0;
+               for( int i=0; i<a.getNumRows(); i++ )
+                       for( int j=0; j<a.getNumColumns(); j++ )
+                               val += a.quickGetValue(i, j) * 
b.quickGetValue(i, j);
+               
+               return val;
+       }
+       
        /**
         * Used for all version of TSMM where the result is known to be 
symmetric.
         * Hence, we compute only the upper triangular matrix and copy this 
partial

Reply via email to