[SYSTEMML-1741,1536,1296] New cost-based codegen optimizer (V2)

This patch introduced the second generation of the cost-based codegen
optimizer for selecting operator fusion plans. The core idea is based on
the notion of 'interesting points' - instead of reasoning about
materialized outputs, we now reason about materialization points per
consumer, templates switches (for sparsity exploitation), and
multi-aggregate plans in a holistic cost-based optimization framework.
This fine-grained reasoning allows the optimizer to recognize many new
patterns and reason correctly about patterns where we previously made
wrong decisions.

This patch includes a major refactoring of the existing fusion
heuristics and  cost-based optimizer, which are still available but we
use the new optimizer by default. Common code has been consolidated in
the PlanAnalzer and related classes. 

Due to the fine-grained reasoning about interesting points, the
exponential search space exploded. For example, on GLM binomial probit
and Lenet, the maximum number of materialization points increased from 7
to 18 and 5 to 12 respectively. Overall, the optimizer considers 262,516
and 7,719,381 plans. Accordingly, the new optimizer also introduces two
high-impact pruning techniques (which reduced the number of enumerated
plans to 7,731 and 473,840 while still ensuring plan optimality):

(a) Pruning by lower/upper bound cost: We maintain the cost of the best
seen plan as the upper bound and compute lower bound costs based on min
partition reads, writes, compute, and assigned materialization points.
Whenever, the lower bound exceeds the upper bound, we can prune entire
areas of the search space. 

(b) Pruning by conditionally independent structure: We also build a
reachability graph of materialization points in order to exploit
conditionally independent substructures. Certain assigned
materialization points form - conditionally on their assignment -
boundaries that split the remaining graph into independent subproblems.

Additionally, this patch also extends the statistic tool to report the
number of enumerated plans (with pruning) and fixes a couple of minor
codegen compiler and runtime issues. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7b4a3418
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7b4a3418
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7b4a3418

Branch: refs/heads/master
Commit: 7b4a3418a3aa5995a17f1e749bfa0f6fa0c65548
Parents: 4ec6f08
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sat Jul 15 12:43:01 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Aug 5 00:56:23 2017 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    |    4 +
 .../sysml/hops/codegen/SpoofCompiler.java       |   25 +-
 .../hops/codegen/opt/InterestingPoint.java      |  101 ++
 .../sysml/hops/codegen/opt/PlanAnalyzer.java    |  311 +++++
 .../sysml/hops/codegen/opt/PlanPartition.java   |   82 ++
 .../sysml/hops/codegen/opt/PlanSelection.java   |  167 +++
 .../hops/codegen/opt/PlanSelectionFuseAll.java  |   94 ++
 .../codegen/opt/PlanSelectionFuseCostBased.java |  892 ++++++++++++++
 .../opt/PlanSelectionFuseCostBasedV2.java       | 1100 ++++++++++++++++++
 .../opt/PlanSelectionFuseNoRedundancy.java      |  108 ++
 .../hops/codegen/opt/ReachabilityGraph.java     |  398 +++++++
 .../hops/codegen/template/CPlanMemoTable.java   |   60 +-
 .../hops/codegen/template/PlanSelection.java    |  122 --
 .../codegen/template/PlanSelectionFuseAll.java  |   93 --
 .../template/PlanSelectionFuseCostBased.java    | 1009 ----------------
 .../template/PlanSelectionFuseNoRedundancy.java |  107 --
 .../hops/codegen/template/TemplateBase.java     |   12 +-
 .../hops/codegen/template/TemplateCell.java     |   12 +-
 .../hops/codegen/template/TemplateMultiAgg.java |    6 +-
 .../codegen/template/TemplateOuterProduct.java  |   12 +-
 .../hops/codegen/template/TemplateRow.java      |    6 +-
 .../hops/codegen/template/TemplateUtils.java    |   26 +-
 .../runtime/codegen/LibSpoofPrimitives.java     |    6 +
 .../sysml/runtime/codegen/SpoofCellwise.java    |    5 +-
 .../sysml/runtime/codegen/SpoofRowwise.java     |   15 +-
 .../java/org/apache/sysml/utils/Statistics.java |   19 +-
 .../functions/codegen/MiscPatternTest.java      |  151 +++
 .../scripts/functions/codegen/miscPattern1.R    |   35 +
 .../scripts/functions/codegen/miscPattern1.dml  |   35 +
 .../scripts/functions/codegen/miscPattern2.R    |   36 +
 .../scripts/functions/codegen/miscPattern2.dml  |   33 +
 .../functions/codegen/ZPackageSuite.java        |    1 +
 32 files changed, 3693 insertions(+), 1390 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 7aa58b7..be4c7e4 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -955,6 +955,10 @@ public abstract class Hop
                _dim2 = dim2;
        }
        
+       public double getSparsity() {
+               return OptimizerUtils.getSparsity(_dim1, _dim2, _nnz);
+       }
+       
        protected void setOutputDimensions(Lop lop) 
                throws HopsException
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index dadb318..5029802 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -43,16 +43,17 @@ import 
org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
 import org.apache.sysml.hops.codegen.cplan.CNodeRow;
 import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
+import org.apache.sysml.hops.codegen.opt.PlanSelection;
+import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseAll;
+import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBased;
+import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBasedV2;
+import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseNoRedundancy;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.template.TemplateBase;
 import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 import org.apache.sysml.hops.codegen.template.CPlanCSERewriter;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
-import org.apache.sysml.hops.codegen.template.PlanSelection;
-import org.apache.sysml.hops.codegen.template.PlanSelectionFuseCostBased;
-import org.apache.sysml.hops.codegen.template.PlanSelectionFuseAll;
-import org.apache.sysml.hops.codegen.template.PlanSelectionFuseNoRedundancy;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
@@ -109,7 +110,7 @@ public class SpoofCompiler
        public static final boolean PRUNE_REDUNDANT_PLANS = true;
        public static PlanCachePolicy PLAN_CACHE_POLICY   = 
PlanCachePolicy.CSLH;
        public static final int PLAN_CACHE_SIZE           = 1024; //max 1K 
