Repository: systemml
Updated Branches:
  refs/heads/master 8a1f98e1b -> ec024661a


[SYSTEMML-1990] New rewrites for rand, outer products, and cbind/rbind

This patch generalizes two existing rewrites and introduces a new
rewrite for nary cbind/rbind:

(1) Generalized rand-binary fusion: So far we only fused binary
operations with literals into rand operations. For special cases of
multiply and add, we now also fuse binary operations with variable
scalar inputs.

(2) Generalized outer-product rewrites: Outer products for replication
with subsequent comparison operations are rewritten to binary outer
operations. This patch generalizes the detection of such patterns to
partial unknowns.

(3) New rbind/cbind folding: Exploiting the recently added nary
cbind/rbind operations, we now recursively fold subsequent cbind/rbind
operations with single consumers into nary cbind/rbind.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/578a9869
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/578a9869
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/578a9869

Branch: refs/heads/master
Commit: 578a98697e6219947b748337f9a121acf54afe53
Parents: 8a1f98e
Author: Matthias Boehm <[email protected]>
Authored: Mon Nov 6 13:14:14 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Nov 8 13:22:13 2017 -0800

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/DataGenOp.java   |  15 +-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  21 +++
 .../RewriteAlgebraicSimplificationDynamic.java  |   5 +-
 .../RewriteAlgebraicSimplificationStatic.java   | 140 +++++++++++++++----
 .../java/org/apache/sysml/utils/Statistics.java |   3 +-
 .../test/integration/AutomatedTestBase.java     |  15 ++
 .../functions/misc/RewriteFoldRCBindTest.java   | 101 +++++++++++++
 .../functions/misc/RewriteFusedRandTest.java    |  62 ++++----
 .../scripts/functions/misc/RewriteFoldCBind.dml |  28 ++++
 .../scripts/functions/misc/RewriteFoldRBind.dml |  28 ++++
 .../scripts/functions/misc/RewriteFusedRand.dml |  29 ----
 .../functions/misc/RewriteFusedRandLit.dml      |  29 ++++
 .../functions/misc/RewriteFusedRandVar1.dml     |  28 ++++
 .../functions/misc/RewriteFusedRandVar2.dml     |  28 ++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 15 files changed, 446 insertions(+), 87 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/DataGenOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java 
b/src/main/java/org/apache/sysml/hops/DataGenOp.java
index 69f35ed..500bbd9 100644
--- a/src/main/java/org/apache/sysml/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java
@@ -45,7 +45,6 @@ import org.apache.sysml.runtime.util.UtilFunctions;
  */
 public class DataGenOp extends Hop implements MultiThreadedHop
 {
-       
        public static final long UNSPECIFIED_SEED = -1;
        
         // defines the specific data generation method
@@ -366,15 +365,21 @@ public class DataGenOp extends Hop implements 
MultiThreadedHop
        }
        
 
