[SYSTEMML-2444] Fix robustness wsloss rewrite order, Kmeans-predict SystemML 1.2 includes several fixes for correct constant propagation. Theses modifications, however, led to a performance regression for Kmeans-predict 10Mx1K because due to partial unknowns the sum-sq rewrite was applied before the wsloss rewrite, which destroyed the wsloss pattern. We now made the overlapping patterns more robust by matching against sum(tmp^2) and sumSq(tmp).
On the perftest Kmeans-predict 10Mx1K (80GB) scenario, this patch improved end-to-end performance (incl read and spark ctx creation) from 252s to 63s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0aaf11d8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0aaf11d8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0aaf11d8 Branch: refs/heads/master Commit: 0aaf11d82b5680dbf4f21d1195f397e66edc22a2 Parents: 56c81cb Author: Matthias Boehm <[email protected]> Authored: Mon Jul 16 19:36:16 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jul 16 19:36:44 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 11 + .../RewriteAlgebraicSimplificationDynamic.java | 224 ++++++++++--------- 2 files changed, 127 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0aaf11d8/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 48b95cc..025f98a 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -777,6 +777,13 @@ public class HopRewriteUtils return ( hop.getNnz()==0 ); } + public static boolean isEqualMatrixSize(BinaryOp hop) { + return hop.getDataType().isMatrix() + && hop.getInput().get(0).getDataType().isMatrix() + && hop.getInput().get(1).getDataType().isMatrix() + && isEqualSize(hop.getInput().get(0), hop.getInput().get(1)); + } + public static boolean isEqualSize( Hop hop1, Hop hop2 ) { return (hop1.dimsKnown() && hop2.dimsKnown() && hop1.getDim1() == hop2.getDim1() @@ -1016,6 +1023,10 @@ public class HopRewriteUtils return hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply(); } + public static boolean isAggUnaryOp(Hop hop, AggOp op, Direction dir) { + return isAggUnaryOp(hop, op) && ((AggUnaryOp)hop).getDirection()==dir; + } + public static boolean isAggUnaryOp(Hop hop, AggOp...op) { if( !(hop instanceof AggUnaryOp) ) return false; http://git-wip-us.apache.org/repos/asf/systemml/blob/0aaf11d8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 062da2f..4f0ef51 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -1200,6 +1200,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) + * 4) sumSq (X - U %*% t(V)) (no weighting sumSq) * * NOTE: We include transpose into the pattern because during runtime we need to compute * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation @@ -1220,69 +1221,59 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //NOTE: there might be also a general simplification without custom operator //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2 Hop hnew = null; + boolean appliedPattern = false; - if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol - && ((AggUnaryOp)hi).getOp() == AggOp.SUM //all patterns rooted by sum() + if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) //all patterns rooted by sum() && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult { BinaryOp bop = (BinaryOp) hi.getInput().get(0); - boolean appliedPattern = false; //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2) if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) - && bop.getInput().get(0).getDataType()==DataType.MATRIX + && bop.getInput().get(0).getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv - && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2) + && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1).getInput().get(1), 2) ) { Hop W = bop.getInput().get(0); Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V)) if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) - && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv + && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv && tmp.getInput().get(0).getDataType() == DataType.MATRIX ) { //a) sum (W * (X - U %*% t(V)) ^ 2) int uvIndex = -1; if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices - && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT - { - uvIndex = 1; + && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) { //BLOCKSIZE CONSTRAINT + uvIndex = 1; } //b) sum (W * (U %*% t(V) - X) ^ 2) else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices - && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT - { + && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) { //BLOCKSIZE CONSTRAINT uvIndex = 0; - } - - if( uvIndex >= 0 ) //rewrite match - { + } + + if( uvIndex >= 0 ) { //rewrite match Hop X = tmp.getInput().get((uvIndex==0)?1:0); Hop U = tmp.getInput().get(uvIndex).getInput().get(0); Hop V = tmp.getInput().get(uvIndex).getInput().get(1); - - if( !HopRewriteUtils.isTransposeOperation(V) ) { - V = HopRewriteUtils.createTranspose(V); - } - else{ - V = V.getInput().get(0); - } - + V = !HopRewriteUtils.isTransposeOperation(V) ? + HopRewriteUtils.createTranspose(V) : V.getInput().get(0); + //handle special case of post_nz if( HopRewriteUtils.isNonZeroIndicator(W, X) ){ W = new LiteralOp(1); } - + //construct quaternary hop - hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, - OpOp4.WSLOSS, X, U, V, W, true); + hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, + ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true); HopRewriteUtils.setOutputParametersForScalar(hnew); - + appliedPattern = true; - LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")"); + LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")"); } } } @@ -1290,107 +1281,124 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2) if( !appliedPattern - && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2 - && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) - && bop.getInput().get(0).getDataType()==DataType.MATRIX - && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv - && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) + && bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2) + && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) + && HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) { - Hop lleft = bop.getInput().get(0).getInput().get(0); - Hop lright = bop.getInput().get(0).getInput().get(1); - - //a) sum ((X - W * (U %*% t(V))) ^ 2) - int wuvIndex = -1; - if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){ - wuvIndex = 1; - } - //b) sum ((W * (U %*% t(V)) - X) ^ 2) - else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){ - wuvIndex = 0; - } - - if( wuvIndex >= 0 ) //rewrite match - { - Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0); - Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V))) - - if( ((BinaryOp)tmp).getOp()==OpOp2.MULT - && tmp.getInput().get(0).getDataType() == DataType.MATRIX - && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv - && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT - { - Hop W = tmp.getInput().get(0); - Hop U = tmp.getInput().get(1).getInput().get(0); - Hop V = tmp.getInput().get(1).getInput().get(1); - - if( !HopRewriteUtils.isTransposeOperation(V) ) { - V = HopRewriteUtils.createTranspose(V); - } - else { - V = V.getInput().get(0); - } - - hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, - OpOp4.WSLOSS, X, U, V, W, false); - HopRewriteUtils.setOutputParametersForScalar(hnew); - - appliedPattern = true; - LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")"); - } - } + Hop lleft = bop.getInput().get(0).getInput().get(0); + Hop lright = bop.getInput().get(0).getInput().get(1); + + //a) sum ((X - W * (U %*% t(V))) ^ 2) + int wuvIndex = -1; + if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){ + wuvIndex = 1; + } + //b) sum ((W * (U %*% t(V)) - X) ^ 2) + else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){ + wuvIndex = 0; + } + + if( wuvIndex >= 0 ) //rewrite match + { + Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0); + Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V))) + + if( ((BinaryOp)tmp).getOp()==OpOp2.MULT + && tmp.getInput().get(0).getDataType() == DataType.MATRIX + && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv + && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT + { + Hop W = tmp.getInput().get(0); + Hop U = tmp.getInput().get(1).getInput().get(0); + Hop V = tmp.getInput().get(1).getInput().get(1); + V = !HopRewriteUtils.isTransposeOperation(V) ? + HopRewriteUtils.createTranspose(V) : V.getInput().get(0); + hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, + ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); + HopRewriteUtils.setOutputParametersForScalar(hnew); + appliedPattern = true; + LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")"); + } + } } - + //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) //alternative pattern: sum (((U %*% t(V)) - X) ^ 2) if( !appliedPattern - && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2 - && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) - && bop.getInput().get(0).getDataType()==DataType.MATRIX - && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv - && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) + && bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2) + && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) + && HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) //prevent mv { Hop lleft = bop.getInput().get(0).getInput().get(0); Hop lright = bop.getInput().get(0).getInput().get(1); - + //a) sum ((X - (U %*% t(V))) ^ 2) int uvIndex = -1; - if( lright instanceof AggBinaryOp //ba gurantees matrices - && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT - { + if( lright instanceof AggBinaryOp //ba guarantees matrices + && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT uvIndex = 1; } //b) sum (((U %*% t(V)) - X) ^ 2) - else if( lleft instanceof AggBinaryOp //ba gurantees matrices - && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT - { + else if( lleft instanceof AggBinaryOp //ba guarantees matrices + && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT uvIndex = 0; } - - if( uvIndex >= 0 ) //rewrite match - { + + if( uvIndex >= 0 ) { //rewrite match Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0); Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) Hop W = new LiteralOp(1); //no weighting Hop U = tmp.getInput().get(0); Hop V = tmp.getInput().get(1); - - if( !HopRewriteUtils.isTransposeOperation(V) ) { - V = HopRewriteUtils.createTranspose(V); - } - else { - V = V.getInput().get(0); - } - - hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, - OpOp4.WSLOSS, X, U, V, W, false); + V = !HopRewriteUtils.isTransposeOperation(V) ? + HopRewriteUtils.createTranspose(V) : V.getInput().get(0); + hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, + ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); HopRewriteUtils.setOutputParametersForScalar(hnew); - appliedPattern = true; - LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")"); + + LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")"); } - } + } + } + + //Pattern 4) sumSq (X - U %*% t(V)) (no weighting) + //alternative pattern: sumSq (U %*% t(V) - X) + if( !appliedPattern + && HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ, Direction.RowCol) + && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MINUS) + && HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput().get(0))) //prevent mv + { + Hop lleft = hi.getInput().get(0).getInput().get(0); + Hop lright = hi.getInput().get(0).getInput().get(1); + + //a) sumSq (X - U %*% t(V)) + int uvIndex = -1; + if( lright instanceof AggBinaryOp //ba guarantees matrices + && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT + uvIndex = 1; + } + //b) sumSq (U %*% t(V) - X) + else if( lleft instanceof AggBinaryOp //ba guarantees matrices + && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT + uvIndex = 0; + } + + if( uvIndex >= 0 ) { //rewrite match + Hop X = hi.getInput().get(0).getInput().get((uvIndex==0)?1:0); + Hop tmp = hi.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) + Hop W = new LiteralOp(1); //no weighting + Hop U = tmp.getInput().get(0); + Hop V = tmp.getInput().get(1); + V = !HopRewriteUtils.isTransposeOperation(V) ? + HopRewriteUtils.createTranspose(V) : V.getInput().get(0); + hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, + ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); + HopRewriteUtils.setOutputParametersForScalar(hnew); + appliedPattern = true; + + LOG.debug("Applied simplifyWeightedSquaredLoss4"+uvIndex+" (line "+hi.getBeginLine()+")"); + } } //relink new hop into original position @@ -2168,7 +2176,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule HopRewriteUtils.cleanupUnreferenced(hi, sumInput); hi = sumSq; - LOG.debug("Applied fuseSumSquared."); + LOG.debug("Applied fuseSumSquared (line " +hi.getBeginLine()+")."); } } }
