[SYSTEMML-824] Performance scalar/agg/binary ops (sparse, par, overlap)

This patch makes various performance improvements some of which change
the asymptotic behavior with huge improvements on large data.

(1) Scalar operations: Right scalar divide with non-zero scalar (e.g.,
X/7) are now marked as sparse-safe.

(2) Unary aggregate operations: The min size threshold for parallel
operations is now checked on the number of non-zeros instead of number
of cells in order to avoid unnecessary thread pool creation overhead on
small data. 

(3) Ternary aggregate operations: The rewrites of sum(v1*v2*v3) to
tak+(v1,v2,v3) is now only applied if sum is the only consumer of the
intermediate in order to avoid unnecessary redundancy.

(4) Binary operations: There is now an additional special case for "skip
empty" operations like * and / (with dense rhs) that were handled so far
with the generic case (e.g., * sparse-dense with sparse output) leading
to an asymptotic improvement from O(n log n) to O(n).

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2137a7e4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2137a7e4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2137a7e4

Branch: refs/heads/master
Commit: 2137a7e4a0af7a48326294fa6bb56084ade9d2b5
Parents: b1dc0d5
Author: Matthias Boehm <[email protected]>
Authored: Wed Aug 10 23:04:28 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Aug 11 11:52:13 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  16 ++-
 .../sysml/runtime/matrix/data/LibMatrixAgg.java |   2 +-
 .../runtime/matrix/data/LibMatrixBincell.java   | 101 +++++++++++--------
 .../matrix/operators/RightScalarOperator.java   |   9 +-
 4 files changed, 70 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2137a7e4/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 0152d54..8b44e4b 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -495,31 +495,29 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                
                //currently we support only sum over binary multiply but 
potentially 
                //it can be generalized to any RC aggregate over two common 