-       public HashMap<String, Integer> getParamIndexMap()
-       {
+       public HashMap<String, Integer> getParamIndexMap() {
                return _paramIndexMap;
        }
        
-       public int getParamIndex(String key)
-       {
+       public int getParamIndex(String key) {
                return _paramIndexMap.get(key);
        }
+       
+       public Hop getInput(String key) {
+               return getInput().get(getParamIndex(key));
+       }
+       
+       public void setInput(String key, Hop hop) {
+               getInput().set(getParamIndex(key), hop);
+       }
 
        public boolean hasConstantValue() 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index b6db466..15cc2cb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -40,6 +40,7 @@ import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.FileFormatTypes;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
+import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.HopsException;
@@ -47,6 +48,7 @@ import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.MemoTable;
+import org.apache.sysml.hops.NaryOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.ReorgOp;
@@ -603,6 +605,16 @@ public class HopRewriteUtils
                return ix;
        }
        
+       public static NaryOp createNary(OpOpN op, Hop... inputs) throws 
HopsException {
+               Hop mainInput = inputs[0];
+               NaryOp nop = new NaryOp(mainInput.getName(), 
mainInput.getDataType(),
+                       mainInput.getValueType(), op, inputs);
+               nop.setOutputBlocksizes(mainInput.getRowsInBlock(), 
mainInput.getColsInBlock());
+               copyLineNumbers(mainInput, nop);
+               nop.refreshSizeInformation();
+               return nop;
+       }
+       
        public static Hop createValueHop( Hop hop, boolean row ) 
                throws HopsException
        {
@@ -957,6 +969,15 @@ public class HopRewriteUtils
                return (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ);
        }
        
+       public static boolean isNary(Hop hop, OpOpN type) {
+               return hop instanceof NaryOp && ((NaryOp)hop).getOp()==type;
+       }
+       
+       public static boolean isNary(Hop hop, OpOpN... types) {
+               return ( hop instanceof NaryOp 
+                       && ArrayUtils.contains(types, ((NaryOp) hop).getOp()));
+       }
+       
        public static boolean isNonZeroIndicator(Hop pred, Hop hop )
        {
                if( pred instanceof BinaryOp && 
((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index eba06fc..0fa1aed 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -432,8 +432,9 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        else if(HopRewriteUtils.isValidOuterBinaryOp(bop) 
                                && HopRewriteUtils.isMatrixMultiply(left)
                                && 
HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
-                               && left.getInput().get(0).getDim2() == 1 
//column vector 
-                               && left.getDim1() != 1 && right.getDim1() == 1 
) //outer vector product 
+                               && (left.getInput().get(0).getDim2() == 1 
//outer product
+                                       || left.getInput().get(1).getDim1() == 
1)
+                               && left.getDim1() != 1 && right.getDim1() == 1 
) //outer vector binary 
                        {
                                Hop hnew = 
HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
                                HopRewriteUtils.replaceChildReference(parent, 
hi, hnew, pos);

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 64a37d7..2d5d881 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -38,10 +38,12 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp3;
+import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.NaryOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
@@ -147,23 +149,24 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = removeUnnecessaryBinaryOperation(hop, hi, i);   
//e.g., X*1 -> X (dep: should come after rm unnecessary vectorize)
                        hi = fuseDatagenAndBinaryOperation(hi);              
//e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7)
                        hi = fuseDatagenAndMinusOperation(hi);               
//e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2)
-                       hi = simplifyBinaryToUnaryOperation(hop, hi, i);     
//e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
-                       hi = canonicalizeMatrixMultScalarAdd(hi);            
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) 
-                       hi = simplifyReverseOperation(hop, hi, i);           
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
-                       if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
-                               hi = simplifyMultiBinaryToBinaryOperation(hi);  
     //e.g., 1-X*Y -> X 1-* Y
-                       hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
-                       hi = simplifyBushyBinaryOperation(hop, hi, i);       
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
-                       hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
-                       hi = removeUnnecessaryAggregates(hi);                
//e.g., sum(rowSums(X)) -> sum(X)
-                       hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
-                       hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
-                       hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
-                       hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lamda*X) -> lamda*sum(X)
-                       hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
-                       hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
-                       if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
-                               hi = fuseBinarySubDAGToUnaryOperation(hop, hi, 
i);   //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> 
selp(X)
+                       hi = foldMultipleAppendOperations(hi);               
//e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z)
+                       hi = simplifyBinaryToUnaryOperation(hop, hi, i);     
//e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
+                       hi = canonicalizeMatrixMultScalarAdd(hi);            
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) 
+                       hi = simplifyReverseOperation(hop, hi, i);           
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
+                       if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
+                               hi = simplifyMultiBinaryToBinaryOperation(hi);  
     //e.g., 1-X*Y -> X 1-* Y
+                       hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
+                       hi = simplifyBushyBinaryOperation(hop, hi, i);       
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
+                       hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
+                       hi = removeUnnecessaryAggregates(hi);                
//e.g., sum(rowSums(X)) -> sum(X)
+                       hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
+                       hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
+                       hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
+                       hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lamda*X) -> lamda*sum(X)
+                       hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
+                       hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
+                       if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
+                               hi = fuseBinarySubDAGToUnaryOperation(hop, hi, 
i);   //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> 
selp(X)
                        hi = simplifyTraceMatrixMult(hop, hi, i);            
//e.g., trace(X%*%Y)->sum(X*t(Y));  
                        hi = simplifySlicedMatrixMult(hop, hi, i);           
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
                        hi = simplifyConstantSort(hop, hi, i);               
//e.g., order(matrix())->matrix/seq; 
@@ -358,14 +361,13 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        //the creation of multiple datagen ops and thus 
potentially different results if seed not specified)
                        
                        //left input rand and hence output matrix double, right 
