Repository: systemml
Updated Branches:
  refs/heads/master 46f4d9207 -> a5a4d4d33


[SYSTEMML-2328] Improved codegen row template tmp vector allocation

The row template uses a preallocated ring buffer of temporary vectors
per thread to avoid excessive GC. So far the size of this ring buffer
was determined in a best effort manner by handle unary pipelines and
otherwise counting the number of vector intermediates per operations.
However, for large fused operators this is wasteful as most temporary
vectors are only used for a very localized scope. This patch introduces
a primitive to determine the exact maximum number of live vector
intermediates needed for a fused operator. Overall this improves GC
overheads and cache locality.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5a4d4d3
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5a4d4d3
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5a4d4d3

Branch: refs/heads/master
Commit: a5a4d4d335a8a2440d236e70b5b92cd8ca8f28e0
Parents: 46f4d92
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu May 17 16:30:11 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu May 17 17:22:01 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       | 22 ++++---
 .../hops/codegen/template/TemplateRow.java      |  3 +-
 .../hops/codegen/template/TemplateUtils.java    | 66 ++++++++++++++++++--
 3 files changed, 77 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 6ef8dc4..368dc94 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -102,14 +102,15 @@ public class SpoofCompiler
        private static final Log LOG = 
LogFactory.getLog(SpoofCompiler.class.getName());
        
        //internal configuration flags
-       public static boolean LDEBUG                      = false;
-       public static CompilerType JAVA_COMPILER          = 
CompilerType.JANINO; 
-       public static PlanSelector PLAN_SEL_POLICY        = 
PlanSelector.FUSE_COST_BASED_V2; 
-       public static IntegrationType INTEGRATION         = 
IntegrationType.RUNTIME;
-       public static final boolean RECOMPILE_CODEGEN     = true;
-       public static final boolean PRUNE_REDUNDANT_PLANS = true;
-       public static PlanCachePolicy PLAN_CACHE_POLICY   = 
PlanCachePolicy.CSLH;
-       public static final int PLAN_CACHE_SIZE           = 1024; //max 1K 
classes 
+       public static boolean LDEBUG                       = false;
+       public static CompilerType JAVA_COMPILER           = 
CompilerType.JANINO; 
+       public static PlanSelector PLAN_SEL_POLICY         = 
PlanSelector.FUSE_COST_BASED_V2; 
+       public static IntegrationType INTEGRATION          = 
IntegrationType.RUNTIME;
+       public static final boolean RECOMPILE_CODEGEN      = true;
+       public static final boolean PRUNE_REDUNDANT_PLANS  = true;
+       public static PlanCachePolicy PLAN_CACHE_POLICY    = 
PlanCachePolicy.CSLH;
+       public static final int PLAN_CACHE_SIZE            = 1024; //max 1K 
classes
+       public static final RegisterAlloc REG_ALLOC_POLICY = 
RegisterAlloc.EXACT;
        
        public enum CompilerType {
                AUTO,
@@ -148,6 +149,11 @@ public class SpoofCompiler
                }
        }
        
+       public enum RegisterAlloc {
+               HEURISTIC,
+               EXACT,
+       }
+       
        static {
                // for internal debugging only
                if( LDEBUG ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index fb153de..95be74a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -235,7 +235,8 @@ public class TemplateRow extends TemplateBase
                if( tpl.getRowType().isConstDim2(n2) )
                        tpl.setConstDim2(n2);
                tpl.setNumVectorIntermediates(TemplateUtils
-                       .determineMinVectorIntermediates(output));
+                       .determineMinVectorIntermediates(output,
+                       inputs.isEmpty() ? null : inputs.get(0)));
                tpl.getOutput().resetVisitStatus();
                tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), 
sinHops[0].getHopID());
                tpl.setBeginLine(hop.getBeginLine());

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 4a61678..3a0b1ed 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -21,8 +21,12 @@ package org.apache.sysml.hops.codegen.template;
 
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
 
 import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
@@ -34,6 +38,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.SpoofCompiler;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
@@ -353,12 +358,23 @@ public class TemplateUtils
                return ret;
        }
        
-       public static int determineMinVectorIntermediates(CNode node) {
+       public static int determineMinVectorIntermediates(CNode node, CNode 
main) {
                node.resetVisitStatus();
-               boolean unaryPipe = isUnaryOperatorPipeline(node);
-               node.resetVisitStatus();
-               int count = unaryPipe ? getMaxVectorIntermediates(node) :
-                       countVectorIntermediates(node);
+               int count = -1;
+               switch( SpoofCompiler.REG_ALLOC_POLICY ) {
+                       case HEURISTIC:
+                               boolean unaryPipe = 
isUnaryOperatorPipeline(node);
+                               node.resetVisitStatus();
+                               count = unaryPipe ? 
getMaxVectorIntermediates(node) :
+                                       countVectorIntermediates(node);
+                               break;
+                       case EXACT:
+                               Map<Long, Set<Long>> parents = 
getAllParents(node);
+                               node.resetVisitStatus();
+                               count = getMaxLiveVectorIntermediates(
+                                       node, main, parents, new HashSet<>());
+                               break;
+               }
                node.resetVisitStatus();
                return count;
        }
@@ -411,6 +427,46 @@ public class TemplateUtils
                                && 
((CNodeTernary)node).getType().isVectorPrimitive()) ? 1 : 0;
                return ret + cntBin + cntUn + cntTn;
        }
+       
+       
+       
+       public static int getMaxLiveVectorIntermediates(CNode node, CNode main, 
Map<Long, Set<Long>> parents, Set<Pair<Long, Long>> stack) {
+               if( node.isVisited() )
+                       return -1;
+               //recursively process inputs
+               int max = -1;
+               for( CNode c : node.getInput() )
+                       max = Math.max(max, getMaxLiveVectorIntermediates(c, 
main, parents, stack));
+               // add current node consumers
+               if( !node.getDataType().isScalar() && 
parents.containsKey(node.getID())
+                       && node != main ) {
+                       for( Long pID : parents.get(node.getID()) )
+                               stack.add(Pair.of(pID, node.getID()));
+               }
+               //get current maximum (distinct dep targets)
+               max = Math.max(max, (int)stack.stream()
+                       .map(p -> p.getValue()).distinct().count());
+               //remove input dependencies
+               for( CNode c : node.getInput() )
+                       stack.remove(Pair.of(node.getID(), c.getID()));
+               node.setVisited();
+               return max;
+       }
+       
+       public static Map<Long, Set<Long>> getAllParents(CNode node) {
+               Map<Long, Set<Long>> ret = new HashMap<>();
+               getAllParents(node, ret);
+               return ret;
+       }
+       
+       public static void getAllParents(CNode node, Map<Long, Set<Long>> 
parents) {
+               for( CNode c : node.getInput() ) {
+                       if( !parents.containsKey(c) )
+                               parents.put(c.getID(), new HashSet<>());
+                       parents.get(c.getID()).add(node.getID());
+                       getAllParents(c, parents);
+               }
+       }
 
        public static boolean isType(TemplateType type, TemplateType... 
validTypes) {
                return ArrayUtils.contains(validTypes, type);

Reply via email to