[SYSTEMML-1288] Extended code generator (multi-agg across partitions) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9820f4c5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9820f4c5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9820f4c5
Branch: refs/heads/master Commit: 9820f4c5293c69873f68544748507b6473948f12 Parents: 7c15339 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 6 20:57:07 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Apr 6 21:16:31 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeMultiAgg.java | 5 +- .../template/PlanSelectionFuseCostBased.java | 54 +++++++++++++++++++- .../hops/codegen/template/TemplateCell.java | 3 +- 3 files changed, 59 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java index 7ec07a6..d9502be 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java @@ -106,7 +106,10 @@ public class CNodeMultiAgg extends CNodeTpl for( int i=0; i<_outputs.size(); i++ ) { CNode out = _outputs.get(i); String tmpOut = getAggTemplate(i); - tmpOut = tmpOut.replace("%IN%", out.getVarname()); + //get variable name (w/ handling of direct consumption of inputs) + String varName = (out instanceof CNodeData && ((CNodeData)out).getHopID()== + ((CNodeData)_inputs.get(0)).getHopID()) ? "a" : out.getVarname(); + tmpOut = tmpOut.replace("%IN%", varName); tmpOut = tmpOut.replace("%IX%", String.valueOf(i)); sb.append(tmpOut); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 50d6ff1..151dab2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -34,7 +34,9 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; @@ -82,12 +84,15 @@ public class PlanSelectionFuseCostBased extends PlanSelection if( LOG.isTraceEnabled() ) LOG.trace("Partition materialization points: "+Arrays.toString(M.toArray(new Long[0]))); - //step 3: create composite templates entries + //step 3: create composite templates (within the partition) createAndAddMultiAggPlans(memo, partition, R); //step 4: plan enumeration and plan selection selectPlans(memo, partition, R, M); } + + //step 5: add composite templates (across partitions) + createAndAddMultiAggPlans(memo, roots); //take all distinct best plans for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) @@ -217,6 +222,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection && partition.contains(hop.getHopID()); } + //within-partition multi-agg templates private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) { //create index of plans that reference full aggregates to avoid circular dependencies @@ -262,6 +268,30 @@ public class PlanSelectionFuseCostBased extends PlanSelection } } + //across-partition multi-agg templates + private static void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) + { + //#1: collect full aggregations over shared inputs (otherwise never fused) + HashMap<Long, ArrayList<Long>> fullAggs = new HashMap<Long, ArrayList<Long>>(); + Hop.resetVisitStatus(roots); + for( Hop hop : roots ) + rCollectAggregatesSharedRead(hop, fullAggs); + + //construct and add multiagg template plans (w/ max 3 aggregations) + for( Entry<Long, ArrayList<Long>> e : fullAggs.entrySet() ) { + if( e.getValue().size()<=1 ) + continue; + ArrayList<Long> aggs = e.getValue(); + MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, + aggs.get(0), aggs.get(1), (aggs.size()>2)?aggs.get(2):-1); + for( int i=0; i<aggs.size(); i++ ) { + memo.add(memo._hopRefs.get(aggs.get(i)), me); + if( LOG.isTraceEnabled() ) + LOG.trace("Added multiagg* plan: "+aggs.get(i)+" "+me); + } + } + } + private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { //ensure that aggregates are independent of each other, i.e., //they to not have potentially transitive parent child references @@ -285,6 +315,28 @@ public class PlanSelectionFuseCostBased extends PlanSelection return ret; } + private static void rCollectAggregatesSharedRead(Hop current, HashMap<Long, ArrayList<Long>> aggs) { + if( current.isVisited() ) + return; + + //collect all applicable full aggregations per read + if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) + && ((AggUnaryOp)current).getDirection()==Direction.RowCol + && current.getInput().get(0) instanceof DataOp ) + { + Hop input = current.getInput().get(0); + if( !aggs.containsKey(input.getHopID()) ) + aggs.put(input.getHopID(), new ArrayList<Long>()); + aggs.get(input.getHopID()).add(current.getHopID()); + } + + //recursively process children + for( Hop c : current.getInput() ) + rCollectAggregatesSharedRead(c, aggs); + + current.setVisited(); + } + private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) { //if no materialization points, use basic fuse-all w/ partition awareness http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index e3c12d5..885d3db 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -29,6 +29,7 @@ import java.util.stream.Collectors; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; @@ -149,7 +150,7 @@ public class TemplateCell extends TemplateBase MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl); for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); - if( me.isPlanRef(i) ) + if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp) ) rConstructCplan(c, memo, tmp, inHops, compileLiterals); else { CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