scalar literal
-                       if( left instanceof DataGenOp && 
((DataGenOp)left).getOp()==DataGenMethod.RAND &&
+                       if( HopRewriteUtils.isDataGenOp(left, 
DataGenMethod.RAND) &&
                                right instanceof LiteralOp && 
left.getParent().size()==1 )
                        {
                                DataGenOp inputGen = (DataGenOp)left;
-                               HashMap<String,Integer> params = 
inputGen.getParamIndexMap();
-                               Hop pdf = 
left.getInput().get(params.get(DataExpression.RAND_PDF));
-                               Hop min = 
left.getInput().get(params.get(DataExpression.RAND_MIN));
-                               Hop max = 
left.getInput().get(params.get(DataExpression.RAND_MAX));
+                               Hop pdf = 
inputGen.getInput(DataExpression.RAND_PDF);
+                               Hop min = 
inputGen.getInput(DataExpression.RAND_MIN);
+                               Hop max = 
inputGen.getInput(DataExpression.RAND_MAX);
                                double sval = 
((LiteralOp)right).getDoubleValue();
                                
                                if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, 
OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV)
@@ -396,10 +398,9 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                left instanceof LiteralOp && 
right.getParent().size()==1 )
                        {
                                DataGenOp inputGen = (DataGenOp)right;
-                               HashMap<String,Integer> params = 
inputGen.getParamIndexMap();
-                               Hop pdf = 
right.getInput().get(params.get(DataExpression.RAND_PDF));
-                               Hop min = 
right.getInput().get(params.get(DataExpression.RAND_MIN));
-                               Hop max = 
right.getInput().get(params.get(DataExpression.RAND_MAX));
+                               Hop pdf = 
inputGen.getInput(DataExpression.RAND_PDF);
+                               Hop min = 
inputGen.getInput(DataExpression.RAND_MIN);
+                               Hop max = 
inputGen.getInput(DataExpression.RAND_MAX);
                                double sval = 
((LiteralOp)left).getDoubleValue();
                                
                                if( (bop.getOp()==OpOp2.MULT || 
bop.getOp()==OpOp2.PLUS)
@@ -423,6 +424,44 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                        LOG.debug("Applied 
fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+").");
                                }
                        }
+                       //left input rand and hence output matrix double, right 
scalar variable
+                       else if( HopRewriteUtils.isDataGenOp(left, 
DataGenMethod.RAND) 
+                               && right.getDataType().isScalar() && 
left.getParent().size()==1 )
+                       {
+                               DataGenOp gen = (DataGenOp)left;
+                               Hop min = gen.getInput(DataExpression.RAND_MIN);
+                               Hop max = gen.getInput(DataExpression.RAND_MAX);
+                               
+                               if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS)
+                                       && 
HopRewriteUtils.isLiteralOfValue(min, 0)
+                                       && 
HopRewriteUtils.isLiteralOfValue(max, 0) )
+                               {
+                                       gen.setInput(DataExpression.RAND_MIN, 
right);
+                                       gen.setInput(DataExpression.RAND_MAX, 
right);
+                                       //rewire all parents (avoid anomalies 
with replicated datagen)
+                                       List<Hop> parents = new 
ArrayList<>(bop.getParent());
+                                       for( Hop p : parents )
+                                               
HopRewriteUtils.replaceChildReference(p, bop, gen);
+                                       hi = gen;
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3a (line "+bop.getBeginLine()+").");
+                               }
+                               else if( HopRewriteUtils.isBinary(bop, 
OpOp2.MULT)
+                                       && 
(HopRewriteUtils.isLiteralOfValue(min, 0)
+                                               || 
HopRewriteUtils.isLiteralOfValue(min, 1))
+                                       && 
HopRewriteUtils.isLiteralOfValue(max, 1) )
+                               {
+                                       if( 
HopRewriteUtils.isLiteralOfValue(min, 1) )
+                                               
gen.setInput(DataExpression.RAND_MIN, right);
+                                       gen.setInput(DataExpression.RAND_MAX, 
right);
+                                       //rewire all parents (avoid anomalies 
with replicated datagen)
+                                       List<Hop> parents = new 
ArrayList<>(bop.getParent());
+                                       for( Hop p : parents )
+                                               
HopRewriteUtils.replaceChildReference(p, bop, gen);
+                                       hi = gen;
+                                       LOG.debug("Applied 
fuseDatagenAndBinaryOperation3b (line "+bop.getBeginLine()+").");
+                               }
+                       }
+                       
                }
                
                return hi;
