[SYSTEMML-1288] Extended code generator (multi-agg across partitions)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9820f4c5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9820f4c5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9820f4c5

Branch: refs/heads/master
Commit: 9820f4c5293c69873f68544748507b6473948f12
Parents: 7c15339
Author: Matthias Boehm <[email protected]>
Authored: Thu Apr 6 20:57:07 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Apr 6 21:16:31 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeMultiAgg.java |  5 +-
 .../template/PlanSelectionFuseCostBased.java    | 54 +++++++++++++++++++-
 .../hops/codegen/template/TemplateCell.java     |  3 +-
 3 files changed, 59 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
index 7ec07a6..d9502be 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
@@ -106,7 +106,10 @@ public class CNodeMultiAgg extends CNodeTpl
                for( int i=0; i<_outputs.size(); i++ ) {
                        CNode out = _outputs.get(i);
                        String tmpOut = getAggTemplate(i);
-                       tmpOut = tmpOut.replace("%IN%", out.getVarname());
+                       //get variable name (w/ handling of direct consumption 
of inputs)
+                       String varName = (out instanceof CNodeData && 
((CNodeData)out).getHopID()==
+                               ((CNodeData)_inputs.get(0)).getHopID()) ? "a" : 
out.getVarname(); 
+                       tmpOut = tmpOut.replace("%IN%", varName);
                        tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
                        sb.append(tmpOut);
                }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 50d6ff1..151dab2 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -34,7 +34,9 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
@@ -82,12 +84,15 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        if( LOG.isTraceEnabled() )
                                LOG.trace("Partition materialization points: 
"+Arrays.toString(M.toArray(new Long[0])));
                        
-                       //step 3: create composite templates entries
+                       //step 3: create composite templates (within the 
partition)
                        createAndAddMultiAggPlans(memo, partition, R);
                        
                        //step 4: plan enumeration and plan selection
                        selectPlans(memo, partition, R, M);
                }
+               
+               //step 5: add composite templates (across partitions)
+               createAndAddMultiAggPlans(memo, roots);
        
                //take all distinct best plans
                for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
@@ -217,6 +222,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        && partition.contains(hop.getHopID());
        }
        
+       //within-partition multi-agg templates
        private static void createAndAddMultiAggPlans(CPlanMemoTable memo, 
HashSet<Long> partition, HashSet<Long> R)
        {
                //create index of plans that reference full aggregates to avoid 
circular dependencies
@@ -262,6 +268,30 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                }
        }
        
+       //across-partition multi-agg templates
+       private static void createAndAddMultiAggPlans(CPlanMemoTable memo, 
ArrayList<Hop> roots)
+       {
+               //#1: collect full aggregations over shared inputs (otherwise 
never fused)
+               HashMap<Long, ArrayList<Long>> fullAggs = new HashMap<Long, 
ArrayList<Long>>();
+               Hop.resetVisitStatus(roots);
+               for( Hop hop : roots )
+                       rCollectAggregatesSharedRead(hop, fullAggs);
+               
+               //construct and add multiagg template plans (w/ max 3 
aggregations)
+               for( Entry<Long, ArrayList<Long>> e : fullAggs.entrySet() ) {
+                       if( e.getValue().size()<=1 )
+                               continue;
+                       ArrayList<Long> aggs = e.getValue();
+                       MemoTableEntry me = new 
MemoTableEntry(TemplateType.MultiAggTpl,
+                               aggs.get(0), aggs.get(1), 
(aggs.size()>2)?aggs.get(2):-1);
+                       for( int i=0; i<aggs.size(); i++ ) {
+                               memo.add(memo._hopRefs.get(aggs.get(i)), me);
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Added multiagg* plan: 
"+aggs.get(i)+" "+me);
+                       }
+               }
+       }
+       
        private static boolean isValidMultiAggregate(CPlanMemoTable memo, 
MemoTableEntry me) {
                //ensure that aggregates are independent of each other, i.e.,
                //they to not have potentially transitive parent child 
references
@@ -285,6 +315,28 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                return ret;
        }
        
+       private static void rCollectAggregatesSharedRead(Hop current, 
HashMap<Long, ArrayList<Long>> aggs) {
+               if( current.isVisited() )
+                       return;
+               
+               //collect all applicable full aggregations per read
+               if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
+                       && 
((AggUnaryOp)current).getDirection()==Direction.RowCol
+                       && current.getInput().get(0) instanceof DataOp )
+               {
+                       Hop input = current.getInput().get(0);
+                       if( !aggs.containsKey(input.getHopID()) )
+                               aggs.put(input.getHopID(), new 
ArrayList<Long>());
+                       aggs.get(input.getHopID()).add(current.getHopID());
+               }
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rCollectAggregatesSharedRead(c, aggs);
+               
+               current.setVisited();
+       }
+       
        private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, 
HashSet<Long> R, ArrayList<Long> M) 
        {
                //if no materialization points, use basic fuse-all w/ partition 
awareness

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index e3c12d5..885d3db 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -29,6 +29,7 @@ import java.util.stream.Collectors;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
@@ -149,7 +150,7 @@ public class TemplateCell extends TemplateBase
                MemoTableEntry me = memo.getBest(hop.getHopID(), 
TemplateType.CellTpl);
                for( int i=0; i<hop.getInput().size(); i++ ) {
                        Hop c = hop.getInput().get(i);
-                       if( me.isPlanRef(i) )
+                       if( me!=null && me.isPlanRef(i) && !(c instanceof 
DataOp) )
                                rConstructCplan(c, memo, tmp, inHops, 
compileLiterals);
                        else {
                                CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);    

Reply via email to