[SYSTEMML-809] Fix result correctness sparse-dense wdivmm matrix mult

This patch essentially reverts commit
fceb6620e7bc29a097dd7ac6057f963b3a2ff235, which introduced a correctness
issue due to wrong maintenance of current positions (curk[i-bi2] instead
of curk[i-bi]). After the fix, the previous "performance improvement"
turned out to be ineffective. Hence, we revert this change and fix our
testsuite accordingly. The issue did not show up because the test data
was too small, i.e., both blocking levels were larger than the block per
thread.

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cc5d22c1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cc5d22c1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cc5d22c1

Branch: refs/heads/master
Commit: cc5d22c1d03515fa7e903449f6da227f2fb9bd76
Parents: 23de8e8
Author: Matthias Boehm <[email protected]>
Authored: Fri Aug 5 22:59:49 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Aug 7 12:30:40 2016 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 98 ++++++++++----------
 .../quaternary/WeightedDivMatrixMultTest.java   | 14 ++-
 2 files changed, 58 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc5d22c1/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index b5fe663..302a1df 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -2738,18 +2738,17 @@ public class LibMatrixMult
                //approach: iterate over non-zeros of w, selective mm 
computation
                //blocked over ij, while maintaining front of column indexes, 
where the
                //blocksize is chosen such that we reuse each  Ui/Vj vector on 
average 8 times,
-               //we use an additional ij blocking to ensure that Ui/Vj as well 
as the output
-               //in case of wdivmm_left fit into L2 cache (avoid L2/LLC 
misses).
-               final int blocksizeIJ = (int) (8L*mW.rlen*mW.clen/mW.nonZeros);
-               final int blocksizeIJ2 = L2_CACHESIZE / (2*mU.clen*8); //mU 
guaranteed <=1000
+               //with custom blocksizeJ for wdivmm_left to avoid LLC misses on 
output.
+               final int blocksizeI = (int) (8L*mW.rlen*mW.clen/mW.nonZeros);
+               final int blocksizeJ = left ? 
Math.max(8,Math.min(L2_CACHESIZE/(mU.clen*8), blocksizeI)) : blocksizeI;
                
-               int[] curk = new int[blocksizeIJ];              
-               boolean[] aligned = (four&&!scalar) ? new boolean[blocksizeIJ] 
: null;
+               int[] curk = new int[blocksizeI];               
+               boolean[] aligned = (four&&!scalar) ? new boolean[blocksizeI] : 
null;
                
