[SYSTEMML-1717] Reduced alloc codegen row-wise vector intermediates This patch improves the decision on the number of required thread-local vector intermediates for codegen rowwise operations (similar in spirit to register allocation in PL compilers). So far we used the number of vector operations. Instead, we now identify unary vector pipelines which allow the reuse of vector intermediates in the ring buffer of allocated vectors. For long unary pipelines or larger vectors, this significantly reduces the L1 data cache misses.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/70f8f146 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/70f8f146 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/70f8f146 Branch: refs/heads/master Commit: 70f8f146bdcfb30b5c4f56cc03728468cc2770fc Parents: 7cc70b6 Author: Matthias Boehm <[email protected]> Authored: Thu Jun 22 13:16:40 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 24 13:50:35 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeBinary.java | 4 ++ .../hops/codegen/template/TemplateRow.java | 2 +- .../hops/codegen/template/TemplateUtils.java | 50 +++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 6e72ae1..3771de5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -183,6 +183,10 @@ public class CNodeBinary extends CNode throw new RuntimeException("Invalid binary type: "+this.toString()); } } + public boolean isVectorPrimitive() { + return isVectorScalarPrimitive() + || isVectorVectorPrimitive(); + } public boolean isVectorScalarPrimitive() { return this == VECT_DIV_SCALAR || this == VECT_MULT_SCALAR || this == VECT_MINUS_SCALAR || this == VECT_PLUS_SCALAR http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 0a1a651..b3b4b8f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -145,7 +145,7 @@ public class TemplateRow extends TemplateBase CNodeRow tpl = new CNodeRow(inputs, output); tpl.setRowType(TemplateUtils.getRowType(hop, sinHops[0])); tpl.setNumVectorIntermediates(TemplateUtils - .countVectorIntermediates(output, new HashSet<Long>())); + .determineMinVectorIntermediates(output)); tpl.getOutput().resetVisitStatus(); tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops[0].getHopID()); tpl.setBeginLine(hop.getBeginLine()); http://git-wip-us.apache.org/repos/asf/systemml/blob/70f8f146/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 6111e9d..fca203d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -350,19 +350,55 @@ public class TemplateUtils return ret; } - public static int countVectorIntermediates(CNode node, HashSet<Long> memo) { - //memoization to prevent double counting - if( memo.contains(node.getID()) ) + public static int determineMinVectorIntermediates(CNode node) { + node.resetVisitStatus(); + boolean unaryPipe = isUnaryOperatorPipeline(node); + node.resetVisitStatus(); + int count = unaryPipe ? getMaxVectorIntermediates(node) : + countVectorIntermediates(node); + node.resetVisitStatus(); + return count; + } + + public static boolean isUnaryOperatorPipeline(CNode node) { + if( node.isVisited() ) { + //second reference to vector intermediate invalidates a unary pipeline + return !((node instanceof CNodeBinary && ((CNodeBinary)node).getType().isVectorPrimitive()) + || (node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive())); + } + boolean ret = true; + for( CNode input : node.getInput() ) + ret &= isUnaryOperatorPipeline(input); + node.setVisited(); + return ret; + } + + public static int getMaxVectorIntermediates(CNode node) { + if( node.isVisited() ) + return 0; + int max = 0; + for( CNode input : node.getInput() ) + max = Math.max(max, getMaxVectorIntermediates(input)); + max = Math.max(max, (node instanceof CNodeBinary)? + ((CNodeBinary)node).getType().isVectorVectorPrimitive() ? 3 : + ((CNodeBinary)node).getType().isVectorScalarPrimitive() ? 2 : 0 : 0); + max = Math.max(max, (node instanceof CNodeUnary + && ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 2 : 0); + node.setVisited(); + return max; + } + + public static int countVectorIntermediates(CNode node) { + if( node.isVisited() ) return 0; - memo.add(node.getID()); + node.setVisited(); //compute vector requirements over all inputs int ret = 0; for( CNode c : node.getInput() ) - ret += countVectorIntermediates(c, memo); + ret += countVectorIntermediates(c); //compute vector requirements of current node int cntBin = (node instanceof CNodeBinary - && (((CNodeBinary)node).getType().isVectorScalarPrimitive() - || ((CNodeBinary)node).getType().isVectorVectorPrimitive())) ? 1 : 0; + && ((CNodeBinary)node).getType().isVectorPrimitive()) ? 1 : 0; int cntUn = (node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0; return ret + cntBin + cntUn;