@@ -478,6 +517,55 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                return hi;
        }
        
+       private static Hop foldMultipleAppendOperations(Hop hi) 
+               throws HopsException
+       {
+               if( hi.getDataType().isMatrix() //no string appends
+                       && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, 
OpOp2.RBIND) 
+                       || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND))
+                       && !OptimizerUtils.isHadoopExecutionMode() )
+               {
+                       OpOp2 bop = (hi instanceof BinaryOp) ? 
((BinaryOp)hi).getOp() :
+                               OpOp2.valueOf(((NaryOp)hi).getOp().name());
+                       OpOpN nop = (hi instanceof NaryOp) ? 
((NaryOp)hi).getOp() :
+                               OpOpN.valueOf(((BinaryOp)hi).getOp().name());
+                       
+                       boolean converged = false;
+                       while( !converged ) {
+                               //get first matching cbind or rbind
+                               Hop first = hi.getInput().stream()
+                                       .filter(h -> 
HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop))
+                                       .findFirst().orElse(null);
+                               
+                               //replace current op with new nary cbind/rbind
+                               if( first != null && 
first.getParent().size()==1 ) {
+                                       //construct new list of inputs (in 
original order)
+                                       ArrayList<Hop> linputs = new 
ArrayList<>();
+                                       for(Hop in : hi.getInput())
+                                               if( in == first )
+                                                       
linputs.addAll(first.getInput());
+                                               else
+                                                       linputs.add(in);
+                                       Hop hnew = 
HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0]));
+                                       //clear dangling references
+                                       
HopRewriteUtils.removeAllChildReferences(hi);
+                                       
HopRewriteUtils.removeAllChildReferences(first);
+                                       //rewire all parents (avoid anomalies 
with refs to hi)
+                                       List<Hop> parents = new 
ArrayList<>(hi.getParent());
+                                       for( Hop p : parents )
+                                               
HopRewriteUtils.replaceChildReference(p, hi, hnew);
+                                       hi = hnew;
+                                       LOG.debug("Applied 
foldMultipleAppendOperations (line "+hi.getBeginLine()+").");
+                               }
+                               else {
+                                       converged = true;
+                               }
+                       }
+               }
+               
+               return hi;
+       }
+       
        /**
         * Handle simplification of binary operations (relies on previous 
common subexpression elimination).
         * At the same time this servers as a canonicalization for more complex 
rewrites. 

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 555493c..5cc0650 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -545,7 +545,8 @@ public class Statistics
        }
        
        public static long getCPHeavyHitterCount(String opcode) {
-               return _cpInstCounts.get(opcode);
+               Long tmp = _cpInstCounts.get(opcode);
+               return (tmp != null) ? tmp : 0;
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java 
b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index d39e513..3bd78b0 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1808,6 +1808,21 @@ public abstract class AutomatedTestBase
                return writeInputFrame(name, data, false, schema, oi);
        }
 
+       protected boolean heavyHittersContainsString(String... str) {
+               for( String opcode : Statistics.getCPHeavyHitterOpCodes())
+                       for( String s : str )
+                               if(opcode.equals(s))
+                                       return true;
+               return false;
+       }
+       
+       protected boolean heavyHittersContainsString(String str, int minCount) {
+               int count = 0;
+               for( String opcode : Statistics.getCPHeavyHitterOpCodes())
+                       count += opcode.equals(str) ? 1 : 0;
+               return (count >= minCount);
+       }
+       
        protected boolean heavyHittersContainsSubString(String... str) {
                for( String opcode : Statistics.getCPHeavyHitterOpCodes())
                        for( String s : str )

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java
new file mode 100644
index 0000000..83050bd
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+public class RewriteFoldRCBindTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME1 = "RewriteFoldCBind";
+       private static final String TEST_NAME2 = "RewriteFoldRBind";
+       
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFoldRCBindTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 1932;
+       private static final int cols = 14;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testRewriteFoldCBindNoRewrite() {
+               testRewriteFoldRCBind( TEST_NAME1, false );
+       }
+       
+       @Test
+       public void testRewriteFoldCBindRewrite() {
+               testRewriteFoldRCBind( TEST_NAME1, true );
+       }
+       
+       @Test
+       public void testRewriteFoldRBindNoRewrite() {
+               testRewriteFoldRCBind( TEST_NAME2, false );
+       }
+       
+       @Test
+       public void testRewriteFoldRBindRewrite() {
+               testRewriteFoldRCBind( TEST_NAME2, true );
+       }
+
+       private void testRewriteFoldRCBind( String testname, boolean rewrites )
+       {       
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats", "-args", 
String.valueOf(rows), 
+                                       String.valueOf(cols), output("R") };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       //run performance tests
+                       runTest(true, false, null, -1); 
+                       
+                       //compare matrices 
+                       Double ret = readDMLMatrixFromHDFS("R").get(new 
CellIndex(1,1));
+                       Assert.assertEquals("Wrong result", new 
Double(5*rows*cols), ret);
+                       
+                       //check for applied rewrites
+                       if( rewrites ) {
+                               
Assert.assertTrue(!heavyHittersContainsString("append")
+                                       && 
Statistics.getCPHeavyHitterCount("cbind") <= 1
+                                       && 
Statistics.getCPHeavyHitterCount("rbind") <= 1);
+                       }
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
index 1491aa6..d7fe902 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysml.test.integration.functions.misc;
 
-import java.util.HashMap;
-
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysml.hops.OptimizerUtils;
@@ -29,13 +27,12 @@ import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
 
-/**
- * 
- * 
- */
 public class RewriteFusedRandTest extends AutomatedTestBase 
