Repository: systemml
Updated Branches:
  refs/heads/master 92034e64f -> 9481bef4e


[SYSTEMML-2157] Fix codegen optimizer (suboptimal plans after row2cell)

The cost-based codegen optimizer converts all partial row fusion plans
into cell plans if none of the operations requires access to entire
rows. However, the existing implementation of this pre-processing step
led to suboptimal plans for special cases. This patch completely reworks
this analysis step, which also improves its performance by using a
single pass over the sub-DAG of each fusion partition. We now also
properly track all operations and plans, where this row2cell conversion
is inapplicable. Finally, the row template has been extended to allow
unary operations in opening conditions (unless these operations work
over row vectors).

Together, these modifications led to a runtime improvement for auto
encoder over mnist1m from 446s to 373s (~600s without codegen).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14ea51be
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14ea51be
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14ea51be

Branch: refs/heads/master
Commit: 14ea51be70ede04dfd3d351205b5ab19f1109d91
Parents: 92034e6
Author: Matthias Boehm <[email protected]>
Authored: Wed Feb 21 21:12:50 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Feb 21 21:12:50 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |   6 +-
 .../opt/PlanSelectionFuseCostBasedV2.java       | 100 +++++++++++++------
 .../hops/codegen/template/CPlanMemoTable.java   |   9 ++
 .../hops/codegen/template/TemplateRow.java      |  10 +-
 4 files changed, 89 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index a1401c3..d7721a1 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -216,10 +216,12 @@ public class CNodeUnary extends CNode
                String varj = _inputs.get(0).getVarname();
                
                //replace sparse and dense inputs
+               boolean vectIn = varj.startsWith("b") && 
!_type.isScalarLookup();
                tmp = tmp.replace("%IN1v%", varj+"vals");
                tmp = tmp.replace("%IN1i%", varj+"ix");
-               tmp = tmp.replace("%IN1%", varj.startsWith("b") && 
!_type.isScalarLookup()
-                       && TemplateUtils.isMatrix(_inputs.get(0)) ? varj + 
".values(rix)" : varj );
+               tmp = tmp.replace("%IN1%", 
+                       (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? 
varj + ".values(rix)" :
+                       (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? 
varj + ".values(0)" : varj));
                
                //replace start position of main input
                String spos = (_inputs.get(0) instanceof CNodeData 

http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 84e4b4c..6ed562a 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -34,6 +34,7 @@ import java.util.stream.Collectors;
 
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
@@ -46,6 +47,8 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpN;
+import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
@@ -635,37 +638,71 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                }
        }
        
-       private static HashSet<Long> getRowAggOpsWithRowRef(CPlanMemoTable 
memo, PlanPartition part) {
-               HashSet<Long> refAggs = new HashSet<>();
-               for( Long hopID : part.getPartition() ) {
-                       if( !memo.contains(hopID, TemplateType.ROW) ) continue;
-                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
-                       for(int i=0; i<3; i++)
-                               if( me.isPlanRef(i) && 
memo.contains(me.input(i), TemplateType.ROW) 
-                                       && 
isRowAggOp(memo.getHopRefs().get(me.input(i))))
-                                       refAggs.add(me.input(i));
+       private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable 
memo, PlanPartition part) {
+               //get row entries that are (a) reachable from rowwise ops (top 
down) other than
+               //operator root nodes, or dependent upon row-wise ops (bottom 
up)
+               HashSet<Long> blacklist = new HashSet<>();
+               HashSet<Pair<Long, Integer>> visited = new HashSet<>();
+               for( Long hopID : part.getRoots() ) {
+                       rCollectDependentRowOps(memo.getHopRefs().get(hopID),
+                               memo, part, blacklist, visited, null, false);
                }
-               return refAggs;
+               return blacklist;
        }
        
