[SYSTEMML-1717] Reduced alloc codegen row-wise vector intermediates

This patch improves the decision on the number of required thread-local
vector intermediates for codegen rowwise operations (similar in spirit
to register allocation in PL compilers). So far we used the number of
vector operations. Instead, we now identify unary vector pipelines which
allow the reuse of vector intermediates in the ring buffer of allocated
vectors. For long unary pipelines or larger vectors, this significantly
reduces the L1 data cache misses.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/70f8f146
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/70f8f146
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/70f8f146

Branch: refs/heads/master
Commit: 70f8f146bdcfb30b5c4f56cc03728468cc2770fc
Parents: 7cc70b6
Author: Matthias Boehm <[email protected]>
Authored: Thu Jun 22 13:16:40 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 24 13:50:35 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   |  4 ++
 .../hops/codegen/template/TemplateRow.java      |  2 +-
 .../hops/codegen/template/TemplateUtils.java    | 50 +++++++++++++++++---
 3 files changed, 48 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 6e72ae1..3771de5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -183,6 +183,10 @@ public class CNodeBinary extends CNode
                                        throw new RuntimeException("Invalid 
binary type: "+this.toString());
                        }
                }
+               public boolean isVectorPrimitive() {
+                       return isVectorScalarPrimitive() 
+                               || isVectorVectorPrimitive();
+               }
                public boolean isVectorScalarPrimitive() {
                        return this == VECT_DIV_SCALAR || this == 
VECT_MULT_SCALAR 
                                || this == VECT_MINUS_SCALAR || this == 
VECT_PLUS_SCALAR

http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 0a1a651..b3b4b8f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -145,7 +145,7 @@ public class TemplateRow extends TemplateBase
                CNodeRow tpl = new CNodeRow(inputs, output);
                tpl.setRowType(TemplateUtils.getRowType(hop, sinHops[0]));
                tpl.setNumVectorIntermediates(TemplateUtils
-                       .countVectorIntermediates(output, new HashSet<Long>()));
+                       .determineMinVectorIntermediates(output));
                tpl.getOutput().resetVisitStatus();
                tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), 
sinHops[0].getHopID());
                tpl.setBeginLine(hop.getBeginLine());

http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 6111e9d..fca203d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -350,19 +350,55 @@ public class TemplateUtils
                return ret;
        }
        
-       public static int countVectorIntermediates(CNode node, HashSet<Long> 
memo) {
-               //memoization to prevent double counting
-               if( memo.contains(node.getID()) )
+       public static int determineMinVectorIntermediates(CNode node) {
+               node.resetVisitStatus();
+               boolean unaryPipe = isUnaryOperatorPipeline(node);
+               node.resetVisitStatus();
+               int count = unaryPipe ? getMaxVectorIntermediates(node) :
+                       countVectorIntermediates(node);
+               node.resetVisitStatus();
+               return count;
+       }
+       
+       public static boolean isUnaryOperatorPipeline(CNode node) {
+               if( node.isVisited() ) {
+                       //second reference to vector intermediate invalidates a 
unary pipeline
+                       return !((node instanceof CNodeBinary && 
((CNodeBinary)node).getType().isVectorPrimitive())
+                               || (node instanceof CNodeUnary && 
((CNodeUnary)node).getType().isVectorScalarPrimitive()));
+               }
+               boolean ret = true;
+               for( CNode input : node.getInput() )
+                       ret &= isUnaryOperatorPipeline(input);
+               node.setVisited();
+               return ret;
+       }
+       
+       public static int getMaxVectorIntermediates(CNode node) {
+               if( node.isVisited() )
+                       return 0;
+               int max = 0;
+               for( CNode input : node.getInput() )
+                       max = Math.max(max, getMaxVectorIntermediates(input));
+               max = Math.max(max, (node instanceof CNodeBinary)? 
+                       ((CNodeBinary)node).getType().isVectorVectorPrimitive() 
? 3 :
+                       ((CNodeBinary)node).getType().isVectorScalarPrimitive() 
? 2 : 0 : 0);
+               max = Math.max(max, (node instanceof CNodeUnary 
+                       && 
((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 2 : 0);
+               node.setVisited();
+               return max;
+       }
+       
+       public static int countVectorIntermediates(CNode node) {
+               if( node.isVisited() )
                        return 0;
-               memo.add(node.getID());
+               node.setVisited();
                //compute vector requirements over all inputs
                int ret = 0;
                for( CNode c : node.getInput() )
-                       ret += countVectorIntermediates(c, memo);
+                       ret += countVectorIntermediates(c);
                //compute vector requirements of current node
                int cntBin = (node instanceof CNodeBinary 
-                       && 
(((CNodeBinary)node).getType().isVectorScalarPrimitive() 
-                       || 
((CNodeBinary)node).getType().isVectorVectorPrimitive())) ? 1 : 0;
+                       && ((CNodeBinary)node).getType().isVectorPrimitive()) ? 
1 : 0;
                int cntUn = (node instanceof CNodeUnary
                                && 
((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0;
                return ret + cntBin + cntUn;

Reply via email to