Repository: systemml
Updated Branches:
  refs/heads/master c2124544d -> c170374e7


[MINOR][SYSTEMML-1741] Improved graph traversal on codegen plan costing 

This patch makes a minor improvement of the codegen graph traversal for
plan costing by stopping at partition boundaries instead of traversing
to the leafs without taking the costs into account. On lenet this led to
a minor but consistent improvement of codegen compilation overheads from
14.4s to 13.1s (for enumerating 474,840 plans).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0221fbcd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0221fbcd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0221fbcd

Branch: refs/heads/master
Commit: 0221fbcd0d3c09c0aef37ee7e77863ecdc0446b4
Parents: c212454
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Mon Aug 7 13:17:32 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Aug 9 13:52:49 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/opt/PlanSelection.java   |  3 +-
 .../opt/PlanSelectionFuseCostBasedV2.java       | 50 +++++++++-----------
 2 files changed, 24 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0221fbcd/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
index 23cc128..d18d156 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -98,8 +98,7 @@ public abstract class PlanSelection
                                return Integer.compare(o1.type.getRank(), 
o2.type.getRank());
                        
                        //for same type, prefer plan with more refs
-                       return Integer.compare(
-                               3-o1.countPlanRefs(), 3-o2.countPlanRefs());
+                       return Integer.compare(-o1.countPlanRefs(), 
-o2.countPlanRefs());
                }
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/0221fbcd/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 2fa0de7..717a059 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -744,11 +744,9 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                //memoization per hop id and cost vector to account for 
redundant
                //computation without double counting materialized results or 
compute
                //costs of complex operation DAGs within a single fused operator
-               VisitMarkCost tag = new VisitMarkCost(current.getHopID(), 
-                       (costsCurrent==null || 
currentType==TemplateType.MAGG)?0:costsCurrent.ID);
-               if( visited.contains(tag) )
-                       return 0;
-               visited.add(tag);
+               if( !visited.add(new VisitMarkCost(current.getHopID(), 
+                       (costsCurrent==null || 
currentType==TemplateType.MAGG)?0:costsCurrent.ID)) )
+                       return 0; //already existing 
                
                //open template if necessary, including memoization
                //under awareness of current plan choice
@@ -792,8 +790,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                }
                
                //add compute costs of current operator to costs vector
-               if( part.getPartition().contains(current.getHopID()) )
-                       costVect.computeCosts += 
computeCosts.get(current.getHopID());
+               costVect.computeCosts += computeCosts.get(current.getHopID());
                
                //process children recursively
                for( int i=0; i< current.getInput().size(); i++ ) {
@@ -803,34 +800,33 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        else if( best!=null && isImplicitlyFused(current, i, 
best.type) )
                                
costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
                        else { //include children and I/O costs
-                               costs += rGetPlanCosts(memo, c, visited, part, 
matPoints, plan, computeCosts, null, null);
+                               if( part.getPartition().contains(c.getHopID()) )
+                                       costs += rGetPlanCosts(memo, c, 
visited, part, matPoints, plan, computeCosts, null, null);
                                if( costVect != null && 
c.getDataType().isMatrix() )
                                        costVect.addInputSize(c.getHopID(), 
getSize(c));
                        }
                }
                
                //add costs for opened fused operator
-               if( part.getPartition().contains(current.getHopID()) ) {
-                       if( opened ) {
-                               if( LOG.isTraceEnabled() ) {
-                                       String type = (best !=null) ? 
best.type.name() : "HOP";
-                                       LOG.trace("Cost vector ("+type+" 
"+current.getHopID()+"): "+costVect);
-                               }
-                               double tmpCosts = costVect.outSize * 8 / 
WRITE_BANDWIDTH //time for output write
-                                       + Math.max(costVect.getSumInputSizes() 
* 8 / READ_BANDWIDTH,
-                                       
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
-                               //sparsity correction for outer-product 
template (and sparse-safe cell)
-                               if( best != null && best.type == 
TemplateType.OUTER ) {
-                                       Hop driver = 
memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
-                                       tmpCosts *= driver.dimsKnown(true) ? 
driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
-                               }
-                               costs += tmpCosts;
+               if( opened ) {
+                       if( LOG.isTraceEnabled() ) {
+                               String type = (best !=null) ? best.type.name() 
: "HOP";
+                               LOG.trace("Cost vector ("+type+" 
"+current.getHopID()+"): "+costVect);
                        }
-                       //add costs for non-partition read in the middle of 
fused operator
-                       else if( 
part.getExtConsumed().contains(current.getHopID()) ) {
-                               costs += rGetPlanCosts(memo, current, visited,
-                                       part, matPoints, plan, computeCosts, 
null, null);
+                       double tmpCosts = costVect.outSize * 8 / 
WRITE_BANDWIDTH //time for output write
+                               + Math.max(costVect.getSumInputSizes() * 8 / 
READ_BANDWIDTH,
+                               
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
+                       //sparsity correction for outer-product template (and 
sparse-safe cell)
+                       if( best != null && best.type == TemplateType.OUTER ) {
+                               Hop driver = 
memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+                               tmpCosts *= driver.dimsKnown(true) ? 
driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
                        }
+                       costs += tmpCosts;
+               }
+               //add costs for non-partition read in the middle of fused 
operator
+               else if( part.getExtConsumed().contains(current.getHopID()) ) {
+                       costs += rGetPlanCosts(memo, current, visited,
+                               part, matPoints, plan, computeCosts, null, 
null);
                }
                
                //sanity check non-negative costs

Reply via email to