This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 553d30c482 [SYSTEMDS-3781] New constant conjunction/disjunction
rewrites
553d30c482 is described below
commit 553d30c48272ef4fb10a73e85b602324c5e99ac4
Author: aarna <[email protected]>
AuthorDate: Fri Nov 8 17:27:19 2024 +0100
[SYSTEMDS-3781] New constant conjunction/disjunction rewrites
Closes #2134.
---
.../RewriteAlgebraicSimplificationStatic.java | 1095 ++++++++++----------
.../RewriteConstantConjunctionDisjunctionTest.java | 80 ++
.../RewriteBooleanSimplificationTestAnd.dml | 27 +
.../rewrite/RewriteBooleanSimplificationTestOr.dml | 27 +
4 files changed, 704 insertions(+), 525 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index a18a2b7466..056770dceb 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -62,26 +62,26 @@ import org.apache.sysds.common.Types.ValueType;
* estimate, in MR this allows map-only operations and hence prevents
* unnecessary shuffle and sort) and (2) remove binary operations that
* are in itself are unnecessary (e.g., *1 and /1).
- *
+ *
*/
public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
{
//valid aggregation operation types for rowOp to colOp conversions and
vice versa
private static final AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new
AggOp[] {
- AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN,
AggOp.VAR};
-
+ AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX,
AggOp.MEAN, AggOp.VAR};
+
//valid binary operations for distributive and associate reorderings
- private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new
OpOp2[] {OpOp2.PLUS, OpOp2.MINUS};
+ private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new
OpOp2[] {OpOp2.PLUS, OpOp2.MINUS};
private static final OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new
OpOp2[] {OpOp2.PLUS, OpOp2.MULT};
-
+
//valid binary operations for scalar operations
- private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[]
{OpOp2.AND, OpOp2.DIV,
- OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV,
OpOp2.LESS, OpOp2.LESSEQUAL,
- OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS,
OpOp2.MULT, OpOp2.NOTEQUAL,
- OpOp2.OR, OpOp2.PLUS, OpOp2.POW};
-
+ private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[]
{OpOp2.AND, OpOp2.DIV,
+ OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL,
OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL,
+ OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS,
OpOp2.MODULUS, OpOp2.MULT, OpOp2.NOTEQUAL,
+ OpOp2.OR, OpOp2.PLUS, OpOp2.POW};
+
@Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state)
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state)
{
if( roots == null )
return roots;
@@ -90,32 +90,32 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
for( Hop h : roots )
rule_AlgebraicSimplification( h, false );
Hop.resetVisitStatus(roots, true);
-
+
//one pass descend-rewrite (for rollup)
for( Hop h : roots )
rule_AlgebraicSimplification( h, true );
Hop.resetVisitStatus(roots, true);
-
+
//cleanup remove (twrite <- tread) pairs (unless checkpointing)
removeTWriteTReadPairs(roots);
-
+
return roots;
}
@Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
{
if( root == null )
return root;
-
+
//one pass rewrite-descend (rewrite created pattern)
rule_AlgebraicSimplification( root, false );
root.resetVisitStatus();
-
+
//one pass descend-rewrite (for rollup)
rule_AlgebraicSimplification( root, true );
-
+
return root;
}
@@ -125,24 +125,24 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
* (1) the results would be not exactly the same (2 rounds instead of
1) and (2) it should
* come before constant folding while the other simplifications should
come after constant
* folding. Hence, not applied yet.
- *
+ *
* @param hop high-level operator
* @param descendFirst if process children recursively first
*/
- private void rule_AlgebraicSimplification(Hop hop, boolean
descendFirst)
+ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
{
if(hop.isVisited())
return;
-
+
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
Hop hi = hop.getInput().get(i);
-
+
//process childs recursively first (to allow roll-up)
if( descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);
//see below
-
+
//apply actual simplification rewrites (of childs incl
checks)
hi = removeUnnecessaryVectorizeOperation(hi);
//e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X
hi = removeUnnecessaryBinaryOperation(hop, hi, i);
//e.g., X*1 -> X (dep: should come after rm unnecessary vectorize)
@@ -153,6 +153,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = canonicalizeMatrixMultScalarAdd(hi);
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps)
hi = simplifyCTableWithConstMatrixInputs(hi);
//e.g., table(X, matrix(1,...)) -> table(X, 1)
hi = removeUnnecessaryCTable(hop, hi, i);
//e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and
sum(table(X, Y)) -> nrow(X)
+ hi = simplifyConstantConjunction(hop, hi, i);
//e.g., a & !a -> FALSE
hi = simplifyReverseOperation(hop, hi, i);
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi);
//e.g., 1-X*Y -> X 1-* Y
@@ -169,11 +170,11 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyTransposedAppend(hop, hi, i);
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = fuseBinarySubDAGToUnaryOperation(hop, hi,
i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) ->
selp(X)
- hi = simplifyTraceMatrixMult(hop, hi, i);
//e.g., trace(X%*%Y)->sum(X*t(Y));
+ hi = simplifyTraceMatrixMult(hop, hi, i);
//e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifySlicedMatrixMult(hop, hi, i);
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyListIndexing(hi);
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
- hi = simplifyConstantSort(hop, hi, i);
//e.g., order(matrix())->matrix/seq;
- hi = simplifyOrderedSort(hop, hi, i);
//e.g., order(matrix())->seq;
+ hi = simplifyConstantSort(hop, hi, i);
//e.g., order(matrix())->matrix/seq;
+ hi = simplifyOrderedSort(hop, hi, i);
//e.g., order(matrix())->seq;
hi = fuseOrderOperationChain(hi);
//e.g., order(order(X,2),1) -> order(X,(12))
hi = removeUnnecessaryReorgOperation(hop, hi, i);
//e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
hi = removeUnnecessaryRemoveEmpty(hop, hi, i);
//e.g., nrow(removeEmpty(A)) -> nnz(A) iff col vector
@@ -187,15 +188,15 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = fuseLogNzBinaryOperation(hop, hi, i);
//e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
}
hi = simplifyOuterSeqExpand(hop, hi, i);
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true,
cast=false)
- hi = simplifyBinaryComparisonChain(hop, hi, i);
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 ->
outer(v1,v2,"!="),
+ hi = simplifyBinaryComparisonChain(hop, hi, i);
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 ->
outer(v1,v2,"!="),
hi = simplifyCumsumColOrFullAggregates(hi);
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
hi = simplifyCumsumReverse(hop, hi, i);
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
- hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
+ hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
hi = fixNonScalarPrint(hop, hi, i);
//e.g., print(m) -> print(toString(m))
-
+
//process childs recursively after rewrites (to
investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);
@@ -203,21 +204,21 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hop.setVisited();
}
-
+
private static Hop removeUnnecessaryVectorizeOperation(Hop hi)
{
- //applies to all binary matrix operations, if one input is
unnecessarily vectorized
- if( hi instanceof BinaryOp &&
hi.getDataType()==DataType.MATRIX
- && ((BinaryOp)hi).supportsMatrixScalarOperations() )
+ //applies to all binary matrix operations, if one input is
unnecessarily vectorized
+ if( hi instanceof BinaryOp &&
hi.getDataType()==DataType.MATRIX
+ &&
((BinaryOp)hi).supportsMatrixScalarOperations() )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
-
- //NOTE: these rewrites of binary cell operations need
to be aware that right is
+
+ //NOTE: these rewrites of binary cell operations need
to be aware that right is
//potentially a vector but the result is of the size of
left
//TODO move to dynamic rewrites (since size dependent
to account for mv binary cell and outer operations)
-
+
if( !(left.getDim1()>1 && left.getDim2()==1 &&
right.getDim1()==1 && right.getDim2()>1) ) // no outer
{
//check and remove right vectorized scalar
@@ -229,7 +230,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
Hop drightIn =
dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
HopRewriteUtils.cleanupUnreferenced(dright);
-
+
LOG.debug("Applied
removeUnnecessaryVectorizeOperation1");
}
}
@@ -238,50 +239,50 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
{
DataGenOp dleft = (DataGenOp) left;
if( dleft.getOp()==OpOpDG.RAND &&
dleft.hasConstantValue()
- && (left.getDim2()==1 ||
right.getDim2()>1)
- && (left.getDim1()==1 ||
right.getDim1()>1))
+ && (left.getDim2()==1
|| right.getDim2()>1)
+ && (left.getDim1()==1
|| right.getDim1()>1))
{
Hop dleftIn =
dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
HopRewriteUtils.cleanupUnreferenced(dleft);
-
+
LOG.debug("Applied
removeUnnecessaryVectorizeOperation2");
}
}
//Note: we applied this rewrite to at most one
side in order to keep the
//output semantically equivalent. However,
future extensions might consider
- //to remove vectors from both side, compute the
binary op on scalars and
+ //to remove vectors from both side, compute the
binary op on scalars and
//finally feed it into a datagenop of the
original dimensions.
}
}
-
+
return hi;
}
-
-
+
+
/**
* handle removal of unnecessary binary operations
- *
+ *
* X/1 or X*1 or 1*X or X-0 -> X
* -1*X or X*-1-> -X
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
- private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop
hi, int pos )
+ private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop
hi, int pos )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
- //X/1 or X*1 -> X
- if( left.getDataType()==DataType.MATRIX
- && right instanceof LiteralOp &&
right.getValueType().isNumeric()
- && ((LiteralOp)right).getDoubleValue()==1.0 )
+ //X/1 or X*1 -> X
+ if( left.getDataType()==DataType.MATRIX
+ && right instanceof LiteralOp &&
right.getValueType().isNumeric()
+ &&
((LiteralOp)right).getDoubleValue()==1.0 )
{
if( bop.getOp()==OpOp2.DIV ||
bop.getOp()==OpOp2.MULT )
{
@@ -291,8 +292,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
LOG.debug("Applied
removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")");
}
}
- //X-0 -> X
- else if( left.getDataType()==DataType.MATRIX
+ //X-0 -> X
+ else if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp &&
right.getValueType().isNumeric()
&&
((LiteralOp)right).getDoubleValue()==0.0 )
{
@@ -305,7 +306,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
}
//1*X -> X
- else if( right.getDataType()==DataType.MATRIX
+ else if( right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp &&
left.getValueType().isNumeric()
&&
((LiteralOp)left).getDoubleValue()==1.0 )
{
@@ -318,9 +319,9 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
}
//-1*X -> -X
- //note: this rewrite is necessary since the new antlr
parser always converts
+ //note: this rewrite is necessary since the new antlr
parser always converts
//-X to -1*X due to mechanical reasons
- else if( right.getDataType()==DataType.MATRIX
+ else if( right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp &&
left.getValueType().isNumeric()
&&
((LiteralOp)left).getDoubleValue()==-1.0 )
{
@@ -334,7 +335,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
}
//X*-1 -> -X (see comment above)
- else if( left.getDataType()==DataType.MATRIX
+ else if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp &&
right.getValueType().isNumeric()
&&
((LiteralOp)right).getDoubleValue()==-1.0 )
{
@@ -344,85 +345,129 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
HopRewriteUtils.removeChildReferenceByPos(bop, right, 1);
HopRewriteUtils.addChildReference(bop,
new LiteralOp(0), 0);
hi = bop;
-
+
LOG.debug("Applied
removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")");
}
}
}
-
+
+ return hi;
+ }
+
+ public static Hop simplifyConstantConjunction(Hop parent, Hop hi, int
pos) {
+ if (hi instanceof BinaryOp) {
+ BinaryOp bop = (BinaryOp) hi;
+ Hop left = hi.getInput().get(0);
+ Hop right = hi.getInput().get(1);
+
+ // Patterns: a & !a --> FALSE / !a & a --> FALSE
+ if (bop.getOp() == OpOp2.AND
+ && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput().get(0))
+ || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput().get(0) == right)))
+ {
+ LiteralOp falseOp = new LiteralOp(false);
+
+ // Ensure parent has the input before
attempting replacement
+ if (parent != null && parent.getInput().size()
> pos) {
+
HopRewriteUtils.replaceChildReference(parent, hi, falseOp, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi,
left, right);
+ hi = falseOp;
+ }
+
+ LOG.debug("Applied simplifyBooleanRewrite1
(line " + hi.getBeginLine() + ").");
+ }
+ // Pattern: a | !a --> TRUE
+ else if (bop.getOp() == OpOp2.OR
+ && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput().get(0))
+ || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput().get(0) == right)))
+ {
+ LiteralOp trueOp = new LiteralOp(true);
+
+ // Ensure parent has the input before
attempting replacement
+ if (parent != null && parent.getInput().size()
> pos) {
+
HopRewriteUtils.replaceChildReference(parent, hi, trueOp, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi,
left, right);
+ hi = trueOp;
+ }
+
+ LOG.debug("Applied simplifyBooleanRewrite2
(line " + hi.getBeginLine() + ").");
+ }
+ }
+
return hi;
}
-
+
+
/**
* Handle removal of unnecessary binary operations over rand data
- *
+ *
* rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 ->
rand(min+(-7),max+(-7))
* 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
- *
+ *
* @param hi high-order operation
* @return high-level operator
*/
@SuppressWarnings("incomplete-switch")
- private static Hop fuseDatagenAndBinaryOperation( Hop hi )
+ private static Hop fuseDatagenAndBinaryOperation( Hop hi )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
-
+
//NOTE: rewrite not applied if more than one datagen
consumer because this would lead to
//the creation of multiple datagen ops and thus
potentially different results if seed not specified)
-
+
//left input rand and hence output matrix double, right
scalar literal
if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) &&
- right instanceof LiteralOp &&
left.getParent().size()==1 )
+ right instanceof LiteralOp &&
left.getParent().size()==1 )
{
DataGenOp inputGen = (DataGenOp)left;
Hop pdf =
inputGen.getInput(DataExpression.RAND_PDF);
Hop min =
inputGen.getInput(DataExpression.RAND_MIN);
Hop max =
inputGen.getInput(DataExpression.RAND_MAX);
double sval =
((LiteralOp)right).getDoubleValue();
- boolean pdfUniform = pdf instanceof LiteralOp
- &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
-
+ boolean pdfUniform = pdf instanceof LiteralOp
+ &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
+
if( HopRewriteUtils.isBinary(bop, OpOp2.MULT,
OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV)
- && min instanceof LiteralOp && max
instanceof LiteralOp && pdfUniform )
+ && min instanceof LiteralOp &&
max instanceof LiteralOp && pdfUniform )
{
//create fused data gen operator
DataGenOp gen = null;
switch( bop.getOp() ) { //fuse via
scale and shift
case MULT: gen =
HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); break;
case PLUS:
- case MINUS: gen =
HopRewriteUtils.copyDataGenOp(inputGen,
- 1, sval *
((bop.getOp()==OpOp2.MINUS)?-1:1)); break;
+ case MINUS: gen =
HopRewriteUtils.copyDataGenOp(inputGen,
+ 1, sval *
((bop.getOp()==OpOp2.MINUS)?-1:1)); break;
case DIV: gen =
HopRewriteUtils.copyDataGenOp(inputGen, 1/sval, 0); break;
}
-
+
//rewire all parents (avoid anomalies
with replicated datagen)
List<Hop> parents = new
ArrayList<>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, gen);
-
+
hi = gen;
LOG.debug("Applied
fuseDatagenAndBinaryOperation1 "
- + "("+bop.getFilename()+", line
"+bop.getBeginLine()+").");
+ +
"("+bop.getFilename()+", line "+bop.getBeginLine()+").");
}
}
//right input rand and hence output matrix double, left
scalar literal
else if( right instanceof DataGenOp &&
((DataGenOp)right).getOp()==OpOpDG.RAND &&
- left instanceof LiteralOp &&
right.getParent().size()==1 )
+ left instanceof LiteralOp &&
right.getParent().size()==1 )
{
DataGenOp inputGen = (DataGenOp)right;
Hop pdf =
inputGen.getInput(DataExpression.RAND_PDF);
Hop min =
inputGen.getInput(DataExpression.RAND_MIN);
Hop max =
inputGen.getInput(DataExpression.RAND_MAX);
double sval =
((LiteralOp)left).getDoubleValue();
- boolean pdfUniform = pdf instanceof LiteralOp
- &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
-
+ boolean pdfUniform = pdf instanceof LiteralOp
+ &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
+
if( (bop.getOp()==OpOp2.MULT ||
bop.getOp()==OpOp2.PLUS)
- && min instanceof LiteralOp && max
instanceof LiteralOp && pdfUniform )
+ && min instanceof LiteralOp &&
max instanceof LiteralOp && pdfUniform )
{
//create fused data gen operator
DataGenOp gen = null;
@@ -431,32 +476,32 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
else { //OpOp2.PLUS
gen =
HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
-
+
//rewire all parents (avoid anomalies
with replicated datagen)
List<Hop> parents = new
ArrayList<>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, gen);
-
+
hi = gen;
LOG.debug("Applied
fuseDatagenAndBinaryOperation2 "
- + "("+bop.getFilename()+", line
"+bop.getBeginLine()+").");
+ +
"("+bop.getFilename()+", line "+bop.getBeginLine()+").");
}
}
//left input rand and hence output matrix double, right
scalar variable
- else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND)
- && right.getDataType().isScalar() &&
left.getParent().size()==1 )
+ else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND)
+ && right.getDataType().isScalar() &&
left.getParent().size()==1 )
{
DataGenOp gen = (DataGenOp)left;
Hop min = gen.getInput(DataExpression.RAND_MIN);
Hop max = gen.getInput(DataExpression.RAND_MAX);
Hop pdf = gen.getInput(DataExpression.RAND_PDF);
- boolean pdfUniform = pdf instanceof LiteralOp
- &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
-
-
+ boolean pdfUniform = pdf instanceof LiteralOp
+ &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue());
+
+
if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS)
- &&
HopRewriteUtils.isLiteralOfValue(min, 0)
- &&
HopRewriteUtils.isLiteralOfValue(max, 0) )
+ &&
HopRewriteUtils.isLiteralOfValue(min, 0)
+ &&
HopRewriteUtils.isLiteralOfValue(max, 0) )
{
gen.setInput(DataExpression.RAND_MIN,
right, true);
gen.setInput(DataExpression.RAND_MAX,
right, true);
@@ -466,12 +511,12 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied
fuseDatagenAndBinaryOperation3a "
- + "("+bop.getFilename()+", line
"+bop.getBeginLine()+").");
+ +
"("+bop.getFilename()+", line "+bop.getBeginLine()+").");
}
else if( HopRewriteUtils.isBinary(bop,
OpOp2.MULT)
- &&
((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform)
+ &&
((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform)
||
HopRewriteUtils.isLiteralOfValue(min, 1))
- &&
HopRewriteUtils.isLiteralOfValue(max, 1) )
+ &&
HopRewriteUtils.isLiteralOfValue(max, 1) )
{
if(
HopRewriteUtils.isLiteralOfValue(min, 1) )
gen.setInput(DataExpression.RAND_MIN, right, true);
@@ -482,24 +527,24 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied
fuseDatagenAndBinaryOperation3b "
- + "("+bop.getFilename()+", line
"+bop.getBeginLine()+").");
+ +
"("+bop.getFilename()+", line "+bop.getBeginLine()+").");
}
}
}
-
+
return hi;
}
-
- private static Hop fuseDatagenAndMinusOperation( Hop hi )
+
+ private static Hop fuseDatagenAndMinusOperation( Hop hi )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
-
+
if( right instanceof DataGenOp &&
((DataGenOp)right).getOp()==OpOpDG.RAND &&
- left instanceof LiteralOp &&
((LiteralOp)left).getDoubleValue()==0.0 )
+ left instanceof LiteralOp &&
((LiteralOp)left).getDoubleValue()==0.0 )
{
DataGenOp inputGen = (DataGenOp)right;
HashMap<String,Integer> params =
inputGen.getParamIndexMap();
@@ -508,55 +553,55 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
int ixMax = params.get(DataExpression.RAND_MAX);
Hop min = right.getInput().get(ixMin);
Hop max = right.getInput().get(ixMax);
-
+
//apply rewrite under additional conditions
(for simplicity)
- if( inputGen.getParent().size()==1
- && min instanceof LiteralOp && max
instanceof LiteralOp && pdf instanceof LiteralOp
- &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
+ if( inputGen.getParent().size()==1
+ && min instanceof LiteralOp &&
max instanceof LiteralOp && pdf instanceof LiteralOp
+ &&
DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
{
//exchange and *-1 (special case 0
stays 0 instead of -0 for consistency)
double newMinVal =
(((LiteralOp)max).getDoubleValue()==0)?0:(-1 *
((LiteralOp)max).getDoubleValue());
double newMaxVal =
(((LiteralOp)min).getDoubleValue()==0)?0:(-1 *
((LiteralOp)min).getDoubleValue());
Hop newMin = new LiteralOp(newMinVal);
Hop newMax = new LiteralOp(newMaxVal);
-
+
HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
-
+
//rewire all parents (avoid anomalies
with replicated datagen)
List<Hop> parents = new
ArrayList<>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, inputGen);
-
+
hi = inputGen;
LOG.debug("Applied
fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+").");
}
}
}
-
+
return hi;
}
-
- private static Hop foldMultipleAppendOperations(Hop hi)
+
+ private static Hop foldMultipleAppendOperations(Hop hi)
{
if( hi.getDataType().isMatrix() //no string appends or frames
- && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND,
OpOp2.RBIND)
- || HopRewriteUtils.isNary(hi, OpOpN.CBIND,
OpOpN.RBIND)) )
+ && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND,
OpOp2.RBIND)
+ || HopRewriteUtils.isNary(hi, OpOpN.CBIND,
OpOpN.RBIND)) )
{
OpOp2 bop = (hi instanceof BinaryOp) ?
((BinaryOp)hi).getOp() :
- OpOp2.valueOf(((NaryOp)hi).getOp().name());
+
OpOp2.valueOf(((NaryOp)hi).getOp().name());
OpOpN nop = (hi instanceof NaryOp) ?
((NaryOp)hi).getOp() :
- OpOpN.valueOf(((BinaryOp)hi).getOp().name());
-
+
OpOpN.valueOf(((BinaryOp)hi).getOp().name());
+
boolean converged = false;
while( !converged ) {
//get first matching cbind or rbind
Hop first = hi.getInput().stream()
- .filter(h ->
HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop))
- .findFirst().orElse(null);
-
+ .filter(h ->
HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop))
+ .findFirst().orElse(null);
+
//replace current op with new nary cbind/rbind
if( first != null &&
first.getParent().size()==1 ) {
//construct new list of inputs (in
original order)
@@ -582,29 +627,29 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
}
}
-
+
return hi;
}
-
+
/**
* Handle simplification of binary operations (relies on previous
common subexpression elimination).
* At the same time this servers as a canonicalization for more complex
rewrites.
- *
+ *
* X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
- private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi,
int pos )
+ private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi,
int pos )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
-
+
//patterns: X+X -> X*2, X*X -> X^2,
if( left == right &&
left.getDataType()==DataType.MATRIX )
{
@@ -614,48 +659,48 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
{
bop.setOp(OpOp2.MULT);
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
-
+
LOG.debug("Applied
simplifyBinaryToUnaryOperation1 (line "+hi.getBeginLine()+").");
}
else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2
{
bop.setOp(OpOp2.POW);
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
-
+
LOG.debug("Applied
simplifyBinaryToUnaryOperation2 (line "+hi.getBeginLine()+").");
}
}
//patterns: (X>0)-(X<0) -> sign(X)
- else if( bop.getOp() == OpOp2.MINUS
- && HopRewriteUtils.isBinary(left,
OpOp2.GREATER)
- && HopRewriteUtils.isBinary(right, OpOp2.LESS)
- && left.getInput().get(0) ==
right.getInput().get(0)
- && left.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0
- && right.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 )
+ else if( bop.getOp() == OpOp2.MINUS
+ && HopRewriteUtils.isBinary(left,
OpOp2.GREATER)
+ && HopRewriteUtils.isBinary(right,
OpOp2.LESS)
+ && left.getInput().get(0) ==
right.getInput().get(0)
+ && left.getInput().get(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0
+ && right.getInput().get(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 )
{
UnaryOp uop =
HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
HopRewriteUtils.replaceChildReference(parent,
hi, uop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left,
right);
hi = uop;
-
+
LOG.debug("Applied
simplifyBinaryToUnaryOperation3 (line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
+
/**
* Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
* U%*%V-eps into the common representation U%*%V+s which simplifies
* subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
- *
+ *
* @param hi high-level operator
* @return high-level operator
*/
- private static Hop canonicalizeMatrixMultScalarAdd( Hop hi )
+ private static Hop canonicalizeMatrixMultScalarAdd( Hop hi )
{
//pattern: binary operation (+ or -) of matrix mult and scalar
if( hi instanceof BinaryOp )
@@ -663,10 +708,10 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
-
+
//pattern: (eps + U%*%V) -> (U%*%V+eps)
if( left.getDataType().isScalar() && right instanceof
AggBinaryOp
- && bop.getOp()==OpOp2.PLUS )
+ && bop.getOp()==OpOp2.PLUS )
{
HopRewriteUtils.removeAllChildReferences(bop);
HopRewriteUtils.addChildReference(bop, right,
0);
@@ -683,11 +728,11 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
LOG.debug("Applied
canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
- private static Hop simplifyCTableWithConstMatrixInputs( Hop hi )
+
+ private static Hop simplifyCTableWithConstMatrixInputs( Hop hi )
{
//pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X,
1, 7)
if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) {
@@ -698,7 +743,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
Hop inNew =
((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN);
HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i);
LOG.debug("Applied
simplifyCTableWithConstMatrixInputs"
- + i + " (line
"+hi.getBeginLine()+").");
+ + i + " (line
"+hi.getBeginLine()+").");
}
}
}
@@ -706,9 +751,9 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
private static Hop removeUnnecessaryCTable( Hop parent, Hop hi, int pos
) {
- if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol)
- && HopRewriteUtils.isTernary(hi.getInput().get(0),
OpOp3.CTABLE)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0))
+ if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol)
+ &&
HopRewriteUtils.isTernary(hi.getInput().get(0), OpOp3.CTABLE)
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0))
{
Hop matrixInput =
hi.getInput().get(0).getInput().get(0);
OpOp1 opcode = matrixInput.getDim2() == 1 ? OpOp1.NROW
: OpOp1.LENGTH;
@@ -724,67 +769,67 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
* NOTE: this would be by definition a dynamic rewrite; however, we
apply it as a static
* rewrite in order to apply it before splitting dags which would hide
the table information
* if dimensions are not specified.
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
- private static Hop simplifyReverseOperation( Hop parent, Hop hi, int
pos )
+ private static Hop simplifyReverseOperation( Hop parent, Hop hi, int
pos )
{
- if( hi instanceof AggBinaryOp
- && hi.getInput().get(0) instanceof TernaryOp )
+ if( hi instanceof AggBinaryOp
+ && hi.getInput().get(0) instanceof TernaryOp )
{
TernaryOp top = (TernaryOp) hi.getInput().get(0);
-
+
if( top.getOp()==OpOp3.CTABLE
- &&
HopRewriteUtils.isBasic1NSequence(top.getInput().get(0))
- &&
HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1))
- &&
top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1())
+ &&
HopRewriteUtils.isBasic1NSequence(top.getInput().get(0))
+ &&
HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1))
+ &&
top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1())
{
ReorgOp rop =
HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
HopRewriteUtils.replaceChildReference(parent,
hi, rop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, top);
hi = rop;
-
+
LOG.debug("Applied simplifyReverseOperation.");
}
}
-
+
return hi;
}
-
+
private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
{
//pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
- && hi.getDataType() == DataType.MATRIX
- && hi.getInput().get(0) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1
- && HopRewriteUtils.isBinary(hi.getInput().get(1),
OpOp2.MULT)
- && hi.getInput().get(1).getParent().size() == 1 )
//single consumer
+ && hi.getDataType() == DataType.MATRIX
+ && hi.getInput().get(0) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1
+ &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
+ && hi.getInput().get(1).getParent().size() == 1
) //single consumer
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
-
+
//set new binaryop type and rewire inputs
bop.setOp(OpOp2.MINUS1_MULT);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.addChildReference(bop, left);
HopRewriteUtils.addChildReference(bop, right);
-
+
LOG.debug("Applied
simplifyMultiBinaryToBinaryOperation.");
}
-
+
return hi;
}
-
+
/**
* (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
* (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
- *
- *
+ *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
@@ -797,21 +842,21 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
-
+
//(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
//(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
boolean applied = false;
- if( left.getDataType()==DataType.MATRIX &&
right.getDataType()==DataType.MATRIX
- && HopRewriteUtils.isValidOp(bop.getOp(),
LOOKUP_VALID_DISTRIBUTIVE_BINARY) )
+ if( left.getDataType()==DataType.MATRIX &&
right.getDataType()==DataType.MATRIX
+ &&
HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) )
{
Hop X = null; Hop Y = null;
if( HopRewriteUtils.isBinary(left, OpOp2.MULT)
) //(Y*X-X) -> (Y-1)*X
{
Hop leftC1 = left.getInput().get(0);
Hop leftC2 = left.getInput().get(1);
-
+
if(
leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX
&&
- (right == leftC1 || right ==
leftC2) && leftC1 !=leftC2 ){ //any mult order
+ (right == leftC1 ||
right == leftC2) && leftC1 !=leftC2 ){ //any mult order
X = right;
Y = ( right == leftC1 ) ?
leftC2 : leftC1;
}
@@ -823,17 +868,17 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.cleanupUnreferenced(hi, left);
hi = mult;
applied = true;
-
+
LOG.debug("Applied
simplifyDistributiveBinaryOperation1 (line "+hi.getBeginLine()+").");
}
}
-
+
if( !applied && HopRewriteUtils.isBinary(right,
OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X
{
Hop rightC1 = right.getInput().get(0);
Hop rightC2 = right.getInput().get(1);
if(
rightC1.getDataType()==DataType.MATRIX &&
rightC2.getDataType()==DataType.MATRIX &&
- (left == rightC1 || left ==
rightC2) && rightC1 !=rightC2 ){ //any mult order
+ (left == rightC1 ||
left == rightC2) && rightC1 !=rightC2 ){ //any mult order
X = left;
Y = ( left == rightC1 ) ?
rightC2 : rightC1;
}
@@ -847,21 +892,21 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
LOG.debug("Applied
simplifyDistributiveBinaryOperation2 (line "+hi.getBeginLine()+").");
}
- }
+ }
}
}
-
+
return hi;
}
-
+
/**
* t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
* t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
- *
+ *
* Note: Restriction ba() at leaf and root instead of data at leaf to
not reorganize too
* eagerly, which would loose additional rewrite potential. This
rewrite has two goals
* (1) enable XtwXv, and increase piggybacking potential by creating
bushy trees.
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
@@ -875,46 +920,46 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
OpOp2 op = bop.getOp();
-
+
if( left.getDataType()==DataType.MATRIX &&
right.getDataType()==DataType.MATRIX &&
- HopRewriteUtils.isValidOp(op,
LOOKUP_VALID_ASSOCIATIVE_BINARY) )
+ HopRewriteUtils.isValidOp(op,
LOOKUP_VALID_ASSOCIATIVE_BINARY) )
{
boolean applied = false;
-
+
if( right instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
-
- if( op==op2 &&
right2.getDataType()==DataType.MATRIX
- && (right2 instanceof
AggBinaryOp) )
+
+ if( op==op2 &&
right2.getDataType()==DataType.MATRIX
+ && (right2 instanceof
AggBinaryOp) )
{
//(X*(Y*op()) -> (X*Y)*op()
BinaryOp bop3 =
HopRewriteUtils.createBinary(left, left2, op);
BinaryOp bop4 =
HopRewriteUtils.createBinary(bop3, right2, op);
-
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
+
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
-
+
applied = true;
-
+
LOG.debug("Applied
simplifyBushyBinaryOperation1");
}
}
-
+
if( !applied && left instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)left;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
-
- if( op==op2 &&
left2.getDataType()==DataType.MATRIX
- && (left2 instanceof
AggBinaryOp)
- && (right2.getDim2() > 1 ||
right.getDim2() == 1) //X not vector, or Y vector
- && (right2.getDim1() > 1 ||
right.getDim1() == 1) ) //X not vector, or Y vector
+
+ if( op==op2 &&
left2.getDataType()==DataType.MATRIX
+ && (left2 instanceof
AggBinaryOp)
+ && (right2.getDim2() >
1 || right.getDim2() == 1) //X not vector, or Y vector
+ && (right2.getDim1() >
1 || right.getDim1() == 1) ) //X not vector, or Y vector
{
//((op()*X)*Y) -> op()*(X*Y)
BinaryOp bop3 =
HopRewriteUtils.createBinary(right2, right, op);
@@ -922,39 +967,39 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
-
+
LOG.debug("Applied
simplifyBushyBinaryOperation2");
}
}
}
-
+
}
-
+
return hi;
}
-
+
private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi,
int pos )
{
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg
- && hi.getInput().get(0) instanceof ReorgOp ) //reorg
operation
+ && hi.getInput().get(0) instanceof ReorgOp )
//reorg operation
{
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
if( (rop.getOp()==ReOrgOp.TRANS ||
rop.getOp()==ReOrgOp.RESHAPE
|| rop.getOp() == ReOrgOp.REV )
//valid reorg
- && rop.getParent().size()==1 )
//uagg only reorg consumer
+ && rop.getParent().size()==1 )
//uagg only reorg consumer
{
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);
-
+
LOG.debug("Applied
simplifyUnaryAggReorgOperation");
}
}
-
+
return hi;
}
-
+
private static Hop removeUnnecessaryAggregates(Hop hi)
{
//sum(rowSums(X)) -> sum(X), sum(colSums(X)) -> sum(X)
@@ -962,44 +1007,44 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X)
//sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X)
if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof
AggUnaryOp
- && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
- && hi.getInput().get(0).getParent().size()==1 )
+ &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && hi.getInput().get(0).getParent().size()==1 )
{
AggUnaryOp au1 = (AggUnaryOp) hi;
AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0);
- if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM
|| au2.getOp()==AggOp.SUM_SQ))
- || (au1.getOp()==AggOp.MIN &&
au2.getOp()==AggOp.MIN)
- || (au1.getOp()==AggOp.MAX &&
au2.getOp()==AggOp.MAX) )
+ if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM
|| au2.getOp()==AggOp.SUM_SQ))
+ || (au1.getOp()==AggOp.MIN &&
au2.getOp()==AggOp.MIN)
+ || (au1.getOp()==AggOp.MAX &&
au2.getOp()==AggOp.MAX) )
{
Hop input = au2.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(au2);
HopRewriteUtils.replaceChildReference(au1, au2,
input);
if( au2.getOp() == AggOp.SUM_SQ )
au1.setOp(AggOp.SUM_SQ);
-
+
LOG.debug("Applied removeUnnecessaryAggregates
(line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
- private static Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop
hi, int pos )
+
+ private static Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop
hi, int pos )
{
// Note: This rewrite is not applicable for all binary
operations because some of them
// are undefined over scalars. We explicitly exclude potential
conflicting matrix-scalar binary
// operations; other operations like cbind/rbind will never
occur as matrix-scalar operations.
-
- if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
- && hi.getInput().get(0) instanceof BinaryOp
- && HopRewriteUtils.isBinary(hi.getInput().get(0),
LOOKUP_VALID_SCALAR_BINARY))
+
+ if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
+ && hi.getInput().get(0) instanceof BinaryOp
+ &&
HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY))
{
BinaryOp bin = (BinaryOp) hi.getInput().get(0);
BinaryOp bout = null;
-
+
//as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
- if(
bin.getInput().get(0).getDataType()==DataType.MATRIX
- &&
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+ if( bin.getInput().get(0).getDataType()==DataType.MATRIX
+ &&
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
UnaryOp cast1 =
HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
UnaryOp cast2 =
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast1,
cast2, bin.getOp());
@@ -1014,86 +1059,86 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
UnaryOp cast =
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout =
HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
}
-
+
if( bout != null ) {
HopRewriteUtils.replaceChildReference(parent,
hi, bout, pos);
-
+
LOG.debug("Applied
simplifyBinaryMatrixScalarOperation.");
}
}
-
+
return hi;
}
-
+
private static Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop
hi, int pos )
{
- if( hi instanceof AggUnaryOp && hi.getParent().size()==1
- && (((AggUnaryOp) hi).getDirection()==Direction.Row ||
((AggUnaryOp) hi).getDirection()==Direction.Col)
- &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1)
- && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(),
LOOKUP_VALID_ROW_COL_AGGREGATE) )
+ if( hi instanceof AggUnaryOp && hi.getParent().size()==1
+ && (((AggUnaryOp)
hi).getDirection()==Direction.Row || ((AggUnaryOp)
hi).getDirection()==Direction.Col)
+ &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1)
+ && HopRewriteUtils.isValidOp(((AggUnaryOp)
hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) )
{
AggUnaryOp uagg = (AggUnaryOp) hi;
-
+
//get input rewire existing operators (remove inner
transpose)
Hop input = uagg.getInput().get(0).getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeChildReferenceByPos(parent, hi,
pos);
-
+
//pattern 1: row-aggregate to col aggregate, e.g.,
rowSums(t(X))->t(colSums(X))
if( uagg.getDirection()==Direction.Row ) {
- uagg.setDirection(Direction.Col);
- LOG.debug("Applied
pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+").");
+ uagg.setDirection(Direction.Col);
+ LOG.debug("Applied
pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+").");
}
//pattern 2: col-aggregate to row aggregate, e.g.,
colSums(t(X))->t(rowSums(X))
else if( uagg.getDirection()==Direction.Col ) {
- uagg.setDirection(Direction.Row);
+ uagg.setDirection(Direction.Row);
LOG.debug("Applied
pushdownUnaryAggTransposeOperation2 (line "+hi.getBeginLine()+").");
}
-
+
//create outer transpose operation and rewire operators
HopRewriteUtils.addChildReference(uagg, input);
uagg.refreshSizeInformation();
Hop trans = HopRewriteUtils.createTranspose(uagg);
//incl refresh size
HopRewriteUtils.addChildReference(parent, trans, pos);
//by def, same size
-
- hi = trans;
+
+ hi = trans;
}
-
+
return hi;
}
-
+
private static Hop pushdownCSETransposeScalarOperation( Hop parent, Hop
hi, int pos )
{
// a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
// probed at root node of b in above example
// (with support for left or right scalar operations)
- if( HopRewriteUtils.isTransposeOperation(hi, 1)
- &&
HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size()==1)
+ if( HopRewriteUtils.isTransposeOperation(hi, 1)
+ &&
HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0))
+ && hi.getInput().get(0).getParent().size()==1)
{
int Xpos =
hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1;
Hop X = hi.getInput().get(0).getInput().get(Xpos);
BinaryOp binary = (BinaryOp) hi.getInput().get(0);
-
- if(
HopRewriteUtils.containsTransposeOperation(X.getParent())
- && !HopRewriteUtils.isValidOp(binary.getOp(),
new OpOp2[]{OpOp2.MOMENT, OpOp2.QUANTILE}))
+
+ if(
HopRewriteUtils.containsTransposeOperation(X.getParent())
+ &&
!HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.MOMENT,
OpOp2.QUANTILE}))
{
//clear existing wiring
-
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.removeChildReference(hi,
binary);
HopRewriteUtils.removeChildReference(binary, X);
-
+
//re-wire operators
HopRewriteUtils.addChildReference(parent,
binary, pos);
HopRewriteUtils.addChildReference(binary, hi,
Xpos);
HopRewriteUtils.addChildReference(hi, X);
//note: common subexpression later eliminated
by dedicated rewrite
-
+
hi = binary;
LOG.debug("Applied
pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+").");
- }
+ }
}
-
+
return hi;
}
@@ -1103,20 +1148,20 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
&& ((AggUnaryOp)hi).getOp()==AggOp.SUM // only
one parent which is the sum
&&
HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1)
&&
((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR &&
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX)
-
||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX &&
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR)))
+
||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX &&
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR)))
{
- Hop operand1 = hi.getInput().get(0).getInput().get(0);
+ Hop operand1 = hi.getInput().get(0).getInput().get(0);
Hop operand2 = hi.getInput().get(0).getInput().get(1);
//check which operand is the Scalar and which is the
matrix
- Hop lamda = (operand1.getDataType()==DataType.SCALAR) ?
operand1 : operand2;
- Hop matrix = (operand1.getDataType()==DataType.MATRIX)
? operand1 : operand2;
+ Hop lamda = (operand1.getDataType()==DataType.SCALAR) ?
operand1 : operand2;
+ Hop matrix = (operand1.getDataType()==DataType.MATRIX)
? operand1 : operand2;
AggUnaryOp
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
Hop bop = HopRewriteUtils.createBinary(lamda, aggOp,
OpOp2.MULT);
-
+
HopRewriteUtils.replaceChildReference(parent, hi, bop,
pos);
-
+
LOG.debug("Applied pushdownSumBinaryMult.");
return bop;
}
@@ -1125,89 +1170,89 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
private static Hop pullupAbs(Hop parent, Hop hi, int pos ) {
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
- && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS)
- && hi.getInput(0).getParent().size()==1
- && HopRewriteUtils.isUnary(hi.getInput(1), OpOp1.ABS)
- && hi.getInput(1).getParent().size()==1)
+ && HopRewriteUtils.isUnary(hi.getInput(0),
OpOp1.ABS)
+ && hi.getInput(0).getParent().size()==1
+ && HopRewriteUtils.isUnary(hi.getInput(1),
OpOp1.ABS)
+ && hi.getInput(1).getParent().size()==1)
{
Hop operand1 = hi.getInput(0).getInput(0);
Hop operand2 = hi.getInput(1).getInput(0);
Hop bop = HopRewriteUtils.createBinary(operand1,
operand2, OpOp2.MULT);
Hop uop = HopRewriteUtils.createUnary(bop, OpOp1.ABS);
HopRewriteUtils.replaceChildReference(parent, hi, uop,
pos);
-
+
LOG.debug("Applied pullupAbs (line
"+hi.getBeginLine()+").");
return uop;
}
return hi;
}
-
+
private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int
pos )
{
if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX
//unaryop
- && hi.getInput().get(0) instanceof BinaryOp
//binaryop - ppred
- && ((BinaryOp)hi.getInput().get(0)).isPPredOperation() )
+ && hi.getInput().get(0) instanceof BinaryOp
//binaryop - ppred
+ &&
((BinaryOp)hi.getInput().get(0)).isPPredOperation() )
{
UnaryOp uop = (UnaryOp) hi; //valid unary op
if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN
- || uop.getOp()==OpOp1.CEIL ||
uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
+ || uop.getOp()==OpOp1.CEIL ||
uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
{
//clear link unary-binary
Hop input = uop.getInput().get(0);
HopRewriteUtils.replaceChildReference(parent,
hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
-
- LOG.debug("Applied
simplifyUnaryPPredOperation.");
+
+ LOG.debug("Applied
simplifyUnaryPPredOperation.");
}
}
-
+
return hi;
}
-
+
private static Hop simplifyTransposedAppend( Hop parent, Hop hi, int
pos )
{
//e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B)))
--> cbind(A,B)
if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted
- && hi.getInput().get(0) instanceof BinaryOp
- && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND
//append (cbind/rbind)
- || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND)
- && hi.getInput().get(0).getParent().size() == 1 ) //single
consumer of append
+ && hi.getInput().get(0) instanceof BinaryOp
+ &&
(((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind)
+ ||
((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND)
+ && hi.getInput().get(0).getParent().size() == 1
) //single consumer of append
{
BinaryOp bop = (BinaryOp)hi.getInput().get(0);
//both inputs transpose ops, where transpose is single
consumer
- if(
HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1)
- &&
HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) )
+ if(
HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1)
+ &&
HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) )
{
Hop left =
bop.getInput().get(0).getInput().get(0);
Hop right =
bop.getInput().get(1).getInput().get(0);
-
+
//create new subdag (no in-place dag update to
prevent anomalies with
//multiple consumers during rewrite process)
OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ?
OpOp2.RBIND : OpOp2.CBIND;
BinaryOp bopnew =
HopRewriteUtils.createBinary(left, right, binop);
HopRewriteUtils.replaceChildReference(parent,
hi, bopnew, pos);
-
+
hi = bopnew;
LOG.debug("Applied simplifyTransposedAppend
(line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
+
/**
* handle simplification of more complex sub DAG to unary operation.
- *
+ *
* X*(1-X) -> sprop(X)
* (1-X)*X -> sprop(X)
* 1/(1+exp(-X)) -> sigmoid(X)
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
*/
- private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop
hi, int pos )
+ private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop
hi, int pos )
{
if( hi instanceof BinaryOp )
{
@@ -1215,7 +1260,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applied = false;
-
+
//sample proportion (sprop) operator
if( bop.getOp() == OpOp2.MULT &&
left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
@@ -1223,97 +1268,97 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//note: if there are multiple consumers on the
intermediate,
//we follow the heuristic that redundant
computation is more beneficial,
//i.e., we still fuse but leave the
intermediate for the other consumers
-
+
if( left instanceof BinaryOp ) //(1-X)*X
{
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput().get(0);
- Hop left2 = bleft.getInput().get(1);
-
+ Hop left2 = bleft.getInput().get(1);
+
if( left1 instanceof LiteralOp &&
-
HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
- left2 == right && bleft.getOp()
== OpOp2.MINUS )
+
HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
+ left2 == right &&
bleft.getOp() == OpOp2.MINUS )
{
UnaryOp unary =
HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
-
+
LOG.debug("Applied
fuseBinarySubDAGToUnaryOperation-sprop1");
}
- }
+ }
if( !applied && right instanceof BinaryOp )
//X*(1-X)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput().get(0);
- Hop right2 = bright.getInput().get(1);
-
+ Hop right2 = bright.getInput().get(1);
+
if( right1 instanceof LiteralOp &&
-
HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
- right2 == left &&
bright.getOp() == OpOp2.MINUS )
+
HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
+ right2 == left &&
bright.getOp() == OpOp2.MINUS )
{
UnaryOp unary =
HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
-
+
LOG.debug("Applied
fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
-
+
//sigmoid operator
if( !applied && bop.getOp() == OpOp2.DIV &&
left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
- && left instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
+ && left instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
{
//note: if there are multiple consumers on the
intermediate,
//we follow the heuristic that redundant
computation is more beneficial,
//i.e., we still fuse but leave the
intermediate for the other consumers
-
+
BinaryOp bop2 = (BinaryOp)right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
-
+
if( bop2.getOp() == OpOp2.PLUS &&
left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX
- && left2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof
UnaryOp)
+ && left2 instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof
UnaryOp)
{
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput().get(0);
-
- if( uop.getOp()==OpOp1.EXP )
+
+ if( uop.getOp()==OpOp1.EXP )
{
UnaryOp unary = null;
-
+
//Pattern 1: (1/(1 + exp(-X))
if(
HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) {
BinaryOp bop3 =
(BinaryOp) uopin;
Hop left3 =
bop3.getInput().get(0);
Hop right3 =
bop3.getInput().get(1);
-
+
if( left3 instanceof
LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 )
unary =
HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
- }
+ }
//Pattern 2: (1/(1 + exp(X)),
e.g., where -(-X) has been removed by
//the 'remove unnecessary
minus' rewrite --> reintroduce the minus
else {
BinaryOp minus =
HopRewriteUtils.createBinaryMinus(uopin);
unary =
HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
- }
-
+ }
+
if( unary != null ) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
-
+
LOG.debug("Applied
fuseBinarySubDAGToUnaryOperation-sigmoid1");
- }
+ }
}
- }
+ }
}
-
+
//select positive (selp) operator (note: same initial
pattern as sprop)
if( !applied && bop.getOp() == OpOp2.MULT &&
left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
@@ -1325,17 +1370,17 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
-
+
if( left2 instanceof LiteralOp &&
-
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
- left1 == right &&
(bleft.getOp() == OpOp2.GREATER ) )
+
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
+ left1 == right &&
(bleft.getOp() == OpOp2.GREATER ) )
{
BinaryOp binary =
HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
-
+
LOG.debug("Applied
fuseBinarySubDAGToUnaryOperation-max0a");
}
}
@@ -1344,23 +1389,23 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
-
+
if( right2 instanceof LiteralOp &&
-
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
- right1 == left &&
bright.getOp() == OpOp2.GREATER )
+
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
+ right1 == left &&
bright.getOp() == OpOp2.GREATER )
{
BinaryOp binary =
HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied= true;
-
+
LOG.debug("Applied
fuseBinarySubDAGToUnaryOperation-max0b");
}
}
}
}
-
+
return hi;
}
@@ -1373,68 +1418,68 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
-
+
//create new operators (incl refresh size
inside for transpose)
ReorgOp trans =
HopRewriteUtils.createTranspose(right);
BinaryOp mult =
HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp sum =
HopRewriteUtils.createSum(mult);
-
+
//rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent,
hi, sum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = sum;
-
+
LOG.debug("Applied simplifyTraceMatrixMult");
- }
+ }
}
-
+
return hi;
}
-
- private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int
pos)
+
+ private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
{
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
- if( hi instanceof IndexingOp
- && ((IndexingOp)hi).isRowLowerEqualsUpper()
- && ((IndexingOp)hi).isColLowerEqualsUpper()
- && hi.getInput().get(0).getParent().size()==1 //rix is
single mm consumer
- &&
HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) )
+ if( hi instanceof IndexingOp
+ && ((IndexingOp)hi).isRowLowerEqualsUpper()
+ && ((IndexingOp)hi).isColLowerEqualsUpper()
+ && hi.getInput().get(0).getParent().size()==1
//rix is single mm consumer
+ &&
HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) )
{
Hop mm = hi.getInput().get(0);
Hop X = mm.getInput().get(0);
Hop Y = mm.getInput().get(1);
Hop rowExpr = hi.getInput().get(1); //rl==ru
Hop colExpr = hi.getInput().get(3); //cl==cu
-
+
HopRewriteUtils.removeAllChildReferences(mm);
-
+
//create new indexing operations
- IndexingOp ix1 = new IndexingOp("tmp1",
DataType.MATRIX, ValueType.FP64, X,
+ IndexingOp ix1 = new IndexingOp("tmp1",
DataType.MATRIX, ValueType.FP64, X,
rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(X, false), true, false);
ix1.setBlocksize(X.getBlocksize());
ix1.refreshSizeInformation();
- IndexingOp ix2 = new IndexingOp("tmp2",
DataType.MATRIX, ValueType.FP64, Y,
+ IndexingOp ix2 = new IndexingOp("tmp2",
DataType.MATRIX, ValueType.FP64, Y,
new LiteralOp(1),
HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
ix2.setBlocksize(Y.getBlocksize());
ix2.refreshSizeInformation();
-
+
//rewire matrix mult over ix1 and ix2
HopRewriteUtils.addChildReference(mm, ix1, 0);
HopRewriteUtils.addChildReference(mm, ix2, 1);
mm.refreshSizeInformation();
-
+
hi = mm;
-
+
LOG.debug("Applied simplifySlicedMatrixMult");
}
-
+
return hi;
}
-
+
private static Hop simplifyListIndexing(Hop hi) {
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
if( hi instanceof IndexingOp && hi.getDataType().isList()
- && !(hi.getInput(4) instanceof LiteralOp) )
+ && !(hi.getInput(4) instanceof LiteralOp) )
{
HopRewriteUtils.replaceChildReference(hi,
hi.getInput(4), new LiteralOp(1));
LOG.debug("Applied simplifyListIndexing (line
"+hi.getBeginLine()+").");
@@ -1442,17 +1487,17 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
- private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
+ private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
{
//order(matrix(7), indexreturn=FALSE) -> matrix(7)
//order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1)
if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
Hop hi2 = hi.getInput().get(0);
-
+
if( hi2 instanceof DataGenOp &&
((DataGenOp)hi2).getOp()==OpOpDG.RAND
- && ((DataGenOp)hi2).hasConstantValue()
- && hi.getInput().get(3) instanceof LiteralOp )
//known indexreturn
+ && ((DataGenOp)hi2).hasConstantValue()
+ && hi.getInput().get(3) instanceof
LiteralOp ) //known indexreturn
{
if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) )
{
@@ -1462,7 +1507,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
-
+
LOG.debug("Applied
simplifyConstantSort1.");
}
else
@@ -1471,30 +1516,30 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
-
+
LOG.debug("Applied
simplifyConstantSort2.");
}
- }
+ }
}
-
+
return hi;
}
-
- private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos)
+
+ private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos)
{
//order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7)
//order(seq(2,N+1,1), indexreturn=TRUE) ->
seq(1,N,1)/seq(N,1,-1)
if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
Hop hi2 = hi.getInput().get(0);
-
+
if( hi2 instanceof DataGenOp &&
((DataGenOp)hi2).getOp()==OpOpDG.SEQ )
{
Hop incr =
hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR));
//check for known ascending ordering and known
indexreturn
if( incr instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1
- && hi.getInput().get(2) instanceof
LiteralOp //decreasing
- && hi.getInput().get(3) instanceof
LiteralOp ) //indexreturn
+ && hi.getInput().get(2)
instanceof LiteralOp //decreasing
+ && hi.getInput().get(3)
instanceof LiteralOp ) //indexreturn
{
if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET,
ASC/DESC
{
@@ -1505,7 +1550,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
-
+
LOG.debug("Applied
simplifyOrderedSort1.");
}
else if(
!HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC
@@ -1514,44 +1559,44 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
-
+
LOG.debug("Applied
simplifyOrderedSort2.");
}
}
}
}
-
+
return hi;
}
- private static Hop fuseOrderOperationChain(Hop hi)
+ private static Hop fuseOrderOperationChain(Hop hi)
{
//order(order(X,2),1) -> order(X, (12)),
if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT)
- && hi.getInput().get(1) instanceof LiteralOp //scalar by
- && hi.getInput().get(2) instanceof LiteralOp //scalar
desc
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret
- {
+ && hi.getInput().get(1) instanceof LiteralOp
//scalar by
+ && hi.getInput().get(2) instanceof LiteralOp
//scalar desc
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret
+ {
LiteralOp by = (LiteralOp) hi.getInput().get(1);
boolean desc =
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));
-
+
//find chain of order operations with same desc/ixret
configuration and single consumers
Set<String> probe = new HashSet<>();
ArrayList<LiteralOp> byList = new ArrayList<>();
byList.add(by); probe.add(by.getStringValue());
Hop input = hi.getInput().get(0);
while( HopRewriteUtils.isReorg(input, ReOrgOp.SORT)
- && input.getInput().get(1) instanceof LiteralOp
//scalar by
- &&
!probe.contains(input.getInput().get(1).getName())
- &&
HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false)
- && input.getParent().size() == 1 )
+ && input.getInput().get(1) instanceof
LiteralOp //scalar by
+ &&
!probe.contains(input.getInput().get(1).getName())
+ &&
HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc)
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false)
+ && input.getParent().size() == 1 )
{
byList.add((LiteralOp)input.getInput().get(1));
probe.add(input.getInput().get(1).getName());
input = input.getInput().get(0);
}
-
+
//merge order chain if at least two instances
if( byList.size() >= 2 ) {
//create new order operations
@@ -1561,7 +1606,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
inputs.add(new LiteralOp(desc));
inputs.add(new LiteralOp(false));
Hop hnew = HopRewriteUtils.createReorg(inputs,
ReOrgOp.SORT);
-
+
//cleanup references recursively
Hop current = hi;
while(current != input ) {
@@ -1569,86 +1614,86 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
HopRewriteUtils.removeAllChildReferences(current);
current = tmp;
}
-
+
//rewire all parents (avoid anomalies with
replicated datagen)
List<Hop> parents = new
ArrayList<>(hi.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, hi, hnew);
-
+
hi = hnew;
LOG.debug("Applied fuseOrderOperationChain
(line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
+
/**
* Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
- private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop
hi, int pos)
+ private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop
hi, int pos)
{
if( HopRewriteUtils.isTransposeOperation(hi)
- && hi.getInput().get(0) instanceof BinaryOp
//basic binary
- &&
((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
+ && hi.getInput().get(0) instanceof BinaryOp
//basic binary
+ &&
((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
{
Hop left = hi.getInput().get(0).getInput().get(0);
Hop C = hi.getInput().get(0).getInput().get(1);
-
+
//check matrix mult and both inputs transposes w/
single consumer
if( left instanceof AggBinaryOp &&
C.getDataType().isMatrix()
- &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
- && left.getInput().get(0).getParent().size()==1
- &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
- && left.getInput().get(1).getParent().size()==1
)
+ &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
+ &&
left.getInput().get(0).getParent().size()==1
+ &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
+ &&
left.getInput().get(1).getParent().size()==1 )
{
Hop A =
left.getInput().get(0).getInput().get(0);
Hop B =
left.getInput().get(1).getInput().get(0);
-
+
AggBinaryOp abop =
HopRewriteUtils.createMatrixMultiply(B, A);
ReorgOp rop =
HopRewriteUtils.createTranspose(C);
BinaryOp bop =
HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
-
+
HopRewriteUtils.replaceChildReference(parent,
hi, bop, pos);
-
+
hi = bop;
LOG.debug("Applied
simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
}
}
-
+
return hi;
}
-
+
// Patterns: X + (X==0) * s -> replace(X, 0, s)
- private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int
pos)
+ private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int
pos)
{
if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) &&
hi.getInput().get(0).isMatrix()
- && HopRewriteUtils.isBinary(hi.getInput().get(1),
OpOp2.MULT)
- && hi.getInput().get(1).getInput().get(1).isScalar()
- &&
HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0),
OpOp2.EQUAL, 0)
- &&
hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0))
)
+ &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
+ &&
hi.getInput().get(1).getInput().get(1).isScalar()
+ &&
HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0),
OpOp2.EQUAL, 0)
+ &&
hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0))
)
{
LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
args.put("target", hi.getInput().get(0));
args.put("pattern", new LiteralOp(0));
args.put("replacement",
hi.getInput().get(1).getInput().get(1));
Hop replace =
HopRewriteUtils.createParameterizedBuiltinOp(
- hi.getInput().get(0), args,
ParamBuiltinOp.REPLACE);
+ hi.getInput().get(0), args,
ParamBuiltinOp.REPLACE);
HopRewriteUtils.replaceChildReference(parent, hi,
replace, pos);
hi = replace;
LOG.debug("Applied simplifyReplaceZeroOperation (line
"+hi.getBeginLine()+").");
}
return hi;
}
-
+
/**
* Pattners: t(t(X)) -> X, rev(rev(X)) -> X
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
@@ -1657,7 +1702,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi,
int pos)
{
ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANS, ReOrgOp.REV};
-
+
if( hi instanceof ReorgOp &&
HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg
{
ReOrgOp firstOp = ((ReorgOp)hi).getOp();
@@ -1669,15 +1714,15 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
HopRewriteUtils.replaceChildReference(parent,
hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
-
+
LOG.debug("Applied
removeUnecessaryReorgOperation.");
}
}
-
+
return hi;
}
-
- /*
+
+ /*
* Eliminate RemoveEmpty for SUM, SUM_SQ, and NNZ (number of non-zeros)
*/
private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int
pos)
@@ -1688,14 +1733,14 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//rowSums(removeEmpty(target=X,margin="cols")) -> rowSums(X)
//colSums(removeEmpty(target=X,margin="rows")) -> colSums(X)
if( (HopRewriteUtils.isSum(hi) || HopRewriteUtils.isSumSq(hi))
- && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size() == 1 )
+ &&
HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0))
+ && hi.getInput().get(0).getParent().size() == 1
)
{
AggUnaryOp agg = (AggUnaryOp)hi;
ParameterizedBuiltinOp rmEmpty =
(ParameterizedBuiltinOp) hi.getInput().get(0);
boolean needRmEmpty = (agg.getDirection() ==
Direction.Row && HopRewriteUtils.isRemoveEmpty(rmEmpty, true))
- || (agg.getDirection() == Direction.Col &&
HopRewriteUtils.isRemoveEmpty(rmEmpty, false));
-
+ || (agg.getDirection() == Direction.Col
&& HopRewriteUtils.isRemoveEmpty(rmEmpty, false));
+
if (rmEmpty.getParameterHop("select") == null &&
!needRmEmpty) {
Hop input = rmEmpty.getTargetHop();
if( input != null ) {
@@ -1704,11 +1749,11 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
}
}
}
-
+
//check if nrow is called on the output of removeEmpty
if( HopRewriteUtils.isUnary(hi, OpOp1.NROW)
- && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0),
true)
- && hi.getInput().get(0).getParent().size() == 1 )
+ &&
HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0), true)
+ && hi.getInput().get(0).getParent().size() == 1
)
{
ParameterizedBuiltinOp rm = (ParameterizedBuiltinOp)
hi.getInput().get(0);
//obtain optional select vector or input if col vector
@@ -1718,9 +1763,9 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
//NOTE: part of static rewrites despite size dependence
for phase
//ordering before rewrite for DAG splits after
table/removeEmpty
Hop input = (rm.getParameterHop("select") != null) ?
- rm.getParameterHop("select") :
- (rm.getDim2() == 1) ? rm.getTargetHop() : null;
-
+ rm.getParameterHop("select") :
+ (rm.getDim2() == 1) ? rm.getTargetHop()
: null;
+
//create new expression w/o rmEmpty if applicable
if( input != null ) {
HopRewriteUtils.removeAllChildReferences(rm);
@@ -1734,32 +1779,32 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
}
}
}
-
+
return hi;
}
- private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos)
+ private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos)
{
- if( hi.getDataType() == DataType.MATRIX && hi instanceof
BinaryOp
- && ((BinaryOp)hi).getOp()==OpOp2.MINUS
//first minus
- && hi.getInput().get(0) instanceof LiteralOp &&
((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 )
+ if( hi.getDataType() == DataType.MATRIX && hi instanceof
BinaryOp
+ && ((BinaryOp)hi).getOp()==OpOp2.MINUS
//first minus
+ && hi.getInput().get(0) instanceof LiteralOp &&
((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 )
{
Hop hi2 = hi.getInput().get(1);
- if( hi2.getDataType() == DataType.MATRIX && hi2
instanceof BinaryOp
- && ((BinaryOp)hi2).getOp()==OpOp2.MINUS
//second minus
- && hi2.getInput().get(0) instanceof LiteralOp
&& ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 )
-
+ if( hi2.getDataType() == DataType.MATRIX && hi2
instanceof BinaryOp
+ && ((BinaryOp)hi2).getOp()==OpOp2.MINUS
//second minus
+ && hi2.getInput().get(0) instanceof
LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 )
+
{
Hop hi3 = hi2.getInput().get(1);
//remove unnecessary chain of -(-())
HopRewriteUtils.replaceChildReference(parent,
hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
-
+
LOG.debug("Applied removeUnecessaryMinus");
}
}
-
+
return hi;
}
@@ -1768,148 +1813,148 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
if( hi instanceof ParameterizedBuiltinOp &&
((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate
{
ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi;
-
+
if( phi.isCountFunction() //aggregate(fn="count")
- && phi.getTargetHop().getDim2()==1 ) //only for
vector
+ && phi.getTargetHop().getDim2()==1 )
//only for vector
{
HashMap<String, Integer> params =
phi.getParamIndexMap();
int ix1 = params.get(Statement.GAGG_TARGET);
int ix2 = params.get(Statement.GAGG_GROUPS);
-
+
//check for unnecessary memory consumption for
"count"
- if( ix1 != ix2 &&
phi.getInput().get(ix1)!=phi.getInput().get(ix2) )
+ if( ix1 != ix2 &&
phi.getInput().get(ix1)!=phi.getInput().get(ix2) )
{
Hop th = phi.getInput().get(ix1);
Hop gh = phi.getInput().get(ix2);
-
+
HopRewriteUtils.replaceChildReference(hi, th, gh, ix1);
-
- LOG.debug("Applied
simplifyGroupedAggregateCount");
+
+ LOG.debug("Applied
simplifyGroupedAggregateCount");
}
}
}
-
+
return hi;
}
-
- private static Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int
pos)
+
+ private static Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int
pos)
{
//pattern X - (s * ppred(X,0,!=)) -> X -nz s
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate for X - tmp if X is sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1).getDataType()==DataType.MATRIX
- && HopRewriteUtils.isBinary(hi.getInput().get(1),
OpOp2.MULT) )
+ &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
+ &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
+ &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) )
{
Hop X = hi.getInput().get(0);
Hop s = hi.getInput().get(1).getInput().get(0);
Hop pred = hi.getInput().get(1).getInput().get(1);
-
+
if( s.getDataType()==DataType.SCALAR &&
pred.getDataType()==DataType.MATRIX
- && HopRewriteUtils.isBinary(pred,
OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend on
common subexpression elimination
- && pred.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && HopRewriteUtils.isBinary(pred,
OpOp2.NOTEQUAL)
+ && pred.getInput().get(0) == X //depend
on common subexpression elimination
+ && pred.getInput().get(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
- Hop hnew = HopRewriteUtils.createBinary(X, s,
OpOp2.MINUS_NZ);
-
+ Hop hnew = HopRewriteUtils.createBinary(X, s,
OpOp2.MINUS_NZ);
+
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos);
hi = hnew;
-
- LOG.debug("Applied fuseMinusNzBinaryOperation
(line "+hi.getBeginLine()+")");
- }
+
+ LOG.debug("Applied fuseMinusNzBinaryOperation
(line "+hi.getBeginLine()+")");
+ }
}
-
+
return hi;
}
-
- private static Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos)
+
+ private static Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos)
{
//pattern ppred(X,0,"!=")*log(X) -> log_nz(X)
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate and to prevent dense intermediates if X is
ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1).getDataType()==DataType.MATRIX
- && HopRewriteUtils.isUnary(hi.getInput().get(1),
OpOp1.LOG) )
+ &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
+ &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
+ &&
HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) )
{
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
-
+
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend on
common subexpression elimination
- && pred.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && pred.getInput().get(0) == X //depend
on common subexpression elimination
+ && pred.getInput().get(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
Hop hnew = HopRewriteUtils.createUnary(X,
OpOp1.LOG_NZ);
-
+
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos);
hi = hnew;
-
- LOG.debug("Applied fuseLogNzUnaryOperation
(line "+hi.getBeginLine()+").");
- }
+
+ LOG.debug("Applied fuseLogNzUnaryOperation
(line "+hi.getBeginLine()+").");
+ }
}
-
+
return hi;
}
- private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int
pos)
+ private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos)
{
//pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate and to prevent dense intermediates if X is
ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1).getDataType()==DataType.MATRIX
- && HopRewriteUtils.isBinary(hi.getInput().get(1),
OpOp2.LOG) )
+ &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
+ &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
+ &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) )
{
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
Hop log = hi.getInput().get(1).getInput().get(1);
-
+
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend on
common subexpression elimination
- && pred.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && pred.getInput().get(0) == X //depend
on common subexpression elimination
+ && pred.getInput().get(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
Hop hnew = HopRewriteUtils.createBinary(X, log,
OpOp2.LOG_NZ);
-
+
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos);
hi = hnew;
-
- LOG.debug("Applied fuseLogNzBinaryOperation
(line "+hi.getBeginLine()+")");
- }
+
+ LOG.debug("Applied fuseLogNzBinaryOperation
(line "+hi.getBeginLine()+")");
+ }
}
-
+
return hi;
}
- private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos)
+ private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos)
{
//pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m,
dir=row, ignore=true, cast=false)
//note: this rewrite supports both left/right sequence
-
+
if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) &&
((BinaryOp)hi).isOuter() )
{
if( (
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a:
outer(v, t(seq(1,m)), "==")
- &&
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0)))
- ||
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b:
outer(seq(1,m), t(v) "==")
+ &&
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0)))
+ ||
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b:
outer(seq(1,m), t(v) "==")
{
//determine variable parameters for pattern a/b
boolean isPatternB =
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
boolean isTransposeRight =
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
- Hop trgt = isPatternB ? (isTransposeRight ?
+ Hop trgt = isPatternB ? (isTransposeRight ?
hi.getInput().get(1).getInput().get(0) : //get v from t(v)
HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v')
hi.getInput().get(0);
//get v directly
Hop seq = isPatternB ?
hi.getInput().get(0) :
hi.getInput().get(1).getInput().get(0);
String direction =
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols";
-
+
//setup input parameter hops
LinkedHashMap<String,Hop> inputargs = new
LinkedHashMap<>();
inputargs.put("target", trgt);
@@ -1917,34 +1962,34 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
inputargs.put("dir", new LiteralOp(direction));
inputargs.put("ignore", new LiteralOp(true));
inputargs.put("cast", new LiteralOp(false));
-
+
//create new hop
ParameterizedBuiltinOp pbop = HopRewriteUtils
- .createParameterizedBuiltinOp(trgt,
inputargs, ParamBuiltinOp.REXPAND);
-
+
.createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND);
+
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent,
hi, pbop, pos);
hi = pbop;
-
+
LOG.debug("Applied simplifyOuterSeqExpand (line
"+hi.getBeginLine()+")");
}
}
-
+
return hi;
}
-
+
private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi,
int pos) {
- if( HopRewriteUtils.isBinaryPPred(hi)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d)
- && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) )
+ if( HopRewriteUtils.isBinaryPPred(hi)
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d)
+ &&
HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) )
{
BinaryOp bop = (BinaryOp) hi;
BinaryOp bop2 = (BinaryOp) hi.getInput().get(0);
boolean one =
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 1);
-
+
//pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=")
if( (one && bop.getOp() == OpOp2.EQUAL)
- || (!one && bop.getOp() == OpOp2.NOTEQUAL) )
+ || (!one && bop.getOp() ==
OpOp2.NOTEQUAL) )
{
HopRewriteUtils.replaceChildReference(parent,
bop, bop2, pos);
HopRewriteUtils.cleanupUnreferenced(bop);
@@ -1955,62 +2000,62 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
else if( !one && bop.getOp() == OpOp2.EQUAL ) {
OpOp2 optr = bop2.getComplementPPredOperation();
BinaryOp tmp =
HopRewriteUtils.createBinary(bop2.getInput().get(0),
- bop2.getInput().get(1), optr,
bop2.isOuter());
+ bop2.getInput().get(1), optr,
bop2.isOuter());
HopRewriteUtils.replaceChildReference(parent,
bop, tmp, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = tmp;
LOG.debug("Applied
simplifyBinaryComparisonChain0 (line "+hi.getBeginLine()+")");
}
}
-
+
return hi;
}
-
+
private static Hop simplifyCumsumColOrFullAggregates(Hop hi) {
//pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col)
- || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol))
- && HopRewriteUtils.isUnary(hi.getInput().get(0),
OpOp1.CUMSUM)
- && hi.getInput().get(0).getParent().size()==1)
+ || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol))
+ &&
HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+ && hi.getInput().get(0).getParent().size()==1)
{
Hop cumsumX = hi.getInput().get(0);
Hop X = cumsumX.getInput().get(0);
Hop mult = HopRewriteUtils.createBinary(X,
- HopRewriteUtils.createSeqDataGenOp(X, false),
OpOp2.MULT);
+ HopRewriteUtils.createSeqDataGenOp(X,
false), OpOp2.MULT);
HopRewriteUtils.replaceChildReference(hi, cumsumX,
mult);
HopRewriteUtils.removeAllChildReferences(cumsumX);
LOG.debug("Applied simplifyCumsumColOrFullAggregates
(line "+hi.getBeginLine()+")");
}
return hi;
}
-
+
private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) {
//pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
- && HopRewriteUtils.isUnary(hi.getInput().get(0),
OpOp1.CUMSUM)
- && hi.getInput().get(0).getParent().size()==1
- &&
HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
- &&
hi.getInput().get(0).getInput().get(0).getParent().size()==1)
+ &&
HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
+ && hi.getInput().get(0).getParent().size()==1
+ &&
HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
+ &&
hi.getInput().get(0).getInput().get(0).getParent().size()==1)
{
Hop cumsumX = hi.getInput().get(0);
Hop revX = cumsumX.getInput().get(0);
Hop X = revX.getInput().get(0);
Hop plus = HopRewriteUtils.createBinary(X,
HopRewriteUtils
- .createAggUnaryOp(X, AggOp.SUM, Direction.Col),
OpOp2.PLUS);
+ .createAggUnaryOp(X, AggOp.SUM,
Direction.Col), OpOp2.PLUS);
Hop minus = HopRewriteUtils.createBinary(plus,
- HopRewriteUtils.createUnary(X, OpOp1.CUMSUM),
OpOp2.MINUS);
+ HopRewriteUtils.createUnary(X,
OpOp1.CUMSUM), OpOp2.MINUS);
HopRewriteUtils.replaceChildReference(parent, hi,
minus, pos);
HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX);
-
+
hi = minus;
LOG.debug("Applied simplifyCumsumReverse (line
"+hi.getBeginLine()+")");
}
return hi;
}
-
+
private static Hop simplifyNotOverComparisons(Hop parent, Hop hi, int
pos){
if(HopRewriteUtils.isUnary(hi, OpOp1.NOT) && hi.getInput(0)
instanceof BinaryOp
- && hi.getInput(0).getParent().size() == 1) //NOT is
only consumer
+ && hi.getInput(0).getParent().size() == 1)
//NOT is only consumer
{
Hop binaryOperator = hi.getInput(0);
Hop A = binaryOperator.getInput(0);
@@ -2041,66 +2086,66 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
-
+
private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) {
if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) &&
!hi.getDataType().isScalar()) {
LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
args.put("target", hi);
Hop newHop =
HopRewriteUtils.createParameterizedBuiltinOp(
- hi, args, ParamBuiltinOp.TOSTRING);
+ hi, args, ParamBuiltinOp.TOSTRING);
HopRewriteUtils.replaceChildReference(parent, hi,
newHop, pos);
hi = newHop;
LOG.debug("Applied fixNonScalarPrint (line " +
hi.getBeginLine() + ")");
}
-
+
return hi;
}
-
+
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
- *
+ *
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
@SuppressWarnings("unused")
- private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos)
+ private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos)
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
-
+
Hop datagen = null;
-
+
//ppred(X,X,"==") -> matrix(1,
rows=nrow(X),cols=nrow(Y))
if( left==right && bop.getOp()==OpOp2.EQUAL ||
bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL )
datagen = HopRewriteUtils.createDataGenOp(left,
1);
-
+
//ppred(X,X,"!=") -> matrix(0,
rows=nrow(X),cols=nrow(Y))
if( left==right && bop.getOp()==OpOp2.NOTEQUAL ||
bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS )
datagen = HopRewriteUtils.createDataGenOp(left,
0);
-
+
if( datagen != null ) {
HopRewriteUtils.replaceChildReference(parent,
hi, datagen, pos);
hi = datagen;
}
}
-
+
return hi;
}
-
+
private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
Iterator<Hop> iter = roots.iterator();
while(iter.hasNext()) {
Hop root = iter.next();
if( HopRewriteUtils.isData(root,
OpOpData.TRANSIENTWRITE)
- && HopRewriteUtils.isData(root.getInput(0),
OpOpData.TRANSIENTREAD)
- &&
root.getName().equals(root.getInput(0).getName())
- && !root.getInput(0).requiresCheckpoint())
+ &&
HopRewriteUtils.isData(root.getInput(0), OpOpData.TRANSIENTREAD)
+ &&
root.getName().equals(root.getInput(0).getName())
+ &&
!root.getInput(0).requiresCheckpoint())
{
iter.remove();
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteConstantConjunctionDisjunctionTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteConstantConjunctionDisjunctionTest.java
new file mode 100644
index 0000000000..c42c7e0992
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteConstantConjunctionDisjunctionTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class RewriteConstantConjunctionDisjunctionTest extends
AutomatedTestBase {
+
+ private static final String TEST_NAME_AND =
"RewriteBooleanSimplificationTestAnd";
+ private static final String TEST_NAME_OR =
"RewriteBooleanSimplificationTestOr";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteConstantConjunctionDisjunctionTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME_AND, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND, new String[] {"R"}));
+ addTestConfiguration(TEST_NAME_OR, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR, new String[] {"R"}));
+ }
+
+ @Test
+ public void testBooleanRewriteAnd() {
+ testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP,
0.0);
+ }
+
+ @Test
+ public void testBooleanRewriteOr() {
+ testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP,
1.0);
+ }
+
+ private void testRewriteBooleanSimplification(String testname, ExecType
et, double expected) {
+ ExecMode platformOld = setExecMode(et);
+
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{"-stats", "-explain",
"-args", output("R")};
+
+ runTest(true, false, null, -1);
+
+ Double ret = readDMLMatrixFromOutputDir("R").get(new
CellIndex(1,1));
+ if( ret == null )
+ ret = 0d;
+ Assert.assertEquals(
+ "Expected boolean simplification result does
not match",
+ expected, ret, 0.0001);
+ Assert.assertFalse(heavyHittersContainsString("!"));
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml
b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml
new file mode 100644
index 0000000000..16a6001f12
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = as.scalar(rand(rows=1,cols=1,seed=1)<0.5);
+result1 = as.matrix(!a & a);
+result2 = as.matrix(a & !a);
+result = result1 & result2;
+write(result, $1)
+
diff --git
a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml
b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml
new file mode 100644
index 0000000000..7aada615ca
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = as.scalar(rand(rows=1,cols=1,seed=1)<0.5);
+result1 = as.matrix(!a | a);
+result2 = as.matrix(a | !a);
+result = result1 & result2;
+write(result, $1)
+