Repository: systemml
Updated Branches:
  refs/heads/master 07bab605e -> e1efd844d


[SYSTEMML-2185] Improved dag-split rewrite (avoid redundant cuts)

This patch cleans up unnecessary redundancy (i.e., unnecessary
artificial dag-cut twrite/tread pairs) if multiple operators consume the
same input and the input is split into another dag. For example, on
stratstats, this patch reduced the number of cuts by more than 2x.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bcaa140d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bcaa140d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bcaa140d

Branch: refs/heads/master
Commit: bcaa140d5b43bf75fd97ae036f15863a154021b0
Parents: 07bab60
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 15 13:57:20 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 15 13:57:20 2018 -0700

----------------------------------------------------------------------
 .../RewriteSplitDagDataDependentOperators.java  | 71 +++++++++++---------
 1 file changed, 38 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/bcaa140d/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index ebf275b..a3e037f 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 
@@ -112,9 +113,9 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                                        //unless there are transient reads w/ 
the same variable name in the current dag which can
                                        //lead to invalid reordering if 
variable consumers are not feeding into the candidate op.
                                        boolean hasTWrites = 
hasTransientWriteParents(c);
-                                       boolean moveTWrite = hasTWrites ? 
HopRewriteUtils.rHasSimpleReadChain(c, 
-                                                       
getFirstTransientWriteParent(c).getName()) : false;
-                                                       
+                                       boolean moveTWrite = hasTWrites ? 
HopRewriteUtils.rHasSimpleReadChain(
+                                               c, 
getFirstTransientWriteParent(c).getName()) : false;
+                                       
                                        String varname = null;
                                        long rlen = c.getDim1();
                                        long clen = c.getDim2();
@@ -124,7 +125,7 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                                        long bclen = c.getColsInBlock();
                                        
                                        if( hasTWrites && moveTWrite) //reuse 
existing transient_write
-                                       {               
+                                       {
                                                Hop twrite = 
getFirstTransientWriteParent(c);
                                                varname = twrite.getName();
                                                
@@ -170,12 +171,12 @@ public class RewriteSplitDagDataDependentOperators 
extends StatementBlockRewrite
                                                }
                                                
                                                //add data-dependent operator 
sub dag to first statement block
-                                               DataOp twrite = new 
DataOp(varname, c.getDataType(), c.getValueType(),
-                                                                               
   c, DataOpTypes.TRANSIENTWRITE, null);
+                                               DataOp twrite = new 
DataOp(varname, c.getDataType(),
+                                                       c.getValueType(), c, 
DataOpTypes.TRANSIENTWRITE, null);
                                                twrite.setVisited();
                                                twrite.setOutputParams(rlen, 
clen, nnz, update, brlen, bclen);
                                                
HopRewriteUtils.copyLineNumbers(c, twrite);
-                                               sb1hops.add(twrite);    
+                                               sb1hops.add(twrite);
                                        }
                                        
                                        //update live in and out of new 
statement block (for piggybacking)
@@ -349,39 +350,43 @@ public class RewriteSplitDagDataDependentOperators 
extends StatementBlockRewrite
                for( Hop h : rootsSB2 )
                        rProbeAndAddHopsToCandidateSet(h, probeSet, candSet);
                
-               //step 3: create additional cuts
-               for( Pair<Hop,Hop> p : candSet ) 
-               {
-                       String varname = createCutVarName(false);
-                       
+               //step 3: create additional cuts with reuse for common 
references
+               HashMap<Long, DataOp> reuseTRead = new HashMap<>();
+               for( Pair<Hop,Hop> p : candSet ) {
                        Hop hop = p.getKey();
                        Hop c = p.getValue();
-
-                       DataOp tread = new DataOp(varname, c.getDataType(), 
c.getValueType(), DataOpTypes.TRANSIENTREAD, 
-                                       null, c.getDim1(), c.getDim2(), 
c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
-                       tread.setVisited();
-                       HopRewriteUtils.copyLineNumbers(c, tread);
-
-                       DataOp twrite = new DataOp(varname, c.getDataType(), 
c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
-                       twrite.setVisited();
-                       twrite.setOutputParams(c.getDim1(), c.getDim2(), 
c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
-                       HopRewriteUtils.copyLineNumbers(c, twrite);
+                       
+                       DataOp tread = reuseTRead.get(c.getHopID());
+                       if( tread == null ) {
+                               String varname = createCutVarName(false);
+                               
+                               tread = new DataOp(varname, c.getDataType(), 
c.getValueType(), DataOpTypes.TRANSIENTREAD, null,
+                                       c.getDim1(), c.getDim2(), c.getNnz(), 
c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
+                               tread.setVisited();
+                               HopRewriteUtils.copyLineNumbers(c, tread);
+                               reuseTRead.put(c.getHopID(), tread);
+                               
+                               DataOp twrite = new DataOp(varname, 
c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
+                               twrite.setVisited();
+                               twrite.setOutputParams(c.getDim1(), 
c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), 
c.getColsInBlock());
+                               HopRewriteUtils.copyLineNumbers(c, twrite);
+                               
+                               //update live in and out of new statement block 
(for piggybacking)
+                               DataIdentifier diVar = new 
DataIdentifier(varname);
+                               diVar.setDimensions(c.getDim1(), c.getDim2());
+                               diVar.setBlockDimensions(c.getRowsInBlock(), 
c.getColsInBlock());
+                               diVar.setDataType(c.getDataType());
+                               diVar.setValueType(c.getValueType());
+                               sb1out.addVariable(varname, new 
DataIdentifier(diVar));
+                               sb2in.addVariable(varname, new 
DataIdentifier(diVar));
+                               
+                               rootsSB1.add(twrite);
+                       }
                        
                        //create additional cut by rewriting both hop dags 
                        int pos = HopRewriteUtils.getChildReferencePos(hop, c);
                        HopRewriteUtils.removeChildReferenceByPos(hop, c, pos);
                        HopRewriteUtils.addChildReference(hop, tread, pos);
-               
-                       //update live in and out of new statement block (for 
piggybacking)
-                       DataIdentifier diVar = new DataIdentifier(varname);
-                       diVar.setDimensions(c.getDim1(), c.getDim2());
-                       diVar.setBlockDimensions(c.getRowsInBlock(), 
c.getColsInBlock());
-                       diVar.setDataType(c.getDataType());
-                       diVar.setValueType(c.getValueType());
-                       sb1out.addVariable(varname, new DataIdentifier(diVar));
-                       sb2in.addVariable(varname, new DataIdentifier(diVar));
-                       
-                       rootsSB1.add(twrite);
                }
        }
 

Reply via email to