[SYSTEMML-552] Cache-conscious sparse-dense wsloss (post_nz pattern)

Similar to cache-conscious sparse-dense wdivmm, wsloss also showed room
for performance improvements with regard to large factors. This patch
introduces cache-conscious sparse-dense operations for the pattern
post_nz, e.g., sum (ppred(X,0, "!=") * (U %*% t(V) - X) ^ 2). On a
scenario with 100k x 100k, sp=0.01 this change led to the following
improvements: for rank=50, 3.6s -> 1.4s; rank=10, 650ms -> 520ms.

Furthermore, this patch also makes various cleanups with regard to
multi-threaded operations: (1) error checking via futures for wsloss and
wcemm, and (2) type safe result/nnz aggregation for wsloss, wcemm, and
wdivmm.   

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cfc561ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cfc561ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cfc561ee

Branch: refs/heads/master
Commit: cfc561eecc8dbc78a943532e19574b02600637c0
Parents: 63a56b0
Author: Matthias Boehm <[email protected]>
Authored: Tue Mar 8 20:27:17 2016 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Tue Mar 8 20:27:17 2016 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 102 ++++++++++---------
 1 file changed, 53 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cfc561ee/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index d18f13d..d56b018 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -24,6 +24,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -584,16 +585,16 @@ public class LibMatrixMult
                try 
                {                       
                        ExecutorService pool = Executors.newFixedThreadPool(k);
-                       ArrayList<ScalarResultTask> tasks = new 
ArrayList<ScalarResultTask>();
+                       ArrayList<MatrixMultWSLossTask> tasks = new 
ArrayList<MatrixMultWSLossTask>();
                        int blklen = (int)(Math.ceil((double)mX.rlen/k));
                        for( int i=0; i<k & i*blklen<mX.rlen; i++ )
                                tasks.add(new MatrixMultWSLossTask(mX, mU, mV, 
mW, wt, i*blklen, Math.min((i+1)*blklen, mX.rlen)));
-                       pool.invokeAll(tasks);
+                       List<Future<Double>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
                        //aggregate partial results
-                       sumScalarResults(tasks, ret);
+                       sumScalarResults(taskret, ret);
                } 
-               catch (InterruptedException e) {
+               catch( Exception e ) {
                        throw new DMLRuntimeException(e);
                }
 
@@ -795,12 +796,12 @@ public class LibMatrixMult
                                        tasks.add(new MatrixMultWDivTask(mW, 
mU, mV, mX, ret, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen), 0, mW.clen));
                        }
                        //execute tasks
-                       List<Future<Object>> taskret = pool.invokeAll(tasks);
+                       List<Future<Long>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
                        //aggregate partial nnz and check for errors
                        ret.nonZeros = 0;  //reset after execute
-                       for( Future<Object> task : taskret )
-                               ret.nonZeros += (Long)task.get();
+                       for( Future<Long> task : taskret )
+                               ret.nonZeros += task.get();
                } 
                catch (Exception e) {
                        throw new DMLRuntimeException(e);
@@ -877,16 +878,16 @@ public class LibMatrixMult
                try 
                {                       
                        ExecutorService pool = Executors.newFixedThreadPool(k);
-                       ArrayList<ScalarResultTask> tasks = new 
ArrayList<ScalarResultTask>();
+                       ArrayList<MatrixMultWCeTask> tasks = new 
ArrayList<MatrixMultWCeTask>();
                        int blklen = (int)(Math.ceil((double)mW.rlen/k));
                        for( int i=0; i<k & i*blklen<mW.rlen; i++ )
                                tasks.add(new MatrixMultWCeTask(mW, mU, mV, wt, 
i*blklen, Math.min((i+1)*blklen, mW.rlen)));
-                       pool.invokeAll(tasks);
+                       List<Future<Double>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
                        //aggregate partial results
-                       sumScalarResults(tasks, ret);
+                       sumScalarResults(taskret, ret);
                } 
-               catch (InterruptedException e) {
+               catch( Exception e ) {
                        throw new DMLRuntimeException(e);
                }
                
@@ -2230,17 +2231,34 @@ public class LibMatrixMult
                else if( wt==WeightsType.POST_NZ )
                {
                        // approach: iterate over W, point-wise in order to 
exploit sparsity
-                       for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd )
-                               if( !x.isEmpty(i) ) {
-                                       int xpos = x.pos(i);
-                                       int xlen = x.size(i);
-                                       int[] xix = x.indexes(i);
-                                       double[] xval = x.values(i);
-                                       for( int k=xpos; k<xpos+xlen; k++ ) {
-                                               double uvij = dotProduct(u, v, 
uix, xix[k]*cd, cd);
-                                               wsloss += 
(xval[k]-uvij)*(xval[k]-uvij);
+                       // blocked over ij, while maintaining front of column 
indexes, where the
+                       // blocksize is chosen such that we reuse each vector 
on average 8 times.
+                       final int blocksizeIJ = (int) 
(8L*mX.rlen*mX.clen/mX.nonZeros); 
+                       int[] curk = new int[blocksizeIJ];                      
+                       
+                       for( int bi=rl; bi<ru; bi+=blocksizeIJ ) {
+                               int bimin = Math.min(ru, bi+blocksizeIJ);
+                               //prepare starting indexes for block row
+                               Arrays.fill(curk, 0); 
+                               //blocked execution over column blocks
+                               for( int bj=0; bj<n; bj+=blocksizeIJ ) {
+                                       int bjmin = Math.min(n, bj+blocksizeIJ);
+                                       for( int i=bi, uix=bi*cd; i<bimin; i++, 
uix+=cd ) {
+                                               if( !x.isEmpty(i) ) {
+                                                       int xpos = x.pos(i);
+                                                       int xlen = x.size(i);
+                                                       int[] xix = 
x.indexes(i);
+                                                       double[] xval = 
x.values(i);
+                                                       int k = xpos + 
curk[i-bi];
+                                                       for( ; k<xpos+xlen && 
xix[k]<bjmin; k++ ) {
+                                                               double uvij = 
dotProduct(u, v, uix, xix[k]*cd, cd);
+                                                               wsloss += 
(xval[k]-uvij)*(xval[k]-uvij);
+                                                       }
+                                                       curk[i-bi] = k - xpos;
+                                               }
                                        }
-                               }       
+                               }
+                       }
                }
                // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
                else if( wt==WeightsType.PRE )
@@ -2650,7 +2668,7 @@ public class LibMatrixMult
                SparseBlock x = (mX==null) ? null : mX.sparseBlock;
                
                //approach: iterate over non-zeros of w, selective mm 
computation
-               //blocked over ij, while maintaining font of column indexes, 
where the
+               //blocked over ij, while maintaining front of column indexes, 
where the
                //blocksize is chosen such that we reuse each vector on average 
8 times.
                final int blocksizeIJ = (int) (8L*mW.rlen*mW.clen/mW.nonZeros); 
                int[] curk = new int[blocksizeIJ];              
@@ -3985,13 +4003,16 @@ public class LibMatrixMult
         * 
         * @param tasks
         * @param ret
+        * @throws ExecutionException 
+        * @throws InterruptedException 
         */
-       private static void sumScalarResults(ArrayList<ScalarResultTask> tasks, 
MatrixBlock ret)
+       private static void sumScalarResults(List<Future<Double>> tasks, 
MatrixBlock ret) 
+               throws InterruptedException, ExecutionException
        {
-               //aggregate partial results
+               //aggregate partial results and check for errors
                double val = 0;
-               for(ScalarResultTask task : tasks)
-                       val += task.getScalarResult();
+               for(Future<Double> task : tasks)
+                       val += task.get();
                ret.quickSetValue(0, 0, val);
        }
        
@@ -4207,19 +4228,12 @@ public class LibMatrixMult
                        return null;
                }
        }
