[SYSTEMML-1933] Generalized codegen cbind handling (part 2), cleanups

This patch finalizes the codegen cbind generalization. We now do not
just fuse cbinds w/ constant vectors but arbitrary vector inputs. This
significantly extended its applicability and also revealed a number of
smaller robustness issues that needed fixing (e.g., row type selection,
row indexing on main input, switch from row to cell template). 

On GLM-probit (100M x 10, 20/10 iterations) this patch improved
end-to-end performance (w/ codegen) from 337s to 185s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/328e8a00
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/328e8a00
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/328e8a00

Branch: refs/heads/master
Commit: 328e8a0020c17c072f13d9a1bc9334af968b9c2b
Parents: 682fc44
Author: Matthias Boehm <[email protected]>
Authored: Tue Sep 26 22:24:54 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Sep 26 23:38:54 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  3 +--
 .../sysml/hops/codegen/cplan/CNodeBinary.java   |  3 ++-
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  4 ++-
 .../sysml/hops/codegen/opt/PlanSelection.java   | 12 ++++++---
 .../opt/PlanSelectionFuseCostBasedV2.java       | 14 ++++++-----
 .../hops/codegen/template/TemplateRow.java      | 26 ++++++++++++++------
 .../hops/codegen/template/TemplateUtils.java    |  3 ++-
 7 files changed, 43 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 1db2910..a4a68bb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -718,8 +718,7 @@ public class SpoofCompiler
                        //remove invalid row templates (e.g., unsatisfied 
blocksize constraint)
                        if( tpl instanceof CNodeRow ) {
                                //check for invalid row cplan over column vector
-                               if(in1.getNumCols() == 1 || 
(((CNodeRow)tpl).getRowType()==RowType.NO_AGG
-                                       && 
tpl.getOutput().getDataType().isScalar()) ) {
+                               if( 
((CNodeRow)tpl).getRowType()==RowType.NO_AGG && 
tpl.getOutput().getDataType().isScalar() ) {
                                        cplans2.remove(e.getKey());
                                        if( LOG.isTraceEnabled() )
                                                LOG.trace("Removed invalid row 
cplan w/o agg on column vector.");

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index c2b5644..42a36ac 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -270,7 +270,8 @@ public class CNodeBinary extends CNode
                
                //generate binary operation (use sparse template, if data input)
                boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData 
-                       && _inputs.get(0).getVarname().startsWith("a")
+                       && (_inputs.get(0).getVarname().startsWith("a")
+                               || _inputs.get(1).getVarname().startsWith("a"))
                        && !_inputs.get(0).isLiteral());
                boolean scalarInput = _inputs.get(0).getDataType().isScalar();
                boolean scalarVector = (_inputs.get(0).getDataType().isScalar()

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index b3720dd..860d35a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -87,7 +87,9 @@ public class CNodeUnary extends CNode
                                case EXP:
                                        return "    double %TMP% = 
FastMath.exp(%IN1%);\n";
                            case LOOKUP_R:
-                               return "    double %TMP% = getValue(%IN1%, 
rowIndex);\n";
+                               return sparse ?
+                                       "    double %TMP% = getValue(%IN1v%, 
%IN1i%, ai, alen, 0);\n" :
+                                       "    double %TMP% = getValue(%IN1%, 
rowIndex);\n";
                            case LOOKUP_C:
                                return "    double %TMP% = getValue(%IN1%, n, 
0, colIndex);\n";
                            case LOOKUP_RC:

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
index d18d156..21f4fd3 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -47,18 +47,22 @@ public abstract class PlanSelection
         * @param memo partial fusion plans P
         * @param roots entry points of HOP DAG G
         */
