Repository: systemml Updated Branches: refs/heads/master 2869d53ce -> 92ecae34d
[SYSTEMML-1662] Fix rewrite issues discovered by extended hop validator This fix pack addresses various issues discovered by the extended HOP DAG validator: (1) Forced visit status reset after static/dynamic rewrites (some of these simplification rewrites create "holes" of new operators with visted=false which is problematic if above and below operators are already marked as visited because an subsequent reset will stop at the "hole" - we now force a full dag traversal after these rewrites) (2) Fix visit status handling in common subexpression elimination (the merge of leaf and inner nodes was incorrect as it did not mark merged opeators as visited leading again to "holes" in the dag) (3) Fix visit status left indexing vectorization (similar to 1 and 2 but for this rewrite it originated from not considered replaced operators for recursive traversal) (4) Fix value type handling of left indexing expressions at parser (validation) level, which already created corrupted hops for left indexing. There are additional issues related to data and value type handling in the mlcontext and jmlc APIs but these will be addressed once the HOP DAG validator is actually pushed to master. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/92ecae34 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/92ecae34 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/92ecae34 Branch: refs/heads/master Commit: 92ecae34db63559131bf50f747650866c25ff4e1 Parents: 2869d53 Author: Matthias Boehm <[email protected]> Authored: Thu Jun 8 23:12:10 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 9 12:17:41 2017 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 23 +++++++- .../RewriteAlgebraicSimplificationDynamic.java | 6 +-- .../RewriteAlgebraicSimplificationStatic.java | 22 +++----- .../RewriteCommonSubexpressionElimination.java | 3 ++ .../rewrite/RewriteIndexingVectorization.java | 57 +++++++++++++------- .../org/apache/sysml/parser/DMLTranslator.java | 2 +- .../functions/codegen/AlgorithmL2SVM.java | 2 +- 7 files changed, 75 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 31e3aa6..742f22a 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -865,13 +866,33 @@ public abstract class Hop hopRoot.resetVisitStatus(); } + public static void resetVisitStatus( ArrayList<Hop> hops, boolean force ) { + if( !force ) + resetVisitStatus(hops); + else { + HashSet<Long> memo = new HashSet<Long>(); + if( hops != null ) + for( Hop hopRoot : hops ) + hopRoot.resetVisitStatusForced(memo); + } + } + public void resetVisitStatus() { if( !isVisited() ) return; - for( Hop h : this.getInput() ) + for( Hop h : getInput() ) h.resetVisitStatus(); setVisited(false); } + + public void resetVisitStatusForced(HashSet<Long> memo) { + if( memo.contains(getHopID()) ) + return; + for( Hop h : getInput() ) + h.resetVisitStatusForced(memo); + setVisited(false); + memo.add(getHopID()); + } public static void resetRecompilationFlag( ArrayList<Hop> hops, ExecType et ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 74e832b..ad80c05 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -94,13 +94,13 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //one pass rewrite-descend (rewrite created pattern) for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); + Hop.resetVisitStatus(roots, true); - Hop.resetVisitStatus(roots); - //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); - + Hop.resetVisitStatus(roots, true); + return roots; } http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 17f1ace..d6b4559 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -81,12 +81,12 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //one pass rewrite-descend (rewrite created pattern) for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); - - Hop.resetVisitStatus(roots); + Hop.resetVisitStatus(roots, true); //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); + Hop.resetVisitStatus(roots, true); return roots; } @@ -498,22 +498,16 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2 { bop.setOp(OpOp2.MULT); - LiteralOp tmp = new LiteralOp(2); - bop.getInput().remove(1); - right.getParent().remove(bop); - HopRewriteUtils.addChildReference(hi, tmp, 1); - - LOG.debug("Applied simplifyBinaryToUnaryOperation1"); + HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); + + LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line "+hi.getBeginLine()+")."); } else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2 { bop.setOp(OpOp2.POW); - LiteralOp tmp = new LiteralOp(2); - bop.getInput().remove(1); - right.getParent().remove(bop); - HopRewriteUtils.addChildReference(hi, tmp, 1); + HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); - LOG.debug("Applied simplifyBinaryToUnaryOperation2"); + LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line "+hi.getBeginLine()+")."); } } //patterns: (X>0)-(X<0) -> sign(X) @@ -531,7 +525,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.cleanupUnreferenced(hi, left, right); hi = uop; - LOG.debug("Applied simplifyBinaryToUnaryOperation3"); + LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line "+hi.getBeginLine()+")."); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java index 72957dd..5379dfe 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java @@ -131,6 +131,7 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule Hop tmp = dataops.get(hi.getName()); if( tmp != hi ) { //if required tmp.getParent().add(hop); + tmp.setVisited(); hop.getInput().set(i, tmp); ret++; } @@ -142,6 +143,7 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule //replace child node ref if( tmp != hi ){ //if required tmp.getParent().add(hop); + tmp.setVisited(); hop.getInput().set(i, tmp); ret++; } @@ -200,6 +202,7 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule { p.getInput().set(k, h1); h1.getParent().add(p); + h1.setVisited(); } //replace h2 w/ h1 in h2-input parents http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java index 4ce1d43..b797e00 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java @@ -50,7 +50,7 @@ public class RewriteIndexingVectorization extends HopRewriteRule { if( roots == null ) return roots; - + for( Hop h : roots ) rule_IndexingVectorization( h ); @@ -84,9 +84,9 @@ public class RewriteIndexingVectorization extends HopRewriteRule //MB: disabled right indexing rewrite because (1) piggybacked in MR anyway, (2) usually //not too much overhead, and (3) makes literal replacement more difficult //vectorizeRightIndexing( hi ); //e.g., multiple rightindexing X[i,1], X[i,3] -> X[i,]; - vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,]; + hi = vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,]; - //process childs recursively after rewrites + //process childs recursively after rewrites rule_IndexingVectorization( hi ); } @@ -189,9 +189,11 @@ public class RewriteIndexingVectorization extends HopRewriteRule } @SuppressWarnings("unchecked") - private void vectorizeLeftIndexing( Hop hop ) + private Hop vectorizeLeftIndexing( Hop hop ) throws HopsException - { + { + Hop ret = hop; + if( hop instanceof LeftIndexingOp ) //left indexing { LeftIndexingOp ihop0 = (LeftIndexingOp) hop; @@ -224,11 +226,14 @@ public class RewriteIndexingVectorization extends HopRewriteRule Hop rowExpr = ihop0.getInput().get(2); //keep before reset //new row indexing operator - IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, - rowExpr, rowExpr, new LiteralOp(1), - HopRewriteUtils.createValueHop(input, false), true, false); + IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), + input, rowExpr, rowExpr, new LiteralOp(1), + HopRewriteUtils.createValueHop(input, false), true, false); HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1); newRix.refreshSizeInformation(); + //reset visit status of copied hops (otherwise hidden by left indexing) + for( Hop c : newRix.getInput() ) + c.resetVisitStatus(); //rewrite bottom left indexing operator HopRewriteUtils.removeChildReference(current, input); //input data @@ -253,11 +258,14 @@ public class RewriteIndexingVectorization extends HopRewriteRule ihop0parentsPos.add(posp); } - LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, - rowExpr, rowExpr, new LiteralOp(1), - HopRewriteUtils.createValueHop(input, false), true, false); + LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), + input, ihop0, rowExpr, rowExpr, new LiteralOp(1), + HopRewriteUtils.createValueHop(input, false), true, false); HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1); newLix.refreshSizeInformation(); + //reset visit status of copied hops (otherwise hidden by left indexing) + for( Hop c : newLix.getInput() ) + c.resetVisitStatus(); for( int i=0; i<ihop0parentsPos.size(); i++ ) { Hop parent = ihop0parents.get(i); @@ -266,7 +274,8 @@ public class RewriteIndexingVectorization extends HopRewriteRule } appliedRow = true; - LOG.debug("Applied vectorizeLeftIndexingRow"); + ret = newLix; + LOG.debug("Applied vectorizeLeftIndexingRow for hop "+hop.getHopID()); } } @@ -296,11 +305,14 @@ public class RewriteIndexingVectorization extends HopRewriteRule Hop colExpr = ihop0.getInput().get(4); //keep before reset //new row indexing operator - IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, - new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), - colExpr, colExpr, false, true); + IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), + input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), + colExpr, colExpr, false, true); HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1); newRix.refreshSizeInformation(); + //reset visit status of copied hops (otherwise hidden by left indexing) + for( Hop c : newRix.getInput() ) + c.resetVisitStatus(); //rewrite bottom left indexing operator HopRewriteUtils.removeChildReference(current, input); //input data @@ -325,11 +337,14 @@ public class RewriteIndexingVectorization extends HopRewriteRule ihop0parentsPos.add(posp); } - LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, - new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), - colExpr, colExpr, false, true); + LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), + input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), + colExpr, colExpr, false, true); HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1); newLix.refreshSizeInformation(); + //reset visit status of copied hops (otherwise hidden by left indexing) + for( Hop c : newLix.getInput() ) + c.resetVisitStatus(); for( int i=0; i<ihop0parentsPos.size(); i++ ) { Hop parent = ihop0parents.get(i); @@ -337,10 +352,12 @@ public class RewriteIndexingVectorization extends HopRewriteRule HopRewriteUtils.addChildReference(parent, newLix, posp); } - appliedRow = true; - LOG.debug("Applied vectorizeLeftIndexingCol"); + ret = newLix; + LOG.debug("Applied vectorizeLeftIndexingCol for hop "+hop.getHopID()); } } } + + return ret; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 75f11b5..2b4bfa0 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1547,7 +1547,7 @@ public class DMLTranslator if( sourceOp.getDataType() == DataType.MATRIX && source.getOutput().getDataType() == DataType.SCALAR ) sourceOp.setDataType(DataType.SCALAR); - Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), target.getValueType(), + Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), ValueType.DOUBLE, targetOp, sourceOp, rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper()); http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java index 2e03ce6..788c73d 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java @@ -43,7 +43,7 @@ public class AlgorithmL2SVM extends AutomatedTestBase private final static double eps = 1e-5; - private final static int rows = 1468; + private final static int rows = 3468; private final static int cols1 = 1007; private final static int cols2 = 987;
