[SYSTEMML-1943] Fix codegen fuse_all optimizer and consolidation

This patch fixes special cases of row operations that caused the
fuse_all optimizer fail on Kmeans. Furthermore, this also includes a
cleanup for consolidating the fuse-all selection of plans as used in the
fuse_all and both cost-based optimizers.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c27c488b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c27c488b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c27c488b

Branch: refs/heads/master
Commit: c27c488bef54887d549792c4cf6532d95c3f5c58
Parents: 8ed2516
Author: Matthias Boehm <[email protected]>
Authored: Sun Oct 1 20:04:39 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Oct 2 00:39:21 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/conf/DMLConfig.java   |  2 +-
 .../apache/sysml/hops/codegen/SpoofFusedOp.java | 11 +++++
 .../sysml/hops/codegen/cplan/CNodeRow.java      |  3 +-
 .../sysml/hops/codegen/opt/PlanSelection.java   | 46 +++++++++++++++++++
 .../hops/codegen/opt/PlanSelectionFuseAll.java  | 47 +-------------------
 .../codegen/opt/PlanSelectionFuseCostBased.java | 45 +------------------
 .../opt/PlanSelectionFuseCostBasedV2.java       | 47 +-------------------
 .../hops/codegen/template/TemplateUtils.java    |  2 +
 .../sysml/runtime/codegen/SpoofRowwise.java     | 25 ++++++-----
 9 files changed, 79 insertions(+), 149 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java 
b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 6a331a6..9835b4d 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -127,7 +127,7 @@ public class DMLConfig
                _defaultVals.put(COMPRESSED_LINALG,      
Compression.CompressConfig.AUTO.name() );
                _defaultVals.put(CODEGEN,                "false" );
                _defaultVals.put(CODEGEN_COMPILER,       
CompilerType.AUTO.name() );
-               _defaultVals.put(CODEGEN_COMPILER,       
PlanSelector.FUSE_COST_BASED_V2.name() );
+               _defaultVals.put(CODEGEN_OPTIMIZER,      
PlanSelector.FUSE_COST_BASED_V2.name() );
                _defaultVals.put(CODEGEN_PLANCACHE,      "true" );
                _defaultVals.put(CODEGEN_LITERALS,       "1" );
                _defaultVals.put(NATIVE_BLAS,            "none" );

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 81b226d..56bfb61 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -42,6 +42,7 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                ROW_DIMS,
                COLUMN_DIMS_ROWS,
                COLUMN_DIMS_COLS,
+               RANK_DIMS_COLS,
                SCALAR,
                MULTI_SCALAR,
                ROW_RANK_DIMS, // right wdivmm, row mm
@@ -163,6 +164,12 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                case COLUMN_DIMS_COLS:
                                        ret = new long[]{1, mc.getCols(), -1};
                                        break;
+                               case RANK_DIMS_COLS: {
+                                       MatrixCharacteristics mc2 = 
memo.getAllInputStats(getInput().get(1));
+                                       if( mc2.dimsKnown() )
+                                               ret = new long[]{1, 
mc2.getCols(), -1};
+                                       break;
+                               }
                                case INPUT_DIMS:
                                        ret = new long[]{mc.getRows(), 
mc.getCols(), -1};
                                        break;
@@ -219,6 +226,10 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                setDim1(1);
                                setDim2(getInput().get(0).getDim2());
                                break;
+                       case RANK_DIMS_COLS:
+                               setDim1(1);
+                               setDim2(getInput().get(1).getDim2());
+                               break;
                        case INPUT_DIMS:
                                setDim1(getInput().get(0).getDim1());
                                setDim2(getInput().get(0).getDim2());

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index 07822d9..9235216 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -158,7 +158,8 @@ public class CNodeRow extends CNodeTpl
                        case COL_AGG:      return 
SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
                        case COL_AGG_T:    return 
SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
                        case COL_AGG_B1:   return 
