Repository: incubator-systemml Updated Branches: refs/heads/master 4316efeba -> 2f7fa8d73
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 41459b4..2ae27c8 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -199,11 +199,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( dright.getOp()==DataGenMethod.RAND && dright.hasConstantValue() ) { Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN)); - HopRewriteUtils.removeChildReference(bop, dright); - HopRewriteUtils.addChildReference(bop, drightIn, 1); - //cleanup if only consumer of intermediate - if( dright.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( dright ); + HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1); + HopRewriteUtils.cleanupUnreferenced(dright); LOG.debug("Applied removeUnnecessaryVectorizeOperation1"); } @@ -217,11 +214,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && (left.getDim1()==1 || right.getDim1()>1)) { Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN)); - HopRewriteUtils.removeChildReference(bop, dleft); - HopRewriteUtils.addChildReference(bop, dleftIn, 0); - //cleanup if only consumer of intermediate - if( dleft.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( dleft ); + HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0); + HopRewriteUtils.cleanupUnreferenced(dleft); LOG.debug("Applied removeUnnecessaryVectorizeOperation2"); } @@ -264,8 +258,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) { - HopRewriteUtils.removeChildReference(parent, bop); - HopRewriteUtils.addChildReference(parent, left, pos); + HopRewriteUtils.replaceChildReference(parent, bop, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")"); @@ -277,8 +270,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { if( bop.getOp()==OpOp2.MINUS ) { - HopRewriteUtils.removeChildReference(parent, bop); - HopRewriteUtils.addChildReference(parent, left, pos); + HopRewriteUtils.replaceChildReference(parent, bop, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line "+bop.getBeginLine()+")"); @@ -290,8 +282,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { if( bop.getOp()==OpOp2.MULT ) { - HopRewriteUtils.removeChildReference(parent, bop); - HopRewriteUtils.addChildReference(parent, right, pos); + HopRewriteUtils.replaceChildReference(parent, bop, right, pos); hi = right; LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line "+bop.getBeginLine()+")"); @@ -306,8 +297,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( bop.getOp()==OpOp2.MULT ) { bop.setOp(OpOp2.MINUS); - HopRewriteUtils.removeChildReferenceByPos(bop, left, 0); - HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); + HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0); hi = bop; LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line "+bop.getBeginLine()+")"); @@ -380,13 +370,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); - for( Hop p : parents ) { - int cpos = HopRewriteUtils.getChildReferencePos(p, bop); - HopRewriteUtils.removeChildReferenceByPos(p, bop, cpos); - HopRewriteUtils.addChildReference(p, gen, cpos); - //propagate potentially updated nnz=0 - p.refreshSizeInformation(); - } + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation1 (line "+bop.getBeginLine()+")."); @@ -417,13 +402,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); - for( Hop p : parents ) { - int cpos = HopRewriteUtils.getChildReferencePos(p, bop); - HopRewriteUtils.removeChildReferenceByPos(p, bop, cpos); - HopRewriteUtils.addChildReference(p, gen, cpos); - //propagate potentially updated nnz=0 - p.refreshSizeInformation(); - } + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+")."); @@ -472,13 +452,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); - for( Hop p : parents ) { - int cpos = HopRewriteUtils.getChildReferencePos(p, bop); - HopRewriteUtils.removeChildReferenceByPos(p, bop, cpos); - HopRewriteUtils.addChildReference(p, inputGen, cpos); - //propagate potentially updated nnz=0 - p.refreshSizeInformation(); - } + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, bop, inputGen); hi = inputGen; LOG.debug("Applied fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+")."); @@ -538,8 +513,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } //patterns: (X>0)-(X<0) -> sign(X) else if( bop.getOp() == OpOp2.MINUS - && left instanceof BinaryOp && right instanceof BinaryOp - && ((BinaryOp)left).getOp()==OpOp2.GREATER && ((BinaryOp)right).getOp()==OpOp2.LESS + && HopRewriteUtils.isBinary(left, OpOp2.GREATER) + && HopRewriteUtils.isBinary(right, OpOp2.LESS) && left.getInput().get(0) == right.getInput().get(0) && left.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0 @@ -547,15 +522,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 ) { UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN); - - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.removeAllChildReferences(hi); - HopRewriteUtils.addChildReference(parent, uop, pos); - if( left.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(left); - if( right.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(right); - + HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); + HopRewriteUtils.cleanupUnreferenced(hi, left, right); hi = uop; LOG.debug("Applied simplifyBinaryToUnaryOperation3"); @@ -598,9 +566,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && bop.getOp() == OpOp2.MINUS ) { bop.setOp(OpOp2.PLUS); - HopRewriteUtils.removeChildReferenceByPos(bop, right, 1); - HopRewriteUtils.addChildReference(bop, - HopRewriteUtils.createBinary(new LiteralOp(0), right, OpOp2.MINUS), 1); + HopRewriteUtils.replaceChildReference(bop, right, + HopRewriteUtils.createBinaryMinus(right), 1); LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+")."); } } @@ -633,14 +600,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1()) { ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV); - - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, rop, pos); - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(hi); - if( top.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(top); - + HopRewriteUtils.replaceChildReference(parent, hi, rop, pos); + HopRewriteUtils.cleanupUnreferenced(hi, top); hi = rop; LOG.debug("Applied simplifyReverseOperation."); @@ -653,12 +614,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) { //pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate) - if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS + if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 - && hi.getInput().get(1) instanceof BinaryOp - && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && hi.getInput().get(1).getParent().size() == 1 ) //single consumer { BinaryOp bop = (BinaryOp)hi; @@ -703,7 +663,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) { Hop X = null; Hop Y = null; - if( left instanceof BinaryOp && ((BinaryOp)left).getOp()==OpOp2.MULT ) //(Y*X-X) -> (Y-1)*X + if( HopRewriteUtils.isBinary(left, OpOp2.MULT) ) //(Y*X-X) -> (Y-1)*X { Hop leftC1 = left.getInput().get(0); Hop leftC2 = left.getInput().get(1); @@ -717,10 +677,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( X != null ){ //rewrite 'binary +/-' HopRewriteUtils.removeChildReference(parent, hi); LiteralOp literal = new LiteralOp(1); - BinaryOp plus = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), bop.getOp(), Y, literal); - HopRewriteUtils.refreshOutputParameters(plus, right); - BinaryOp mult = new BinaryOp(left.getName(), left.getDataType(), left.getValueType(), OpOp2.MULT, plus, X); - HopRewriteUtils.refreshOutputParameters(mult, left); + BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp()); + BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); HopRewriteUtils.addChildReference(parent, mult, pos); hi = mult; @@ -730,7 +688,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } } - if( !applied && right instanceof BinaryOp && ((BinaryOp)right).getOp()==OpOp2.MULT ) //(X-Y*X) -> (1-Y)*X + if( !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X { Hop rightC1 = right.getInput().get(0); Hop rightC2 = right.getInput().get(1); @@ -740,14 +698,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Y = ( left == rightC1 ) ? rightC2 : rightC1; } if( X != null ){ //rewrite '+/- binary' - HopRewriteUtils.removeChildReference(parent, hi); LiteralOp literal = new LiteralOp(1); - BinaryOp plus = new BinaryOp(left.getName(), left.getDataType(), left.getValueType(), bop.getOp(), literal, Y); - HopRewriteUtils.refreshOutputParameters(plus, left); - BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, plus, X); - HopRewriteUtils.refreshOutputParameters(mult, right); - - HopRewriteUtils.addChildReference(parent, mult, pos); + BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp()); + BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); + HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); hi = mult; LOG.debug("Applied simplifyDistributiveBinaryOperation2"); @@ -797,14 +751,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && (right2 instanceof AggBinaryOp) ) { //(X*(Y*op()) -> (X*Y)*op() - HopRewriteUtils.removeChildReference(parent, bop); - - BinaryOp bop3 = new BinaryOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, op, left, left2); - HopRewriteUtils.refreshOutputParameters(bop3, bop); - BinaryOp bop4 = new BinaryOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, op, bop3, right2); - HopRewriteUtils.refreshOutputParameters(bop4, bop2); - - HopRewriteUtils.addChildReference(parent, bop4, pos); + BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op); + BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op); + HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); hi = bop4; applied = true; @@ -828,10 +777,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //((op()*X)*Y) -> op()*(X*Y) HopRewriteUtils.removeChildReference(parent, bop); - BinaryOp bop3 = new BinaryOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, op, right2, right); - HopRewriteUtils.refreshOutputParameters(bop3, bop2); - BinaryOp bop4 = new BinaryOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, op, left2, bop3); - HopRewriteUtils.refreshOutputParameters(bop4, bop); + BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op); + BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op); HopRewriteUtils.addChildReference(parent, bop4, pos); hi = bop4; @@ -871,7 +818,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) throws HopsException { - if( hi instanceof UnaryOp && ((UnaryOp)hi).getOp()==OpOp1.CAST_AS_SCALAR + if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp ) { BinaryOp bin = (BinaryOp) hi.getInput().get(0); @@ -896,8 +843,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } if( bout != null ) { - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, bout, pos); + HopRewriteUtils.replaceChildReference(parent, hi, bout, pos); LOG.debug("Applied simplifyBinaryMatrixScalarOperation."); } @@ -910,8 +856,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { if( hi instanceof AggUnaryOp && hi.getParent().size()==1 && (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col) - && hi.getInput().get(0) instanceof ReorgOp && hi.getInput().get(0).getParent().size()==1 - && ((ReorgOp)hi.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE + && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { AggUnaryOp uagg = (AggUnaryOp) hi; @@ -949,7 +894,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) // probed at root node of b in above example // (with support for left or right scalar operations) - if( HopRewriteUtils.isTransposeOperation(hi) && hi.getParent().size()==1 + if( HopRewriteUtils.isTransposeOperation(hi, 1) && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) && hi.getInput().get(0).getParent().size()==1) { @@ -982,10 +927,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws HopsException { //pattern: sum(lamda*X) -> lamda*sum(X) if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol - && ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM - && ((AggUnaryOp)hi).getInput().get(0) instanceof BinaryOp - && ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.MULT - && hi.getInput().get(0).getParent().size() == 1 // only one parent which is the sum + && ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM // only one parent which is the sum + && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX) ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) { @@ -999,8 +942,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol); Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, bop, pos); + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); LOG.debug("Applied pushdownSumBinaryMult."); return bop; @@ -1021,10 +963,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { //clear link unary-binary Hop input = uop.getInput().get(0); - HopRewriteUtils.removeAllChildReferences(hi); - - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, input, pos); + HopRewriteUtils.replaceChildReference(parent, hi, input, pos); + HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied simplifyUnaryPPredOperation."); @@ -1037,29 +977,25 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos ) { //e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B) - if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //t() rooted + if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted && hi.getInput().get(0) instanceof BinaryOp && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) && hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append { BinaryOp bop = (BinaryOp)hi.getInput().get(0); - if( bop.getInput().get(0) instanceof ReorgOp //both inputs transpose ops - && ((ReorgOp)bop.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE - && bop.getInput().get(0).getParent().size() == 1 //single consumer of transpose - && bop.getInput().get(1) instanceof ReorgOp - && ((ReorgOp)bop.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE - && bop.getInput().get(1).getParent().size() == 1 ) //single consumer of transpose + //both inputs transpose ops, where transpose is single consumer + if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) + && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) ) { Hop left = bop.getInput().get(0).getInput().get(0); Hop right = bop.getInput().get(1).getInput().get(0); //create new subdag (no in-place dag update to prevent anomalies with //multiple consumers during rewrite process) - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND; BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop); - HopRewriteUtils.addChildReference(parent, bopnew, pos); + HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos); hi = bopnew; LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+")."); @@ -1109,15 +1045,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule left2 == right && bleft.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); - if( left.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(left); - + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); @@ -1134,15 +1063,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule right2 == left && bright.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); - if( left.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(right); - + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); @@ -1172,37 +1094,24 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule UnaryOp unary = null; //Pattern 1: (1/(1 + exp(-X)) - if( uopin instanceof BinaryOp && ((BinaryOp)uopin).getOp()==OpOp2.MINUS ) - { + if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) { BinaryOp bop3 = (BinaryOp) uopin; Hop left3 = bop3.getInput().get(0); Hop right3 = bop3.getInput().get(1); - if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) { + if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID); - } } //Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by //the 'remove unnecessary minus' rewrite --> reintroduce the minus - else - { - BinaryOp minus = HopRewriteUtils.createMinus(uopin); + else { + BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin); unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID); } - if( unary != null ) - { - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); - if( bop2.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop2); - if( uop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(uop); - + if( unary != null ) { + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); @@ -1229,14 +1138,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule left1 == right && bleft.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); - if( left.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(left); + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; @@ -1255,14 +1158,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule right1 == left && bright.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); - if( left.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(right); + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied= true; @@ -1277,12 +1174,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3"); @@ -1293,12 +1186,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); - HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); - HopRewriteUtils.addChildReference(parent, unary, pos); - - //cleanup if only consumer of intermediate - if( bop.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences(bop); + HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); + HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4"); @@ -1313,36 +1202,19 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace() { Hop hi2 = hi.getInput().get(0); - if( hi2 instanceof AggBinaryOp && ((AggBinaryOp)hi2).isMatrixMultiply() ) //X%*%Y + if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y { Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); - //remove link from parent to diag - HopRewriteUtils.removeChildReference(parent, hi); - - //remove links to inputs to matrix mult - //removeChildReference(hi2, left); - //removeChildReference(hi2, right); - //create new operators (incl refresh size inside for transpose) ReorgOp trans = HopRewriteUtils.createTranspose(right); - BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, left, trans); - mult.setRowsInBlock(right.getRowsInBlock()); - mult.setColsInBlock(right.getColsInBlock()); - mult.refreshSizeInformation(); - AggUnaryOp sum = new AggUnaryOp(right.getName(), DataType.SCALAR, right.getValueType(), AggOp.SUM, Direction.RowCol, mult); - sum.refreshSizeInformation(); + BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); + AggUnaryOp sum = HopRewriteUtils.createSum(mult); //rehang new subdag under parent node - HopRewriteUtils.addChildReference(parent, sum, pos); - - //cleanup if only consumer of intermediate - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi ); - if( hi2.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi2 ); - + HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); + HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = sum; LOG.debug("Applied simplifyTraceMatrixMult"); @@ -1360,8 +1232,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule && ((IndexingOp)hi).getRowLowerEqualsUpper() && ((IndexingOp)hi).getColLowerEqualsUpper() && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer - && hi.getInput().get(0) instanceof AggBinaryOp - && ((AggBinaryOp)hi.getInput().get(0)).isMatrixMultiply() ) + && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) { Hop mm = hi.getInput().get(0); Hop X = mm.getInput().get(0); @@ -1374,11 +1245,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //create new indexing operations IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false); - HopRewriteUtils.setOutputBlocksizes(ix1, X.getRowsInBlock(), X.getColsInBlock()); + ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock()); ix1.refreshSizeInformation(); IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true); - HopRewriteUtils.setOutputBlocksizes(ix2, Y.getRowsInBlock(), Y.getColsInBlock()); + ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock()); ix2.refreshSizeInformation(); //rewire matrix mult over ix1 and ix2 @@ -1410,12 +1281,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) { //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2); seq.refreshSizeInformation(); - HopRewriteUtils.addChildReference(parent, seq, pos); - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeChildReference(hi, hi2); + HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); + HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; LOG.debug("Applied simplifyConstantSort1."); @@ -1423,10 +1292,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule else { //order(matrix(7), indexreturn=FALSE) -> matrix(7) - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, hi2, pos); - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeChildReference(hi, hi2); + HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); + HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; LOG.debug("Applied simplifyConstantSort2."); @@ -1458,12 +1325,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc); seq.refreshSizeInformation(); - HopRewriteUtils.addChildReference(parent, seq, pos); - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeChildReference(hi, hi2); + HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); + HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; LOG.debug("Applied simplifyOrderedSort1."); @@ -1471,10 +1336,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC { //order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1) - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, hi2, pos); - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeChildReference(hi, hi2); + HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); + HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; LOG.debug("Applied simplifyOrderedSort2."); @@ -1498,7 +1361,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) throws HopsException { - if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //transpose + if( HopRewriteUtils.isTransposeOperation(hi) && hi.getInput().get(0) instanceof BinaryOp //basic binary && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) { @@ -1507,10 +1370,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //check matrix mult and both inputs transposes w/ single consumer if( left instanceof AggBinaryOp && C.getDataType().isMatrix() - && left.getInput().get(0).getParent().size()==1 && left.getInput().get(0) instanceof ReorgOp - && ((ReorgOp)left.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE - && left.getInput().get(1).getParent().size()==1 && left.getInput().get(1) instanceof ReorgOp - && ((ReorgOp)left.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE ) + && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) + && left.getInput().get(0).getParent().size()==1 + && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) + && left.getInput().get(1).getParent().size()==1 ) { Hop A = left.getInput().get(0).getInput().get(0); Hop B = left.getInput().get(1).getInput().get(0); @@ -1519,8 +1382,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule ReorgOp rop = HopRewriteUtils.createTranspose(C); BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS); - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, bop, pos); + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); hi = bop; LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); @@ -1550,16 +1412,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { Hop hi3 = hi2.getInput().get(0); //remove unnecessary chain of t(t()) - HopRewriteUtils.removeChildReference(parent, hi); - HopRewriteUtils.addChildReference(parent, hi3, pos); + HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); + HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; - //cleanup if only consumer of intermediate - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi ); - if( hi2.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi2 ); - LOG.debug("Applied removeUnecessaryReorgOperation."); } } @@ -1582,16 +1438,10 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { Hop hi3 = hi2.getInput().get(1); //remove unnecessary chain of -(-()) - HopRewriteUtils.removeChildReference(parent, hi); - HopRewriteUtils.addChildReference(parent, hi3, pos); + HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); + HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; - //cleanup if only consumer of intermediate - if( hi.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi ); - if( hi2.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi2 ); - LOG.debug("Applied removeUnecessaryMinus"); } } @@ -1618,8 +1468,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop th = phi.getInput().get(ix1); Hop gh = phi.getInput().get(ix2); - HopRewriteUtils.removeChildReference(hi, th); - HopRewriteUtils.addChildReference(hi, gh, ix1); + HopRewriteUtils.replaceChildReference(hi, th, gh, ix1); LOG.debug("Applied simplifyGroupedAggregateCount"); } @@ -1635,29 +1484,25 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern X - (s * ppred(X,0,!=)) -> X -nz s //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate for X - tmp if X is sparse - if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS + if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX - && hi.getInput().get(1) instanceof BinaryOp - && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT ) + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) ) { Hop X = hi.getInput().get(0); Hop s = hi.getInput().get(1).getInput().get(0); Hop pred = hi.getInput().get(1).getInput().get(1); if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX - && pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL + && HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { - Hop hnew = new BinaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, OpOp2.MINUS_NZ, X, s); - HopRewriteUtils.setOutputBlocksizes(hnew, hi.getRowsInBlock(), hi.getColsInBlock()); - hnew.refreshSizeInformation(); - + Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ); + //relink new hop into original position - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, hnew, pos); + HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); @@ -1673,27 +1518,23 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern ppred(X,0,"!=")*log(X) -> log_nz(X) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse - if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MULT + if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX - && hi.getInput().get(1) instanceof UnaryOp - && ((UnaryOp)hi.getInput().get(1)).getOp()==OpOp1.LOG ) + && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); - if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL + if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { - Hop hnew = new UnaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, OpOp1.LOG_NZ, X); - HopRewriteUtils.setOutputBlocksizes(hnew, hi.getRowsInBlock(), hi.getColsInBlock()); - hnew.refreshSizeInformation(); - + Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ); + //relink new hop into original position - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, hnew, pos); + HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+")."); @@ -1709,28 +1550,24 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse - if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MULT + if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX - && hi.getInput().get(1) instanceof BinaryOp - && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.LOG ) + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); Hop log = hi.getInput().get(1).getInput().get(1); - if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL + if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { - Hop hnew = new BinaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, OpOp2.LOG_NZ, X, log); - HopRewriteUtils.setOutputBlocksizes(hnew, hi.getRowsInBlock(), hi.getColsInBlock()); - hnew.refreshSizeInformation(); - + Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ); + //relink new hop into original position - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, hnew, pos); + HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); @@ -1746,18 +1583,15 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //note: this rewrite supports both left/right sequence - if( hi instanceof BinaryOp && ((BinaryOp)hi).isOuterVectorOperator() - && ((BinaryOp)hi).getOp()==OpOp2.EQUAL ) + if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuterVectorOperator() ) { - if( ( hi.getInput().get(1) instanceof ReorgOp //pattern a: outer(v, t(seq(1,m)), "==") - && ((ReorgOp) hi.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE + if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==") && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") { //determine variable parameters for pattern a/b boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); - boolean isTransposeRight = (hi.getInput().get(1) instanceof ReorgOp - && ((ReorgOp) hi.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE); + boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); Hop trgt = isPatternB ? (isTransposeRight ? hi.getInput().get(1).getInput().get(0) : //get v from t(v) HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') @@ -1777,12 +1611,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //create new hop ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, ParamBuiltinOp.REXPAND, inputargs); - HopRewriteUtils.setOutputBlocksizes(pbop, hi.getRowsInBlock(), hi.getColsInBlock()); + pbop.setOutputBlocksizes(hi.getRowsInBlock(), hi.getColsInBlock()); pbop.refreshSizeInformation(); //relink new hop into original position - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, pbop, pos); + HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); hi = pbop; LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")"); @@ -1824,12 +1657,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //create new hop ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, ParamBuiltinOp.REXPAND, inputargs); - HopRewriteUtils.setOutputBlocksizes(pbop, hi.getRowsInBlock(), hi.getColsInBlock()); + pbop.setOutputBlocksizes(hi.getRowsInBlock(), hi.getColsInBlock()); pbop.refreshSizeInformation(); //relink new hop into original position - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - HopRewriteUtils.addChildReference(parent, pbop, pos); + HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); hi = pbop; LOG.debug("Applied simplifyTableSeqExpand (line "+hi.getBeginLine()+")"); @@ -1869,10 +1701,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS ) datagen = HopRewriteUtils.createDataGenOp(left, 0); - if( datagen != null ) - { - HopRewriteUtils.removeChildReference(parent, hi); - HopRewriteUtils.addChildReference(parent, datagen, pos); + if( datagen != null ) { + HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos); hi = datagen; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index 82babd1..0b4faf6 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -226,11 +226,8 @@ public class RewriteConstantFolding extends HopRewriteRule ec.getVariables().removeAll(); //set literal properties (scalar) - literal.setDim1(0); - literal.setDim2(0); - literal.setRowsInBlock(-1); - literal.setColsInBlock(-1); - + HopRewriteUtils.setOutputParametersForScalar(literal); + //System.out.println("Constant folded in "+time.stop()+"ms."); return literal; @@ -278,8 +275,7 @@ public class RewriteConstantFolding extends HopRewriteRule throws HopsException { ArrayList<Hop> in = hop.getInput(); - return ( hop instanceof BinaryOp - && ((BinaryOp)hop).getOp()==OpOp2.AND + return ( HopRewriteUtils.isBinary(hop, OpOp2.AND) && ( (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue()) ||(in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue())) ); } @@ -288,8 +284,7 @@ public class RewriteConstantFolding extends HopRewriteRule throws HopsException { ArrayList<Hop> in = hop.getInput(); - return ( hop instanceof BinaryOp - && ((BinaryOp)hop).getOp()==OpOp2.OR + return ( HopRewriteUtils.isBinary(hop, OpOp2.OR) && ( (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue()) ||(in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue())) ); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java index 08a9599..991dedd 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java @@ -40,7 +40,6 @@ import org.apache.sysml.parser.IfStatementBlock; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.WhileStatementBlock; import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; /** * Rule: Simplify program structure by pulling if or else statement body out @@ -181,8 +180,7 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos]; //replace cast with sum - AggUnaryOp newSum = new AggUnaryOp(cast.getName(), DataType.SCALAR, ValueType.DOUBLE, - aggOp, Direction.RowCol, ix); + AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol); HopRewriteUtils.removeChildReference(cast, ix); HopRewriteUtils.removeChildReference(bop, cast); HopRewriteUtils.addChildReference(bop, newSum, leftScalar?1:0 ); @@ -191,10 +189,8 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule //NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites int index1 = rowIx ? 1 : 3; int index2 = rowIx ? 2 : 4; - HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index1), index1); - HopRewriteUtils.addChildReference(ix, from, index1); - HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index2), index2); - HopRewriteUtils.addChildReference(ix, to, index2); + HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1); + HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2); //update indexing size information if( rowIx ) @@ -288,21 +284,13 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule int index1 = rowIx ? 2 : 4; int index2 = rowIx ? 3 : 5; //modify left indexing bounds - HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 ); - HopRewriteUtils.addChildReference(lix, from, index1); - HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 ); - HopRewriteUtils.addChildReference(lix, to, index2); + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1),from, index1); + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2),to, index2); //modify both right indexing - HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index1-1), index1-1 ); - HopRewriteUtils.addChildReference(rix0, from, index1-1); - HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index2-1), index2-1 ); - HopRewriteUtils.addChildReference(rix0, to, index2-1); - HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index1-1), index1-1 ); - HopRewriteUtils.addChildReference(rix1, from, index1-1); - HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index2-1), index2-1 ); - HopRewriteUtils.addChildReference(rix1, to, index2-1); - rix0.refreshSizeInformation(); - rix1.refreshSizeInformation(); + HopRewriteUtils.replaceChildReference(rix0, rix0.getInput().get(index1-1), from, index1-1); + HopRewriteUtils.replaceChildReference(rix0, rix0.getInput().get(index2-1), to, index2-1); + HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index1-1), from, index1-1); + HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index2-1), to, index2-1); bop.refreshSizeInformation(); lix.refreshSizeInformation(); @@ -385,16 +373,11 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule int index1 = rowIx ? 2 : 4; int index2 = rowIx ? 3 : 5; //modify left indexing bounds - HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 ); - HopRewriteUtils.addChildReference(lix, from, index1); - HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 ); - HopRewriteUtils.addChildReference(lix, to, index2); + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1); + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2); //modify right indexing - HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index1-1), index1-1 ); - HopRewriteUtils.addChildReference(rix, from, index1-1); - HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index2-1), index2-1 ); - HopRewriteUtils.addChildReference(rix, to, index2-1); - rix.refreshSizeInformation(); + HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1-1), from, index1-1); + HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2-1), to, index2-1); uop.refreshSizeInformation(); lix.refreshSizeInformation(); @@ -405,6 +388,4 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule return ret; } - - } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java index c770644..cf5ebce 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java @@ -177,10 +177,8 @@ public class RewriteIndexingVectorization extends HopRewriteRule for( Hop c : ihops ) { HopRewriteUtils.removeChildReference(c, input); //input data HopRewriteUtils.addChildReference(c, newRix, 0); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3),3); //col lower expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4),4); //col upper expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4); + HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3); //col lower expr + HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4); //col upper expr c.refreshSizeInformation(); } @@ -239,10 +237,8 @@ public class RewriteIndexingVectorization extends HopRewriteRule //reset row index all candidates and refresh sizes (bottom-up) for( int i=ihops.size()-1; i>=0; i-- ) { Hop c = ihops.get(i); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2), 2); //row lower expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3), 3); //row upper expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3); + HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2); //row lower expr + HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3); //row upper expr ((LeftIndexingOp)c).setRowLowerEqualsUpper(true); c.refreshSizeInformation(); } @@ -313,10 +309,8 @@ public class RewriteIndexingVectorization extends HopRewriteRule //reset col index all candidates and refresh sizes (bottom-up) for( int i=ihops.size()-1; i>=0; i-- ) { Hop c = ihops.get(i); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4), 4); //col lower expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4); - HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(5), 5); //col upper expr - HopRewriteUtils.addChildReference(c, new LiteralOp(1), 5); + HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4); //col lower expr + HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5); //col upper expr ((LeftIndexingOp)c).setColLowerEqualsUpper(true); c.refreshSizeInformation(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java index bb87fe8..9445fcb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java @@ -96,7 +96,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule if(hop.getVisited() == Hop.VisitStatus.DONE) return; - if ( hop instanceof AggBinaryOp && ((AggBinaryOp) hop).isMatrixMultiply() + if ( HopRewriteUtils.isMatrixMultiply(hop) && !((AggBinaryOp)hop).hasLeftPMInput() && hop.getVisited() != Hop.VisitStatus.DONE ) { @@ -159,7 +159,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule * (either within chain or outside the chain) */ - if ( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply() + if ( HopRewriteUtils.isMatrixMultiply(h) && !((AggBinaryOp)hop).hasLeftPMInput() && h.getVisited() != Hop.VisitStatus.DONE ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveReadAfterWrite.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveReadAfterWrite.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveReadAfterWrite.java index 61cc5e7..9348088 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveReadAfterWrite.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveReadAfterWrite.java @@ -69,11 +69,8 @@ public class RewriteRemoveReadAfterWrite extends HopRewriteRule //rewire read consumers to write input Hop input = writes.get(rfname).getInput().get(0); ArrayList<Hop> parents = (ArrayList<Hop>) rhop.getParent().clone(); - for( Hop p : parents ) { - int pos = HopRewriteUtils.getChildReferencePos(p, rhop); - HopRewriteUtils.removeChildReferenceByPos(p, rhop, pos); - HopRewriteUtils.addChildReference(p, input, pos); - } + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, rhop, input); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java index 0ff154c..95a1214 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java @@ -123,11 +123,8 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule Hop input = uop2.getInput().get(0); //rewire parents ArrayList<Hop> parents = (ArrayList<Hop>) hop.getParent().clone(); - for( Hop p : parents ) { - int ix = HopRewriteUtils.getChildReferencePos(p, hop); - HopRewriteUtils.removeChildReference(p, hop); - HopRewriteUtils.addChildReference(p, input, ix); - } + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hop, input); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 69cabc5..c4e6caa 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -37,7 +37,6 @@ import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; -import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; @@ -132,13 +131,9 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite for( int i=0; i<parents.size(); i++ ) { //prevent concurrent modification by index access Hop parent = parents.get(i); - if( !candChilds.contains(parent) ) //anomaly filter - { - if( parent != twrite ) { - int pos = HopRewriteUtils.getChildReferencePos(parent, c); - HopRewriteUtils.removeChildReferenceByPos(parent, c, pos); - HopRewriteUtils.addChildReference(parent, tread, pos); - } + if( !candChilds.contains(parent) ) { //anomaly filter + if( parent != twrite ) + HopRewriteUtils.replaceChildReference(parent, c, tread); else sb.get_hops().remove(parent); } @@ -163,11 +158,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //prevent concurrent modification by index access Hop parent = parents.get(i); if( !candChilds.contains(parent) ) //anomaly filter - { - int pos = HopRewriteUtils.getChildReferencePos(parent, c); - HopRewriteUtils.removeChildReferenceByPos(parent, c, pos); - HopRewriteUtils.addChildReference(parent, tread, pos); - } + HopRewriteUtils.replaceChildReference(parent, c, tread); } //add data-dependent operator sub dag to first statement block @@ -258,7 +249,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite for( Hop p : hop.getParent() ) { //list of operators without need for empty blocks to be extended as needed noEmptyBlocks &= ( p instanceof AggBinaryOp && hop == p.getInput().get(0) - || p instanceof UnaryOp && ((UnaryOp)p).getOp()==OpOp1.NROW); + || HopRewriteUtils.isUnary(p, OpOp1.NROW) ); onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0)); } pbhop.setOutputEmptyBlocks(!noEmptyBlocks); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index e0a6590..692762e 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -87,9 +87,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule for( int i=0; i<parents.size(); i++ ) { Hop parent = parents.get(i); - int pos = HopRewriteUtils.getChildReferencePos(parent, reblock); - HopRewriteUtils.removeChildReferenceByPos(parent, reblock, pos); - HopRewriteUtils.addChildReference(parent, tread, pos); + HopRewriteUtils.replaceChildReference(parent, reblock, tread); } //add reblock sub dag to first statement block http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index e3533b7..0063997 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2899,9 +2899,9 @@ public class DMLTranslator } private void setBlockSizeAndRefreshSizeInfo(Hop in, Hop out) { - HopRewriteUtils.setOutputBlocksizes(out, in.getRowsInBlock(), in.getColsInBlock()); - HopRewriteUtils.copyLineNumbers(in, out); + out.setOutputBlocksizes(in.getRowsInBlock(), in.getColsInBlock()); out.refreshSizeInformation(); + HopRewriteUtils.copyLineNumbers(in, out); } private ArrayList<Hop> getALHopsForConvOpPoolingCOL2IM(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 66065bf..bbe5bf7 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -2589,7 +2589,7 @@ public class OptimizerRuleBased extends Optimizer ret = true; sharedVars.add(ch.getName()); } - else if( ch instanceof ReorgOp && ((ReorgOp)ch).getOp()==ReOrgOp.TRANSPOSE + else if( HopRewriteUtils.isTransposeOperation(ch) && ch.getInput().get(0) instanceof DataOp && ch.getInput().get(0).getDataType() == DataType.MATRIX && inputVars.contains(ch.getInput().get(0).getName()) ) //&& !partitionedVars.contains(ch.getInput().get(0).getName())) @@ -2707,8 +2707,7 @@ public class OptimizerRuleBased extends Optimizer for( Hop in : h.getInput() ) { if( in instanceof DataOp ) cand.add( in.getName() ); - else if( in instanceof ReorgOp - && ((ReorgOp)in).getOp()==ReOrgOp.TRANSPOSE + else if( HopRewriteUtils.isTransposeOperation(in) && in.getInput().get(0) instanceof DataOp ) cand.add( in.getInput().get(0).getName() ); }
