[SYSTEMML-2444] Fix robustness wsloss rewrite order, Kmeans-predict

SystemML 1.2 includes several fixes for correct constant propagation.
Theses modifications, however, led to a performance regression for
Kmeans-predict 10Mx1K because due to partial unknowns the sum-sq rewrite
was applied before the wsloss rewrite, which destroyed the wsloss
pattern. We now made the overlapping patterns more robust by matching
against sum(tmp^2) and sumSq(tmp). 

On the perftest Kmeans-predict 10Mx1K (80GB) scenario, this patch
improved end-to-end performance (incl read and spark ctx creation) from
252s to 63s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0aaf11d8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0aaf11d8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0aaf11d8

Branch: refs/heads/master
Commit: 0aaf11d82b5680dbf4f21d1195f397e66edc22a2
Parents: 56c81cb
Author: Matthias Boehm <[email protected]>
Authored: Mon Jul 16 19:36:16 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jul 16 19:36:44 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  11 +
 .../RewriteAlgebraicSimplificationDynamic.java  | 224 ++++++++++---------
 2 files changed, 127 insertions(+), 108 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0aaf11d8/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 48b95cc..025f98a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -777,6 +777,13 @@ public class HopRewriteUtils
                return ( hop.getNnz()==0 );
        }
        
+       public static boolean isEqualMatrixSize(BinaryOp hop) {
+               return hop.getDataType().isMatrix()
+                       && hop.getInput().get(0).getDataType().isMatrix()
+                       && hop.getInput().get(1).getDataType().isMatrix()
+                       && isEqualSize(hop.getInput().get(0), 
hop.getInput().get(1));
+       }
+       
        public static boolean isEqualSize( Hop hop1, Hop hop2 ) {
                return (hop1.dimsKnown() && hop2.dimsKnown()
                        && hop1.getDim1() == hop2.getDim1()
@@ -1016,6 +1023,10 @@ public class HopRewriteUtils
                return hop instanceof AggBinaryOp && 
((AggBinaryOp)hop).isMatrixMultiply();
        }
        
+       public static boolean isAggUnaryOp(Hop hop, AggOp op, Direction dir) {
+               return isAggUnaryOp(hop, op) && 
((AggUnaryOp)hop).getDirection()==dir;
+       }
+       
        public static boolean isAggUnaryOp(Hop hop, AggOp...op) {
                if( !(hop instanceof AggUnaryOp) )
                        return false;

http://git-wip-us.apache.org/repos/asf/systemml/blob/0aaf11d8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 062da2f..4f0ef51 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1200,6 +1200,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
         * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
         * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
         * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
+        * 4) sumSq (X - U %*% t(V)) (no weighting sumSq)
         * 
         * NOTE: We include transpose into the pattern because during runtime 
we need to compute
         * U%*% t(V) pointwise; having V and not t(V) at hand allows for a 
cache-friendly implementation
@@ -1220,69 +1221,59 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                //NOTE: there might be also a general simplification without 
custom operator
                //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
                Hop hnew = null;
+               boolean appliedPattern = false;
                
-               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol
-                       && ((AggUnaryOp)hi).getOp() == AggOp.SUM     //all 
patterns rooted by sum()
+               if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, 
Direction.RowCol) //all patterns rooted by sum()
                        && hi.getInput().get(0) instanceof BinaryOp  //all 
patterns subrooted by binary op
                        && hi.getInput().get(0).getDim2() > 1  )     //not 
applied for vector-vector mult
                {
                        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
-                       boolean appliedPattern = false;
                        
                        //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post 
weighting)
                        //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
                        if( bop.getOp()==OpOp2.MULT && 
HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW)
-                               && 
bop.getInput().get(0).getDataType()==DataType.MATRIX 
+                               && 
bop.getInput().get(0).getDataType()==DataType.MATRIX
                                && 
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) 
//prevent mv
-                               && bop.getInput().get(1).getInput().get(1) 
instanceof LiteralOp
-                               && 
HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
+                               && 
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1).getInput().get(1), 2) )
                        {
                                Hop W = bop.getInput().get(0);
                                Hop tmp = 
bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
                                
                                if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS)
-                                       && 
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) 
//prevent mv       
+                                       && 
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) 
//prevent mv
                                        && tmp.getInput().get(0).getDataType() 
