[SYSTEMML-833] Additional cleanup rewrites (unnecess. cast, reorg, agg)

This patch adds various additional cleanup rewrites in order to simplify
debugging. In detail this includes:

(1) Unnecessary data type casts (e.g., as.scalar(as.matrix))
(2) Unnecessary reorg operations (e.g., t(X), iff X 1x1 dims)
(3) Unnecessary aggregation (e.g., sum(X) iff X 1x1 dims)
(4) Pushdown of scalar casts (e.g., as.scalar(X*s)->as.scalar(X)*s)

Note that these rewrites enable each other; e.g., once (2), (3), and (4)
are performed, unnecessary casts (1) can be removed avoiding long chains
of unnecessary operations like sum(t(as.matrix(t(X))*7)) ->
as.scalar(X)*7. 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/11a85775
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/11a85775
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/11a85775

Branch: refs/heads/master
Commit: 11a85775f11e4490d957fe4f9fab4bfd8ea7a138
Parents: 461184a
Author: Matthias Boehm <[email protected]>
Authored: Sat Jul 30 23:48:51 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jul 31 19:16:59 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  3 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     | 11 ++--
 .../RewriteAlgebraicSimplificationDynamic.java  | 67 ++++++++++++++++----
 .../RewriteAlgebraicSimplificationStatic.java   | 47 ++++++++++++++
 .../rewrite/RewriteRemoveUnnecessaryCasts.java  | 21 +++++-
 .../cp/ArithmeticBinaryCPInstruction.java       | 14 ++--
 6 files changed, 134 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index f7e4656..3bfdcb5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -512,7 +512,8 @@ public class HopRewriteUtils
        public static UnaryOp createUnary(Hop input, OpOp1 type) 
                throws HopsException
        {
-               UnaryOp unary = new UnaryOp(input.getName(), 
input.getDataType(), input.getValueType(), type, input);
+               DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : 
input.getDataType();
+               UnaryOp unary = new UnaryOp(input.getName(), dt, 
input.getValueType(), type, input);
                HopRewriteUtils.setOutputBlocksizes(unary, 
input.getRowsInBlock(), input.getColsInBlock());
                HopRewriteUtils.copyLineNumbers(input, unary);
                unary.refreshSizeInformation(); 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8e645dc..e7b03c4 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -130,12 +130,13 @@ public class ProgramRewriter
                                _dagRuleSet.add( new 
RewriteAlgebraicSimplificationDynamic()      ); //dependencies: cse
                                _dagRuleSet.add( new 
RewriteAlgebraicSimplificationStatic()       ); //dependencies: cse
                        }
-                       
-                       //reapply cse after rewrites because (1) applied 
rewrites on operators w/ multiple parents, and
-                       //(2) newly introduced operators potentially created 
redundancy (incl leaf merge to allow for cse)
-                       if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             
-                               _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications    
                   
                }
+               
+               // cleanup after all rewrites applied 
+               // (newly introduced operators, introduced redundancy after 
rewrites w/ multiple parents) 
+               _dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()        
     );         
+               if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )     
        
+                       _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination(true) );                     
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 10953f5..793bc25 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -69,6 +69,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        //valid aggregation operation types for empty (sparse-safe) operations 
(not all operations apply)
        //AggOp.MEAN currently not due to missing count/corrections
        private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new 
AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
+       private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new 
AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
        
        //valid unary operation types for empty (sparse-safe) operations (not 
all operations apply)
        private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new 
OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, 
OpOp1.CUMSUM}; 
@@ -149,13 +150,14 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        hi = removeUnnecessaryLeftIndexing(hop, hi, i);   
//e.g., X[,1]=Y -> Y, if output == input dims 
                        hi = fuseLeftIndexingChainToAppend(hop, hi, i);   
//e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
                        hi = removeUnnecessaryCumulativeOp(hop, hi, i);   
//e.g., cumsum(X) -> X, if nrow(X)==1;
-                       hi = removeUnnecessaryReorgOperation(hop, hi, i); 
//e.g., matrix(X) -> X, if output == input dims
+                       hi = removeUnnecessaryReorgOperation(hop, hi, i); 
//e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims
                        hi = removeUnnecessaryOuterProduct(hop, hi, i);   
//e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
                        hi = fuseDatagenAndReorgOperation(hop, hi, i);    
//e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
                        hi = simplifyColwiseAggregate(hop, hi, i);        
//e.g., colsums(X) -> sum(X) or X, if col/row vector
                        hi = simplifyRowwiseAggregate(hop, hi, i);        
