Repository: systemml Updated Branches: refs/heads/master 07bab605e -> e1efd844d
[SYSTEMML-2185] Improved dag-split rewrite (avoid redundant cuts) This patch cleans up unnecessary redundancy (i.e., unnecessary artificial dag-cut twrite/tread pairs) if multiple operators consume the same input and the input is split into another dag. For example, on stratstats, this patch reduced the number of cuts by more than 2x. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bcaa140d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bcaa140d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bcaa140d Branch: refs/heads/master Commit: bcaa140d5b43bf75fd97ae036f15863a154021b0 Parents: 07bab60 Author: Matthias Boehm <[email protected]> Authored: Thu Mar 15 13:57:20 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Mar 15 13:57:20 2018 -0700 ---------------------------------------------------------------------- .../RewriteSplitDagDataDependentOperators.java | 71 +++++++++++--------- 1 file changed, 38 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/bcaa140d/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index ebf275b..a3e037f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -112,9 +113,9 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //unless there are transient reads w/ the same variable name in the current dag which can //lead to invalid reordering if variable consumers are not feeding into the candidate op. boolean hasTWrites = hasTransientWriteParents(c); - boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain(c, - getFirstTransientWriteParent(c).getName()) : false; - + boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain( + c, getFirstTransientWriteParent(c).getName()) : false; + String varname = null; long rlen = c.getDim1(); long clen = c.getDim2(); @@ -124,7 +125,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite long bclen = c.getColsInBlock(); if( hasTWrites && moveTWrite) //reuse existing transient_write - { + { Hop twrite = getFirstTransientWriteParent(c); varname = twrite.getName(); @@ -170,12 +171,12 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite } //add data-dependent operator sub dag to first statement block - DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), - c, DataOpTypes.TRANSIENTWRITE, null); + DataOp twrite = new DataOp(varname, c.getDataType(), + c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); twrite.setVisited(); twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen); HopRewriteUtils.copyLineNumbers(c, twrite); - sb1hops.add(twrite); + sb1hops.add(twrite); } //update live in and out of new statement block (for piggybacking) @@ -349,39 +350,43 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite for( Hop h : rootsSB2 ) rProbeAndAddHopsToCandidateSet(h, probeSet, candSet); - //step 3: create additional cuts - for( Pair<Hop,Hop> p : candSet ) - { - String varname = createCutVarName(false); - + //step 3: create additional cuts with reuse for common references + HashMap<Long, DataOp> reuseTRead = new HashMap<>(); + for( Pair<Hop,Hop> p : candSet ) { Hop hop = p.getKey(); Hop c = p.getValue(); - - DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, - null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); - tread.setVisited(); - HopRewriteUtils.copyLineNumbers(c, tread); - - DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); - twrite.setVisited(); - twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); - HopRewriteUtils.copyLineNumbers(c, twrite); + + DataOp tread = reuseTRead.get(c.getHopID()); + if( tread == null ) { + String varname = createCutVarName(false); + + tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, null, + c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); + tread.setVisited(); + HopRewriteUtils.copyLineNumbers(c, tread); + reuseTRead.put(c.getHopID(), tread); + + DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); + twrite.setVisited(); + twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); + HopRewriteUtils.copyLineNumbers(c, twrite); + + //update live in and out of new statement block (for piggybacking) + DataIdentifier diVar = new DataIdentifier(varname); + diVar.setDimensions(c.getDim1(), c.getDim2()); + diVar.setBlockDimensions(c.getRowsInBlock(), c.getColsInBlock()); + diVar.setDataType(c.getDataType()); + diVar.setValueType(c.getValueType()); + sb1out.addVariable(varname, new DataIdentifier(diVar)); + sb2in.addVariable(varname, new DataIdentifier(diVar)); + + rootsSB1.add(twrite); + } //create additional cut by rewriting both hop dags int pos = HopRewriteUtils.getChildReferencePos(hop, c); HopRewriteUtils.removeChildReferenceByPos(hop, c, pos); HopRewriteUtils.addChildReference(hop, tread, pos); - - //update live in and out of new statement block (for piggybacking) - DataIdentifier diVar = new DataIdentifier(varname); - diVar.setDimensions(c.getDim1(), c.getDim2()); - diVar.setBlockDimensions(c.getRowsInBlock(), c.getColsInBlock()); - diVar.setDataType(c.getDataType()); - diVar.setValueType(c.getValueType()); - sb1out.addVariable(varname, new DataIdentifier(diVar)); - sb2in.addVariable(varname, new DataIdentifier(diVar)); - - rootsSB1.add(twrite); } }