-       public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> 
roots);    
+       public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> 
roots);
        
        /**
-        * Determines if the given partial fusion plan is valid.
+        * Determines if the given partial fusion plan is a valid entry point
+        * of a fused operator.
         * 
         * @param me memo table entry
         * @param hop current hop
         * @return true if entry is valid as top-level plan
         */
        public static boolean isValid(MemoTableEntry me, Hop hop) {
-               return (me.type != TemplateType.OUTER //ROW, CELL, MAGG
-                       || (me.closed || 
HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
+               return (me.type == TemplateType.CELL)
+                       || (me.type == TemplateType.MAGG)
+                       || (me.type == TemplateType.ROW && 
!HopRewriteUtils.isTransposeOperation(hop))
+                       || (me.type == TemplateType.OUTER 
+                               && (me.closed || 
HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
        }
        
        protected void addBestPlan(long hopID, MemoTableEntry me) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 7c27dcf..8d1c4c0 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -43,6 +43,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
@@ -568,18 +569,18 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                }
        }
        
-       private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited) {
+       private static boolean isRowTemplateWithoutAggOrVects(CPlanMemoTable 
memo, Hop current, HashSet<Long> visited) {
                //consider all aggregations other than root operation
                MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
                boolean ret = true;
                for(int i=0; i<3; i++)
                        if( me.isPlanRef(i) )
-                               ret &= rIsRowTemplateWithoutAgg(memo, 
+                               ret &= rIsRowTemplateWithoutAggOrVects(memo, 
                                        current.getInput().get(i), visited);
                return ret;
        }
        
-       private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, 
Hop current, HashSet<Long> visited) {
+       private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable 
memo, Hop current, HashSet<Long> visited) {
                if( visited.contains(current.getHopID()) )
                        return true;
                
@@ -587,8 +588,9 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
                for(int i=0; i<3; i++)
                        if( me!=null && me.isPlanRef(i) )
-                               ret &= rIsRowTemplateWithoutAgg(memo, 
current.getInput().get(i), visited);
-               ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp);
+                               ret &= rIsRowTemplateWithoutAggOrVects(memo, 
current.getInput().get(i), visited);
+               ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp
+                       || HopRewriteUtils.isBinary(current, OpOp2.CBIND));
                
                visited.add(current.getHopID());
                return ret;
@@ -628,7 +630,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                for( Long hopID : part.getPartition() ) {
                        MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
                        if( me != null && me.type == TemplateType.ROW && 
memo.contains(hopID, TemplateType.CELL)
-                               && isRowTemplateWithoutAgg(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+                               && isRowTemplateWithoutAggOrVects(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
                                List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
                                memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
                                if( LOG.isTraceEnabled() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index d9209be..1aaa84f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -78,7 +78,7 @@ public class TemplateRow extends TemplateBase
                return (hop instanceof BinaryOp && hop.dimsKnown() && 
isValidBinaryOperation(hop)
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix()
-                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
+                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
LibMatrixMult.isSkinnyRightHandSide(
@@ -101,9 +101,9 @@ public class TemplateRow extends TemplateBase
                return !isClosed() && 
                        (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop) ) 
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().indexOf(input)==0
-                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
+                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
-                                       && TemplateCell.isValidOperation(hop))
+                               && TemplateCell.isValidOperation(hop))
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
                                && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection() == Direction.RowCol 
@@ -121,7 +121,9 @@ public class TemplateRow extends TemplateBase
                //merge rowagg tpl with cell tpl if input is a vector
                return !isClosed() &&
                        ((hop instanceof BinaryOp && isValidBinaryOperation(hop)
-                               && hop.getDim1() > 1 && input.getDim1()>1) 
+                               && hop.getDim1() > 1 && input.getDim1()>1)
+                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix()
+                               && hop.dimsKnown() && 
TemplateUtils.isColVector(hop.getInput().get(1)))
                         ||(hop instanceof AggBinaryOp
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))
                                && (input.getDim2()==1 || 
(input==hop.getInput().get(1) 
@@ -184,6 +186,7 @@ public class TemplateRow extends TemplateBase
                Hop[] sinHops = inHops.stream()
                        .filter(h -> !(h.getDataType().isScalar() && 
tmp.get(h.getHopID()).isLiteral()))
                        .sorted(new 
HopInputComparator(inHops2.get("X"),inHops2.get("B1"))).toArray(Hop[]::new);
+               inHops2.putIfAbsent("X", sinHops[0]); //robustness special cases
                
                //construct template node
                ArrayList<CNode> inputs = new ArrayList<CNode>();
@@ -326,10 +329,19 @@ public class TemplateRow extends TemplateBase
                {
                        //special case for cbind with zeros
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
-                       CNode cdata2 = TemplateUtils.createCNodeData(
-                               
HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
+                       CNode cdata2 = null;
+                       if( 
HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)) ) {
+                               cdata2 = 
TemplateUtils.createCNodeData(HopRewriteUtils
+                                       
.getDataGenOpConstantValue(hop.getInput().get(1)), true);
+                               inHops.remove(hop.getInput().get(1)); //rm 
0-matrix
+                       }
+                       else {
+                               cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
+                               cdata2 = 
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
+                       }
                        out = new CNodeBinary(cdata1, cdata2, 
BinType.VECT_CBIND);
-                       inHops.remove(hop.getInput().get(1)); //rm 0-matrix
+                       if( cdata1 instanceof CNodeData )
+                               inHops2.put("X", hop.getInput().get(0));
                }
                else if(hop instanceof BinaryOp)
                {

http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 9d7baf9..21f44b2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -235,7 +235,8 @@ public class TemplateUtils
        }
        
        public static boolean isLookup(CNode node, boolean includeRC1) {
-               return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, 
UnaryType.LOOKUP_RC)
+               return isUnary(node, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC)
+                       || (includeRC1 && isUnary(node, UnaryType.LOOKUP_R))
                        || (includeRC1 && isTernary(node, 
TernaryType.LOOKUP_RC1));
        }
        

Reply via email to