== DataType.MATRIX )
                                {
                                        //a) sum (W * (X - U %*% t(V)) ^ 2)
                                        int uvIndex = -1;
                                        if( tmp.getInput().get(1) instanceof 
AggBinaryOp  //ba gurantees matrices
-                                                       && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) 
//BLOCKSIZE CONSTRAINT
-                                       {
-                                               uvIndex = 1;   
+                                               && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) { 
//BLOCKSIZE CONSTRAINT
+                                               uvIndex = 1;
                                        }
                                        //b) sum (W * (U %*% t(V) - X) ^ 2)
                                        else if(tmp.getInput().get(0) 
instanceof AggBinaryOp  //ba gurantees matrices
-                                               && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) 
//BLOCKSIZE CONSTRAINT
-                                       {
+                                               && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) { 
//BLOCKSIZE CONSTRAINT
                                                uvIndex = 0;
-                                       }   
-                                
-                                       if( uvIndex >= 0 ) //rewrite match
-                                       {
+                                       }
+
+                                       if( uvIndex >= 0 ) { //rewrite match
                                                Hop X = 
tmp.getInput().get((uvIndex==0)?1:0); 
                                                Hop U = 
tmp.getInput().get(uvIndex).getInput().get(0);
                                                Hop V = 
tmp.getInput().get(uvIndex).getInput().get(1);
-                           
-                                               if( 
!HopRewriteUtils.isTransposeOperation(V) ) {
-                                                       V = 
HopRewriteUtils.createTranspose(V);
-                                               }
-                                               else{
-                                                       V = V.getInput().get(0);
-                                               }
-                           
+                                               V = 
!HopRewriteUtils.isTransposeOperation(V) ?
+                                                       
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+
                                                //handle special case of post_nz
                                                if( 
HopRewriteUtils.isNonZeroIndicator(W, X) ){
                                                        W = new LiteralOp(1);
                                                }
-                                               
+
                                                //construct quaternary hop
-                                               hnew = new 
QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, 
-                                                               OpOp4.WSLOSS, 
X, U, V, W, true);
+                                               hnew = new 
QuaternaryOp(hi.getName(), DataType.SCALAR,
+                                                       ValueType.DOUBLE,  
OpOp4.WSLOSS, X, U, V, W, true);
                                                
HopRewriteUtils.setOutputParametersForScalar(hnew);
-       
+
                                                appliedPattern = true;
-                                               LOG.debug("Applied 
simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");  
+                                               LOG.debug("Applied 
simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");
                                        }
                                }
                        }