binary operations
-               if( _direction == Direction.RowCol 
-                       && _op == AggOp.SUM ) 
+               if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES &&
+                       _direction == Direction.RowCol && _op == AggOp.SUM ) 
                {
                        Hop input1 = getInput().get(0);
-                       if( input1 instanceof BinaryOp && 
((BinaryOp)input1).getOp()==OpOp2.MULT )
+                       if( input1.getParent().size() == 1 && //sum single 
consumer
+                               input1 instanceof BinaryOp && 
((BinaryOp)input1).getOp()==OpOp2.MULT )
                        {
                                Hop input11 = input1.getInput().get(0);
                                Hop input12 = input1.getInput().get(1);
                                
-                               if( input11 instanceof BinaryOp && 
((BinaryOp)input11).getOp()==OpOp2.MULT )
-                               {
+                               if( input11 instanceof BinaryOp && 
((BinaryOp)input11).getOp()==OpOp2.MULT ) {
                                        //ternary, arbitrary matrices but no 
mv/outer operations.
                                        ret = 
HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
                                                && 
HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)       
                                                && 
HopRewriteUtils.isEqualSize(input12, input1);
                                }
-                               else if( input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT )
-                               {
+                               else if( input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT ) {
                                        //ternary, arbitrary matrices but no 
mv/outer operations.
                                        ret = 
HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
                                                        && 
HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)       
                                                        && 
HopRewriteUtils.isEqualSize(input11, input1);
                                }
-                               else
-                               {
+                               else {
                                        //binary, arbitrary matrices but no 
mv/outer operations.
                                        ret = 
HopRewriteUtils.isEqualSize(input11, input12);
                                }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2137a7e4/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index cf689b8..3357dee 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -229,7 +229,7 @@ public class LibMatrixAgg
                throws DMLRuntimeException
        {
                //fall back to sequential version if necessary
-               if(    k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD 
|| in.rlen <= k
+               if(    k <= 1 || (long)in.nonZeros < PAR_NUMCELL_THRESHOLD || 
in.rlen <= k/2
                        || (!(uaop.indexFn instanceof ReduceCol) &&  
out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD ) || 
                        !out.isThreadSafe()) {
                        aggregateUnaryMatrix(in, out, uaop);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2137a7e4/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
index f9269e8..8e7d330 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
@@ -228,8 +228,8 @@ public class LibMatrixBincell
        private static void safeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) 
                throws DMLRuntimeException 
        {
-               boolean isMultiply = (op.fn instanceof Multiply);
-               boolean skipEmpty = (isMultiply);
+               boolean skipEmpty = (op.fn instanceof Multiply 
+                               || isSparseSafeDivide(op, m2) );
                
                //skip empty blocks (since sparse-safe)
                if(    m1.isEmptyBlock(false) && m2.isEmptyBlock(false) 
@@ -425,18 +425,35 @@ public class LibMatrixBincell
                                }
                                ret.nonZeros = nnz;
                        }
+                       else if( skipEmpty && (m1.sparse || m2.sparse) ) 
+                       {
+                               SparseBlock a = m1.sparse ? m1.sparseBlock : 
m2.sparseBlock;
+                               if( a != null ) {
+                                       MatrixBlock b = m1.sparse ? m2 : m1;
+                                       for( int i=0; i<a.numRows(); i++ ) {
+                                               if( a.isEmpty(i) ) continue;
+                                               int apos = a.pos(i);
+                                               int alen = a.size(i);
+                                               int[] aix = a.indexes(i);
+                                               double[] avals = a.values(i);
+                                               for(int k = apos; k < 
apos+alen; k++) {
+                                                       double in2 = 
b.quickGetValue(i, aix[k]);
+                                                       if( in2==0 ) continue;
+                                                       double val = 
op.fn.execute(avals[k], in2);
+                                                       ret.appendValue(i, 
aix[k], val);
+                                               }
+                                       }
+                               }
+                       }
                        else //generic case
                        {
-                               double thisvalue, thatvalue, resultvalue;
                                for(int r=0; r<rlen; r++)
-                                       for(int c=0; c<clen; c++)
-                                       {
-                                               thisvalue=m1.quickGetValue(r, 
c);
-                                               thatvalue=m2.quickGetValue(r, 
c);
-                                               if(thisvalue==0 && thatvalue==0)
-                                                       continue;
-                                               
resultvalue=op.fn.execute(thisvalue, thatvalue);
-                                               ret.appendValue(r, c, 
resultvalue);
+                                       for(int c=0; c<clen; c++) {
+                                               double in1 = 
m1.quickGetValue(r, c);
+                                               double in2 = 
m2.quickGetValue(r, c);
+                                               if( in1==0 && in2==0) continue;
+                                               double val = op.fn.execute(in1, 
in2);
+                                               ret.appendValue(r, c, val);
                                        }
                        }
                }
@@ -995,23 +1012,9 @@ public class LibMatrixBincell
                        }
                        ret.nonZeros = nnz;
                }
-               else //DENSE <- DENSE
-               {
-                       //allocate dense block
-                       ret.allocateDenseBlock(true);
-               
-                       double[] a = m1.denseBlock;
-                       double[] c = ret.denseBlock;
-                       
-                       int limit = m1.rlen*m1.clen;
-                       int nnz = 0;
-                       for( int i=0; i<limit; i++ ) {
-                               c[i] = op.executeScalar( a[i] );
-                               nnz += (c[i] != 0) ? 1 : 0;
-                       }
-                       ret.nonZeros = nnz;
+               else { //DENSE <- DENSE
+                       denseBinaryScalar(m1, ret, op);
                }
-               
        }
        
        /**
@@ -1068,27 +1071,39 @@ public class LibMatrixBincell
                        }
                        ret.nonZeros = nnz;
                }
-               else //DENSE MATRIX
-               {
-                       //allocate dense block (if necessary), incl clear nnz
-                       ret.allocateDenseBlock(true);
-                       
-                       double[] a = m1.denseBlock;
-                       double[] c = ret.denseBlock;
-                       
-                       //compute scalar operation, incl nnz maintenance
-                       int limit = m1.rlen*m1.clen;
-                       int nnz = 0;
-                       for( int i=0; i<limit; i++ ) {
-                               c[i] = op.executeScalar( a[i] );
-                               nnz += (c[i] != 0) ? 1 : 0;
-                       }
-                       ret.nonZeros = nnz;
+               else { //DENSE MATRIX
+                       denseBinaryScalar(m1, ret, op);
                }
        }
 
        /**
         * 
+        * @param m1
+        * @param ret
+        * @param op
+        * @throws DMLRuntimeException 
+        */
+       private static void denseBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op) 
+               throws DMLRuntimeException 
+       {
+               //allocate dense block (if necessary), incl clear nnz
+               ret.allocateDenseBlock(true);
+               
+               double[] a = m1.denseBlock;
+               double[] c = ret.denseBlock;
+               
+               //compute scalar operation, incl nnz maintenance
+               int limit = m1.rlen*m1.clen;
+               int nnz = 0;
+               for( int i=0; i<limit; i++ ) {
+                       c[i] = op.executeScalar( a[i] );
+                       nnz += (c[i] != 0) ? 1 : 0;
+               }
+               ret.nonZeros = nnz;
+       }
+       
+       /**
+        * 
         * @param m1ret
         * @param m2
         * @param op

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2137a7e4/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
index 3a4c481..cea41f1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java
@@ -21,6 +21,7 @@
 package org.apache.sysml.runtime.matrix.operators;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.Divide;
 import org.apache.sysml.runtime.functionobjects.GreaterThan;
 import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
 import org.apache.sysml.runtime.functionobjects.LessThan;
@@ -49,12 +50,10 @@ public class RightScalarOperator extends ScalarOperator
                super.setConstant(cst);
                
                //enable conditionally sparse safe operations
-               if(    (fn instanceof GreaterThan && _constant>=0)
+               sparseSafe = (fn instanceof GreaterThan && _constant>=0)
                        || (fn instanceof GreaterThanEquals && _constant>0)
                        || (fn instanceof LessThan && _constant<=0)
-                       || (fn instanceof LessThanEquals && _constant<0))
-               {
-                       sparseSafe = true;
-               }
+                       || (fn instanceof LessThanEquals && _constant<0)
+                       || (fn instanceof Divide && _constant!=0);
        }
 }

Reply via email to