Repository: systemml
Updated Branches:
  refs/heads/master 2869d53ce -> 92ecae34d


[SYSTEMML-1662] Fix rewrite issues discovered by extended hop validator

This fix pack addresses various issues discovered by the extended HOP
DAG validator:

(1) Forced visit status reset after static/dynamic rewrites (some of
these simplification rewrites create "holes" of new operators with
visted=false which is problematic if above and below operators are
already marked as visited because an subsequent reset will stop at the
"hole" - we now force a full dag traversal after these rewrites)

(2) Fix visit status handling in common subexpression elimination (the
merge of leaf and inner nodes was incorrect as it did not mark merged
opeators as visited leading again to "holes" in the dag)

(3) Fix visit status left indexing vectorization (similar to 1 and 2 but
for this rewrite it originated from not considered replaced operators
for recursive traversal)

(4) Fix value type handling of left indexing expressions at parser
(validation) level, which already created corrupted hops for left
indexing. 

There are additional issues related to data and value type handling in
the mlcontext and jmlc APIs but these will be addressed once the HOP DAG
validator is actually pushed to master.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/92ecae34
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/92ecae34
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/92ecae34

Branch: refs/heads/master
Commit: 92ecae34db63559131bf50f747650866c25ff4e1
Parents: 2869d53
Author: Matthias Boehm <[email protected]>
Authored: Thu Jun 8 23:12:10 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Jun 9 12:17:41 2017 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    | 23 +++++++-
 .../RewriteAlgebraicSimplificationDynamic.java  |  6 +--
 .../RewriteAlgebraicSimplificationStatic.java   | 22 +++-----
 .../RewriteCommonSubexpressionElimination.java  |  3 ++
 .../rewrite/RewriteIndexingVectorization.java   | 57 +++++++++++++-------
 .../org/apache/sysml/parser/DMLTranslator.java  |  2 +-
 .../functions/codegen/AlgorithmL2SVM.java       |  2 +-
 7 files changed, 75 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 31e3aa6..742f22a 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops;
 
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -865,13 +866,33 @@ public abstract class Hop
                                hopRoot.resetVisitStatus();
        }
        
+       public static void resetVisitStatus( ArrayList<Hop> hops, boolean force 
) {
+               if( !force )
+                       resetVisitStatus(hops);
+               else {
+                       HashSet<Long> memo = new HashSet<Long>();
+                       if( hops != null )
+                               for( Hop hopRoot : hops )
+                                       hopRoot.resetVisitStatusForced(memo);
+               }
+       }
+       
        public void resetVisitStatus()  {
                if( !isVisited() )
                        return;
-               for( Hop h : this.getInput() )
+               for( Hop h : getInput() )
                        h.resetVisitStatus();           
                setVisited(false);
        }
+       
+       public void resetVisitStatusForced(HashSet<Long> memo) {
+               if( memo.contains(getHopID()) )
+                       return;
+               for( Hop h : getInput() )
+                       h.resetVisitStatusForced(memo);
+               setVisited(false);
+               memo.add(getHopID());
+       }
 
        public static void resetRecompilationFlag( ArrayList<Hop> hops, 
ExecType et )
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 74e832b..ad80c05 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -94,13 +94,13 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                //one pass rewrite-descend (rewrite created pattern)
                for( Hop h : roots )
                        rule_AlgebraicSimplification( h, false );
+               Hop.resetVisitStatus(roots, true);
 
-               Hop.resetVisitStatus(roots);
-               
                //one pass descend-rewrite (for rollup) 
                for( Hop h : roots )
                        rule_AlgebraicSimplification( h, true );
-
+               Hop.resetVisitStatus(roots, true);
+               
                return roots;
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 17f1ace..d6b4559 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -81,12 +81,12 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                //one pass rewrite-descend (rewrite created pattern)
                for( Hop h : roots )
                        rule_AlgebraicSimplification( h, false );
-
-               Hop.resetVisitStatus(roots);
+               Hop.resetVisitStatus(roots, true);
                
                //one pass descend-rewrite (for rollup) 
                for( Hop h : roots )
                        rule_AlgebraicSimplification( h, true );
+               Hop.resetVisitStatus(roots, true);
                
                return roots;
        }