-
-       /**
-        * 
-        */
-       private static interface ScalarResultTask extends Callable<Object>{
-               public double getScalarResult();
-       }
        
        /**
         * 
         * 
         */
-       private static class MatrixMultWSLossTask implements ScalarResultTask
+       private static class MatrixMultWSLossTask implements Callable<Double>
        {
                private MatrixBlock _mX = null;
                private MatrixBlock _mU = null;
@@ -4247,7 +4261,7 @@ public class LibMatrixMult
                }
                
                @Override
-               public Object call() throws DMLRuntimeException
+               public Double call() throws DMLRuntimeException
                {
                        if( !_mX.sparse && !_mU.sparse && !_mV.sparse && 
(_mW==null || !_mW.sparse) 
                                && !_mX.isEmptyBlock() && !_mU.isEmptyBlock() 
&& !_mV.isEmptyBlock() 
@@ -4260,11 +4274,6 @@ public class LibMatrixMult
                        else
                                matrixMultWSLossGeneric(_mX, _mU, _mV, _mW, 
_ret, _wt, _rl, _ru);
 
-                       return null;
-               }
-               
-               @Override
-               public double getScalarResult() {
                        return _ret.quickGetValue(0, 0);
                }
        }
@@ -4322,7 +4331,7 @@ public class LibMatrixMult
         * 
         * 
         */
-       private static class MatrixMultWDivTask implements Callable<Object> 
+       private static class MatrixMultWDivTask implements Callable<Long> 
        {
                private MatrixBlock _mW = null;
                private MatrixBlock _mU = null;
@@ -4351,7 +4360,7 @@ public class LibMatrixMult
                }
                
                @Override
-               public Object call() throws DMLRuntimeException
+               public Long call() throws DMLRuntimeException
                {
                        //core weighted div mm computation
                        boolean scalarX = _wt.hasScalar();
@@ -4369,7 +4378,7 @@ public class LibMatrixMult
                }
        }
        
-       private static class MatrixMultWCeTask implements ScalarResultTask
+       private static class MatrixMultWCeTask implements Callable<Double>
        {
                private MatrixBlock _mW = null;
                private MatrixBlock _mU = null;
@@ -4395,7 +4404,7 @@ public class LibMatrixMult
                }
                
                @Override
-               public Object call() throws DMLRuntimeException
+               public Double call() throws DMLRuntimeException
                {
                        //core weighted div mm computation
                        if( !_mW.sparse && !_mU.sparse && !_mV.sparse && 
!_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
@@ -4406,11 +4415,6 @@ public class LibMatrixMult
                                matrixMultWCeMMGeneric(_mW, _mU, _mV, _ret, 
_wt, _rl, _ru);
                        
                        
-                       return null;
-               }
-               
-               @Override
-               public double getScalarResult() {
                        return _ret.quickGetValue(0, 0);
                }
        }

Reply via email to