Repository: systemml Updated Branches: refs/heads/master 46f4d9207 -> a5a4d4d33
[SYSTEMML-2328] Improved codegen row template tmp vector allocation The row template uses a preallocated ring buffer of temporary vectors per thread to avoid excessive GC. So far the size of this ring buffer was determined in a best effort manner by handle unary pipelines and otherwise counting the number of vector intermediates per operations. However, for large fused operators this is wasteful as most temporary vectors are only used for a very localized scope. This patch introduces a primitive to determine the exact maximum number of live vector intermediates needed for a fused operator. Overall this improves GC overheads and cache locality. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5a4d4d3 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5a4d4d3 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5a4d4d3 Branch: refs/heads/master Commit: a5a4d4d335a8a2440d236e70b5b92cd8ca8f28e0 Parents: 46f4d92 Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu May 17 16:30:11 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu May 17 17:22:01 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 22 ++++--- .../hops/codegen/template/TemplateRow.java | 3 +- .../hops/codegen/template/TemplateUtils.java | 66 ++++++++++++++++++-- 3 files changed, 77 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 6ef8dc4..368dc94 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -102,14 +102,15 @@ public class SpoofCompiler private static final Log LOG = LogFactory.getLog(SpoofCompiler.class.getName()); //internal configuration flags - public static boolean LDEBUG = false; - public static CompilerType JAVA_COMPILER = CompilerType.JANINO; - public static PlanSelector PLAN_SEL_POLICY = PlanSelector.FUSE_COST_BASED_V2; - public static IntegrationType INTEGRATION = IntegrationType.RUNTIME; - public static final boolean RECOMPILE_CODEGEN = true; - public static final boolean PRUNE_REDUNDANT_PLANS = true; - public static PlanCachePolicy PLAN_CACHE_POLICY = PlanCachePolicy.CSLH; - public static final int PLAN_CACHE_SIZE = 1024; //max 1K classes + public static boolean LDEBUG = false; + public static CompilerType JAVA_COMPILER = CompilerType.JANINO; + public static PlanSelector PLAN_SEL_POLICY = PlanSelector.FUSE_COST_BASED_V2; + public static IntegrationType INTEGRATION = IntegrationType.RUNTIME; + public static final boolean RECOMPILE_CODEGEN = true; + public static final boolean PRUNE_REDUNDANT_PLANS = true; + public static PlanCachePolicy PLAN_CACHE_POLICY = PlanCachePolicy.CSLH; + public static final int PLAN_CACHE_SIZE = 1024; //max 1K classes + public static final RegisterAlloc REG_ALLOC_POLICY = RegisterAlloc.EXACT; public enum CompilerType { AUTO, @@ -148,6 +149,11 @@ public class SpoofCompiler } } + public enum RegisterAlloc { + HEURISTIC, + EXACT, + } + static { // for internal debugging only if( LDEBUG ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index fb153de..95be74a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -235,7 +235,8 @@ public class TemplateRow extends TemplateBase if( tpl.getRowType().isConstDim2(n2) ) tpl.setConstDim2(n2); tpl.setNumVectorIntermediates(TemplateUtils - .determineMinVectorIntermediates(output)); + .determineMinVectorIntermediates(output, + inputs.isEmpty() ? null : inputs.get(0))); tpl.getOutput().resetVisitStatus(); tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops[0].getHopID()); tpl.setBeginLine(hop.getBeginLine()); http://git-wip-us.apache.org/repos/asf/systemml/blob/a5a4d4d3/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 4a61678..3a0b1ed 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -21,8 +21,12 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; import org.apache.commons.lang.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; @@ -34,6 +38,7 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.codegen.SpoofCompiler; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; @@ -353,12 +358,23 @@ public class TemplateUtils return ret; } - public static int determineMinVectorIntermediates(CNode node) { + public static int determineMinVectorIntermediates(CNode node, CNode main) { node.resetVisitStatus(); - boolean unaryPipe = isUnaryOperatorPipeline(node); - node.resetVisitStatus(); - int count = unaryPipe ? getMaxVectorIntermediates(node) : - countVectorIntermediates(node); + int count = -1; + switch( SpoofCompiler.REG_ALLOC_POLICY ) { + case HEURISTIC: + boolean unaryPipe = isUnaryOperatorPipeline(node); + node.resetVisitStatus(); + count = unaryPipe ? getMaxVectorIntermediates(node) : + countVectorIntermediates(node); + break; + case EXACT: + Map<Long, Set<Long>> parents = getAllParents(node); + node.resetVisitStatus(); + count = getMaxLiveVectorIntermediates( + node, main, parents, new HashSet<>()); + break; + } node.resetVisitStatus(); return count; } @@ -411,6 +427,46 @@ public class TemplateUtils && ((CNodeTernary)node).getType().isVectorPrimitive()) ? 1 : 0; return ret + cntBin + cntUn + cntTn; } + + + + public static int getMaxLiveVectorIntermediates(CNode node, CNode main, Map<Long, Set<Long>> parents, Set<Pair<Long, Long>> stack) { + if( node.isVisited() ) + return -1; + //recursively process inputs + int max = -1; + for( CNode c : node.getInput() ) + max = Math.max(max, getMaxLiveVectorIntermediates(c, main, parents, stack)); + // add current node consumers + if( !node.getDataType().isScalar() && parents.containsKey(node.getID()) + && node != main ) { + for( Long pID : parents.get(node.getID()) ) + stack.add(Pair.of(pID, node.getID())); + } + //get current maximum (distinct dep targets) + max = Math.max(max, (int)stack.stream() + .map(p -> p.getValue()).distinct().count()); + //remove input dependencies + for( CNode c : node.getInput() ) + stack.remove(Pair.of(node.getID(), c.getID())); + node.setVisited(); + return max; + } + + public static Map<Long, Set<Long>> getAllParents(CNode node) { + Map<Long, Set<Long>> ret = new HashMap<>(); + getAllParents(node, ret); + return ret; + } + + public static void getAllParents(CNode node, Map<Long, Set<Long>> parents) { + for( CNode c : node.getInput() ) { + if( !parents.containsKey(c) ) + parents.put(c.getID(), new HashSet<>()); + parents.get(c.getID()).add(node.getID()); + getAllParents(c, parents); + } + } public static boolean isType(TemplateType type, TemplateType... validTypes) { return ArrayUtils.contains(validTypes, type);