[SYSTEMML-1418,1293,1424] New cost-based codegen plan selection

This patch introduces a new cost-based plan selector for code generation
including an initial cost model, which will become the default plan
selector in the future. The core ideas are (1) to determine partitions,
i.e, maximally connected sub graphs of memo table entries, (2) to
determine potential materialization points (w/ >=2 consumers) per
partition, and (3) to enumerate and cost alternative plans per partition
via a branch and bound algorithm. 

Initial end-to-end experiments on L2SVM, MLogreg, and GLM are very
promising as the resulting performance is equivalent or better than both
existing plan selection heuristics (fuse-all, fuse-no-redundancy).
However, this selector is not yet applied by default, because the cost
model needs further validation and the branch and bound pruning
conditions for skipping of subsets of plans are not integrated yet.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c45bb41f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c45bb41f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c45bb41f

Branch: refs/heads/master
Commit: c45bb41ffd3ca7c788eb7f66594906849fe9822a
Parents: ffc9120
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Apr 2 23:34:52 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Apr 2 23:46:20 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  28 +-
 .../hops/codegen/template/CPlanMemoTable.java   |  25 +-
 .../template/PlanSelectionFuseCostBased.java    | 576 +++++++++++++++++++
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  15 +
 4 files changed, 628 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c45bb41f/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 988c13b..179a06a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -29,6 +29,8 @@ import java.util.Map.Entry;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -45,6 +47,7 @@ import 
org.apache.sysml.hops.codegen.template.TemplateBase.CloseType;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysml.hops.codegen.template.PlanSelection;
+import org.apache.sysml.hops.codegen.template.PlanSelectionFuseCostBased;
 import org.apache.sysml.hops.codegen.template.PlanSelectionFuseAll;
 import org.apache.sysml.hops.codegen.template.PlanSelectionFuseNoRedundancy;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
@@ -111,6 +114,14 @@ public class SpoofCompiler
                }
        }
        
+       static {
+               // for internal debugging only
+               if( LDEBUG ) {
+                       Logger.getLogger("org.apache.sysml.hops.codegen")
+                                 .setLevel((Level) Level.TRACE);
+               }
+       }
+       
        //plan cache for cplan->compiled source to avoid unnecessary 
codegen/source code compile
        //for equal operators from (1) different hop dags and (2) repeated 
recompilation 
        //note: if PLAN_CACHE_SIZE is exceeded, we evict the 
least-recently-used plan (LRU policy)
@@ -242,8 +253,8 @@ public class SpoofCompiler
                        cplans = cleanupCPlans(cplans);
                        
                        //explain before modification