//e.g., rowsums(X) -> sum(X) or X, if row/col vector
                        hi = simplifyColSumsMVMult(hop, hi, i);           
//e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
                        hi = simplifyRowSumsMVMult(hop, hi, i);           
//e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
+                       hi = simplifyUnnecessaryAggregate(hop, hi, i);    
//e.g., sum(X) -> as.scalar(X), if 1x1 dims
                        hi = simplifyEmptyAggregate(hop, hi, i);          
//e.g., sum(X) -> 0, if nnz(X)==0
                        hi = simplifyEmptyUnaryOperation(hop, hi, i);     
//e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0                   
                        hi = simplifyEmptyReorgOperation(hop, hi, i);     
//e.g., t(X) -> matrix(0, ncol(X), nrow(X)) 
@@ -428,22 +430,26 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
         */
        private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
        {
-               if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp() == 
ReOrgOp.RESHAPE ) //reshape operation
+               if( hi instanceof ReorgOp ) 
                {
+                       ReorgOp rop = (ReorgOp) hi;
                        Hop input = hi.getInput().get(0); 
-
-                       if( HopRewriteUtils.isEqualSize(hi, input) ) //equal 
dims
-                       {
-                               //equal dims of reshape input and output -> no 
need for reshape because 
-                               //byrow always refers to both input/output and 
hence gives the same result
-                               
-                               //remove unnecessary right indexing
-                               HopRewriteUtils.removeChildReference(parent, 
hi);                               
+                       boolean apply = false;
+                       
+                       //equal dims of reshape input and output -> no need for 
reshape because 
+                       //byrow always refers to both input/output and hence 
gives the same result
+                       apply |= (rop.getOp()==ReOrgOp.RESHAPE && 
HopRewriteUtils.isEqualSize(hi, input));
+                       
+                       //1x1 dimensions of transpose/reshape -> no need for 
reorg      
+                       apply |= ((rop.getOp()==ReOrgOp.TRANSPOSE || 
rop.getOp()==ReOrgOp.RESHAPE) 
+                                       && rop.getDim1()==1 && 
rop.getDim2()==1);
+                       
+                       if( apply ) {
+                               
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);                     
        
                                HopRewriteUtils.addChildReference(parent, 
input, pos);
                                parent.refreshSizeInformation();
                                hi = input;
-                               
-                               LOG.debug("Applied removeUnnecessaryReshape");
+                               LOG.debug("Applied removeUnnecessaryReorg.");
                        }                       
                }
                
@@ -841,6 +847,43 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
         * @return
         * @throws HopsException
         */
+       private Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) 
+               throws HopsException
+       {
+               //e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, 
max, prod, trace)
+               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol  ) 
+               {
+                       AggUnaryOp uhi = (AggUnaryOp)hi;
+                       Hop input = uhi.getInput().get(0);
+                       
+                       if( HopRewriteUtils.isValidOp(uhi.getOp(), 
LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){               
+                               
+                               if( input.getDim1()==1 && input.getDim2()==1 )
+                               {
+                                       UnaryOp cast = 
HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
+                                       
+                                       //remove unnecessary aggregation 
+                                       
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+                                       
HopRewriteUtils.addChildReference(parent, cast, pos);
+                                       parent.refreshSizeInformation();
+                                       hi = cast;
+                                       
+                                       LOG.debug("Applied 
simplifyUnncessaryAggregate");
+                               }
+                       }                       
+               }
+               
+               return hi;
+       }
+       
+       /**
+        * 
+        * @param parent
+        * @param hi
+        * @param pos
+        * @return
+        * @throws HopsException
+        */
        private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) 
                throws HopsException
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f23686c..784d678 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -144,6 +144,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
                        hi = simplifyBushyBinaryOperation(hop, hi, i);       
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
                        hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
+                       hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
                        hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lamda*X) -> lamda*sum(X)
                        hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
@@ -890,6 +891,52 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
         * @param hi
         * @param pos
         * @return
