[SYSTEMML-1714] Extended codegen row template (multiple matrix inputs) Given the recent generalization of vector primitives for scalar-vector and sparse-unsafe operations, this patch now enables codegen row-wise operations over multiple input matrices, which helps to reduces the number of intermediates due to template switches between row-wise and cell-wise templates.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e42133fe Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e42133fe Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e42133fe Branch: refs/heads/master Commit: e42133fecacc4c5b7e4192533e93a647abbb58b1 Parents: c17b8a8 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 24 13:41:17 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 24 13:50:37 2017 -0700 ---------------------------------------------------------------------- .../template/PlanSelectionFuseCostBased.java | 13 +++++++++++- .../hops/codegen/template/TemplateBase.java | 2 +- .../hops/codegen/template/TemplateRow.java | 21 +++++++++++--------- 3 files changed, 25 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 742f4d6..5cc18ea 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -441,7 +441,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection for( Long hopID : R ) { MemoTableEntry me = memo.getBest(hopID, TemplateType.RowTpl); if( me.type == TemplateType.RowTpl && memo.contains(hopID, TemplateType.CellTpl) - && rIsRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) { + && isRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) { List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.RowTpl); memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(blacklist)); if( LOG.isTraceEnabled() ) { @@ -523,6 +523,17 @@ public class PlanSelectionFuseCostBased extends PlanSelection } } + private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + //consider all aggregations other than root operation + MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); + boolean ret = true; + for(int i=0; i<3; i++) + if( me.isPlanRef(i) ) + ret &= rIsRowTemplateWithoutAgg(memo, + current.getInput().get(i), visited); + return ret; + } + private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { if( visited.contains(current.getHopID()) ) return true; http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index f5527f5..f0fe3fa 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -28,8 +28,8 @@ public abstract class TemplateBase public enum TemplateType { //ordering specifies type preferences MultiAggTpl, - RowTpl, OuterProdTpl, + RowTpl, CellTpl; public int getRank() { return this.ordinal(); http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index b3b4b8f..601d664 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -71,8 +71,8 @@ public class TemplateRow extends TemplateBase @Override public boolean open(Hop hop) { - return (hop instanceof BinaryOp && hop.dimsKnown() && hop.getInput().get(0).getDim2()>1 - && hop.getInput().get(1).getDim2()==1 && TemplateCell.isValidOperation(hop)) + return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop) + && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol @@ -83,10 +83,7 @@ public class TemplateRow extends TemplateBase @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && - ( (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) - && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) - || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) - || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) ) + ( (hop instanceof BinaryOp && isValidBinaryOperation(hop) ) || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0 && input.getDim2()==1 && hop.getInput().get(1).getDim2()==1 && HopRewriteUtils.isEmpty(hop.getInput().get(1))) @@ -104,9 +101,7 @@ public class TemplateRow extends TemplateBase public boolean merge(Hop hop, Hop input) { //merge rowagg tpl with cell tpl if input is a vector return !isClosed() && - ((hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) - && (input.getDim2()==1 //matrix-scalar/vector-vector ops ) - || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop))) + ((hop instanceof BinaryOp && isValidBinaryOperation(hop)) ||(hop instanceof AggBinaryOp && input.getDim2()==1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } @@ -121,6 +116,14 @@ public class TemplateRow extends TemplateBase else return CloseType.OPEN; } + + private boolean isValidBinaryOperation(Hop hop) { + //exclude unsupported and matrix-rowvector ops + return TemplateUtils.isOperationSupported(hop) + && (HopRewriteUtils.isBinaryMatrixScalarOperation(hop) + || HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) + || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)); + } @Override public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