-       private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable 
memo, Hop current, HashSet<Long> visited, boolean inclRoot) {
-               if( visited.contains(current.getHopID()) )
-                       return true;
+       private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable 
memo, PlanPartition part,
+               HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, 
TemplateType type, boolean foundRowOp) 
+       {
+               //avoid redundant evaluation of processed and non-partition 
nodes
+               Pair<Long, Integer> key = Pair.of(hop.getHopID(),
+                       (foundRowOp?Short.MAX_VALUE:0) + 
((type!=null)?type.ordinal()+1:0));
+               if( visited.contains(key) || 
!part.getPartition().contains(hop.getHopID()) ) {
+                       return;
+               }
+               
+               //process node itself (top-down)
+               MemoTableEntry me = (type == null) ? 
memo.getBest(hop.getHopID()) :
+                       memo.getBest(hop.getHopID(), type);
+               boolean inRow = (me != null && me.type == TemplateType.ROW && 
type == TemplateType.ROW);
+               boolean diffPlans = part.getMatPointsExt().length > 0 //guard 
against plan differences
+                       && memo.contains(hop.getHopID(), TemplateType.ROW)
+                       && !memo.hasOnlyExactMatches(hop.getHopID(), 
TemplateType.ROW, TemplateType.CELL);
+               if( inRow && foundRowOp )
+                       blacklist.add(hop.getHopID());
+               if( isRowAggOp(hop, inRow) || diffPlans ) { 
+                       blacklist.add(hop.getHopID());
+                       foundRowOp = true;
+               }
                
-               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
-               boolean ret = !inclRoot || !isRowAggOp(current);
-               for(int i=0; i<3 && ret; i++)
-                       if( me!=null && me.isPlanRef(i) )
-                               ret &= rIsRowTemplateWithoutAggOrVects(memo, 
-                                       current.getInput().get(i), visited, 
true);
+               //process children recursively
+               for( int i=0; i<hop.getInput().size(); i++ ) {
+                       boolean lfoundRowOp = foundRowOp && me != null 
+                               && (me.isPlanRef(i) || isImplicitlyFused(hop, 
i, me.type));
+                       rCollectDependentRowOps(hop.getInput().get(i), memo,
+                               part, blacklist, visited, 
me!=null?me.type:null, lfoundRowOp);
+               }
                
-               visited.add(current.getHopID());
-               return ret;
+               //process node itself (bottom-up)
+               if( !blacklist.contains(hop.getHopID()) ) {
+                       for( int i=0; i<hop.getInput().size(); i++ )
+                               if( me != null && me.type == TemplateType.ROW
+                                       && (me.isPlanRef(i) || 
isImplicitlyFused(hop, i, me.type))
+                                       && 
blacklist.contains(hop.getInput().get(i).getHopID()) ) {
+                                       blacklist.add(hop.getHopID());
+                               }
+               }
+               
+               visited.add(key);
        }
        
-       private static boolean isRowAggOp(Hop hop){
-               return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp
-                       || HopRewriteUtils.isBinary(hop, OpOp2.CBIND));
+       private static boolean isRowAggOp(Hop hop, boolean inRow) {
+               return HopRewriteUtils.isBinary(hop, OpOp2.CBIND)
+                       || HopRewriteUtils.isNary(hop, OpOpN.CBIND)
+                       || (hop instanceof AggBinaryOp && (inRow || 
!hop.dimsKnown()
+                               || (hop.getDim1()!=1 && hop.getDim2()!=1)))
+                       || (HopRewriteUtils.isReorg(hop, ReOrgOp.TRANSPOSE) 
+                               && (hop.getDim1()!=1 && hop.getDim2()!=1))
+                       || (hop instanceof AggUnaryOp && inRow);
        }
        
        private static boolean isValidRow2CellOp(Hop hop) {
@@ -704,16 +741,19 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                }
                
                //prune row aggregates with pure cellwise operations