SpoofOutputDimsType.COLUMN_RANK_DIMS; 
-                       case COL_AGG_B1_T: return 
SpoofOutputDimsType.COLUMN_RANK_DIMS_T; 
+                       case COL_AGG_B1_T: return 
SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
+                       case COL_AGG_B1R:  return 
SpoofOutputDimsType.RANK_DIMS_COLS;
                        default:
                                throw new RuntimeException("Unsupported row 
type: "+_type.toString());
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
index 21f4fd3..4cf56c4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -34,6 +34,9 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 
 public abstract class PlanSelection 
 {
+       private static final BasicPlanComparator BASE_COMPARE = new 
BasicPlanComparator();
+       private final TypedPlanComparator _typedCompare = new 
TypedPlanComparator();
+       
        private final HashMap<Long, List<MemoTableEntry>> _bestPlans = 
                        new HashMap<Long, List<MemoTableEntry>>();
        private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
@@ -84,6 +87,49 @@ public abstract class PlanSelection
                _visited.add(new VisitMark(hopID, type));
        }
        
+       protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
+       {       
+               if( isVisited(current.getHopID(), currentType) 
+                       || (partition!=null && 
!partition.contains(current.getHopID())) )
+                       return;
+               
+               //step 1: prune subsumed plans of same type
+               if( memo.contains(current.getHopID()) ) {
+                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
+                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
+                       for( MemoTableEntry e1 : hopP )
+                               for( MemoTableEntry e2 : hopP )
+                                       if( e1 != e2 && e1.subsumes(e2) )
+                                               rmSet.add(e2);
+                       memo.remove(current, rmSet);
+               }
+               
+               //step 2: select plan for current path
+               MemoTableEntry best = null;
+               if( memo.contains(current.getHopID()) ) {
+                       if( currentType == null ) {
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> isValid(p, current))
+                                       .min(BASE_COMPARE).orElse(null);
+                       }
+                       else {
+                               _typedCompare.setType(currentType);
+                               best = memo.get(current.getHopID()).stream()
+                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
+                                       .min(_typedCompare).orElse(null);
+                       }
+                       addBestPlan(current.getHopID(), best);
+               }
+               
+               //step 3: recursively process children
+               for( int i=0; i< current.getInput().size(); i++ ) {
+                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
+                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
+               }
+               
+               setVisited(current.getHopID(), currentType);
+       }
+       
        /**
         * Basic plan comparator to compare memo table entries with regard to
         * a pre-defined template preference order and the number of references.

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
index 8636bea..3e0561d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
@@ -20,15 +20,12 @@
 package org.apache.sysml.hops.codegen.opt;
 
 import java.util.ArrayList;
-import java.util.Comparator;
 import java.util.Map.Entry;
-import java.util.HashSet;
 import java.util.List;
 
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 
 /**
  * This plan selection heuristic aims for maximal fusion, which
@@ -43,52 +40,10 @@ public class PlanSelectionFuseAll extends PlanSelection
        public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
                //pruning and collection pass
                for( Hop hop : roots )
-                       rSelectPlans(memo, hop, null);
+                       rSelectPlansFuseAll(memo, hop, null, null);
                
                //take all distinct best plans
                for( Entry<Long, List<MemoTableEntry>> e : 
getBestPlans().entrySet() )
                        memo.setDistinct(e.getKey(), e.getValue());
        }
-       
-       private void rSelectPlans(CPlanMemoTable memo, Hop current, 
TemplateType currentType) 
-       {       
-               if( isVisited(current.getHopID(), currentType) )
-                       return;
-               
-               //step 1: prune subsumed plans of same type
-               if( memo.contains(current.getHopID()) ) {
-                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
-                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
-                       for( MemoTableEntry e1 : hopP )
-                               for( MemoTableEntry e2 : hopP )
-                                       if( e1 != e2 && e1.subsumes(e2) )
-                                               rmSet.add(e2);
-                       memo.remove(current, rmSet);
-               }
-               
-               //step 2: select plan for current path
-               MemoTableEntry best = null;
-               if( memo.contains(current.getHopID()) ) {
-                       if( currentType == null ) {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> isValid(p, current))
-                                       .min(new 
BasicPlanComparator()).orElse(null);
-                       }
-                       else {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
-                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
-                                       .orElse(null);
-                       }
-                       addBestPlan(current.getHopID(), best);
-               }
-               
-               //step 3: recursively process children
-               for( int i=0; i< current.getInput().size(); i++ ) {
-                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
-                       rSelectPlans(memo, current.getInput().get(i), pref);
-               }
-               
-               setVisited(current.getHopID(), currentType);
-       }       
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
index acb90e2..f67604d 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -507,52 +507,9 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        }
                }
                
-               visited.add(current.getHopID());                
+               visited.add(current.getHopID());
        }
        
-       private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
-       {       
-               if( isVisited(current.getHopID(), currentType) 
-                       || !partition.contains(current.getHopID()) )
-                       return;
-               
-               //step 1: prune subsumed plans of same type
-               if( memo.contains(current.getHopID()) ) {
-                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
-                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
-                       for( MemoTableEntry e1 : hopP )
-                               for( MemoTableEntry e2 : hopP )
-                                       if( e1 != e2 && e1.subsumes(e2) )
-                                               rmSet.add(e2);
-                       memo.remove(current, rmSet);
-               }
-               
-               //step 2: select plan for current path
-               MemoTableEntry best = null;
-               if( memo.contains(current.getHopID()) ) {
-                       if( currentType == null ) {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> isValid(p, current))
-                                       .min(new 
BasicPlanComparator()).orElse(null);
-                       }
-                       else {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
-                                       .min(Comparator.comparing(p -> 
7-((p.type==currentType)?4:0)-p.countPlanRefs()))
-                                       .orElse(null);
-                       }
-                       addBestPlan(current.getHopID(), best);
-               }
-               
-               //step 3: recursively process children
-               for( int i=0; i< current.getInput().size(); i++ ) {
-                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
-                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
-               }
-               
-               setVisited(current.getHopID(), currentType);
-       }       
-       
        private static boolean[] createAssignment(int len, int pos) {
                boolean[] ret = new boolean[len]; 
                int tmp = pos;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 8d1c4c0..31e8427 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -98,8 +98,6 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
        
        private static final IDSequence COST_ID = new IDSequence();
        private static final TemplateRow ROW_TPL = new TemplateRow();
-       private static final BasicPlanComparator BASE_COMPARE = new 
BasicPlanComparator();
-       private final TypedPlanComparator _typedCompare = new 
TypedPlanComparator();
        
        @Override
        public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
@@ -726,50 +724,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        }
                }
                
-               visited.add(current.getHopID());                
-       }
-       
-       private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, 
TemplateType currentType, HashSet<Long> partition) 
-       {       
-               if( isVisited(current.getHopID(), currentType) 
-                       || !partition.contains(current.getHopID()) )
-                       return;
-               
-               //step 1: prune subsumed plans of same type
-               if( memo.contains(current.getHopID()) ) {
-                       HashSet<MemoTableEntry> rmSet = new 
HashSet<MemoTableEntry>();
-                       List<MemoTableEntry> hopP = 
memo.get(current.getHopID());
-                       for( MemoTableEntry e1 : hopP )
-                               for( MemoTableEntry e2 : hopP )
-                                       if( e1 != e2 && e1.subsumes(e2) )
-                                               rmSet.add(e2);
-                       memo.remove(current, rmSet);
-               }
-               
-               //step 2: select plan for current path
-               MemoTableEntry best = null;
-               if( memo.contains(current.getHopID()) ) {
-                       if( currentType == null ) {
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> isValid(p, current))
-                                       .min(BASE_COMPARE).orElse(null);
-                       }
-                       else {
-                               _typedCompare.setType(currentType);
-                               best = memo.get(current.getHopID()).stream()
-                                       .filter(p -> p.type==currentType || 
p.type==TemplateType.CELL)
-                                       .min(_typedCompare).orElse(null);
-                       }
-                       addBestPlan(current.getHopID(), best);
-               }
-               
-               //step 3: recursively process children
-               for( int i=0; i< current.getInput().size(); i++ ) {
-                       TemplateType pref = (best!=null && best.isPlanRef(i))? 
best.type : null;
-                       rSelectPlansFuseAll(memo, current.getInput().get(i), 
pref, partition);
-               }
-               
-               setVisited(current.getHopID(), currentType);
+               visited.add(current.getHopID());
        }
        
        /////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 06d83bd..4dc0bf2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -204,6 +204,8 @@ public class TemplateUtils
                        return RowType.COL_AGG_B1_T;
                else if( B1 != null && output.getDim1()==B1.getDim2() && 
output.getDim2()==X.getDim2())
                        return RowType.COL_AGG_B1;
+               else if( B1 != null && output.getDim1()==1 && B1.getDim2() == 
output.getDim2() )
+                       return RowType.COL_AGG_B1R;
                else if( X.getDim1() == output.getDim1() && X.getDim2() != 
output.getDim2() )
                        return RowType.NO_AGG_CONST;
                else

http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java 
b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 8b12e7e..311c27f 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -47,22 +47,25 @@ public abstract class SpoofRowwise extends SpoofOperator
        private static final long serialVersionUID = 6242910797139642998L;
        
        public enum RowType {
-               NO_AGG,    //no aggregation
-               NO_AGG_B1, //no aggregation w/ matrix mult B1
+               NO_AGG,       //no aggregation
+               NO_AGG_B1,    //no aggregation w/ matrix mult B1
                NO_AGG_CONST, //no aggregation w/ expansion/contraction
-               FULL_AGG,  //full row/col aggregation
-               ROW_AGG,   //row aggregation (e.g., rowSums() or X %*% v)
-               COL_AGG,   //col aggregation (e.g., colSums() or t(y) %*% X)
-               COL_AGG_T, //transposed col aggregation (e.g., t(X) %*% y)
+               FULL_AGG,     //full row/col aggregation
+               ROW_AGG,      //row aggregation (e.g., rowSums() or X %*% v)
+               COL_AGG,      //col aggregation (e.g., colSums() or t(y) %*% X)
+               COL_AGG_T,    //transposed col aggregation (e.g., t(X) %*% y)
                COL_AGG_B1,   //col aggregation w/ matrix mult B1
-               COL_AGG_B1_T; //transposed col aggregation w/ matrix mult B1
+               COL_AGG_B1_T, //transposed col aggregation w/ matrix mult B1
+               COL_AGG_B1R;  //col aggregation w/ matrix mult B1 to row vector
                
                public boolean isColumnAgg() {
-                       return (this == COL_AGG || this == COL_AGG_T)
-                               || (this == COL_AGG_B1) || (this == 
COL_AGG_B1_T);
+                       return this == COL_AGG || this == COL_AGG_T
+                               || this == COL_AGG_B1 || this == COL_AGG_B1_T
+                               || this == COL_AGG_B1R;
                }
                public boolean isRowTypeB1() {
-                       return (this == NO_AGG_B1) || (this == COL_AGG_B1) || 
(this == COL_AGG_B1_T);
+                       return this == NO_AGG_B1 || this == COL_AGG_B1 
+                               || this == COL_AGG_B1_T || this == COL_AGG_B1R;
                }
                public boolean isRowTypeB1ColumnAgg() {
                        return (this == COL_AGG_B1) || (this == COL_AGG_B1_T);
@@ -268,7 +271,7 @@ public abstract class SpoofRowwise extends SpoofOperator
                        case COL_AGG_T:    out.reset(n, 1, false); break;
                        case COL_AGG_B1:   out.reset(n2, n, false); break;
                        case COL_AGG_B1_T: out.reset(n, n2, false); break;
-                       
+                       case COL_AGG_B1R:  out.reset(1, n2, false); break;
                }
                out.allocateDenseBlock();
        }

Reply via email to