-                       if( LDEBUG && !cplans.isEmpty() ) { //existing cplans
-                               LOG.info("Codegen EXPLAIN (before optimize): 
\n"+Explain.explainHops(roots));
+                       if( LOG.isTraceEnabled() && !cplans.isEmpty() ) { 
//existing cplans
+                               LOG.trace("Codegen EXPLAIN (before optimize): 
\n"+Explain.explainHops(roots));
                        }
                        
                        //source code generation for all cplans
@@ -258,12 +269,12 @@ public class SpoofCompiler
                                        String src = 
tmp.getValue().codegen(false);
                                        
                                        //explain debug output cplans or 
generated source code
-                                       if( LDEBUG || 
DMLScript.EXPLAIN.isHopsType(recompile) ) {
+                                       if( LOG.isTraceEnabled() || 
DMLScript.EXPLAIN.isHopsType(recompile) ) {
                                                LOG.info("Codegen EXPLAIN 
(generated cplan for HopID: " +  cplan.getKey() +"):");
                                                
LOG.info(tmp.getValue().getClassname()
                                                                
+Explain.explainCPlan(cplan.getValue().getValue()));
                                        }
-                                       if( LDEBUG || 
DMLScript.EXPLAIN.isRuntimeType(recompile) ) {
+                                       if( LOG.isTraceEnabled() || 
DMLScript.EXPLAIN.isRuntimeType(recompile) ) {
                                                LOG.info("Codegen EXPLAIN 
(generated code for HopID: " +  cplan.getKey() +"):");
                                                LOG.info(src);
                                        }
@@ -276,14 +287,14 @@ public class SpoofCompiler
                                        if( 
PLAN_CACHE_POLICY!=PlanCachePolicy.NONE )
                                                
planCache.putPlan(tmp.getValue(), cla);
                                }
-                               else if( LDEBUG || DMLScript.STATISTICS ) {
+                               else if( DMLScript.STATISTICS ) {
                                        
Statistics.incrementCodegenPlanCacheHits();
                                }
                                
                                //make class available and maintain hits
                                if(cla != null)
                                        clas.put(cplan.getKey(), new 
Pair<Hop[],Class<?>>(tmp.getKey(),cla));
-                               if( LDEBUG || DMLScript.STATISTICS )
+                               if( DMLScript.STATISTICS )
                                        
Statistics.incrementCodegenPlanCacheTotal();
                        }
                        
@@ -297,8 +308,8 @@ public class SpoofCompiler
                                ret = rewriteCSE.rewriteHopDAGs(ret, new 
ProgramRewriteStatus());       
                                
                                //explain after modification
-                               if( LDEBUG ) {
-                                       LOG.info("Codegen EXPLAIN (after 
optimize): \n"+Explain.explainHops(roots));
+                               if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Codegen EXPLAIN (after 
optimize): \n"+Explain.explainHops(roots));
                                }
                        }
                }
@@ -333,6 +344,7 @@ public class SpoofCompiler
                        case FUSE_NO_REDUNDANCY: 
                                return new PlanSelectionFuseNoRedundancy();
                        case FUSE_COST_BASED:
+                               return new PlanSelectionFuseCostBased();
                        default:        
                                throw new RuntimeException("Unsupported "
                                        + "plan selector: "+PLAN_SEL_POLICY);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c45bb41f/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 1b920cd..1a82f7d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -38,7 +38,7 @@ import 
org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 
 public class CPlanMemoTable 
 {
-       private static final Log LOG = 
LogFactory.getLog(SpoofCompiler.class.getName());
+       private static final Log LOG = 
LogFactory.getLog(CPlanMemoTable.class.getName());
        
        protected HashMap<Long, List<MemoTableEntry>> _plans;
        protected HashMap<Long, Hop> _hopRefs;
@@ -132,8 +132,8 @@ public class CPlanMemoTable
        }
 
        public void pruneSuboptimal(ArrayList<Hop> roots) {
-               if( SpoofCompiler.LDEBUG )
-                       LOG.info("#1: Memo before plan selection ("+size()+" 
plans)\n"+this);
+               if( LOG.isTraceEnabled() )
+                       LOG.trace("#1: Memo before plan selection ("+size()+" 
plans)\n"+this);
                
                //build index of referenced entries
                HashSet<Long> ix = new HashSet<Long>();
@@ -162,16 +162,16 @@ public class CPlanMemoTable
                for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
                        for( MemoTableEntry me : e.getValue() ) {
                                for( int i=0; i<=2; i++ )
-                                       if( me.isPlanRef(i) && 
_hopRefs.get(me.intput(i)).getParent().size()==1 )
-                                               
_plansBlacklist.add(me.intput(i));
+                                       if( me.isPlanRef(i) && 
_hopRefs.get(me.input(i)).getParent().size()==1 )
+                                               
_plansBlacklist.add(me.input(i));
                        }
                
                //core plan selection
                PlanSelection selector = SpoofCompiler.createPlanSelector();
                selector.selectPlans(this, roots);
                
-               if( SpoofCompiler.LDEBUG )
-                       LOG.info("#2: Memo after plan selection ("+size()+" 
plans)\n"+this);
+               if( LOG.isTraceEnabled() )
+                       LOG.trace("#2: Memo after plan selection ("+size()+" 
plans)\n"+this);
        }
        
        public List<MemoTableEntry> get(long hopID) {
@@ -206,6 +206,15 @@ public class CPlanMemoTable
                        p -> (p.type==pref) ? -1 : p.type.getRank()));
        }
        
+       public long[] getAllRefs(long hopID) {
+               long[] refs = new long[3];
+               for( MemoTableEntry me : get(hopID) )
+                       for( int i=0; i<3; i++ )
+                               if( me.isPlanRef(i) )
+                                       refs[i] |= me.input(i);
+               return refs;
+       }
+       
        public int size() {
                return _plans.values().stream()
                        .map(list -> list.size())
@@ -263,7 +272,7 @@ public class CPlanMemoTable
                                +  ((input2 >= 0) ? 1 : 0)
                                +  ((input3 >= 0) ? 1 : 0);
                }
-               public long intput(int index) {
+               public long input(int index) {
                        return (index==0) ? input1 : (index==1) ? input2 : 
input3;
                }
                public boolean subsumes(MemoTableEntry that) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c45bb41f/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
new file mode 100644
index 0000000..653f43b
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -0,0 +1,576 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map.Entry;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+
+
+/**
+ * This cost-based plan selection algorithm chooses fused operators
+ * based on the DAG structure and resulting overall costs. This includes
+ * decisions on materialization points, template types, and composed
+ * multi output templates. 
+ * 
+ */
+public class PlanSelectionFuseCostBased extends PlanSelection
+{      
+       private static final Log LOG = 
LogFactory.getLog(PlanSelectionFuseCostBased.class.getName());
+       
+       //common bandwidth characteristics, with a conservative write bandwidth 
in order 
+       //to cover result allocation, write into main memory, and potential 
evictions
+       private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024;  
//2GB/s
+       private static final double READ_BANDWIDTH = 32d*1024*1024*1024;  
//32GB/s
+       private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 
//2GFLOPs/core
+               * InfrastructureAnalyzer.getLocalParallelism();
+       
+       @Override
+       public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
+       {
+               //step 1: determine connected sub graphs of plans
+               Collection<HashSet<Long>> parts = getConnectedSubGraphs(memo, 
roots);
+               if( LOG.isTraceEnabled() )
+                       LOG.trace("Connected sub graphs: "+parts.size());
+               
+               for( HashSet<Long> partition : parts ) {
+                       //step 2: determine materialization points
+                       HashSet<Long> R = getPartitionRootNodes(memo, 
partition);
+                       if( LOG.isTraceEnabled() )
+                               LOG.trace("Partition root points: 
"+Arrays.toString(R.toArray(new Long[0])));
+                       ArrayList<Long> M = getMaterializationPoints(R, 
partition, memo);
+                       if( LOG.isTraceEnabled() )
+                               LOG.trace("Partition materialization points: 
"+Arrays.toString(M.toArray(new Long[0])));
+                       
+                       //step 3: plan enumeration and plan selection
+                       selectPlans(memo, partition, R, M);
+               }
+       
+               //take all distinct best plans
+               for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
+                       memo.setDistinct(e.getKey(), e.getValue());
+       }
+       
+       private static Collection<HashSet<Long>> 
getConnectedSubGraphs(CPlanMemoTable memo, ArrayList<Hop> roots) 
+       {
+               //build inverted index for 'referenced by' relationship 
+               HashMap<Long, HashSet<Long>> refBy = new HashMap<Long, 
HashSet<Long>>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo._plans.entrySet() )
+                       for( MemoTableEntry me : e.getValue() ) 
+                               for( int i=0; i<3; i++ )
+                                       if( me.isPlanRef(i) ) {
+                                               if( 
!refBy.containsKey(me.input(i)) )
+                                                       refBy.put(me.input(i), 
new HashSet<Long>());
+                                               
refBy.get(me.input(i)).add(e.getKey());
+                                       }
+               
+               //create a single partition per root node, if reachable over 
refBy of 
+               //other root node the resulting partition is empty and can be 
discarded
+               ArrayList<HashSet<Long>> parts = new ArrayList<HashSet<Long>>();
+               HashSet<Long> visited = new HashSet<Long>();
+               for( Entry<Long, List<MemoTableEntry>> e : 
memo._plans.entrySet() )
+                       if( !refBy.containsKey(e.getKey()) ) { //root node
+                               HashSet<Long> part = 
rGetConnectedSubGraphs(e.getKey(), 
+                                               memo, refBy, visited, new 
HashSet<Long>());
+                               if( !part.isEmpty() )
+                                       parts.add(part);
+                       }
+               
+               return parts;
+       }
+       
+       private static HashSet<Long> rGetConnectedSubGraphs(long hopID, 
CPlanMemoTable memo, 
+                       HashMap<Long, HashSet<Long>> refBy, HashSet<Long> 
visited, HashSet<Long> partition) 
+       {
+               if( visited.contains(hopID) )
+                       return partition;
+               
+               //process node itself w/ memoization
+               if( memo.contains(hopID) ) {
+                       partition.add(hopID);
+                       visited.add(hopID);     
+               }
+               
+               //recursively process parents
+               if( refBy.containsKey(hopID) )
+                       for( Long ref : refBy.get(hopID) )
+                               rGetConnectedSubGraphs(ref, memo, refBy, 
visited, partition);
+               
+               //recursively process children
+               if( memo.contains(hopID) ) {
+                       long[] refs = memo.getAllRefs(hopID);
+                       for( int i=0; i<3; i++ )
+                               if( refs[i] != -1 )
+                                       rGetConnectedSubGraphs(refs[i], memo, 
refBy, visited, partition);
+               }
+               
+               return partition;
+       }
+       
+       private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, 
HashSet<Long> partition) 
+       {
+               //build inverted index of references entries 
+               HashSet<Long> ix = new HashSet<Long>();
+               for( Long hopID : partition )
+                       if( memo.contains(hopID) )
+                               for( MemoTableEntry me : memo.get(hopID) ) {
+                                       ix.add(me.input1); 
+                                       ix.add(me.input2); 
+                                       ix.add(me.input3);
+                               }
+               
+               HashSet<Long> roots = new HashSet<Long>();
+               for( Long hopID : partition )
+                       if( !ix.contains(hopID) )
+                               roots.add(hopID);
+               return roots;
+       }
+       
+       private static ArrayList<Long> getMaterializationPoints(HashSet<Long> 
roots, 
+                       HashSet<Long> partition, CPlanMemoTable memo) 
+       {
+               //collect materialization points bottom-up
+               ArrayList<Long> ret = new ArrayList<Long>();
+               HashSet<Long> visited = new HashSet<Long>();
+               for( Long hopID : roots )
+                       rCollectMaterializationPoints(memo._hopRefs.get(hopID), 
+                                       visited, partition, ret);
+               
+               //remove special-case materialization points
+               Iterator<Long> iter = ret.iterator();
+               while(iter.hasNext()) {
+                       Long hopID = iter.next();
+                       //remove root nodes w/ multiple consumers
+                       if( roots.contains(hopID) )
+                               iter.remove();
+                       //remove tsmm input if consumed in partition
+                       else if( 
HopRewriteUtils.isTsmmInput(memo._hopRefs.get(hopID)))
+                               iter.remove();
+               }
+               
+               return ret;
+       }
+       
+       private static void rCollectMaterializationPoints(Hop current, 
HashSet<Long> visited, 
+                       HashSet<Long> partition, ArrayList<Long> M) 
+       {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rCollectMaterializationPoints(c, visited, partition, M);
+               
+               //collect materialization point
+               if( isMaterializationPointCandidate(current, partition) )
+                       M.add(current.getHopID());
+               
+               visited.add(current.getHopID());
+       }
+       
+       private static boolean isMaterializationPointCandidate(Hop hop, 
HashSet<Long> partition) {
+               return hop.getParent().size()>=2 
+                       && partition.contains(hop.getHopID());
+       }
+       
+       private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, 
HashSet<Long> R, ArrayList<Long> M) 
+       {
+               //if no materialization points, use basic fuse-all w/ partition 
awareness
+               if( M == null || M.isEmpty() ) {
+                       for( Long hopID : R )
+                               rSelectPlansFuseAll(memo, 
+                                       memo._hopRefs.get(hopID), null, 
partition);
+               }
+               else {
+                       //TODO branch and bound pruning, right now we use 
exhaustive enum for early experiments
+                       //via skip ahead in below enumeration algorithm
+                       
+                       //obtain hop compute costs per cell once
+                       HashMap<Long, Double> computeCosts = new HashMap<Long, 
Double>();
+                       for( Long hopID : R )
+                               rGetComputeCosts(memo._hopRefs.get(hopID), 
partition, computeCosts);
+                       
+                       //scan linearized search space, w/ skips for branch and 
bound pruning
+                       int len = (int)Math.pow(2, M.size());
+                       boolean[] bestPlan = null;
+                       double bestC = Double.MAX_VALUE;
+                       
+                       for( int i=0; i<len; i++ ) {
+                               //construct assignment
+                               boolean[] plan = createAssignment(M.size(), i);
+                               
+                               //cost assignment on hops
+                               double C = getPlanCost(memo, partition, R, M, 
plan, computeCosts);
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Enum: 
"+Arrays.toString(plan)+" -> "+C);
+                               
+                               //cost comparisons
+                               if( bestPlan == null || C < bestC ) {
+                                       bestC = C;
+                                       bestPlan = plan;
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Enum: Found new best 
plan.");
+                               }
+                       }
+                       
+                       //prune memo table wrt best plan and select plans
+                       HashSet<Long> visited = new HashSet<Long>();
+                       for( Long hopID : R )
+                               rPruneSuboptimalPlans(memo, 
memo._hopRefs.get(hopID), 
+                                       visited, partition, M, bestPlan);
+                       for( Long hopID : R )
+                               rSelectPlansFuseAll(memo, 
+                                       memo._hopRefs.get(hopID), null, 
partition);
+               }
+       }
+       
+       private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, 
boolean[] plan) {
+               //memoization (not via hops because in middle of dag)
+               if( visited.contains(current.getHopID()) )
+                       return;
+               
+               //remove memo table entries if necessary
+               long hopID = current.getHopID();
+               if( partition.contains(hopID) && memo.contains(hopID) ) {
+                       Iterator<MemoTableEntry> iter = 
memo.get(hopID).iterator();
+                       while( iter.hasNext() ) {
+                               MemoTableEntry me = iter.next();
+                               if( !hasNoRefToMaterialization(me, M, plan) ){
+                                       iter.remove();
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed memo table 
entry: "+me);
+                               }
+                       }
+               }
+               
+               //process children recursively
+               for( Hop c : current.getInput() )
+                       rPruneSuboptimalPlans(memo, c, visited, partition, M, 
plan);
+               
+               visited.add(current.getHopID());                
+       }
+       
+       private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
+       {       
+               if( isVisited(current.getHopID(), currentType) 
+                       || !partition.contains(current.getHopID()) )
+                       return;
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CellTpl)
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }       
+       
+       private static boolean[] createAssignment(int len, int pos) {
+               boolean[] ret = new boolean[len]; 
+               int tmp = pos;
+               for( int i=0; i<len; i++ ) {
+                       ret[i] = (tmp < (int)Math.pow(2, len-i-1));
+                       tmp %= Math.pow(2, len-i-1);
+               }
+               return ret;     
+       }
+       
+       /////////////////////////////////////////////////////////
+       // Cost model fused operators w/ materialization points
+       //////////
+       
+       private static double getPlanCost(CPlanMemoTable memo, HashSet<Long> 
partition, HashSet<Long> R, 
+                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts) 
+       {
+               //high level heuristic: every hop or fused operator has the 
following cost: 
+               //WRITE + min(COMPUTE, READ), where WRITE costs are given by 
the output size, 
+               //READ costs by the input sizes, and COMPUTE by operation 
specific FLOP
+               //counts times number of cells of main input, disregarding 
sparsity for now.
+               
+               HashSet<Long> visited = new HashSet<Long>();
+               double costs = 0;
+               for( Long hopID : R )
+                       costs += rGetPlanCosts(memo, memo._hopRefs.get(hopID), 
+                                       visited, partition, M, plan, 
computeCosts, null, null);         
+               return costs;
+       }
+       
+       private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, 
HashSet<Long> visited, HashSet<Long> partition, 
+                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts, OperatorStats costsCurrent, TemplateType currentType) 
+       {
+               if( visited.contains(current.getHopID()) )
+                       return 0; //dont double count 
+               
+               //open template if necessary, including memoization
+               //under awareness of current plan choice
+               MemoTableEntry best = null;
+               boolean opened = false;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .filter(p -> 
hasNoRefToMaterialization(p, M, plan))
+                                       .min(new 
BasicPlanComparator()).orElse(null);
+                               opened = true;
+                               visited.add(current.getHopID());
+                       }
+                       else {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CellTpl)
+                                       .filter(p -> 
hasNoRefToMaterialization(p, M, plan))
+                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+                                       .orElse(null);
+                       }
+               }
+               
+               //create new cost vector if opened, initialized with write costs
+               OperatorStats costVect = !opened ? costsCurrent : 
+                       new 
OperatorStats(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
+               
+               //add compute costs of current operator to costs vector 
+               if( partition.contains(current.getHopID()) )
+                       costVect.computeCosts += 
computeCosts.get(current.getHopID());
+               
+               //process children recursively
+               double costs = 0;
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       if( best!=null && best.isPlanRef(i) )
+                               costs += rGetPlanCosts(memo, c, visited, 
partition, M, plan, computeCosts, costVect, best.type);
+                       else { //include children and I/O costs
+                               costs += rGetPlanCosts(memo, c, visited, 
partition, M, plan, computeCosts, null, null);
+                               if( costVect != null )
+                                       costVect.addInputSize( c.getHopID(), 
Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
+                       }                               
+               }       
+               
+               //add costs for opened fused operator
+               if( partition.contains(current.getHopID()) ) {
+                       if( opened ) {
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Cost vector for fused 
operator: "+costVect);
+                               costs += costVect.outSize * 8 / 
WRITE_BANDWIDTH; //time for output write
+                               costs += Math.min(
+                                               
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH, 
+                                               costVect.getSumInputSizes() * 8 
/ READ_BANDWIDTH); 
+                       }
+                       //add costs for non-partition read in the middle of 
fused operator
+                       else if( hasNonPartitionConsumer(current, partition) ) {
+                               costs += rGetPlanCosts(memo, current, visited, 
partition, M, plan, computeCosts, null, null);
+                       }
+               }
+               
+               //sanity check non-negative costs
+               if( costs < 0 || Double.isNaN(costs) || 
Double.isInfinite(costs) )
+                       throw new RuntimeException("Wrong cost estimate: 
"+costs);
+               
+               return costs;
+       }
+       
+       private static void rGetComputeCosts(Hop current, HashSet<Long> 
partition, HashMap<Long, Double> computeCosts) 
+       {
+               if( computeCosts.containsKey(current.getHopID()) )
+                       return;
+               
+               //recursively process children
+               for( Hop c : current.getInput() )
+                       rGetComputeCosts(c, partition, computeCosts);
+               
+               //get costs for given hop
+               double costs = 0;
+               if( current instanceof UnaryOp ) {
+                       switch( ((UnaryOp)current).getOp() ) {
+                               case ABS:   
+                               case ROUND:
+                               case CEIL:
+                               case FLOOR:
+                               case SIGN:
+                               case SELP:   costs = 1; break; 
+                               case SPROP:
+                               case SQRT:   costs = 2; break;
+                               case EXP:    costs = 18; break;
+                               case LOG:    costs = 32; break;
+                               case NCOL:
+                               case NROW:
+                               case PRINT:
+                               case CAST_AS_BOOLEAN:
+                               case CAST_AS_DOUBLE:
+                               case CAST_AS_INT:
+                               case CAST_AS_MATRIX:
+                               case CAST_AS_SCALAR: costs = 1; break;
+                               case CUMSUM:
+                               case CUMMIN:
+                               case CUMMAX:
+                               case CUMPROD: costs = 1; break;
+                               default:
+                                       throw new RuntimeException("Cost model 
not "
+                                               + "implemented yet for: 
"+((UnaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof BinaryOp ) {
+                       switch( ((BinaryOp)current).getOp() ) {
+                               case MULT: 
+                               case PLUS:
+                               case MINUS:
+                               case EQUAL:
+                               case NOTEQUAL:
+                               case LESS:
+                               case LESSEQUAL:
+                               case GREATER:
+                               case GREATEREQUAL: 
+                               case CBIND:
+                               case RBIND: costs = 1; break;
+                               case DIV:   costs = 22; break;
+                               case LOG:   costs = 32; break;
+                               case POW:   costs = 
(HopRewriteUtils.isLiteralOfValue(
+                                               current.getInput().get(1), 2) ? 
1 : 16); break;
+                               default:
+                                       throw new RuntimeException("Cost model 
not "
+                                               + "implemented yet for: 
"+((BinaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof TernaryOp ) {
+                       switch( ((TernaryOp)current).getOp() ) {
+                               case PLUS_MULT: 
+                               case MINUS_MULT: costs = 2; break;
+                               default:
+                                       throw new RuntimeException("Cost model 
not "
+                                               + "implemented yet for: 
"+((TernaryOp)current).getOp());
+                       }
+               }
+               else if( current instanceof ParameterizedBuiltinOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof IndexingOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof ReorgOp ) {
+                       costs = 1;
+               }
+               else if( current instanceof AggBinaryOp ) {
+                       costs = 2; //matrix vector
+               }
+               else if( current instanceof AggUnaryOp) {
+                       switch(((AggUnaryOp)current).getOp()) {
+                       case SUM:    costs = 4; break; 
+                       case SUM_SQ: costs = 5; break;
+                       case MIN:
+                       case MAX:    costs = 1; break;
+                       default:
+                               throw new RuntimeException("Cost model not "
+                                       + "implemented yet for: 
"+((AggUnaryOp)current).getOp());                       
+                       }
+               }
+               
+               computeCosts.put(current.getHopID(), costs);
+       }
+       
+       private static boolean hasNoRefToMaterialization(MemoTableEntry me, 
ArrayList<Long> M, boolean[] plan) {
+               boolean ret = true;
+               for( int i=0; ret && i<3; i++ )
+                       ret &= (!M.contains(me.input(i)) || 
!plan[M.indexOf(me.input(i))]);
+               return ret;
+       }
+       
+       private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> 
partition) {
+               boolean ret = false;
+               for( Hop p : hop.getParent() )
+                       ret |= !partition.contains(p.getHopID());
+               return ret;
+       }
+       
+       private static class OperatorStats {
+               public final double outSize; 
+               public double computeCosts = 0;
+               public final HashMap<Long, Double> inSizes = new HashMap<Long, 
Double>();
+               
+               public OperatorStats(double outputSize) {
+                       outSize = outputSize;
+               }
+               public void addInputSize(long hopID, double inputSize) {
+                       //ensures that input sizes are not double counted
+                       inSizes.put(hopID, inputSize);
+               }
+               public double getSumInputSizes() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> d.doubleValue()).sum();
+               }
+               public double getMaxInputSize() {
+                       return inSizes.values().stream()
+                               .mapToDouble(d -> 
d.doubleValue()).max().orElse(0);
+               }
+               @Override
+               public String toString() {
+                       return "["+outSize+", "+computeCosts+", {"
+                               +Arrays.toString(inSizes.values().toArray(new 
Double[0]))+"}]";
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c45bb41f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index a13f165..21a9acb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -186,6 +186,12 @@ public class HopRewriteUtils
                return Long.MAX_VALUE;
        }
        
+       public static boolean isLiteralOfValue( Hop hop, double val ) {
+               return (hop instanceof LiteralOp 
+                       && (hop.getValueType()==ValueType.DOUBLE || 
hop.getValueType()==ValueType.INT)
+                       && getDoubleValueSafe((LiteralOp)hop)==val);
+       }
+       
        public static ScalarObject getScalarObject( LiteralOp op )
        {
                ScalarObject ret = null;
@@ -751,6 +757,15 @@ public class HopRewriteUtils
                        || isTransposeOperation(hop2) && hop2.getInput().get(0) 
== hop1;        
        }
        
+       public static boolean isTsmmInput(Hop input) {
+               if( input.getParent().size()==2 )
+                       for(int i=0; i<2; i++)
+                               if( isMatrixMultiply(input.getParent().get(i)) 
&& isTransposeOfItself(
+                                       
input.getParent().get(i).getInput().get(0), 
input.getParent().get(i).getInput().get(1)) )
+                                       return true;
+               return false;
+       }
+       
        public static boolean isBinary(Hop hop, OpOp2 type) {
                return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==type;
        }

Reply via email to