[SYSTEMML-1933] Generalized codegen cbind handling (part 2), cleanups This patch finalizes the codegen cbind generalization. We now do not just fuse cbinds w/ constant vectors but arbitrary vector inputs. This significantly extended its applicability and also revealed a number of smaller robustness issues that needed fixing (e.g., row type selection, row indexing on main input, switch from row to cell template).
On GLM-probit (100M x 10, 20/10 iterations) this patch improved end-to-end performance (w/ codegen) from 337s to 185s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/328e8a00 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/328e8a00 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/328e8a00 Branch: refs/heads/master Commit: 328e8a0020c17c072f13d9a1bc9334af968b9c2b Parents: 682fc44 Author: Matthias Boehm <[email protected]> Authored: Tue Sep 26 22:24:54 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Sep 26 23:38:54 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 3 +-- .../sysml/hops/codegen/cplan/CNodeBinary.java | 3 ++- .../sysml/hops/codegen/cplan/CNodeUnary.java | 4 ++- .../sysml/hops/codegen/opt/PlanSelection.java | 12 ++++++--- .../opt/PlanSelectionFuseCostBasedV2.java | 14 ++++++----- .../hops/codegen/template/TemplateRow.java | 26 ++++++++++++++------ .../hops/codegen/template/TemplateUtils.java | 3 ++- 7 files changed, 43 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 1db2910..a4a68bb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -718,8 +718,7 @@ public class SpoofCompiler //remove invalid row templates (e.g., unsatisfied blocksize constraint) if( tpl instanceof CNodeRow ) { //check for invalid row cplan over column vector - if(in1.getNumCols() == 1 || (((CNodeRow)tpl).getRowType()==RowType.NO_AGG - && tpl.getOutput().getDataType().isScalar()) ) { + if( ((CNodeRow)tpl).getRowType()==RowType.NO_AGG && tpl.getOutput().getDataType().isScalar() ) { cplans2.remove(e.getKey()); if( LOG.isTraceEnabled() ) LOG.trace("Removed invalid row cplan w/o agg on column vector."); http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index c2b5644..42a36ac 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -270,7 +270,8 @@ public class CNodeBinary extends CNode //generate binary operation (use sparse template, if data input) boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData - && _inputs.get(0).getVarname().startsWith("a") + && (_inputs.get(0).getVarname().startsWith("a") + || _inputs.get(1).getVarname().startsWith("a")) && !_inputs.get(0).isLiteral()); boolean scalarInput = _inputs.get(0).getDataType().isScalar(); boolean scalarVector = (_inputs.get(0).getDataType().isScalar() http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index b3720dd..860d35a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -87,7 +87,9 @@ public class CNodeUnary extends CNode case EXP: return " double %TMP% = FastMath.exp(%IN1%);\n"; case LOOKUP_R: - return " double %TMP% = getValue(%IN1%, rowIndex);\n"; + return sparse ? + " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" : + " double %TMP% = getValue(%IN1%, rowIndex);\n"; case LOOKUP_C: return " double %TMP% = getValue(%IN1%, n, 0, colIndex);\n"; case LOOKUP_RC: http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java index d18d156..21f4fd3 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java @@ -47,18 +47,22 @@ public abstract class PlanSelection * @param memo partial fusion plans P * @param roots entry points of HOP DAG G */ - public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots); + public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots); /** - * Determines if the given partial fusion plan is valid. + * Determines if the given partial fusion plan is a valid entry point + * of a fused operator. * * @param me memo table entry * @param hop current hop * @return true if entry is valid as top-level plan */ public static boolean isValid(MemoTableEntry me, Hop hop) { - return (me.type != TemplateType.OUTER //ROW, CELL, MAGG - || (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop))); + return (me.type == TemplateType.CELL) + || (me.type == TemplateType.MAGG) + || (me.type == TemplateType.ROW && !HopRewriteUtils.isTransposeOperation(hop)) + || (me.type == TemplateType.OUTER + && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop))); } protected void addBestPlan(long hopID, MemoTableEntry me) { http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 7c27dcf..8d1c4c0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -43,6 +43,7 @@ import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; @@ -568,18 +569,18 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } } - private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + private static boolean isRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { //consider all aggregations other than root operation MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); boolean ret = true; for(int i=0; i<3; i++) if( me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAgg(memo, + ret &= rIsRowTemplateWithoutAggOrVects(memo, current.getInput().get(i), visited); return ret; } - private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { if( visited.contains(current.getHopID()) ) return true; @@ -587,8 +588,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); for(int i=0; i<3; i++) if( me!=null && me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited); - ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp); + ret &= rIsRowTemplateWithoutAggOrVects(memo, current.getInput().get(i), visited); + ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp + || HopRewriteUtils.isBinary(current, OpOp2.CBIND)); visited.add(current.getHopID()); return ret; @@ -628,7 +630,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection for( Long hopID : part.getPartition() ) { MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) - && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { + && isRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist)); if( LOG.isTraceEnabled() ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index d9209be..1aaa84f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -78,7 +78,7 @@ public class TemplateRow extends TemplateBase return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop) && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() - && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) + && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide( @@ -101,9 +101,9 @@ public class TemplateRow extends TemplateBase return !isClosed() && ( (hop instanceof BinaryOp && isValidBinaryOperation(hop) ) || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0 - && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) + && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) - && TemplateCell.isValidOperation(hop)) + && TemplateCell.isValidOperation(hop)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() == Direction.RowCol @@ -121,7 +121,9 @@ public class TemplateRow extends TemplateBase //merge rowagg tpl with cell tpl if input is a vector return !isClosed() && ((hop instanceof BinaryOp && isValidBinaryOperation(hop) - && hop.getDim1() > 1 && input.getDim1()>1) + && hop.getDim1() > 1 && input.getDim1()>1) + || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() + && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1))) ||(hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) && (input.getDim2()==1 || (input==hop.getInput().get(1) @@ -184,6 +186,7 @@ public class TemplateRow extends TemplateBase Hop[] sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) .sorted(new HopInputComparator(inHops2.get("X"),inHops2.get("B1"))).toArray(Hop[]::new); + inHops2.putIfAbsent("X", sinHops[0]); //robustness special cases //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); @@ -326,10 +329,19 @@ public class TemplateRow extends TemplateBase { //special case for cbind with zeros CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); - CNode cdata2 = TemplateUtils.createCNodeData( - HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true); + CNode cdata2 = null; + if( HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)) ) { + cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils + .getDataGenOpConstantValue(hop.getInput().get(1)), true); + inHops.remove(hop.getInput().get(1)); //rm 0-matrix + } + else { + cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); + } out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND); - inHops.remove(hop.getInput().get(1)); //rm 0-matrix + if( cdata1 instanceof CNodeData ) + inHops2.put("X", hop.getInput().get(0)); } else if(hop instanceof BinaryOp) { http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 9d7baf9..21f44b2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -235,7 +235,8 @@ public class TemplateUtils } public static boolean isLookup(CNode node, boolean includeRC1) { - return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC) + return isUnary(node, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC) + || (includeRC1 && isUnary(node, UnaryType.LOOKUP_R)) || (includeRC1 && isTernary(node, TernaryType.LOOKUP_RC1)); }
