[SYSTEMML-1943] Fix codegen fuse_all optimizer and consolidation This patch fixes special cases of row operations that caused the fuse_all optimizer fail on Kmeans. Furthermore, this also includes a cleanup for consolidating the fuse-all selection of plans as used in the fuse_all and both cost-based optimizers.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c27c488b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c27c488b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c27c488b Branch: refs/heads/master Commit: c27c488bef54887d549792c4cf6532d95c3f5c58 Parents: 8ed2516 Author: Matthias Boehm <[email protected]> Authored: Sun Oct 1 20:04:39 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Oct 2 00:39:21 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/conf/DMLConfig.java | 2 +- .../apache/sysml/hops/codegen/SpoofFusedOp.java | 11 +++++ .../sysml/hops/codegen/cplan/CNodeRow.java | 3 +- .../sysml/hops/codegen/opt/PlanSelection.java | 46 +++++++++++++++++++ .../hops/codegen/opt/PlanSelectionFuseAll.java | 47 +------------------- .../codegen/opt/PlanSelectionFuseCostBased.java | 45 +------------------ .../opt/PlanSelectionFuseCostBasedV2.java | 47 +------------------- .../hops/codegen/template/TemplateUtils.java | 2 + .../sysml/runtime/codegen/SpoofRowwise.java | 25 ++++++----- 9 files changed, 79 insertions(+), 149 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/conf/DMLConfig.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java index 6a331a6..9835b4d 100644 --- a/src/main/java/org/apache/sysml/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java @@ -127,7 +127,7 @@ public class DMLConfig _defaultVals.put(COMPRESSED_LINALG, Compression.CompressConfig.AUTO.name() ); _defaultVals.put(CODEGEN, "false" ); _defaultVals.put(CODEGEN_COMPILER, CompilerType.AUTO.name() ); - _defaultVals.put(CODEGEN_COMPILER, PlanSelector.FUSE_COST_BASED_V2.name() ); + _defaultVals.put(CODEGEN_OPTIMIZER, PlanSelector.FUSE_COST_BASED_V2.name() ); _defaultVals.put(CODEGEN_PLANCACHE, "true" ); _defaultVals.put(CODEGEN_LITERALS, "1" ); _defaultVals.put(NATIVE_BLAS, "none" ); http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java index 81b226d..56bfb61 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java @@ -42,6 +42,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop ROW_DIMS, COLUMN_DIMS_ROWS, COLUMN_DIMS_COLS, + RANK_DIMS_COLS, SCALAR, MULTI_SCALAR, ROW_RANK_DIMS, // right wdivmm, row mm @@ -163,6 +164,12 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop case COLUMN_DIMS_COLS: ret = new long[]{1, mc.getCols(), -1}; break; + case RANK_DIMS_COLS: { + MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); + if( mc2.dimsKnown() ) + ret = new long[]{1, mc2.getCols(), -1}; + break; + } case INPUT_DIMS: ret = new long[]{mc.getRows(), mc.getCols(), -1}; break; @@ -219,6 +226,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop setDim1(1); setDim2(getInput().get(0).getDim2()); break; + case RANK_DIMS_COLS: + setDim1(1); + setDim2(getInput().get(1).getDim2()); + break; case INPUT_DIMS: setDim1(getInput().get(0).getDim1()); setDim2(getInput().get(0).getDim2()); http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java index 07822d9..9235216 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java @@ -158,7 +158,8 @@ public class CNodeRow extends CNodeTpl case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS; - case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T; + case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T; + case COL_AGG_B1R: return SpoofOutputDimsType.RANK_DIMS_COLS; default: throw new RuntimeException("Unsupported row type: "+_type.toString()); } http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java index 21f4fd3..4cf56c4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java @@ -34,6 +34,9 @@ import org.apache.sysml.runtime.util.UtilFunctions; public abstract class PlanSelection { + private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator(); + private final TypedPlanComparator _typedCompare = new TypedPlanComparator(); + private final HashMap<Long, List<MemoTableEntry>> _bestPlans = new HashMap<Long, List<MemoTableEntry>>(); private final HashSet<VisitMark> _visited = new HashSet<VisitMark>(); @@ -84,6 +87,49 @@ public abstract class PlanSelection _visited.add(new VisitMark(hopID, type)); } + protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) + { + if( isVisited(current.getHopID(), currentType) + || (partition!=null && !partition.contains(current.getHopID())) ) + return; + + //step 1: prune subsumed plans of same type + if( memo.contains(current.getHopID()) ) { + HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); + List<MemoTableEntry> hopP = memo.get(current.getHopID()); + for( MemoTableEntry e1 : hopP ) + for( MemoTableEntry e2 : hopP ) + if( e1 != e2 && e1.subsumes(e2) ) + rmSet.add(e2); + memo.remove(current, rmSet); + } + + //step 2: select plan for current path + MemoTableEntry best = null; + if( memo.contains(current.getHopID()) ) { + if( currentType == null ) { + best = memo.get(current.getHopID()).stream() + .filter(p -> isValid(p, current)) + .min(BASE_COMPARE).orElse(null); + } + else { + _typedCompare.setType(currentType); + best = memo.get(current.getHopID()).stream() + .filter(p -> p.type==currentType || p.type==TemplateType.CELL) + .min(_typedCompare).orElse(null); + } + addBestPlan(current.getHopID(), best); + } + + //step 3: recursively process children + for( int i=0; i< current.getInput().size(); i++ ) { + TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; + rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition); + } + + setVisited(current.getHopID(), currentType); + } + /** * Basic plan comparator to compare memo table entries with regard to * a pre-defined template preference order and the number of references. http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java index 8636bea..3e0561d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java @@ -20,15 +20,12 @@ package org.apache.sysml.hops.codegen.opt; import java.util.ArrayList; -import java.util.Comparator; import java.util.Map.Entry; -import java.util.HashSet; import java.util.List; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.template.CPlanMemoTable; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; -import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; /** * This plan selection heuristic aims for maximal fusion, which @@ -43,52 +40,10 @@ public class PlanSelectionFuseAll extends PlanSelection public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { //pruning and collection pass for( Hop hop : roots ) - rSelectPlans(memo, hop, null); + rSelectPlansFuseAll(memo, hop, null, null); //take all distinct best plans for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) memo.setDistinct(e.getKey(), e.getValue()); } - - private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) - { - if( isVisited(current.getHopID(), currentType) ) - return; - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(new BasicPlanComparator()).orElse(null); - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CELL) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlans(memo, current.getInput().get(i), pref); - } - - setVisited(current.getHopID(), currentType); - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java index acb90e2..f67604d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java @@ -507,52 +507,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection } } - visited.add(current.getHopID()); + visited.add(current.getHopID()); } - private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) - { - if( isVisited(current.getHopID(), currentType) - || !partition.contains(current.getHopID()) ) - return; - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(new BasicPlanComparator()).orElse(null); - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CELL) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition); - } - - setVisited(current.getHopID(), currentType); - } - private static boolean[] createAssignment(int len, int pos) { boolean[] ret = new boolean[len]; int tmp = pos; http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 8d1c4c0..31e8427 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -98,8 +98,6 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static final IDSequence COST_ID = new IDSequence(); private static final TemplateRow ROW_TPL = new TemplateRow(); - private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator(); - private final TypedPlanComparator _typedCompare = new TypedPlanComparator(); @Override public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) @@ -726,50 +724,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } } - visited.add(current.getHopID()); - } - - private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) - { - if( isVisited(current.getHopID(), currentType) - || !partition.contains(current.getHopID()) ) - return; - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(BASE_COMPARE).orElse(null); - } - else { - _typedCompare.setType(currentType); - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CELL) - .min(_typedCompare).orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition); - } - - setVisited(current.getHopID(), currentType); + visited.add(current.getHopID()); } ///////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 06d83bd..4dc0bf2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -204,6 +204,8 @@ public class TemplateUtils return RowType.COL_AGG_B1_T; else if( B1 != null && output.getDim1()==B1.getDim2() && output.getDim2()==X.getDim2()) return RowType.COL_AGG_B1; + else if( B1 != null && output.getDim1()==1 && B1.getDim2() == output.getDim2() ) + return RowType.COL_AGG_B1R; else if( X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2() ) return RowType.NO_AGG_CONST; else http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java index 8b12e7e..311c27f 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java @@ -47,22 +47,25 @@ public abstract class SpoofRowwise extends SpoofOperator private static final long serialVersionUID = 6242910797139642998L; public enum RowType { - NO_AGG, //no aggregation - NO_AGG_B1, //no aggregation w/ matrix mult B1 + NO_AGG, //no aggregation + NO_AGG_B1, //no aggregation w/ matrix mult B1 NO_AGG_CONST, //no aggregation w/ expansion/contraction - FULL_AGG, //full row/col aggregation - ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v) - COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X) - COL_AGG_T, //transposed col aggregation (e.g., t(X) %*% y) + FULL_AGG, //full row/col aggregation + ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v) + COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X) + COL_AGG_T, //transposed col aggregation (e.g., t(X) %*% y) COL_AGG_B1, //col aggregation w/ matrix mult B1 - COL_AGG_B1_T; //transposed col aggregation w/ matrix mult B1 + COL_AGG_B1_T, //transposed col aggregation w/ matrix mult B1 + COL_AGG_B1R; //col aggregation w/ matrix mult B1 to row vector public boolean isColumnAgg() { - return (this == COL_AGG || this == COL_AGG_T) - || (this == COL_AGG_B1) || (this == COL_AGG_B1_T); + return this == COL_AGG || this == COL_AGG_T + || this == COL_AGG_B1 || this == COL_AGG_B1_T + || this == COL_AGG_B1R; } public boolean isRowTypeB1() { - return (this == NO_AGG_B1) || (this == COL_AGG_B1) || (this == COL_AGG_B1_T); + return this == NO_AGG_B1 || this == COL_AGG_B1 + || this == COL_AGG_B1_T || this == COL_AGG_B1R; } public boolean isRowTypeB1ColumnAgg() { return (this == COL_AGG_B1) || (this == COL_AGG_B1_T); @@ -268,7 +271,7 @@ public abstract class SpoofRowwise extends SpoofOperator case COL_AGG_T: out.reset(n, 1, false); break; case COL_AGG_B1: out.reset(n2, n, false); break; case COL_AGG_B1_T: out.reset(n, n2, false); break; - + case COL_AGG_B1R: out.reset(1, n2, false); break; } out.allocateDenseBlock(); }