-               //blocked execution over row/column blocks
-               for( int bi=rl; bi<ru; bi+=blocksizeIJ ) 
+               //blocked execution over row blocks
+               for( int bi=rl; bi<ru; bi+=blocksizeI ) 
                {
-                       int bimin = Math.min(ru, bi+blocksizeIJ);
+                       int bimin = Math.min(ru, bi+blocksizeI);
                        //prepare starting indexes for block row
                        for( int i=bi; i<bimin; i++ ) {
                                int k = (cl==0||w.isEmpty(i)) ? 0 : 
w.posFIndexGTE(i,cl);
@@ -2760,51 +2759,48 @@ public class LibMatrixMult
                                for( int i=bi; i<bimin; i++ )
                                        aligned[i-bi] = w.isAligned(i-bi, x);
                        
-                       for( int bj=cl; bj<cu; bj+=blocksizeIJ )  
+                       //blocked execution over column blocks
+                       for( int bj=cl; bj<cu; bj+=blocksizeJ )  
                        {
-                               //blocked execution over row/column blocks for 
L2
-                               for( int bi2=bi, bjmin=Math.min(cu, 
bj+blocksizeIJ); bi2<bimin; bi2+=blocksizeIJ2)                                  
    
-                                       for( int bj2=bj, bimin2=Math.min(bimin, 
bi2+blocksizeIJ2); bj2<bjmin; bj2+=blocksizeIJ2 ) {
-                                                       
-                                               //core wdivmm block matrix mult
-                                               for( int i=bi2, uix=bi2*cd, 
bjmin2=Math.min(bjmin, bj2+blocksizeIJ2); i<bimin2; i++, uix+=cd ) {
-                                                       if( w.isEmpty(i) ) 
continue;
-                                                       
-                                                       int wpos = w.pos(i);
-                                                       int wlen = w.size(i);
-                                                       int[] wix = 
w.indexes(i);
-                                                       double[] wval = 
w.values(i);                            
-                                                       
-                                                       int k = wpos + 
curk[i-bi2];
-                                                       if( basic ) {
-                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
-                                                                       
ret.appendValue( i, wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd));
-                                                       }
-                                                       else if( four ) { 
//left/right
-                                                               //checking 
alignment per row is ok because early abort if false, 
-                                                               //row nnz 
likely fit in L1/L2 cache, and asymptotically better if aligned
-                                                               if( !scalar && 
w.isAligned(i, x) ) {
-                                                                       //O(n) 
where n is nnz in w/x 
-                                                                       
double[] xvals = x.values(i);
-                                                                       for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
-                                                                               
wdivmm(wval[k], xvals[k], u, v, c, uix, wix[k]*cd, left, scalar, cd);
-                                                               }
-                                                               else {
-                                                                       
//scalar or O(n log m) where n/m are nnz in w/x
-                                                                       for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
-                                                                               
if (scalar)
-                                                                               
        wdivmm(wval[k], eps, u, v, c, uix, wix[k]*cd, left, scalar, cd);
-                                                                               
else
-                                                                               
        wdivmm(wval[k], x.get(i, wix[k]), u, v, c, uix, wix[k]*cd, left, 
scalar, cd);
-                                                               }
-                                                       }
-                                                       else { //left/right 
minus default
-                                                               for( ; 
k<wpos+wlen && wix[k]<bjmin2; k++ )
-                                                                       
wdivmm(wval[k], u, v, c, uix, wix[k]*cd, left, mult, minus, cd);
-                                                       }
-                                                       curk[i-bi2] = k - wpos;
+                               int bjmin = Math.min(cu, bj+blocksizeJ);
+                               //core wdivmm block matrix mult
+                               for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd 
) {
+                                       if( w.isEmpty(i) ) continue;
+                                       
+                                       int wpos = w.pos(i);
+                                       int wlen = w.size(i);
+                                       int[] wix = w.indexes(i);
+                                       double[] wval = w.values(i);            
                
+                                       
+                                       int k = wpos + curk[i-bi];
+                                       if( basic ) {
+                                               for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
+                                                       ret.appendValue( i, 
wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd));
+                                       }
+                                       else if( four ) { //left/right
+                                               //checking alignment per row is 
ok because early abort if false, 
+                                               //row nnz likely fit in L1/L2 
cache, and asymptotically better if aligned
+                                               if( !scalar && w.isAligned(i, 
x) ) {
+                                                       //O(n) where n is nnz 
in w/x 
+                                                       double[] xvals = 
x.values(i);
+                                                       for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
+                                                               wdivmm(wval[k], 
xvals[k], u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                                               }
+                                               else {
+                                                       //scalar or O(n log m) 
where n/m are nnz in w/x
+                                                       for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
+                                                               if (scalar)
+                                                                       
wdivmm(wval[k], eps, u, v, c, uix, wix[k]*cd, left, scalar, cd);
+                                                               else
+                                                                       
wdivmm(wval[k], x.get(i, wix[k]), u, v, c, uix, wix[k]*cd, left, scalar, cd);
                                                }
                                        }
+                                       else { //left/right minus default
+                                               for( ; k<wpos+wlen && 
wix[k]<bjmin; k++ )
+                                                       wdivmm(wval[k], u, v, 
c, uix, wix[k]*cd, left, mult, minus, cd);
+                                       }
+                                       curk[i-bi] = k - wpos;
+                               }
                        }
                }
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc5d22c1/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
index 6ff8de6..d8053b3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
@@ -66,9 +66,13 @@ public class WeightedDivMatrixMultTest extends 
AutomatedTestBase
        private final static double eps = 1e-6;
        private final static double div_eps = 0.1;
        
-       private final static int rows = 1201;
-       private final static int cols = 1103;
-       private final static int rank = 10;
+       private final static int rows1 = 1201;
+       private final static int cols1 = 1103;
+       private final static int rows2 = 3401;
+       private final static int cols2 = 2403;
+       private final static int rank1 = 10;
+       private final static int rank2 = 100;
+       
        private final static double spSparse = 0.001;
        private final static double spDense = 0.6;
        
@@ -617,6 +621,10 @@ public class WeightedDivMatrixMultTest extends 
AutomatedTestBase
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir() + " " + div_eps;
        
+                       int rows = sparse ? rows2 : rows1;
+                       int cols = sparse ? cols2 : cols1;
+                       int rank = sparse ? rank2 : rank1;
+                       
                        //generate actual dataset 
                        double[][] W = getRandomMatrix(rows, cols, 0, 1, 
sparsity, 7); 
                        writeInputMatrixWithMTD("W", W, true);

Reply via email to