Repository: systemml Updated Branches: refs/heads/master 92034e64f -> 9481bef4e
[SYSTEMML-2157] Fix codegen optimizer (suboptimal plans after row2cell) The cost-based codegen optimizer converts all partial row fusion plans into cell plans if none of the operations requires access to entire rows. However, the existing implementation of this pre-processing step led to suboptimal plans for special cases. This patch completely reworks this analysis step, which also improves its performance by using a single pass over the sub-DAG of each fusion partition. We now also properly track all operations and plans, where this row2cell conversion is inapplicable. Finally, the row template has been extended to allow unary operations in opening conditions (unless these operations work over row vectors). Together, these modifications led to a runtime improvement for auto encoder over mnist1m from 446s to 373s (~600s without codegen). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14ea51be Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14ea51be Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14ea51be Branch: refs/heads/master Commit: 14ea51be70ede04dfd3d351205b5ab19f1109d91 Parents: 92034e6 Author: Matthias Boehm <[email protected]> Authored: Wed Feb 21 21:12:50 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Feb 21 21:12:50 2018 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeUnary.java | 6 +- .../opt/PlanSelectionFuseCostBasedV2.java | 100 +++++++++++++------ .../hops/codegen/template/CPlanMemoTable.java | 9 ++ .../hops/codegen/template/TemplateRow.java | 10 +- 4 files changed, 89 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index a1401c3..d7721a1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -216,10 +216,12 @@ public class CNodeUnary extends CNode String varj = _inputs.get(0).getVarname(); //replace sparse and dense inputs + boolean vectIn = varj.startsWith("b") && !_type.isScalarLookup(); tmp = tmp.replace("%IN1v%", varj+"vals"); tmp = tmp.replace("%IN1i%", varj+"ix"); - tmp = tmp.replace("%IN1%", varj.startsWith("b") && !_type.isScalarLookup() - && TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".values(rix)" : varj ); + tmp = tmp.replace("%IN1%", + (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" : + (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj)); //replace start position of main input String spos = (_inputs.get(0) instanceof CNodeData http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 84e4b4c..6ed562a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -34,6 +34,7 @@ import java.util.stream.Collectors; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; @@ -46,6 +47,8 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOpN; +import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; @@ -635,37 +638,71 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } } - private static HashSet<Long> getRowAggOpsWithRowRef(CPlanMemoTable memo, PlanPartition part) { - HashSet<Long> refAggs = new HashSet<>(); - for( Long hopID : part.getPartition() ) { - if( !memo.contains(hopID, TemplateType.ROW) ) continue; - MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); - for(int i=0; i<3; i++) - if( me.isPlanRef(i) && memo.contains(me.input(i), TemplateType.ROW) - && isRowAggOp(memo.getHopRefs().get(me.input(i)))) - refAggs.add(me.input(i)); + private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable memo, PlanPartition part) { + //get row entries that are (a) reachable from rowwise ops (top down) other than + //operator root nodes, or dependent upon row-wise ops (bottom up) + HashSet<Long> blacklist = new HashSet<>(); + HashSet<Pair<Long, Integer>> visited = new HashSet<>(); + for( Long hopID : part.getRoots() ) { + rCollectDependentRowOps(memo.getHopRefs().get(hopID), + memo, part, blacklist, visited, null, false); } - return refAggs; + return blacklist; } - private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited, boolean inclRoot) { - if( visited.contains(current.getHopID()) ) - return true; + private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, + HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) + { + //avoid redundant evaluation of processed and non-partition nodes + Pair<Long, Integer> key = Pair.of(hop.getHopID(), + (foundRowOp?Short.MAX_VALUE:0) + ((type!=null)?type.ordinal()+1:0)); + if( visited.contains(key) || !part.getPartition().contains(hop.getHopID()) ) { + return; + } + + //process node itself (top-down) + MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : + memo.getBest(hop.getHopID(), type); + boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW); + boolean diffPlans = part.getMatPointsExt().length > 0 //guard against plan differences + && memo.contains(hop.getHopID(), TemplateType.ROW) + && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL); + if( inRow && foundRowOp ) + blacklist.add(hop.getHopID()); + if( isRowAggOp(hop, inRow) || diffPlans ) { + blacklist.add(hop.getHopID()); + foundRowOp = true; + } - MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); - boolean ret = !inclRoot || !isRowAggOp(current); - for(int i=0; i<3 && ret; i++) - if( me!=null && me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAggOrVects(memo, - current.getInput().get(i), visited, true); + //process children recursively + for( int i=0; i<hop.getInput().size(); i++ ) { + boolean lfoundRowOp = foundRowOp && me != null + && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)); + rCollectDependentRowOps(hop.getInput().get(i), memo, + part, blacklist, visited, me!=null?me.type:null, lfoundRowOp); + } - visited.add(current.getHopID()); - return ret; + //process node itself (bottom-up) + if( !blacklist.contains(hop.getHopID()) ) { + for( int i=0; i<hop.getInput().size(); i++ ) + if( me != null && me.type == TemplateType.ROW + && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)) + && blacklist.contains(hop.getInput().get(i).getHopID()) ) { + blacklist.add(hop.getHopID()); + } + } + + visited.add(key); } - private static boolean isRowAggOp(Hop hop){ - return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp - || HopRewriteUtils.isBinary(hop, OpOp2.CBIND)); + private static boolean isRowAggOp(Hop hop, boolean inRow) { + return HopRewriteUtils.isBinary(hop, OpOp2.CBIND) + || HopRewriteUtils.isNary(hop, OpOpN.CBIND) + || (hop instanceof AggBinaryOp && (inRow || !hop.dimsKnown() + || (hop.getDim1()!=1 && hop.getDim2()!=1))) + || (HopRewriteUtils.isReorg(hop, ReOrgOp.TRANSPOSE) + && (hop.getDim1()!=1 && hop.getDim2()!=1)) + || (hop instanceof AggUnaryOp && inRow); } private static boolean isValidRow2CellOp(Hop hop) { @@ -704,16 +741,19 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } //prune row aggregates with pure cellwise operations - HashSet<Long> refAggs = getRowAggOpsWithRowRef(memo, part); + //(we determine a blacklist of all operators in a partition that either + //depend upon row aggregates or on which row aggregates depend) + HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, part); for( Long hopID : part.getPartition() ) { + if( blacklist.contains(hopID) ) continue; MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); - if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, me, TemplateType.CELL) - && rIsRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>(), refAggs.contains(hopID)) ) { - List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); - memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist)); + if( me != null && me.type == TemplateType.ROW + && memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL) ) { + List<MemoTableEntry> rmList = memo.get(hopID, TemplateType.ROW); + memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(rmList)); if( LOG.isTraceEnabled() ) { LOG.trace("Removed row memo table entries w/o aggregation: " - + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); + + Arrays.toString(rmList.toArray(new MemoTableEntry[0]))); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 0c3bb90..5c90ca0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -107,6 +107,15 @@ public class CPlanMemoTable && p.isValid() && !types.contains(p.type)); } + public boolean hasOnlyExactMatches(long hopID, TemplateType type1, TemplateType type2) { + List<MemoTableEntry> l1 = get(hopID, type1); + List<MemoTableEntry> l2 = get(hopID, type2); + boolean ret = l1.size() == l2.size(); + for( MemoTableEntry me : l1 ) + ret &= l2.stream().anyMatch(p -> p.equalPlanRefs(me)); + return ret; + } + public int countEntries(long hopID) { return get(hopID).size(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index d54cf63..6c141ed 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -82,6 +82,8 @@ public class TemplateRow extends TemplateBase public boolean open(Hop hop) { return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop) && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) + || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) + && TemplateCell.isValidOperation(hop) && hop.getDim1() > 1) || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV @@ -95,7 +97,7 @@ public class TemplateRow extends TemplateBase && hop.getParent().get(0) instanceof AggBinaryOp && hop.getParent().get(0).dimsKnown() && hop.getParent().get(0).getInput().indexOf(hop) == 0 && isFuseSkinnyMatrixMult(hop.getParent().get(0))) - || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1 && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) || (hop instanceof IndexingOp && hop.getInput().get(0).getDim2() >= 0 @@ -337,7 +339,7 @@ public class TemplateRow extends TemplateBase CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); // if one input is a matrix then we need to do vector by scalar operations - if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 + if(hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType()==DataType.MATRIX ) ) { if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) { @@ -381,8 +383,8 @@ public class TemplateRow extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); // if one input is a matrix then we need to do vector by scalar operations - if( (hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1) - || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1) + if( (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) + || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && (hop.getDim2() != 1) //not a known vector output && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