@@ -1290,107 +1281,124 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre 
weighting)
                        //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
                        if( !appliedPattern
-                               && bop.getOp()==OpOp2.POW && 
bop.getInput().get(1) instanceof LiteralOp
-                               && 
HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
-                               && 
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) 
-                               && 
bop.getInput().get(0).getDataType()==DataType.MATRIX 
-                               && 
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), 
bop.getInput().get(0).getInput().get(1)) //prevent mv
-                               && 
bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+                               && bop.getOp()==OpOp2.POW && 
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2)
+                               && 
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)
+                               && 
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0)))
                        {
-                           Hop lleft = 
bop.getInput().get(0).getInput().get(0); 
-                           Hop lright = 
bop.getInput().get(0).getInput().get(1); 
-                
-                           //a) sum ((X - W * (U %*% t(V))) ^ 2)
-                           int wuvIndex = -1;
-                           if( lright instanceof BinaryOp && 
lright.getInput().get(1) instanceof AggBinaryOp ){
-                               wuvIndex = 1;
-                           }
-                           //b) sum ((W * (U %*% t(V)) - X) ^ 2)
-                           else if( lleft instanceof BinaryOp && 
lleft.getInput().get(1) instanceof AggBinaryOp ){
-                               wuvIndex = 0;
-                           }
-                           
-                           if( wuvIndex >= 0 ) //rewrite match
-                           {
-                               Hop X = 
bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
-                               Hop tmp = 
bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
-                               
-                               if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
-                                       && tmp.getInput().get(0).getDataType() 
== DataType.MATRIX       
-                                       && 
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) 
//prevent mv
-                                       && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) 
//BLOCKSIZE CONSTRAINT
-                               {
-                                       Hop W = tmp.getInput().get(0); 
-                                       Hop U = 
tmp.getInput().get(1).getInput().get(0);
-                                       Hop V = 
tmp.getInput().get(1).getInput().get(1);
-                                       
-                                       if( 
!HopRewriteUtils.isTransposeOperation(V) ) { 
-                                               V = 
HopRewriteUtils.createTranspose(V);
-                                       }
-                                       else {
-                                               V = V.getInput().get(0);
-                                       }
-                                       
-                                       hnew = new QuaternaryOp(hi.getName(), 
DataType.SCALAR, ValueType.DOUBLE, 
-                                                         OpOp4.WSLOSS, X, U, 
V, W, false);
-                                       
HopRewriteUtils.setOutputParametersForScalar(hnew);
-    
-                                       appliedPattern = true;
-                                       LOG.debug("Applied 
simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");     
-                               }
-                           }
+                               Hop lleft = 
bop.getInput().get(0).getInput().get(0); 
+                               Hop lright = 
bop.getInput().get(0).getInput().get(1); 
+
+                               //a) sum ((X - W * (U %*% t(V))) ^ 2)
+                               int wuvIndex = -1;
+                               if( lright instanceof BinaryOp && 
lright.getInput().get(1) instanceof AggBinaryOp ){
+                                       wuvIndex = 1;
+                               }
+                               //b) sum ((W * (U %*% t(V)) - X) ^ 2)
+                               else if( lleft instanceof BinaryOp && 
lleft.getInput().get(1) instanceof AggBinaryOp ){
+                                       wuvIndex = 0;
+                               }
+
+                               if( wuvIndex >= 0 ) //rewrite match
+                               {
+                                       Hop X = 
bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
+                                       Hop tmp = 
bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
+                                       
+                                       if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
+                                               && 
tmp.getInput().get(0).getDataType() == DataType.MATRIX
+                                               && 
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) 
//prevent mv
+                                               && 
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) 
//BLOCKSIZE CONSTRAINT
+                                       {
+                                               Hop W = tmp.getInput().get(0);
+                                               Hop U = 
tmp.getInput().get(1).getInput().get(0);
+                                               Hop V = 
tmp.getInput().get(1).getInput().get(1);
+                                               V = 
!HopRewriteUtils.isTransposeOperation(V) ?
+                                                       
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+                                               hnew = new 
QuaternaryOp(hi.getName(), DataType.SCALAR,
+                                                       ValueType.DOUBLE, 
OpOp4.WSLOSS, X, U, V, W, false);
+                                               
HopRewriteUtils.setOutputParametersForScalar(hnew);
+                                               appliedPattern = true;
+                                               LOG.debug("Applied 
simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");     
+                                       }
+                               }
                        }