@@ -498,22 +498,16 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2
                                {
                                        bop.setOp(OpOp2.MULT);
-                                       LiteralOp tmp = new LiteralOp(2);
-                                       bop.getInput().remove(1);
-                                       right.getParent().remove(bop);
-                                       HopRewriteUtils.addChildReference(hi, 
tmp, 1);
-
-                                       LOG.debug("Applied 
simplifyBinaryToUnaryOperation1");
+                                       
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
+                                       
+                                       LOG.debug("Applied 
simplifyBinaryToUnaryOperation1 (line "+hi.getBeginLine()+").");
                                }
                                else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2
                                {
                                        bop.setOp(OpOp2.POW);
-                                       LiteralOp tmp = new LiteralOp(2);
-                                       bop.getInput().remove(1);
-                                       right.getParent().remove(bop);
-                                       HopRewriteUtils.addChildReference(hi, 
tmp, 1);
+                                       
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
                                        
-                                       LOG.debug("Applied 
simplifyBinaryToUnaryOperation2");
+                                       LOG.debug("Applied 
simplifyBinaryToUnaryOperation2 (line "+hi.getBeginLine()+").");
                                }
                        }
                        //patterns: (X>0)-(X<0) -> sign(X)
@@ -531,7 +525,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                HopRewriteUtils.cleanupUnreferenced(hi, left, 
right);
                                hi = uop;
                                
-                               LOG.debug("Applied 
simplifyBinaryToUnaryOperation3");
+                               LOG.debug("Applied 
simplifyBinaryToUnaryOperation3 (line "+hi.getBeginLine()+").");
                        }
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
index 72957dd..5379dfe 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
@@ -131,6 +131,7 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                        Hop tmp = dataops.get(hi.getName());
                                        if( tmp != hi ) { //if required
                                                tmp.getParent().add(hop);
+                                               tmp.setVisited();
                                                hop.getInput().set(i, tmp);
                                                ret++;
                                        }
@@ -142,6 +143,7 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                        //replace child node ref
                                        if( tmp != hi ){ //if required
                                                tmp.getParent().add(hop);
+                                               tmp.setVisited();
                                                hop.getInput().set(i, tmp);
                                                ret++;
                                        }
@@ -200,6 +202,7 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                                                {
                                                                        
p.getInput().set(k, h1);
                                                                        
h1.getParent().add(p);
+                                                                       
h1.setVisited();
                                                                }
                                                
                                                //replace h2 w/ h1 in h2-input 
parents

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
index 4ce1d43..b797e00 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java
@@ -50,7 +50,7 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
        {
                if( roots == null )
                        return roots;
-
+               
                for( Hop h : roots )
                        rule_IndexingVectorization( h );
                
@@ -84,9 +84,9 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                        //MB: disabled right indexing rewrite because (1) 
piggybacked in MR anyway, (2) usually
                        //not too much overhead, and (3) makes literal 
replacement more difficult
                        //vectorizeRightIndexing( hi ); //e.g., multiple 
rightindexing X[i,1], X[i,3] -> X[i,];
-                       vectorizeLeftIndexing( hi );  //e.g., multiple left 
indexing X[i,1], X[i,3] -> X[i,]; 
+                       hi = vectorizeLeftIndexing( hi );  //e.g., multiple 
left indexing X[i,1], X[i,3] -> X[i,]; 
                        
-                       //process childs recursively after rewrites 
+                       //process childs recursively after rewrites
                        rule_IndexingVectorization( hi );
                }
 
@@ -189,9 +189,11 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
        }
        
        @SuppressWarnings("unchecked")
-       private void vectorizeLeftIndexing( Hop hop )
+       private Hop vectorizeLeftIndexing( Hop hop )
                throws HopsException
-       {               
+       {
+               Hop ret = hop;
+               
                if( hop instanceof LeftIndexingOp ) //left indexing
                {
                        LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
@@ -224,11 +226,14 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                        Hop rowExpr = ihop0.getInput().get(2); 
//keep before reset
                                        
                                        //new row indexing operator
-                                       IndexingOp newRix = new 
IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, 
-                                                                   rowExpr, 
rowExpr, new LiteralOp(1), 
-                                                                   
HopRewriteUtils.createValueHop(input, false), true, false); 
+                                       IndexingOp newRix = new 
IndexingOp("tmp1", input.getDataType(), input.getValueType(), 
+                                               input, rowExpr, rowExpr, new 
LiteralOp(1), 
+                                               
HopRewriteUtils.createValueHop(input, false), true, false); 
                                        
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), 
input.getColsInBlock(), -1);
                                        newRix.refreshSizeInformation();
+                                       //reset visit status of copied hops 
(otherwise hidden by left indexing)
+                                       for( Hop c : newRix.getInput() )
+                                               c.resetVisitStatus();
                                        
                                        //rewrite bottom left indexing operator
                                        
HopRewriteUtils.removeChildReference(current, input); //input data
@@ -253,11 +258,14 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                                ihop0parentsPos.add(posp);
                                        }
                                        
-                                       LeftIndexingOp newLix = new 
LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, 
-                                                                               
                        rowExpr, rowExpr, new LiteralOp(1), 
-                                                                               
                        HopRewriteUtils.createValueHop(input, false), true, 
false); 
+                                       LeftIndexingOp newLix = new 
LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), 
+                                               input, ihop0, rowExpr, rowExpr, 
new LiteralOp(1), 
+                                               
HopRewriteUtils.createValueHop(input, false), true, false); 
                                        
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), 
input.getColsInBlock(), -1);
                                        newLix.refreshSizeInformation();