-               HashSet<Long> refAggs = getRowAggOpsWithRowRef(memo, part);
+               //(we determine a blacklist of all operators in a partition 
that either
+               //depend upon row aggregates or on which row aggregates depend)
+               HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, 
part);
                for( Long hopID : part.getPartition() ) {
+                       if( blacklist.contains(hopID) ) continue;
                        MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
-                       if( me != null && me.type == TemplateType.ROW && 
memo.contains(hopID, me, TemplateType.CELL)
-                               && rIsRowTemplateWithoutAggOrVects(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>(), refAggs.contains(hopID)) ) {
-                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
-                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<>(blacklist));
+                       if( me != null && me.type == TemplateType.ROW
+                               && memo.hasOnlyExactMatches(hopID, 
TemplateType.ROW, TemplateType.CELL) ) {
+                               List<MemoTableEntry> rmList = memo.get(hopID, 
TemplateType.ROW); 
+                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<>(rmList));
                                if( LOG.isTraceEnabled() ) {
                                        LOG.trace("Removed row memo table 
entries w/o aggregation: "
-                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                                               + 
Arrays.toString(rmList.toArray(new MemoTableEntry[0])));
                                }
                        }
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 0c3bb90..5c90ca0 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -107,6 +107,15 @@ public class CPlanMemoTable
                                && p.isValid() && !types.contains(p.type));
        }
        
+       public boolean hasOnlyExactMatches(long hopID, TemplateType type1, 
TemplateType type2) {
+               List<MemoTableEntry> l1 = get(hopID, type1);
+               List<MemoTableEntry> l2 = get(hopID, type2);
+               boolean ret = l1.size() == l2.size();
+               for( MemoTableEntry me : l1 )
+                       ret &= l2.stream().anyMatch(p -> p.equalPlanRefs(me));
+               return ret;
+       }
+       
        public int countEntries(long hopID) {
                return get(hopID).size();
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/14ea51be/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index d54cf63..6c141ed 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -82,6 +82,8 @@ public class TemplateRow extends TemplateBase
        public boolean open(Hop hop) {
                return (hop instanceof BinaryOp && hop.dimsKnown() && 
isValidBinaryOperation(hop)
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
+                       || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp)
+                               && TemplateCell.isValidOperation(hop) && 
hop.getDim1() > 1)
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
@@ -95,7 +97,7 @@ public class TemplateRow extends TemplateBase
                                && hop.getParent().get(0) instanceof 
AggBinaryOp && hop.getParent().get(0).dimsKnown()
                                && 
hop.getParent().get(0).getInput().indexOf(hop) == 0
                                && 
isFuseSkinnyMatrixMult(hop.getParent().get(0)))
-                       || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol 
+                       || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1
                                && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
                        || (hop instanceof IndexingOp && 
hop.getInput().get(0).getDim2() >= 0
@@ -337,7 +339,7 @@ public class TemplateRow extends TemplateBase
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
                        
                        // if one input is a matrix then we need to do vector 
by scalar operations
-                       if(hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1 
+                       if(hop.getInput().get(0).getDim1() >= 1 && 
hop.getInput().get(0).getDim2() > 1 
                                || (!hop.dimsKnown() && 
cdata1.getDataType()==DataType.MATRIX ) ) 
                        {
                                if( HopRewriteUtils.isUnary(hop, 
SUPPORTED_VECT_UNARY) ) {
@@ -381,8 +383,8 @@ public class TemplateRow extends TemplateBase
                        CNode cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
                        
                        // if one input is a matrix then we need to do vector 
by scalar operations
-                       if( (hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1)
-                               || (hop.getInput().get(1).getDim1() > 1 && 
hop.getInput().get(1).getDim2() > 1)
+                       if( (hop.getInput().get(0).getDim1() >= 1 && 
hop.getInput().get(0).getDim2() > 1)
+                               || (hop.getInput().get(1).getDim1() >= 1 && 
hop.getInput().get(1).getDim2() > 1)
                                || (!(hop.dimsKnown() && 
hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown())
                                        && (hop.getDim2() != 1) //not a known 
vector output
                                        && (cdata1.getDataType().isMatrix() || 
cdata2.getDataType().isMatrix())))

Reply via email to