Repository: incubator-systemml
Updated Branches:
  refs/heads/master 3841ca88e -> fceb6620e


[SYSTEMML-809] Performance sparse-dense wdivmm (multi-level blocking)

So far we used a best-effort cache blocking to ensure row reuse.
However, with high sparsity and large factors this led to severe cache
thrashing effects. We now use a multi-level cache blocking: best effort
for reuse and subsequent blocking for L2 cache size. This also
generalizes the custom blocking for wdivmm_left. 

The performance improvements on 200k x 200k, sparsity=0.001, rank=100
were substantial for both wdivmm left (2400ms -> 540ms) and wdivmm_right
(1500ms -> 600ms).

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/fceb6620
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/fceb6620
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/fceb6620

Branch: refs/heads/master
Commit: fceb6620e7bc29a097dd7ac6057f963b3a2ff235
Parents: 3841ca8
Author: Matthias Boehm <[email protected]>
Authored: Sun Jul 24 23:30:33 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jul 24 23:30:33 2016 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 99 +++++++++++---------
 1 file changed, 54 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/fceb6620/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 57ce557..b5fe663 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -68,9 +68,10 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 public class LibMatrixMult 
 {
        //internal configuration
-       public static final boolean LOW_LEVEL_OPTIMIZATION = true;
-       public static final long MEM_OVERHEAD_THRESHOLD = 2L*1024*1024; //MAX 2 
MB
+       private static final boolean LOW_LEVEL_OPTIMIZATION = true;
+       private static final long MEM_OVERHEAD_THRESHOLD = 2L*1024*1024; //MAX 
2 MB
        private static final long PAR_MINFLOP_THRESHOLD = 2L*1024*1024; //MIN 2 
MFLOP
+       private static final int L2_CACHESIZE = 256 *1024; //256KB (common size)
        
        private LibMatrixMult() {
                //prevent instantiation via private constructor
@@ -2737,15 +2738,18 @@ public class LibMatrixMult
                //approach: iterate over non-zeros of w, selective mm 
computation
                //blocked over ij, while maintaining front of column indexes, 
where the
                //blocksize is chosen such that we reuse each  Ui/Vj vector on 
average 8 times,
-               //with custom blocksizeJ for wdivmm_left to avoid LLC misses on 
output.
-               final int blocksizeI = (int) (8L*mW.rlen*mW.clen/mW.nonZeros);
-               final int blocksizeJ = left ? 
Math.max(8,Math.min(256*1024/(mU.clen*8), blocksizeI)) : blocksizeI; 
-               int[] curk = new int[blocksizeI];               
-               boolean[] aligned = (four&&!scalar) ? new boolean[blocksizeI] : 
null;
+               //we use an additional ij blocking to ensure that Ui/Vj as well 
as the output
+               //in case of wdivmm_left fit into L2 cache (avoid L2/LLC 
misses).
+               final int blocksizeIJ = (int) (8L*mW.rlen*mW.clen/mW.nonZeros);
+               final int blocksizeIJ2 = L2_CACHESIZE / (2*mU.clen*8); //mU 
guaranteed <=1000
                
-               for( int bi = rl; bi < ru; bi+=blocksizeI ) 
+               int[] curk = new int[blocksizeIJ];              
+               boolean[] aligned = (four&&!scalar) ? new boolean[blocksizeIJ] 
: null;
+               
+               //blocked execution over row/column blocks
+               for( int bi=rl; bi<ru; bi+=blocksizeIJ ) 
                {
-                       int bimin = Math.min(ru, bi+blocksizeI);
+                       int bimin = Math.min(ru, bi+blocksizeIJ);
                        //prepare starting indexes for block row
                        for( int i=bi; i<bimin; i++ ) {
                                int k = (cl==0||w.isEmpty(i)) ? 0 : 
w.posFIndexGTE(i,cl);
@@ -2755,47 +2759,52 @@ public class LibMatrixMult
                        if( four && !scalar )
                                for( int i=bi; i<bimin; i++ )
                                        aligned[i-bi] = w.isAligned(i-bi, x);
-                       //blocked execution over column blocks
-                       for( int bj = cl; bj < cu; bj+=blocksizeJ ) 
+                       
+                       for( int bj=cl; bj<cu; bj+=blocksizeIJ )  
                        {
-                               int bjmin = Math.min(cu, bj+blocksizeJ);
-                               for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd 
) {
-                                       if( !w.isEmpty(i) ) {
-                                               int wpos = w.pos(i);
-                                               int wlen = w.size(i);
-                                               int[] wix = w.indexes(i);
-                                               double[] wval = w.values(i);    
                        
-                                               
-                                               int k = wpos + curk[i-bi];
-                                               if( basic ) {
-                                                       for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
-                                                               
ret.appendValue( i, wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd));
-                                               }
-                                               else if( four ) { //left/right
-                                                       //checking alignment 
per row is ok because early abort if false, 
-                                                       //row nnz likely fit in 
L1/L2 cache, and asymptotically better if aligned
-                                                       if( !scalar && 
w.isAligned(i, x) ) {
-                                                               //O(n) where n 
is nnz in w/x 
-                                                               double[] xvals 
= x.values(i);
-                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin; k++ )
-                                                                       
wdivmm(wval[k], xvals[k], u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                               //blocked execution over row/column blocks for 
L2
+                               for( int bi2=bi, bjmin=Math.min(cu, 
bj+blocksizeIJ); bi2<bimin; bi2+=blocksizeIJ2)                                  
    
+                                       for( int bj2=bj, bimin2=Math.min(bimin, 
bi2+blocksizeIJ2); bj2<bjmin; bj2+=blocksizeIJ2 ) {
+                                                       
+                                               //core wdivmm block matrix mult
+                                               for( int i=bi2, uix=bi2*cd, 
bjmin2=Math.min(bjmin, bj2+blocksizeIJ2); i<bimin2; i++, uix+=cd ) {
+                                                       if( w.isEmpty(i) ) 
continue;
+                                                       
+                                                       int wpos = w.pos(i);
+                                                       int wlen = w.size(i);
+                                                       int[] wix = 
w.indexes(i);
+                                                       double[] wval = 
w.values(i);                            
+                                                       
+                                                       int k = wpos + 
curk[i-bi2];
+                                                       if( basic ) {
+                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
+                                                                       
ret.appendValue( i, wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd));
                                                        }
-                                                       else {
-                                                               //scalar or O(n 
log m) where n/m are nnz in w/x
-                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin; k++ )
-                                                                       if 
(scalar)
-                                                                               
wdivmm(wval[k], eps, u, v, c, uix, wix[k]*cd, left, scalar, cd);
-                                                                       else
-                                                                               
wdivmm(wval[k], x.get(i, wix[k]), u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                                                       else if( four ) { 
//left/right
+                                                               //checking 
alignment per row is ok because early abort if false, 
+                                                               //row nnz 
likely fit in L1/L2 cache, and asymptotically better if aligned
+                                                               if( !scalar && 
w.isAligned(i, x) ) {
+                                                                       //O(n) 
where n is nnz in w/x 
+                                                                       
double[] xvals = x.values(i);
+                                                                       for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
+                                                                               
wdivmm(wval[k], xvals[k], u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                                                               }
+                                                               else {
+                                                                       
//scalar or O(n log m) where n/m are nnz in w/x
+                                                                       for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
+                                                                               
if (scalar)
+                                                                               
        wdivmm(wval[k], eps, u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                                                                               
else
+                                                                               
        wdivmm(wval[k], x.get(i, wix[k]), u, v, c, uix, wix[k]*cd, left, 
scalar, cd);
+                                                               }
                                                        }
+                                                       else { //left/right 
minus default
+                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
+                                                                       
wdivmm(wval[k], u, v, c, uix, wix[k]*cd, left, mult, minus, cd);
+                                                       }
+                                                       curk[i-bi2] = k - wpos;
                                                }
-                                               else { //left/right minus 
default
-                                                       for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
-                                                               wdivmm(wval[k], 
u, v, c, uix, wix[k]*cd, left, mult, minus, cd);
-                                               }
-                                               curk[i-bi] = k - wpos;
                                        }
-                               }
                        }
                }
        }

Reply via email to