+                                       //reset visit status of copied hops 
(otherwise hidden by left indexing)
+                                       for( Hop c : newLix.getInput() )
+                                               c.resetVisitStatus();
                                        
                                        for( int i=0; i<ihop0parentsPos.size(); 
i++ ) {
                                                Hop parent = 
ihop0parents.get(i);
@@ -266,7 +274,8 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                        }
                                        
                                        appliedRow = true;
-                                       LOG.debug("Applied 
vectorizeLeftIndexingRow");
+                                       ret = newLix;
+                                       LOG.debug("Applied 
vectorizeLeftIndexingRow for hop "+hop.getHopID());
                                }
                        }
                        
@@ -296,11 +305,14 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                        Hop colExpr = ihop0.getInput().get(4); 
//keep before reset
                                        
                                        //new row indexing operator
-                                       IndexingOp newRix = new 
IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, 
-                                                               new 
LiteralOp(1), HopRewriteUtils.createValueHop(input, true),            
-                                                                       
colExpr, colExpr, false, true); 
+                                       IndexingOp newRix = new 
IndexingOp("tmp1", input.getDataType(), input.getValueType(), 
+                                               input, new LiteralOp(1), 
HopRewriteUtils.createValueHop(input, true),            
+                                               colExpr, colExpr, false, true); 
                                        
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), 
input.getColsInBlock(), -1);
                                        newRix.refreshSizeInformation();
+                                       //reset visit status of copied hops 
(otherwise hidden by left indexing)
+                                       for( Hop c : newRix.getInput() )
+                                               c.resetVisitStatus();
                                        
                                        //rewrite bottom left indexing operator
                                        
HopRewriteUtils.removeChildReference(current, input); //input data
@@ -325,11 +337,14 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                                ihop0parentsPos.add(posp);
                                        }
                                        
-                                       LeftIndexingOp newLix = new 
LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, 
-                                                                               
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), 
-                                                                               
                        colExpr, colExpr, false, true); 
+                                       LeftIndexingOp newLix = new 
LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), 
+                                               input, ihop0, new LiteralOp(1), 
HopRewriteUtils.createValueHop(input, true), 
+                                               colExpr, colExpr, false, true); 
                                        
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), 
input.getColsInBlock(), -1);
                                        newLix.refreshSizeInformation();
+                                       //reset visit status of copied hops 
(otherwise hidden by left indexing)
+                                       for( Hop c : newLix.getInput() )
+                                               c.resetVisitStatus();
                                        
                                        for( int i=0; i<ihop0parentsPos.size(); 
i++ ) {
                                                Hop parent = 
ihop0parents.get(i);
@@ -337,10 +352,12 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                                
HopRewriteUtils.addChildReference(parent, newLix, posp);
                                        }
                                        
-                                       appliedRow = true;
-                                       LOG.debug("Applied 
vectorizeLeftIndexingCol");
+                                       ret = newLix;
+                                       LOG.debug("Applied 
vectorizeLeftIndexingCol for hop "+hop.getHopID());
                                }
                        }
                }
+               
+               return ret;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 75f11b5..2b4bfa0 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -1547,7 +1547,7 @@ public class DMLTranslator
                if( sourceOp.getDataType() == DataType.MATRIX && 
source.getOutput().getDataType() == DataType.SCALAR )
                        sourceOp.setDataType(DataType.SCALAR);
                
-               Hop leftIndexOp = new LeftIndexingOp(target.getName(), 
target.getDataType(), target.getValueType(), 
+               Hop leftIndexOp = new LeftIndexingOp(target.getName(), 
target.getDataType(), ValueType.DOUBLE, 
                                targetOp, sourceOp, rowLowerHops, rowUpperHops, 
colLowerHops, colUpperHops, 
                                target.getRowLowerEqualsUpper(), 
target.getColLowerEqualsUpper());
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/92ecae34/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java
index 2e03ce6..788c73d 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmL2SVM.java
@@ -43,7 +43,7 @@ public class AlgorithmL2SVM extends AutomatedTestBase
        
        private final static double eps = 1e-5;
        
-       private final static int rows = 1468;
+       private final static int rows = 3468;
        private final static int cols1 = 1007;
        private final static int cols2 = 987;
        

Reply via email to