-{      
-       private static final String TEST_NAME1 = "RewriteFusedRand";
+{
+       private static final String TEST_NAME1 = "RewriteFusedRandLit";
+       private static final String TEST_NAME2 = "RewriteFusedRandVar1";
+       private static final String TEST_NAME3 = "RewriteFusedRandVar2";
+       
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFusedRandTest.class.getSimpleName() + "/";
        
@@ -47,44 +44,50 @@ public class RewriteFusedRandTest extends AutomatedTestBase
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
        }
 
        @Test
-       public void testRewriteFusedRandUniformNoRewrite()  {
+       public void testRewriteFusedRandUniformNoRewrite() {
                testRewriteFusedRand( TEST_NAME1, "uniform", false );
        }
        
        @Test
-       public void testRewriteFusedRandNormalNoRewrite()  {
+       public void testRewriteFusedRandNormalNoRewrite() {
                testRewriteFusedRand( TEST_NAME1, "normal", false );
        }
        
        @Test
-       public void testRewriteFusedRandPoissonNoRewrite()  {
+       public void testRewriteFusedRandPoissonNoRewrite() {
                testRewriteFusedRand( TEST_NAME1, "poisson", false );
        }
        
        @Test
-       public void testRewriteFusedRandUniform()  {
+       public void testRewriteFusedRandUniform() {
                testRewriteFusedRand( TEST_NAME1, "uniform", true );
        }
        
        @Test
-       public void testRewriteFusedRandNormal()  {
+       public void testRewriteFusedRandNormal() {
                testRewriteFusedRand( TEST_NAME1, "normal", true );
        }
        
        @Test
-       public void testRewriteFusedRandPoisson()  {
+       public void testRewriteFusedRandPoisson() {
                testRewriteFusedRand( TEST_NAME1, "poisson", true );
        }
        
-       /**
-        * 
-        * @param condition
-        * @param branchRemoval
-        * @param IPA
-        */
+       @Test
+       public void testRewriteFusedZerosPlusVar() {
+               testRewriteFusedRand( TEST_NAME2, "uniform", true );
+       }
+       
+       @Test
+       public void testRewriteFusedOnesMultVar() {
+               testRewriteFusedRand( TEST_NAME3, "uniform", true );
+       }
+       
        private void testRewriteFusedRand( String testname, String pdf, boolean 
rewrites )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -95,7 +98,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{ "-args", 
String.valueOf(rows), 
+                       programArgs = new String[]{ "-stats", "-args", 
String.valueOf(rows), 
                                        String.valueOf(cols), pdf, 
String.valueOf(seed), output("R") };
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
 
@@ -103,11 +106,22 @@ public class RewriteFusedRandTest extends 
AutomatedTestBase
                        runTest(true, false, null, -1); 
                        
                        //compare matrices 
-                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
-                       Assert.assertEquals("Wrong result, expected: "+rows, 
new Double(rows), dmlfile.get(new CellIndex(1,1)));
+                       Double ret = readDMLMatrixFromHDFS("R").get(new 
CellIndex(1,1));
+                       if( testname.equals(TEST_NAME1) )
+                               Assert.assertEquals("Wrong result", new 
Double(rows), ret);
+                       else if( testname.equals(TEST_NAME2) )
+                               Assert.assertEquals("Wrong result", new 
Double(Math.pow(rows*cols, 2)), ret);
+                       else if( testname.equals(TEST_NAME3) )
+                               Assert.assertEquals("Wrong result", new 
Double(Math.pow(rows*cols, 2)), ret);
+                       
+                       //check for applied rewrites
+                       if( rewrites && pdf.equals("uniform") ) {
+                               
Assert.assertTrue(!heavyHittersContainsString("+")
+                                       && !heavyHittersContainsString("*"));
+                       }
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
                }
        }       
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFoldCBind.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFoldCBind.dml 
b/src/test/scripts/functions/misc/RewriteFoldCBind.dml
new file mode 100644
index 0000000..47733fe
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFoldCBind.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = cbind(cbind(X,X),cbind(X,X,X))
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFoldRBind.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFoldRBind.dml 
b/src/test/scripts/functions/misc/RewriteFoldRBind.dml
new file mode 100644
index 0000000..c489f7c
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFoldRBind.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = rbind(rbind(X,X),rbind(X,X,X))
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRand.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRand.dml 
b/src/test/scripts/functions/misc/RewriteFusedRand.dml
deleted file mode 100644
index ab00f04..0000000
--- a/src/test/scripts/functions/misc/RewriteFusedRand.dml
+++ /dev/null
@@ -1,29 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandLit.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/misc/RewriteFusedRandLit.dml
new file mode 100644
index 0000000..ab00f04
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFusedRandLit.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+
+while(FALSE){} #prevent cse
+
+X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+
+R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
+write(R, $5);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml 
b/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml
new file mode 100644
index 0000000..f37557f
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = matrix(0, $1, $2) + sum(X);
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $5);

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml 
b/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml
new file mode 100644
index 0000000..d98ef08
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, $1, $2)
+while(FALSE){}
+Y = matrix(1, $1, $2) * sum(X);
+while(FALSE){}
+R = as.matrix(sum(Y))
+
+write(R, $5);

http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index a453cbd..ae4f820 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -56,6 +56,7 @@ import org.junit.runners.Suite;
        RewriteCTableToRExpandTest.class,
        RewriteElementwiseMultChainOptimizationTest.class,
        RewriteEliminateAggregatesTest.class,
+       RewriteFoldRCBindTest.class,
        RewriteFuseBinaryOpChainTest.class,
        RewriteFusedRandTest.class,
        RewriteLoopVectorization.class,

Reply via email to