[SYSTEMML-1714] Extended codegen row template (multiple matrix inputs)

Given the recent generalization of vector primitives for scalar-vector
and sparse-unsafe operations, this patch now enables codegen row-wise
operations over multiple input matrices, which helps to reduces the
number of intermediates due to template switches between row-wise and
cell-wise templates.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e42133fe
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e42133fe
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e42133fe

Branch: refs/heads/master
Commit: e42133fecacc4c5b7e4192533e93a647abbb58b1
Parents: c17b8a8
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 24 13:41:17 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 24 13:50:37 2017 -0700

----------------------------------------------------------------------
 .../template/PlanSelectionFuseCostBased.java    | 13 +++++++++++-
 .../hops/codegen/template/TemplateBase.java     |  2 +-
 .../hops/codegen/template/TemplateRow.java      | 21 +++++++++++---------
 3 files changed, 25 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 742f4d6..5cc18ea 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -441,7 +441,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                for( Long hopID : R ) {
                        MemoTableEntry me = memo.getBest(hopID, 
TemplateType.RowTpl);
                        if( me.type == TemplateType.RowTpl && 
memo.contains(hopID, TemplateType.CellTpl)
-                               && rIsRowTemplateWithoutAgg(memo, 
memo._hopRefs.get(hopID), new HashSet<Long>())) {
+                               && isRowTemplateWithoutAgg(memo, 
memo._hopRefs.get(hopID), new HashSet<Long>())) {
                                List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.RowTpl); 
                                memo.remove(memo._hopRefs.get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
                                if( LOG.isTraceEnabled() ) {
@@ -523,6 +523,17 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                }
        }
        
+       private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited) {
+               //consider all aggregations other than root operation
+               MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.RowTpl);
+               boolean ret = true;
+               for(int i=0; i<3; i++)
+                       if( me.isPlanRef(i) )
+                               ret &= rIsRowTemplateWithoutAgg(memo, 
+                                       current.getInput().get(i), visited);
+               return ret;
+       }
+       
        private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, 
Hop current, HashSet<Long> visited) {
                if( visited.contains(current.getHopID()) )
                        return true;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
index f5527f5..f0fe3fa 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
@@ -28,8 +28,8 @@ public abstract class TemplateBase
        public enum TemplateType {
                //ordering specifies type preferences
                MultiAggTpl,
-               RowTpl,
                OuterProdTpl,
+               RowTpl,
                CellTpl;
                public int getRank() {
                        return this.ordinal();

http://git-wip-us.apache.org/repos/asf/systemml/blob/e42133fe/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index b3b4b8f..601d664 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -71,8 +71,8 @@ public class TemplateRow extends TemplateBase
        
        @Override
        public boolean open(Hop hop) {
-               return (hop instanceof BinaryOp && hop.dimsKnown() && 
hop.getInput().get(0).getDim2()>1 
-                               && hop.getInput().get(1).getDim2()==1 && 
TemplateCell.isValidOperation(hop)) 
+               return (hop instanceof BinaryOp && hop.dimsKnown() && 
isValidBinaryOperation(hop)
+                               && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol 
@@ -83,10 +83,7 @@ public class TemplateRow extends TemplateBase
        @Override
        public boolean fuse(Hop hop, Hop input) {
                return !isClosed() && 
-                       (  (hop instanceof BinaryOp && 
TemplateUtils.isOperationSupported(hop) 
-                               && 
(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
-                                       || 
HopRewriteUtils.isBinaryMatrixScalarOperation(hop)
-                                       || 
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) ) 
+                       (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop) ) 
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().indexOf(input)==0
                                && input.getDim2()==1 && 
hop.getInput().get(1).getDim2()==1
                                && 
HopRewriteUtils.isEmpty(hop.getInput().get(1)))
@@ -104,9 +101,7 @@ public class TemplateRow extends TemplateBase
        public boolean merge(Hop hop, Hop input) {
                //merge rowagg tpl with cell tpl if input is a vector
                return !isClosed() &&
-                       ((hop instanceof BinaryOp && 
TemplateUtils.isOperationSupported(hop)
-                               && (input.getDim2()==1 
//matrix-scalar/vector-vector ops )
-                                       || 
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)))
+                       ((hop instanceof BinaryOp && 
isValidBinaryOperation(hop))
                         ||(hop instanceof AggBinaryOp && input.getDim2()==1
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
        }
@@ -121,6 +116,14 @@ public class TemplateRow extends TemplateBase
                else
                        return CloseType.OPEN;
        }
+       
+       private boolean isValidBinaryOperation(Hop hop) {
+               //exclude unsupported and matrix-rowvector ops
+               return TemplateUtils.isOperationSupported(hop)
+                       && (HopRewriteUtils.isBinaryMatrixScalarOperation(hop)
+                       || HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
+                       || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop));
+       }
 
        @Override
        public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable 
memo, boolean compileLiterals) {

Reply via email to