classes 
-       public static final PlanSelector PLAN_SEL_POLICY  = 
PlanSelector.FUSE_COST_BASED; 
+       public static final PlanSelector PLAN_SEL_POLICY  = 
PlanSelector.FUSE_COST_BASED_V2; 
 
        public enum CompilerType {
                JAVAC,
@@ -124,7 +125,9 @@ public class SpoofCompiler
        public enum PlanSelector {
                FUSE_ALL,             //maximal fusion, possible w/ redundant 
compute
                FUSE_NO_REDUNDANCY,   //fusion without redundant compute 
-               FUSE_COST_BASED;      //cost-based decision on materialization 
points
+               FUSE_COST_BASED,      //cost-based decision on materialization 
points
+               FUSE_COST_BASED_V2;   //cost-based decisions on materialization 
points per consumer, multi aggregates,
+                                     //sparsity exploitation, template types, 
local/distributed operations, constraints
                public boolean isHeuristic() {
                        return this == FUSE_ALL
                                || this == FUSE_NO_REDUNDANCY;
@@ -458,6 +461,8 @@ public class SpoofCompiler
                                return new PlanSelectionFuseNoRedundancy();
                        case FUSE_COST_BASED:
                                return new PlanSelectionFuseCostBased();
+                       case FUSE_COST_BASED_V2:
+                               return new PlanSelectionFuseCostBasedV2();
                        default:        
                                throw new RuntimeException("Unsupported "
                                        + "plan selector: "+PLAN_SEL_POLICY);
@@ -530,8 +535,10 @@ public class SpoofCompiler
                }
                
                //prune subsumed / redundant plans
-               if( PRUNE_REDUNDANT_PLANS )
-                       memo.pruneRedundant(hop.getHopID());
+               if( PRUNE_REDUNDANT_PLANS ) {
+                       memo.pruneRedundant(hop.getHopID(),
+                               PLAN_SEL_POLICY.isHeuristic(), null);
+               }
                
                //mark visited even if no plans found (e.g., unsupported ops)
                memo.addHop(hop);
@@ -542,7 +549,7 @@ public class SpoofCompiler
                for(int k=0; k<hop.getInput().size(); k++) {
                        Hop input2 = hop.getInput().get(k);
                        if( input2 != c && tpl.merge(hop, input2) 
-                               && memo.contains(input2.getHopID(), true, 
tpl.getType(), TemplateType.CellTpl))
+                               && memo.contains(input2.getHopID(), true, 
tpl.getType(), TemplateType.CELL))
                                P.crossProduct(k, -1L, input2.getHopID());
                }
                return P;

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/InterestingPoint.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/InterestingPoint.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/InterestingPoint.java
new file mode 100644
index 0000000..1e4deac
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/InterestingPoint.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+/**
+ * Interesting decision point with regard to materialization of intermediates.
+ * These points are defined by a type, as well as hop ID for consumer-producer
+ * relationships. Equivalence is defined solely on the hop IDs, to simplify
+ * their processing and avoid redundant enumeration.
+ *  
+ */
+public class InterestingPoint 
+{
+       public enum DecisionType {
+               MULTI_CONSUMER,
+               TEMPLATE_CHANGE,
+       }
+       
+       private final DecisionType _type;
+       public final long _fromHopID; //consumers
+       public final long _toHopID; //producers
+       
+       public InterestingPoint(DecisionType type, long fromHopID, long 
toHopID) {
+               _type = type;
+               _fromHopID = fromHopID;
+               _toHopID = toHopID;
+       }
+       
+       public DecisionType getType() {
+               return _type;
+       }
+       
+       public long getFromHopID() {
+               return _fromHopID;
+       }
+       
+       public long getToHopID() {
+               return _toHopID;
+       }
+       
+       public static boolean isMatPoint(InterestingPoint[] list, long from, 
MemoTableEntry me, boolean[] plan) {
+               for(int i=0; i<plan.length; i++) {
+                       if( !plan[i] ) continue;
+                       InterestingPoint p = list[i];
+                       if( p._fromHopID!=from ) continue;
+                       for( int j=0; j<3; j++ )
+                               if( p._toHopID==me.input(j) )
+                                       return true;
+               }
+               return false;
+       }
+       
+       public static boolean isMatPoint(InterestingPoint[] list, long from, 
long to) {
+               for(int i=0; i<list.length; i++) {
+                       InterestingPoint p = list[i];
+                       if( p._fromHopID==from && p._toHopID==to )
+                               return true;
+               }
+               return false;
+       }
+       
+       @Override
+       public int hashCode() {
+               return UtilFunctions.longHashCode(_fromHopID, _toHopID);
+       }
+       
+       @Override
+       public boolean equals(Object o) {
+               if( !(o instanceof InterestingPoint) )
+                       return false;
+               InterestingPoint that = (InterestingPoint) o;
+               return _fromHopID == that._fromHopID
+                       && _toHopID == that._toHopID;
+       }
+       
+       @Override
+       public String toString() {
+               String stype = (_type==DecisionType.MULTI_CONSUMER) ? "M" : "T";
+               return "(" + stype+ ":" + _fromHopID + "->" + _toHopID + ")"; 
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
new file mode 100644
index 0000000..9ff6986
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.opt.InterestingPoint.DecisionType;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+
+/**
+ * Utility functions to extract structural information from the memo table,
+ * including connected components (aka partitions) of partial fusion plans, 
+ * materialization points of partitions, and root nodes of partitions.
+ * 
+ */
+public class PlanAnalyzer 
+{
+       private static final Log LOG = 
LogFactory.getLog(PlanAnalyzer.class.getName());
+       
+       public static Collection<PlanPartition> 
analyzePlanPartitions(CPlanMemoTable memo, ArrayList<Hop> roots, boolean ext) {
+               //determine connected sub graphs of plans
+               Collection<HashSet<Long>> parts = getConnectedSubGraphs(memo, 
roots);
+               
+               //determine roots and materialization points
+               Collection<PlanPartition> ret = new ArrayList<>();
+               for( HashSet<Long> partition : parts ) {
+                       HashSet<Long> R = getPartitionRootNodes(memo, 
partition);
+                       HashSet<Long> I = getPartitionInputNodes(R, partition, 
memo);
+                       ArrayList<Long> M = getMaterializationPoints(R, 
partition, memo);
+                       HashSet<Long> Pnpc = 
getNodesWithNonPartitionConsumers(R, partition, memo);
+                       InterestingPoint[] Mext = !ext ? null : 
+                               getMaterializationPointsExt(R, partition, M, 
memo);
+                       ret.add(new PlanPartition(partition, R, I, Pnpc, M, 
Mext));
+               }
+               
+               return ret;
+       }
+       
+       private static Collection<HashSet<Long>> 
getConnectedSubGraphs(CPlanMemoTable memo, ArrayList<Hop> roots) 
+       {
+               //build inverted index for 'referenced by' relationship 
+               HashMap<Long, HashSet<Long>> refBy = new HashMap<Long, 
HashSet<Long>>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo.getPlans().entrySet() )
+                       for( MemoTableEntry me : e.getValue() ) 
+                               for( int i=0; i<3; i++ )
+                                       if( me.isPlanRef(i) ) {
+                                               if( 
!refBy.containsKey(me.input(i)) )
+                                                       refBy.put(me.input(i), 
new HashSet<Long>());
+                                               
refBy.get(me.input(i)).add(e.getKey());
+                                       }
+               
+               //create a single partition per root node, if reachable over 
refBy of 
+               //other root node the resulting partition is empty and can be 
discarded
+               ArrayList<HashSet<Long>> parts = new ArrayList<HashSet<Long>>();
+               HashSet<Long> visited = new HashSet<Long>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo.getPlans().entrySet() )
+                       if( !refBy.containsKey(e.getKey()) ) { //root node
+                               HashSet<Long> part = 
rGetConnectedSubGraphs(e.getKey(), 
+                                               memo, refBy, visited, new 
HashSet<Long>());
+                               if( !part.isEmpty() )
+                                       parts.add(part);
+                       }
+               
+               if( LOG.isTraceEnabled() )
+                       LOG.trace("Connected sub graphs: "+parts.size());
+               
+               return parts;
+       }
+       
+       private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, 
HashSet<Long> partition) 
+       {
+               //build inverted index of references entries 
+               HashSet<Long> ix = new HashSet<Long>();
+               for( Long hopID : partition )
+                       if( memo.contains(hopID) )
+                               for( MemoTableEntry me : memo.get(hopID) ) {
+                                       ix.add(me.input1); 
+                                       ix.add(me.input2); 
+                                       ix.add(me.input3);
+                               }
+               
+               HashSet<Long> roots = new HashSet<Long>();
+               for( Long hopID : partition )
+                       if( !ix.contains(hopID) )
+                               roots.add(hopID);
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Partition root points: "
+                               + Arrays.toString(roots.toArray(new Long[0])));
+               }
+               
+               return roots;
+       }
+       
+       private static ArrayList<Long> getMaterializationPoints(HashSet<Long> 
roots, 
+                       HashSet<Long> partition, CPlanMemoTable memo) 
+       {
+               //collect materialization points bottom-up
+               ArrayList<Long> ret = new ArrayList<Long>();
+               HashSet<Long> visited = new HashSet<Long>();
+               for( Long hopID : roots )
+                       
rCollectMaterializationPoints(memo.getHopRefs().get(hopID), 
+                                       visited, partition, roots, ret);
+               
+               //remove special-case materialization points
+               //(root nodes w/ multiple consumers, tsmm input if consumed in 
partition)
+               ret.removeIf(hopID -> roots.contains(hopID)
+                       || 
HopRewriteUtils.isTsmmInput(memo.getHopRefs().get(hopID)));
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Partition materialization points: "
+                               + Arrays.toString(ret.toArray(new Long[0])));
+               }
+               
+               return ret;
+       }
+       
+       private static void rCollectMaterializationPoints(Hop current, 
HashSet<Long> visited, 
+                       HashSet<Long> partition, HashSet<Long> R, 
ArrayList<Long> M) 
+       {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rCollectMaterializationPoints(c, visited, partition, R, 
M);
+               
+               //collect materialization point
+               if( isMaterializationPointCandidate(current, partition, R) )
+                       M.add(current.getHopID());
+               
+               visited.add(current.getHopID());
+       }
+       
+       private static boolean isMaterializationPointCandidate(Hop hop, 
HashSet<Long> partition, HashSet<Long> R) {
+               return hop.getParent().size()>=2 
+                       && hop.getDataType().isMatrix()
+                       && partition.contains(hop.getHopID())
+                       && !R.contains(hop.getHopID());
+       }
+       
+       private static HashSet<Long> getPartitionInputNodes(HashSet<Long> 
roots, 
+                       HashSet<Long> partition, CPlanMemoTable memo)
+       {
+               HashSet<Long> ret = new HashSet<>();
+               HashSet<Long> visited = new HashSet<>();
+               for( Long hopID : roots ) {
+                       Hop current = memo.getHopRefs().get(hopID);
+                       rCollectPartitionInputNodes(current, visited, 
partition, ret);
+               }
+               return ret;
+       }
+       
+       private static void rCollectPartitionInputNodes(Hop current, 
HashSet<Long> visited, 
+               HashSet<Long> partition, HashSet<Long> I) 
+       {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       if( partition.contains(c.getHopID()) )
+                               rCollectPartitionInputNodes(c, visited, 
partition, I);
+                       else
+                               I.add(c.getHopID());
+               
+               visited.add(current.getHopID());
+       }
+       
+       private static HashSet<Long> getNodesWithNonPartitionConsumers(
+               HashSet<Long> roots, HashSet<Long> partition, CPlanMemoTable 
memo)
+       {
+               HashSet<Long> ret = new HashSet<>();
+               for( Long hopID : partition ) {
+                       Hop hop = memo.getHopRefs().get(hopID);
+                       if( hasNonPartitionConsumer(hop, partition) 
+                               && !roots.contains(hopID))
+                               ret.add(hopID);
+               }
+               return ret;
+       }
+       
+       private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> 
partition) {
+               boolean ret = false;
+               for( Hop p : hop.getParent() )
+                       ret |= !partition.contains(p.getHopID());
+               return ret;
+       }
+       
+       private static InterestingPoint[] 
getMaterializationPointsExt(HashSet<Long> roots, 
+                       HashSet<Long> partition, ArrayList<Long> M, 
CPlanMemoTable memo) 
+       {
+               //collect categories of interesting points
+               ArrayList<InterestingPoint> tmp = new ArrayList<>();
+               tmp.addAll(getMaterializationPointConsumers(M, partition, 
memo));
+               tmp.addAll(getTemplateChangePoints(partition, memo));
+               
+               //reduce to distinct hop->hop pairs (see equals of interesting 
points)
+               InterestingPoint[] ret = tmp.stream().distinct()
+                       .toArray(InterestingPoint[]::new);
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Partition materialization points (extended): 
"
+                               + Arrays.toString(ret));
+               }
+               
+               return ret;
+       }
+       
+       private static ArrayList<InterestingPoint> 
getMaterializationPointConsumers(
+               ArrayList<Long> M, HashSet<Long> partition, CPlanMemoTable 
memo) 
+       {
+               //collect all materialization point consumers
+               ArrayList<InterestingPoint> ret = new ArrayList<>();
+               for( Long hopID : M )
+                       for( Hop parent : 
memo.getHopRefs().get(hopID).getParent() )
+                               if( partition.contains(parent.getHopID()) )
+                                       ret.add(new InterestingPoint(
+                                               DecisionType.MULTI_CONSUMER, 
parent.getHopID(), hopID));
+
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Partition materialization point consumers: "
+                               + Arrays.toString(ret.toArray(new 
InterestingPoint[0])));
+               }
+               
+               return ret;
+       }
+       
+       private static ArrayList<InterestingPoint> getTemplateChangePoints (
+               HashSet<Long> partition, CPlanMemoTable memo) 
+       {
+               //collect all template change points 
+               ArrayList<InterestingPoint> ret = new ArrayList<>();
+               for( Long hopID : partition ) {
+                       long[] refs = memo.getAllRefs(hopID);
+                       for( int i=0; i<3; i++ ) {
+                               if( refs[i] < 0 ) continue;
+                               List<TemplateType> tmp = 
memo.getDistinctTemplateTypes(hopID, i);
+                               
+                               if( memo.containsNotIn(refs[i], tmp, true, 
true) )
+                                       ret.add(new 
InterestingPoint(DecisionType.TEMPLATE_CHANGE, hopID, refs[i]));
+                       }
+               }
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Partition template change points: "
+                               + Arrays.toString(ret.toArray(new 
InterestingPoint[0])));
+               }
+               
+               return ret;
+       }
+       
+       private static HashSet<Long> rGetConnectedSubGraphs(long hopID, 
CPlanMemoTable memo, 
+                       HashMap<Long, HashSet<Long>> refBy, HashSet<Long> 
visited, HashSet<Long> partition) 
+       {
+               if( visited.contains(hopID) )
+                       return partition;
+               
+               //process node itself w/ memoization
+               if( memo.contains(hopID) ) {
+                       partition.add(hopID);
+                       visited.add(hopID);     
+               }
+               
+               //recursively process parents
+               if( refBy.containsKey(hopID) )
+                       for( Long ref : refBy.get(hopID) )
+                               rGetConnectedSubGraphs(ref, memo, refBy, 
visited, partition);
+               
+               //recursively process children
+               if( memo.contains(hopID) ) {
+                       long[] refs = memo.getAllRefs(hopID);
+                       for( int i=0; i<3; i++ )
+                               if( refs[i] != -1 )
+                                       rGetConnectedSubGraphs(refs[i], memo, 
refBy, visited, partition);
+               }
+               
+               return partition;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
new file mode 100644
index 0000000..fbf7d9f
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+
+public class PlanPartition 
+{
+       //nodes of partition (hop IDs)
+       private final HashSet<Long> _nodes;
+       
+       //root nodes of partition (hop IDs)
+       private final HashSet<Long> _roots;
+       
+       //partition inputs 
+       private final HashSet<Long> _inputs;
+       
+       //nodes with non-partition consumers 
+       private final HashSet<Long> _nodesNpc;
+       
+       //materialization points (hop IDs)
+       private final ArrayList<Long> _matPoints;
+       
+       //interesting operator dependencies
+       private InterestingPoint[] _matPointsExt;
+       
+       
+       public PlanPartition(HashSet<Long> P, HashSet<Long> R, HashSet<Long> I, 
HashSet<Long> Pnpc, ArrayList<Long> M, InterestingPoint[] Mext) {
+               _nodes = P;
+               _roots = R;
+               _inputs = I;
+               _nodesNpc = Pnpc;
+               _matPoints = M;
+               _matPointsExt = Mext;
+       }
+       
+       public HashSet<Long> getPartition() {
+               return _nodes;
+       }
+       
+       public HashSet<Long> getRoots() {
+               return _roots;
+       }
+       
+       public HashSet<Long> getInputs() {
+               return _inputs;
+       }
+       
+       public HashSet<Long> getExtConsumed() {
+               return _nodesNpc;
+       } 
+       
+       public ArrayList<Long> getMatPoints() {
+               return _matPoints;
+       }
+       
+       public InterestingPoint[] getMatPointsExt() {
+               return _matPointsExt;
+       }
+
+       public void setMatPointsExt(InterestingPoint[] points) {
+               _matPointsExt = points;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
new file mode 100644
index 0000000..23cc128
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+public abstract class PlanSelection 
+{
+       private final HashMap<Long, List<MemoTableEntry>> _bestPlans = 
+                       new HashMap<Long, List<MemoTableEntry>>();
+       private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
+       
+       /**
+        * Given a HOP DAG G, and a set of partial fusions plans P, find the 
set of optimal, 
+        * non-conflicting fusion plans P' that applied to G minimizes costs C 
with
+        * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a 
set of 
+        * constraints such as memory budgets and block size restrictions per 
fused operator.
+        * 
+        * @param memo partial fusion plans P
+        * @param roots entry points of HOP DAG G
+        */
+       public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> 
roots);    
+       
+       /**
+        * Determines if the given partial fusion plan is valid.
+        * 
+        * @param me memo table entry
+        * @param hop current hop
+        * @return true if entry is valid as top-level plan
+        */
+       public static boolean isValid(MemoTableEntry me, Hop hop) {
+               return (me.type != TemplateType.OUTER //ROW, CELL, MAGG
+                       || (me.closed || 
HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
+       }
+       
+       protected void addBestPlan(long hopID, MemoTableEntry me) {
+               if( me == null ) return;
+               if( !_bestPlans.containsKey(hopID) )
+                       _bestPlans.put(hopID, new ArrayList<MemoTableEntry>());
+               _bestPlans.get(hopID).add(me);
+       }
+       
+       protected HashMap<Long, List<MemoTableEntry>> getBestPlans() {
+               return _bestPlans;
+       }
+       
+       protected boolean isVisited(long hopID, TemplateType type) {
+               return _visited.contains(new VisitMark(hopID, type));
+       }
+       
+       protected void setVisited(long hopID, TemplateType type) {
+               _visited.add(new VisitMark(hopID, type));
+       }
+       
+       /**
+        * Basic plan comparator to compare memo table entries with regard to
+        * a pre-defined template preference order and the number of references.
+        */
+       protected static class BasicPlanComparator implements 
Comparator<MemoTableEntry> {
+               @Override
+               public int compare(MemoTableEntry o1, MemoTableEntry o2) {
+                       return icompare(o1, o2);
+               }
+               
+               public static int icompare(MemoTableEntry o1, MemoTableEntry 
o2) {
+                       if( o2 == null ) return -1;
+                       
+                       //for different types, select preferred type
+                       if( o1.type != o2.type )
+                               return Integer.compare(o1.type.getRank(), 
o2.type.getRank());
+                       
+                       //for same type, prefer plan with more refs
+                       return Integer.compare(
+                               3-o1.countPlanRefs(), 3-o2.countPlanRefs());
+               }
+       }
+       
+       protected static class TypedPlanComparator implements 
Comparator<MemoTableEntry> {
+               private TemplateType _type;
+               
+               public void setType(TemplateType type) {
+                       _type = type;
+               }
+               
+               @Override
+               public int compare(MemoTableEntry o1, MemoTableEntry o2) {
+                       return icompare(o1, o2, _type);
+               }
+               
+               public static int icompare(MemoTableEntry o1, MemoTableEntry 
o2, TemplateType type) {
+                       if( o2 == null ) return -1;
+                       int score1 = 7 - ((o1.type==type)?4:0) - 
o1.countPlanRefs();
+                       int score2 = 7 - ((o2.type==type)?4:0) - 
o2.countPlanRefs();
+                       return Integer.compare(score1, score2);
+               }
+       }
+       
+       protected static class VisitMark {
+               private final long _hopID;
+               private final TemplateType _type;
+               
+               public VisitMark(long hopID, TemplateType type) {
+                       _hopID = hopID;
+                       _type = type;
+               }
+               @Override
+               public int hashCode() {
+                       return UtilFunctions.longHashCode(
+                               _hopID, (_type!=null)?_type.hashCode():0);
+               }
+               @Override 
+               public boolean equals(Object o) {
+                       return (o instanceof VisitMark
+                               && _hopID == ((VisitMark)o)._hopID
+                               && _type == ((VisitMark)o)._type);
+               }
+       }
+       
+       public static class VisitMarkCost {
+               private final long _hopID;
+               private final long _costID;
+               
+               public VisitMarkCost(long hopID, long costID) {
+                       _hopID = hopID;
+                       _costID = costID;
+               }
+               @Override
+               public int hashCode() {
+                       return UtilFunctions.longHashCode(
+                               _hopID, _costID);
+               }
+               @Override 
+               public boolean equals(Object o) {
+                       return (o instanceof VisitMarkCost
+                               && _hopID == ((VisitMarkCost)o)._hopID
+                               && _costID == ((VisitMarkCost)o)._costID);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
new file mode 100644
index 0000000..8636bea
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Map.Entry;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+
+/**
+ * This plan selection heuristic aims for maximal fusion, which
+ * potentially leads to overlapping fused operators and thus,
+ * redundant computation but with a minimal number of materialized
+ * intermediate results.
+ * 
+ */
+public class PlanSelectionFuseAll extends PlanSelection
+{      
+       @Override
+       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
+               //pruning and collection pass
+               for( Hop hop : roots )
+                       rSelectPlans(memo, hop, null);
+               
+               //take all distinct best plans
+               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
+                       memo.setDistinct(e.getKey(), e.getValue());
+       }
+       
+       private void rSelectPlans(CPlanMemoTable memo, Hop current, 
TemplateType currentType) 
+       {       
+               if( isVisited(current.getHopID(), currentType) )
+                       return;
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlans(memo, current.getInput().get(i), pref);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }       
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
new file mode 100644
index 0000000..985cc0f
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -0,0 +1,892 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
+import org.apache.sysml.hops.codegen.template.TemplateRow;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * This cost-based plan selection algorithm chooses fused operators
+ * based on the DAG structure and resulting overall costs. This primarily
+ * includes decisions on materialization points, but also heuristics for 
+ * template types, and composed multi output templates. 
+ * 
+ */
+public class PlanSelectionFuseCostBased extends PlanSelection
+{      
+       private static final Log LOG = 
LogFactory.getLog(PlanSelectionFuseCostBased.class.getName());
+       
+       //common bandwidth characteristics, with a conservative write bandwidth 
in order 
+       //to cover result allocation, write into main memory, and potential 
evictions
+       private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024;  
//2GB/s
+       private static final double READ_BANDWIDTH = 32d*1024*1024*1024;  
//32GB/s
+       private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 
//2GFLOPs/core
+               * InfrastructureAnalyzer.getLocalParallelism();
+       
+       private static final IDSequence COST_ID = new IDSequence();
+       private static final TemplateRow ROW_TPL = new TemplateRow();
+       
+       @Override
+       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
+       {
+               //step 1: analyze connected partitions (nodes, roots, mat 
points)
+               Collection<PlanPartition> parts = 
PlanAnalyzer.analyzePlanPartitions(memo, roots, false);
+               
+               //step 2: optimize individual plan partitions
+               for( PlanPartition part : parts ) {
+                       //create composite templates (within the partition)
+                       createAndAddMultiAggPlans(memo, part.getPartition(), 
part.getRoots());
+                       
+                       //plan enumeration and plan selection
+                       selectPlans(memo, part.getPartition(), part.getRoots(), 
part.getMatPoints());
+               }
+               
+               //step 3: add composite templates (across partitions)
+               createAndAddMultiAggPlans(memo, roots);
+       
+               //take all distinct best plans
+               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
+                       memo.setDistinct(e.getKey(), e.getValue());
+       }
+       
+       //within-partition multi-agg templates
+       private static void createAndAddMultiAggPlans(CPlanMemoTable memo, 
HashSet<Long> partition, HashSet<Long> R)
+       {
+               //create index of plans that reference full aggregates to avoid 
circular dependencies
+               HashSet<Long> refHops = new HashSet<Long>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo.getPlans().entrySet() )
+                       if( !e.getValue().isEmpty() ) {
+                               Hop hop = memo.getHopRefs().get(e.getKey());
+                               for( Hop c : hop.getInput() )
+                                       refHops.add(c.getHopID());
+                       }
+               
+               //find all full aggregations (the fact that they are in the 
same partition guarantees 
+               //that they also have common subexpressions, also full 
aggregations are by def root nodes)
+               ArrayList<Long> fullAggs = new ArrayList<Long>();
+               for( Long hopID : R ) {
+                       Hop root = memo.getHopRefs().get(hopID);
+                       if( !refHops.contains(hopID) && 
isMultiAggregateRoot(root) )
+                               fullAggs.add(hopID);
+               }
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Found within-partition ua(RC) aggregations: 
" +
+                               Arrays.toString(fullAggs.toArray(new Long[0])));
+               }
+               
+               //construct and add multiagg template plans (w/ max 3 
aggregations)
+               for( int i=0; i<fullAggs.size(); i+=3 ) {
+                       int ito = Math.min(i+3, fullAggs.size());
+                       if( ito-i >= 2 ) {
+                               MemoTableEntry me = new 
MemoTableEntry(TemplateType.MAGG,
+                                       fullAggs.get(i), fullAggs.get(i+1), 
((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
+                               if( isValidMultiAggregate(memo, me) ) {
+                                       for( int j=i; j<ito; j++ ) {
+                                               
memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
+                                               if( LOG.isTraceEnabled() )
+                                                       LOG.trace("Added 
multiagg plan: "+fullAggs.get(j)+" "+me);
+                                       }
+                               }
+                               else if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed invalid multiagg 
plan: "+me);
+                               }
+                       }
+               }
+       }
+       
+       //across-partition multi-agg templates with shared reads
+       private void createAndAddMultiAggPlans(CPlanMemoTable memo, 
ArrayList<Hop> roots)
+       {
+               //collect full aggregations as initial set of candidates
+               HashSet<Long> fullAggs = new HashSet<Long>();
+               Hop.resetVisitStatus(roots);
+               for( Hop hop : roots )
+                       rCollectFullAggregates(hop, fullAggs);
+               Hop.resetVisitStatus(roots);
+
+               //remove operators with assigned multi-agg plans
+               fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
+       
+               //check applicability for further analysis
+               if( fullAggs.size() <= 1 )
+                       return;
+       
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Found across-partition ua(RC) aggregations: 
" +
+                               Arrays.toString(fullAggs.toArray(new Long[0])));
+               }
+               
+               //collect information for all candidates 
+               //(subsumed aggregations, and inputs to fused operators) 
+               List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>();
+               for( Long hopID : fullAggs ) {
+                       Hop aggHop = memo.getHopRefs().get(hopID);
+                       AggregateInfo tmp = new AggregateInfo(aggHop);
+                       for( int i=0; i<aggHop.getInput().size(); i++ ) {
+                               Hop c = 
HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? 
+                                       
aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
+                               rExtractAggregateInfo(memo, c, tmp, 
TemplateType.CELL);
+                       }
+                       if( tmp._fusedInputs.isEmpty() ) {
+                               if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+                                       
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+                               }
+                               else    
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+                       }
+                       aggInfos.add(tmp);      
+               }
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Extracted across-partition ua(RC) 
aggregation info: ");
+                       for( AggregateInfo info : aggInfos )
+                               LOG.trace(info);
+               }
+               
+               //sort aggregations by num dependencies to simplify merging
+               //clusters of aggregations with parallel dependencies
+               aggInfos = aggInfos.stream()
+                       .sorted(Comparator.comparing(a -> a._inputAggs.size()))
+                       .collect(Collectors.toList());
+               
+               //greedy grouping of multi-agg candidates
+               boolean converged = false;
+               while( !converged ) {
+                       AggregateInfo merged = null;
+                       for( int i=0; i<aggInfos.size(); i++ ) {
+                               AggregateInfo current = aggInfos.get(i);
+                               for( int j=i+1; j<aggInfos.size(); j++ ) {
+                                       AggregateInfo that = aggInfos.get(j);
+                                       if( current.isMergable(that) ) {
+                                               merged = current.merge(that);
+                                               aggInfos.remove(j); j--;
+                                       }
+                               }
+                       }
+                       converged = (merged == null);
+               }
+               
+               if( LOG.isTraceEnabled() ) {
+                       LOG.trace("Merged across-partition ua(RC) aggregation 
info: ");
+                       for( AggregateInfo info : aggInfos )
+                               LOG.trace(info);
+               }
+               
+               //construct and add multiagg template plans (w/ max 3 
aggregations)
+               for( AggregateInfo info : aggInfos ) {
+                       if( info._aggregates.size()<=1 )
+                               continue;
+                       Long[] aggs = info._aggregates.keySet().toArray(new 
Long[0]);
+                       MemoTableEntry me = new 
MemoTableEntry(TemplateType.MAGG,
+                               aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, 
aggs.length);
+                       for( int i=0; i<aggs.length; i++ ) {
+                               memo.add(memo.getHopRefs().get(aggs[i]), me);
+                               addBestPlan(aggs[i], me);
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Added multiagg* plan: 
"+aggs[i]+" "+me);
+                               
+                       }
+               }
+       }
+       
+       private static boolean isMultiAggregateRoot(Hop root) {
+               return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) 
+                               && 
((AggUnaryOp)root).getDirection()==Direction.RowCol)
+                       || (root instanceof AggBinaryOp && root.getDim1()==1 && 
root.getDim2()==1
+                               && 
HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+       }
+       
+       private static boolean isValidMultiAggregate(CPlanMemoTable memo, 
MemoTableEntry me) {
+               //ensure input consistent sizes (otherwise potential for 
incorrect results)
+               boolean ret = true;
+               Hop refSize = 
memo.getHopRefs().get(me.input1).getInput().get(0);
+               for( int i=1; ret && i<3; i++ ) {
+                       if( me.isPlanRef(i) )
+                               ret &= HopRewriteUtils.isEqualSize(refSize, 
+                                       
memo.getHopRefs().get(me.input(i)).getInput().get(0));
+               }
+               
+               //ensure that aggregates are independent of each other, i.e.,
+               //they to not have potentially transitive parent child 
references
+               for( int i=0; ret && i<3; i++ ) 
+                       if( me.isPlanRef(i) ) {
+                               HashSet<Long> probe = new HashSet<Long>();
+                               for( int j=0; j<3; j++ )
+                                       if( i != j )
+                                               probe.add(me.input(j));
+                               ret &= 
rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
+                       }
+               return ret;
+       }
+       
+       private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> 
probe) {
+               boolean ret = true;
+               for( Hop c : current.getInput() )
+                       ret &= rCheckMultiAggregate(c, probe);
+               ret &= !probe.contains(current.getHopID());
+               return ret;
+       }
+       
+       private static void rCollectFullAggregates(Hop current, HashSet<Long> 
aggs) {
+               if( current.isVisited() )
+                       return;
+               
+               //collect all applicable full aggregations per read
+               if( isMultiAggregateRoot(current) )
+                       aggs.add(current.getHopID());
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rCollectFullAggregates(c, aggs);
+               
+               current.setVisited();
+       }
+       
+       private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop 
current, AggregateInfo aggInfo, TemplateType type) {
+               //collect input aggregates (dependents)
+               if( isMultiAggregateRoot(current) )
+                       aggInfo.addInputAggregate(current.getHopID());
+               
+               //recursively process children
+               MemoTableEntry me = (type!=null) ? 
memo.getBest(current.getHopID()) : null;
+               for( int i=0; i<current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       if( me != null && me.isPlanRef(i) )
+                               rExtractAggregateInfo(memo, c, aggInfo, type);
+                       else {
+                               if( type != null && c.getDataType().isMatrix()  
) //add fused input
+                                       aggInfo.addFusedInput(c.getHopID());
+                               rExtractAggregateInfo(memo, c, aggInfo, null);
+                       }
+               }
+       }
+       
+       private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, 
HashSet<Long> R, ArrayList<Long> M) 
+       {
+               //prune row aggregates with pure cellwise operations
+               for( Long hopID : R ) {
+                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
+                       if( me.type == TemplateType.ROW && memo.contains(hopID, 
TemplateType.CELL)
+                               && isRowTemplateWithoutAgg(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
+                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
+                               if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed row memo table 
entries w/o aggregation: "
+                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                               }
+                       }
+               }
+               
+               //prune suboptimal outer product plans that are dominated by 
outer product plans w/ same number of 
+               //references but better fusion properties (e.g., for the 
patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
+               //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this 
would unnecessarily destroy a fusion pattern.
+               for( Long hopID : partition ) {
+                       if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) 
{
+                               List<MemoTableEntry> entries = memo.get(hopID, 
TemplateType.OUTER);
+                               MemoTableEntry me1 = entries.get(0);
+                               MemoTableEntry me2 = entries.get(1);
+                               MemoTableEntry rmEntry = 
TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
+                               if( rmEntry != null ) {
+                                       
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
+                                       
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed dominated 
outer product memo table entry: " + rmEntry);
+                               }
+                       }
+               }
+               
+               //if no materialization points, use basic fuse-all w/ partition 
awareness
+               if( M == null || M.isEmpty() ) {
+                       for( Long hopID : R )
+                               rSelectPlansFuseAll(memo, 
+                                       memo.getHopRefs().get(hopID), null, 
partition);
+               }
+               else {
+                       //TODO branch and bound pruning, right now we use 
exhaustive enum for early experiments
+                       //via skip ahead in below enumeration algorithm
+                       
+                       //obtain hop compute costs per cell once
+                       HashMap<Long, Double> computeCosts = new HashMap<Long, 
Double>();
+                       for( Long hopID : R )
+                               rGetComputeCosts(memo.getHopRefs().get(hopID), 
partition, computeCosts);
+                       
+                       //scan linearized search space, w/ skips for branch and 
bound pruning
+                       int len = (int)Math.pow(2, M.size());
+                       boolean[] bestPlan = null;
+                       double bestC = Double.MAX_VALUE;
+                       
+                       for( int i=0; i<len; i++ ) {
+                               //construct assignment
+                               boolean[] plan = createAssignment(M.size(), i);
+                               
+                               //cost assignment on hops
+                               double C = getPlanCost(memo, partition, R, M, 
plan, computeCosts);
+                               if( DMLScript.STATISTICS )
+                                       
Statistics.incrementCodegenFPlanCompile(1);
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Enum: 
"+Arrays.toString(plan)+" -> "+C);
+                               
+                               //cost comparisons
+                               if( bestPlan == null || C < bestC ) {
+                                       bestC = C;
+                                       bestPlan = plan;
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Enum: Found new best 
plan.");
+                               }
+                       }
+                       
+                       //prune memo table wrt best plan and select plans
+                       HashSet<Long> visited = new HashSet<Long>();
+                       for( Long hopID : R )
+                               rPruneSuboptimalPlans(memo, 
memo.getHopRefs().get(hopID), 
+                                       visited, partition, M, bestPlan);
+                       HashSet<Long> visited2 = new HashSet<Long>();
+                       for( Long hopID : R )
+                               rPruneInvalidPlans(memo, 
memo.getHopRefs().get(hopID), 
+                                       visited2, partition, M, bestPlan);
+                       
+                       for( Long hopID : R )
+                               rSelectPlansFuseAll(memo, 
+                                       memo.getHopRefs().get(hopID), null, 
partition);
+               }
+       }
+       
+       private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited) {
+               //consider all aggregations other than root operation
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
+               boolean ret = true;
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
+                                       current.getInput().get(i), visited);
+               return ret;
+       }
+       
+       private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, 
Hop current, HashSet<Long> visited) {
+               if( visited.contains(current.getHopID()) )
+                       return true;
+               
+               boolean ret = true;
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
current.getInput().get(i), visited);
+               ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp);
+               
+               visited.add(current.getHopID());
+               return ret;
+       }
+       
+       private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, 
boolean[] plan) {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //remove memo table entries if necessary
+               long hopID = current.getHopID();
+               if( partition.contains(hopID) && memo.contains(hopID) ) {
+                       Iterator<MemoTableEntry> iter = 
memo.get(hopID).iterator();
+                       while( iter.hasNext() ) {
+                               MemoTableEntry me = iter.next();
+                               if( !hasNoRefToMaterialization(me, M, plan) && 
me.type!=TemplateType.OUTER ){
+                                       iter.remove();
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed memo table 
entry: "+me);
+                               }
+                       }
+               }
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rPruneSuboptimalPlans(memo, c, visited, partition, M, 
plan);
+               
+               visited.add(current.getHopID());                
+       }
+       
+       private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, 
boolean[] plan) {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rPruneInvalidPlans(memo, c, visited, partition, M, 
plan);
+               
+               //find invalid row aggregate leaf nodes (see TemplateRow.open) 
w/o matrix inputs, 
+               //i.e., plans that become invalid after the previous pruning 
step
+               long hopID = current.getHopID();
+               if( partition.contains(hopID) && memo.contains(hopID, 
TemplateType.ROW) ) {
+                       for( MemoTableEntry me : memo.get(hopID) ) {
+                               if( me.type==TemplateType.ROW ) {
+                                       //convert leaf node with pure vector 
inputs
+                                       if( !me.hasPlanRef() && 
!TemplateUtils.hasMatrixInput(current) ) {
+                                               me.type = TemplateType.CELL;
+                                               if( LOG.isTraceEnabled() )
+                                                       LOG.trace("Converted 
leaf memo table entry from row to cell: "+me);
+                                       }
+                                       
+                                       //convert inner node without row 
template input
+                                       if( me.hasPlanRef() && 
!ROW_TPL.open(current) ) {
+                                               boolean hasRowInput = false;
+                                               for( int i=0; i<3; i++ )
+                                                       if( me.isPlanRef(i) )
+                                                               hasRowInput |= 
memo.contains(me.input(i), TemplateType.ROW);
+                                               if( !hasRowInput ) {
+                                                       me.type = 
TemplateType.CELL;
+                                                       if( 
LOG.isTraceEnabled() )
+                                                               
LOG.trace("Converted inner memo table entry from row to cell: "+me);    
+                                               }
+                                       }
+                                       
+                               }
+                       }
+               }
+               
+               visited.add(current.getHopID());                
+       }
+       
+       private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
+       {       
+               if( isVisited(current.getHopID(), currentType) 
+                       || !partition.contains(current.getHopID()) )
+                       return;
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }       
+       
+       private static boolean[] createAssignment(int len, int pos) {
+               boolean[] ret = new boolean[len]; 
+               int tmp = pos;
+               for( int i=0; i<len; i++ ) {
+                       ret[i] = (tmp < (int)Math.pow(2, len-i-1));
+                       tmp %= Math.pow(2, len-i-1);
+               }
+               return ret;     
+       }
+       
+       /////////////////////////////////////////////////////////
+       // Cost model fused operators w/ materialization points
+       //////////
+       
+       private static double getPlanCost(CPlanMemoTable memo, HashSet<Long> 
partition, HashSet<Long> R, 
+                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts) 
+       {
+               //high level heuristic: every hop or fused operator has the 
following cost: 
+               //WRITE + max(COMPUTE, READ), where WRITE costs are given by 
the output size, 
+               //READ costs by the input sizes, and COMPUTE by operation 
specific FLOP
+               //counts times number of cells of main input, disregarding 
sparsity for now.
+               
+               HashSet<Pair<Long,Long>> visited = new 
HashSet<Pair<Long,Long>>();
+               double costs = 0;
+               for( Long hopID : R )
+                       costs += rGetPlanCosts(memo, 
memo.getHopRefs().get(hopID), 
+                                       visited, partition, M, plan, 
computeCosts, null, null);         
+               return costs;
+       }
+       
+       private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, 
HashSet<Pair<Long,Long>> visited, HashSet<Long> partition, 
+                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts, CostVector costsCurrent, TemplateType currentType) 
+       {
+               //memoization per hop id and cost vector to account for 
redundant
+               //computation without double counting materialized results or 
compute
+               //costs of complex operation DAGs within a single fused operator
+               Pair<Long,Long> tag = Pair.of(current.getHopID(), 
+                       (costsCurrent==null)?0:costsCurrent.ID);
+               if( visited.contains(tag) )
+                       return 0; 
+               visited.add(tag);       
+               
+               //open template if necessary, including memoization
+               //under awareness of current plan choice
+               MemoTableEntry best = null;
+               boolean opened = false;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .filter(p -> 
hasNoRefToMaterialization(p, M, plan))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                               opened = true;
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .filter(p -> 
hasNoRefToMaterialization(p, M, plan))
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+               }
+               
+               //create new cost vector if opened, initialized with write costs
+               CostVector costVect = !opened ? costsCurrent : 
+                       new 
CostVector(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
+               
+               //add compute costs of current operator to costs vector 
+               if( partition.contains(current.getHopID()) )
+                       costVect.computeCosts += 
computeCosts.get(current.getHopID());
+               
+               //process children recursively
+               double costs = 0;
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       if( best!=null && best.isPlanRef(i) )
+                               costs += rGetPlanCosts(memo, c, visited, 
partition, M, plan, computeCosts, costVect, best.type);
+                       else if( best!=null && isImplicitlyFused(current, i, 
best.type) )
+                               
costVect.addInputSize(c.getInput().get(0).getHopID(), 
Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1));
+                       else { //include children and I/O costs
+                               costs += rGetPlanCosts(memo, c, visited, 
partition, M, plan, computeCosts, null, null);
+                               if( costVect != null && 
c.getDataType().isMatrix() )
+                                       costVect.addInputSize(c.getHopID(), 
Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1));
+                       }                               
+               }       
+               
+               //add costs for opened fused operator
+               if( partition.contains(current.getHopID()) ) {
+                       if( opened ) {
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Cost vector for fused 
operator (hop "+current.getHopID()+"): "+costVect);
+                               costs += costVect.outSize * 8 / 
WRITE_BANDWIDTH; //time for output write
+                               costs += Math.max(
+                                               
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH, 
+                                               costVect.getSumInputSizes() * 8 
/ READ_BANDWIDTH); 
+                       }
+                       //add costs for non-partition read in the middle of 
fused operator
+                       else if( hasNonPartitionConsumer(current, partition) ) {
+                               costs += rGetPlanCosts(memo, current, visited, 
partition, M, plan, computeCosts, null, null);
+                       }
+               }
+               
+               //sanity check non-negative costs
+               if( costs < 0 || Double.isNaN(costs) || 
Double.isInfinite(costs) )
+                       throw new RuntimeException("Wrong cost estimate: 
"+costs);
+               
+               return costs;
+       }
+       
+       private static void rGetComputeCosts(Hop current, HashSet<Long> 
partition, HashMap<Long, Double> computeCosts) 
+       {
+               if( computeCosts.containsKey(current.getHopID()) )
+                       return;
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rGetComputeCosts(c, partition, computeCosts);
+               
+               //get costs for given hop
+               double costs = 1;
+               if( current instanceof UnaryOp ) {
+                       switch( ((UnaryOp)current).getOp() ) {
+                               case ABS:   
+                               case ROUND:
+                               case CEIL:
+                               case FLOOR:
+                               case SIGN:
+                               case SELP:    costs = 1; break; 
+                               case SPROP:
+                               case SQRT:    costs = 2; break;
+                               case EXP:     costs = 18; break;
+                               case SIGMOID: costs = 21; break;
+                               case LOG:    
+                               case LOG_NZ:  costs = 32; break;
+                               case NCOL:
+                               case NROW:
+                               case PRINT:
+                               case CAST_AS_BOOLEAN:
+                               case CAST_AS_DOUBLE:
+                               case CAST_AS_INT:
+                               case CAST_AS_MATRIX:
+                               case CAST_AS_SCALAR: costs = 1; break;
+                               case SIN:     costs = 18; break;
+                               case COS:     costs = 22; break;
+                               case TAN:     costs = 42; break;
+                               case ASIN:    costs = 93; break;
+                               case ACOS:    costs = 103; break;
+                               case ATAN:    costs = 40; break;
+                               case CUMSUM:
+                               case CUMMIN:
+                               case CUMMAX:
+                               case CUMPROD: costs = 1; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((UnaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof BinaryOp ) {
+                       switch( ((BinaryOp)current).getOp() ) {
+                               case MULT: 
+                               case PLUS:
+                               case MINUS:
+                               case MIN:
+                               case MAX: 
+                               case AND:
+                               case OR:
+                               case EQUAL:
+                               case NOTEQUAL:
+                               case LESS:
+                               case LESSEQUAL:
+                               case GREATER:
+                               case GREATEREQUAL: 
+                               case CBIND:
+                               case RBIND:   costs = 1; break;
+                               case INTDIV:  costs = 6; break;
+                               case MODULUS: costs = 8; break;
+                               case DIV:     costs = 22; break;
+                               case LOG:
+                               case LOG_NZ:  costs = 32; break;
+                               case POW:     costs = 
(HopRewriteUtils.isLiteralOfValue(
+                                               current.getInput().get(1), 2) ? 
1 : 16); break;
+                               case MINUS_NZ:
+                               case MINUS1_MULT: costs = 2; break;
+                               case CENTRALMOMENT:
+                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 1; break; 
//count
+                                               case 1: costs = 8; break; //mean
+                                               case 2: costs = 16; break; //cm2
+                                               case 3: costs = 31; break; //cm3
+                                               case 4: costs = 51; break; //cm4
+                                               case 5: costs = 16; break; 
//variance
+                                       }
+                                       break;
+                               case COVARIANCE: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((BinaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof TernaryOp ) {
+                       switch( ((TernaryOp)current).getOp() ) {
+                               case PLUS_MULT: 
+                               case MINUS_MULT: costs = 2; break;
+                               case CTABLE:     costs = 3; break;
+                               case CENTRALMOMENT:
+                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 2; break; 
//count
+                                               case 1: costs = 9; break; //mean
+                                               case 2: costs = 17; break; //cm2
+                                               case 3: costs = 32; break; //cm3
+                                               case 4: costs = 52; break; //cm4
+                                               case 5: costs = 17; break; 
//variance
+                                       }
+                                       break;
+                               case COVARIANCE: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((TernaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof ParameterizedBuiltinOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof IndexingOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof ReorgOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof AggBinaryOp ) {
+                       costs = 2; //matrix vector
+               }
+               else if( current instanceof AggUnaryOp) {
+                       switch(((AggUnaryOp)current).getOp()) {
+                       case SUM:    costs = 4; break; 
+                       case SUM_SQ: costs = 5; break;
+                       case MIN:
+                       case MAX:    costs = 1; break;
+                       default:
+                               LOG.warn("Cost model not "
+                                       + "implemented yet for: 
"+((AggUnaryOp)current).getOp());                       
+                       }
+               }
+               
+               computeCosts.put(current.getHopID(), costs);
+       }
+       
+       private static boolean hasNoRefToMaterialization(MemoTableEntry me, 
ArrayList<Long> M, boolean[] plan) {
+               boolean ret = true;
+               for( int i=0; ret && i<3; i++ )
+                       ret &= (!M.contains(me.input(i)) || 
!plan[M.indexOf(me.input(i))]);
+               return ret;
+       }
+       
+       private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> 
partition) {
+               boolean ret = false;
+               for( Hop p : hop.getParent() )
+                       ret |= !partition.contains(p.getHopID());
+               return ret;
+       }
+       
+       private static boolean isImplicitlyFused(Hop hop, int index, 
TemplateType type) {
+               return type == TemplateType.ROW
+                       && HopRewriteUtils.isMatrixMultiply(hop) && index==0 
+                       && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); 
+       }
+       
+       private static class CostVector {
+               public final long ID;
+               public final double outSize; 
+               public double computeCosts = 0;
+               public final HashMap<Long, Double> inSizes = new HashMap<Long, 
Double>();
+               
+               public CostVector(double outputSize) {
+                       ID = COST_ID.getNextID();
+                       outSize = outputSize;
+               }
+               public void addInputSize(long hopID, double inputSize) {
+                       //ensures that input sizes are not double counted
+                       inSizes.put(hopID, inputSize);
+               }
+               public double getSumInputSizes() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> d.doubleValue()).sum();
+               }
+               public double getMaxInputSize() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> 
d.doubleValue()).max().orElse(0);
+               }
+               @Override
+               public String toString() {
+                       return "["+outSize+", "+computeCosts+", {"
+                               +Arrays.toString(inSizes.keySet().toArray(new 
Long[0]))+", "
+                               +Arrays.toString(inSizes.values().toArray(new 
Double[0]))+"}]";
+               }
+       }
+       
+       private static class AggregateInfo {
+               public final HashMap<Long,Hop> _aggregates;
+               public final HashSet<Long> _inputAggs = new HashSet<Long>();
+               public final HashSet<Long> _fusedInputs = new HashSet<Long>();
+               public AggregateInfo(Hop aggregate) {
+                       _aggregates = new HashMap<Long, Hop>();
+                       _aggregates.put(aggregate.getHopID(), aggregate);
+               }
+               public void addInputAggregate(long hopID) {
+                       _inputAggs.add(hopID);
+               }
+               public void addFusedInput(long hopID) {
+                       _fusedInputs.add(hopID);
+               }
+               public boolean isMergable(AggregateInfo that) {
+                       //check independence
+                       boolean ret = _aggregates.size()<3 
+                               && 
_aggregates.size()+that._aggregates.size()<=3;
+                       for( Long hopID : that._aggregates.keySet() )
+                               ret &= !_inputAggs.contains(hopID);
+                       for( Long hopID : _aggregates.keySet() )
+                               ret &= !that._inputAggs.contains(hopID);
+                       //check partial shared reads
+                       ret &= !CollectionUtils.intersection(
+                               _fusedInputs, that._fusedInputs).isEmpty();
+                       //check consistent sizes (result correctness)
+                       Hop in1 = _aggregates.values().iterator().next();
+                       Hop in2 = that._aggregates.values().iterator().next();
+                       return ret && HopRewriteUtils.isEqualSize(
+                               
in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+                               
in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
+               }
+               public AggregateInfo merge(AggregateInfo that) {
+                       _aggregates.putAll(that._aggregates);
+                       _inputAggs.addAll(that._inputAggs);
+                       _fusedInputs.addAll(that._fusedInputs);
+                       return this;
+               }
+               @Override
+               public String toString() {
+                       return 
"["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": "
+                               +"{"+Arrays.toString(_inputAggs.toArray(new 
Long[0]))+"}," 
+                               +"{"+Arrays.toString(_fusedInputs.toArray(new 
Long[0]))+"}]"; 
+               }
+       }
+}

Reply via email to