+        * @throws HopsException
+        */
+       private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, 
int pos ) 
+               throws HopsException
+       {
+               if(   hi instanceof UnaryOp && 
((UnaryOp)hi).getOp()==OpOp1.CAST_AS_SCALAR  
+                  && hi.getInput().get(0) instanceof BinaryOp ) 
+               {
+                       BinaryOp bin = (BinaryOp) hi.getInput().get(0);
+                       BinaryOp bout = null;
+                       
+                       //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
+                       if( 
bin.getInput().get(0).getDataType()==DataType.MATRIX 
+                               && 
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+                               UnaryOp cast1 = 
HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+                               UnaryOp cast2 = 
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+                               bout = HopRewriteUtils.createBinary(cast1, 
cast2, bin.getOp());
+                       }
+                       //as.scalar(X*s) -> as.scalar(X) * s
+                       else if( 
bin.getInput().get(0).getDataType()==DataType.MATRIX ) {
+                               UnaryOp cast = 
HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+                               bout = HopRewriteUtils.createBinary(cast, 
bin.getInput().get(1), bin.getOp());
+                       }
+                       //as.scalar(s*X) -> s * as.scalar(X)
+                       else if ( 
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+                               UnaryOp cast = 
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+                               bout = 
HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
+                       }
+                       
+                       if( bout != null ) {
+                               
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+                               HopRewriteUtils.addChildReference(parent, bout, 
pos);
+                               
+                               LOG.debug("Applied 
simplifyBinaryMatrixScalarOperation.");
+                       }
+               }
+               
+               return hi;
+       }
+       
+       /**
+        * 
+        * @param parent
+        * @param hi
+        * @param pos
+        * @return
         */
        private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int 
pos )
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
index 36d8712..a8001f8 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.Hop.VisitStatus;
 import org.apache.sysml.hops.UnaryOp;
@@ -73,6 +74,7 @@ public class RewriteRemoveUnnecessaryCasts extends 
HopRewriteRule
         * 
         * @param hop
         */
+       @SuppressWarnings("unchecked")
        private void rule_RemoveUnnecessaryCasts( Hop hop )
        {
                //check mark processed
@@ -84,7 +86,7 @@ public class RewriteRemoveUnnecessaryCasts extends 
HopRewriteRule
                for( int i=0; i<inputs.size(); i++ )
                        rule_RemoveUnnecessaryCasts( inputs.get(i) );
                
-               //remove cast if unnecessary
+               //remove unnecessary value type cast 
                if( hop instanceof UnaryOp && 
HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) )
                {
                        Hop in = hop.getInput().get(0);
@@ -116,6 +118,23 @@ public class RewriteRemoveUnnecessaryCasts extends 
HopRewriteRule
                        }
                }
                
+               //remove unnecessary data type casts
+               if( hop instanceof UnaryOp && hop.getInput().get(0) instanceof 
UnaryOp ) {
+                       UnaryOp uop1 = (UnaryOp) hop;
+                       UnaryOp uop2 = (UnaryOp) hop.getInput().get(0);
+                       if( (uop1.getOp()==OpOp1.CAST_AS_MATRIX && 
uop2.getOp()==OpOp1.CAST_AS_SCALAR) 
+                               || (uop1.getOp()==OpOp1.CAST_AS_SCALAR && 
uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) {
+                               Hop input = uop2.getInput().get(0);
+                               //rewire parents
+                               ArrayList<Hop> parents = (ArrayList<Hop>) 
hop.getParent().clone();
+                               for( Hop p : parents ) {
+                                       int ix = 
HopRewriteUtils.getChildReferencePos(p, hop);
+                                       HopRewriteUtils.removeChildReference(p, 
hop);
+                                       HopRewriteUtils.addChildReference(p, 
input, ix);
+                               }
+                       }
+               }
+               
                //mark processed
                hop.setVisited( VisitStatus.DONE );
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
index 38ba9dd..c9545ac 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
@@ -60,16 +60,10 @@ public abstract class ArithmeticBinaryCPInstruction extends 
BinaryCPInstruction
                //make sure these checks belong here
                //if either input is a matrix, then output
                //has to be a matrix
-               if((dt1 == DataType.MATRIX 
-                       || dt2 == DataType.MATRIX) 
-                  && dt3 != DataType.MATRIX)
-                       throw new DMLRuntimeException("Element-wise matrix 
operations between variables "
-                                                                               
  + in1.getName()
-                                                                               
  + " and "
-                                                                               
  + in2.getName()
-                                                                               
  + " must produce a matrix, which "
-                                                                               
  + out.getName()
-                                                                               
  + "is not");
+               if((dt1 == DataType.MATRIX  || dt2 == DataType.MATRIX) && dt3 
!= DataType.MATRIX) {
+                       throw new DMLRuntimeException("Element-wise matrix 
operations between variables " + in1.getName() + 
+                                       " and " + in2.getName() + " must 
produce a matrix, which " + out.getName() + "is not");
+               }
                
                Operator operator = (dt1 != dt2) ?
                                        
InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) : 

Reply via email to