-                       
+
                        //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
                        //alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
                        if( !appliedPattern
-                               && bop.getOp()==OpOp2.POW && 
bop.getInput().get(1) instanceof LiteralOp
-                               && 
HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
-                               && 
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)         
-                               && 
bop.getInput().get(0).getDataType()==DataType.MATRIX 
-                               && 
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), 
bop.getInput().get(0).getInput().get(1)) //prevent mv
-                               && 
bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+                               && bop.getOp()==OpOp2.POW && 
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2)
+                               && 
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) 
+                               && 
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) //prevent mv
                        {
                                Hop lleft = 
bop.getInput().get(0).getInput().get(0);
                                Hop lright = 
bop.getInput().get(0).getInput().get(1);
-                
+
                                //a) sum ((X - (U %*% t(V))) ^ 2)
                                int uvIndex = -1;
-                               if( lright instanceof AggBinaryOp //ba 
gurantees matrices
-                                       && 
HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) )  //BLOCKSIZE 
CONSTRAINT
-                               {
+                               if( lright instanceof AggBinaryOp //ba 
guarantees matrices
+                                       && 
HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE 
CONSTRAINT
                                        uvIndex = 1;
                                }
                                //b) sum (((U %*% t(V)) - X) ^ 2)
-                               else if( lleft instanceof AggBinaryOp //ba 
gurantees matrices
-                                               && 
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) )  //BLOCKSIZE 
CONSTRAINT
-                               {
+                               else if( lleft instanceof AggBinaryOp //ba 
guarantees matrices
+                                               && 
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE 
CONSTRAINT
                                        uvIndex = 0;
                                }
-                           
-                               if( uvIndex >= 0 ) //rewrite match
-                               {
+
+                               if( uvIndex >= 0 ) { //rewrite match
                                        Hop X = 
bop.getInput().get(0).getInput().get((uvIndex==0)?1:0);
                                        Hop tmp = 
bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
                                        Hop W = new LiteralOp(1); //no 
weighting 
                                        Hop U = tmp.getInput().get(0);
                                        Hop V = tmp.getInput().get(1);
-       
-                                       if( 
!HopRewriteUtils.isTransposeOperation(V) ) { 
-                                               V = 
HopRewriteUtils.createTranspose(V);
-                                       }
-                                       else {
-                                               V = V.getInput().get(0);
-                                       }
-                                       
-                                       hnew = new QuaternaryOp(hi.getName(), 
DataType.SCALAR, ValueType.DOUBLE, 
-                                                         OpOp4.WSLOSS, X, U, 
V, W, false);
+                                       V = 
!HopRewriteUtils.isTransposeOperation(V) ?
+                                               
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+                                       hnew = new QuaternaryOp(hi.getName(), 
DataType.SCALAR,
+                                               ValueType.DOUBLE, OpOp4.WSLOSS, 
X, U, V, W, false);
                                        
HopRewriteUtils.setOutputParametersForScalar(hnew);
-
                                        appliedPattern = true;
-                                       LOG.debug("Applied 
simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");      
+                                       
+                                       LOG.debug("Applied 
simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");
                                }
-                       }                       
+                       }
+               }
+               
+               //Pattern 4) sumSq (X - U %*% t(V)) (no weighting)
+               //alternative pattern: sumSq (U %*% t(V) - X)
+               if( !appliedPattern
+                       && HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ, 
Direction.RowCol)
+                       && HopRewriteUtils.isBinary(hi.getInput().get(0), 
OpOp2.MINUS) 
+                       && 
HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput().get(0))) //prevent mv
+               {
+                       Hop lleft = hi.getInput().get(0).getInput().get(0);
+                       Hop lright = hi.getInput().get(0).getInput().get(1);
+
+                       //a) sumSq (X - U %*% t(V))
+                       int uvIndex = -1;
+                       if( lright instanceof AggBinaryOp //ba guarantees 
matrices
+                               && 
HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE 
CONSTRAINT
+                               uvIndex = 1;
+                       }
+                       //b) sumSq (U %*% t(V) - X)
+                       else if( lleft instanceof AggBinaryOp //ba guarantees 
matrices
+                                       && 
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE 
CONSTRAINT
+                               uvIndex = 0;
+                       }
+
+                       if( uvIndex >= 0 ) { //rewrite match
+                               Hop X = 
hi.getInput().get(0).getInput().get((uvIndex==0)?1:0);
+                               Hop tmp = 
hi.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
+                               Hop W = new LiteralOp(1); //no weighting 
+                               Hop U = tmp.getInput().get(0);
+                               Hop V = tmp.getInput().get(1);
+                               V = !HopRewriteUtils.isTransposeOperation(V) ?
+                                       HopRewriteUtils.createTranspose(V) : 
V.getInput().get(0);
+                               hnew = new QuaternaryOp(hi.getName(), 
DataType.SCALAR,
+                                       ValueType.DOUBLE, OpOp4.WSLOSS, X, U, 
V, W, false);
+                               
HopRewriteUtils.setOutputParametersForScalar(hnew);
+                               appliedPattern = true;
+                               
+                               LOG.debug("Applied 
simplifyWeightedSquaredLoss4"+uvIndex+" (line "+hi.getBeginLine()+")");
+                       }
                }
                
                //relink new hop into original position
@@ -2168,7 +2176,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                        HopRewriteUtils.cleanupUnreferenced(hi, 
sumInput);
                                        hi = sumSq;
                                        
-                                       LOG.debug("Applied fuseSumSquared.");
+                                       LOG.debug("Applied fuseSumSquared (line 
" +hi.getBeginLine()+").");
                                }
                        }
                }

Reply via email to