Repository: systemml Updated Branches: refs/heads/master 31610e36d -> 51154f17b
[SYSTEMML-2480] Fix register allocation codegen row template This patch fixes potential correctness issues in codegen row templates cause by too aggressive reuse of vectors. Specifically, so far we determined the number of required live tmp vectors which is only valid for dynamically managed vector pools (w/ alloc/free). However, for efficiency reasons we use a static array buffer where reuse does not only depend on the number of live objects but also their sequence. This issue did not show up before, because in case of row template with different multiple vector sizes, we conservatively allocate number of types times num registers. However, in special cases such as the attached test, it led to incorrect results. We now use a proof-based algorithm for determining the number of register. Starting from a heuristic allocation (guaranteed upper bound), we systematically decrease the count and verify that no invalid reuse occurs. In attached test, the heuristic allocated 5 vectors (per thread), the incorrect live objects is 2, and our new algorithm finds the true valid minimum of 3. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef50f8a5 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef50f8a5 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef50f8a5 Branch: refs/heads/master Commit: ef50f8a5f99a63524e9ce86c6ef6a6a70e988634 Parents: 31610e3 Author: Matthias Boehm <[email protected]> Authored: Wed Aug 1 22:21:11 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Aug 1 22:21:11 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 7 ++- .../hops/codegen/template/TemplateUtils.java | 63 ++++++++++++++++++-- .../controlprogram/parfor/util/IDSequence.java | 24 ++++---- .../functions/codegen/RowAggTmplTest.java | 20 ++++++- .../scripts/functions/codegen/rowAggPattern45.R | 35 +++++++++++ .../functions/codegen/rowAggPattern45.dml | 33 ++++++++++ 6 files changed, 163 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index fd012ec..9fc9ac4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -110,7 +110,7 @@ public class SpoofCompiler public static final boolean PRUNE_REDUNDANT_PLANS = true; public static PlanCachePolicy PLAN_CACHE_POLICY = PlanCachePolicy.CSLH; public static final int PLAN_CACHE_SIZE = 1024; //max 1K classes - public static final RegisterAlloc REG_ALLOC_POLICY = RegisterAlloc.EXACT; + public static final RegisterAlloc REG_ALLOC_POLICY = RegisterAlloc.EXACT_STATIC_BUFF; public enum CompilerType { AUTO, @@ -150,8 +150,9 @@ public class SpoofCompiler } public enum RegisterAlloc { - HEURISTIC, - EXACT, + HEURISTIC, //max vector intermediates, special handling pipelines (always safe) + EXACT_DYNAMIC_BUFF, //min number of live vector intermediates, assuming dynamic pooling + EXACT_STATIC_BUFF, //min number of live vector intermediates, assuming static array ring buffer } static { http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 438eb56..f2ad7e6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Set; import org.apache.commons.lang.ArrayUtils; +import org.apache.commons.lang3.mutable.MutableInt; import org.apache.commons.lang3.tuple.Pair; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; @@ -59,6 +60,7 @@ import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.util.UtilFunctions; public class TemplateUtils @@ -371,18 +373,34 @@ public class TemplateUtils node.resetVisitStatus(); int count = -1; switch( SpoofCompiler.REG_ALLOC_POLICY ) { - case HEURISTIC: + case HEURISTIC: { boolean unaryPipe = isUnaryOperatorPipeline(node); node.resetVisitStatus(); count = unaryPipe ? getMaxVectorIntermediates(node) : countVectorIntermediates(node); break; - case EXACT: + } + case EXACT_DYNAMIC_BUFF: { Map<Long, Set<Long>> parents = getAllParents(node); node.resetVisitStatus(); count = getMaxLiveVectorIntermediates( node, main, parents, new HashSet<>()); break; + } + case EXACT_STATIC_BUFF: { + //init with basic heuristic + boolean unaryPipe = isUnaryOperatorPipeline(node); + node.resetVisitStatus(); + count = unaryPipe ? getMaxVectorIntermediates(node) : + countVectorIntermediates(node); + //reduce count and proof validity + Map<Long, Set<Long>> parents = getAllParents(node); + Map<Long, Pair<Long, MutableInt>> inUse = new HashMap<>(); //node ID, vector ID, num Refs + Set<Long> inUse2 = new HashSet<>(); //for fast probes + while( count > 0 && isValidNumVectorIntermediates(node, main, parents, inUse, inUse2, count-1) ) + count--; + break; + } } node.resetVisitStatus(); return count; @@ -437,8 +455,6 @@ public class TemplateUtils return ret + cntBin + cntUn + cntTn; } - - public static int getMaxLiveVectorIntermediates(CNode node, CNode main, Map<Long, Set<Long>> parents, Set<Pair<Long, Long>> stack) { if( node.isVisited() ) return -1; @@ -462,6 +478,45 @@ public class TemplateUtils return max; } + public static boolean isValidNumVectorIntermediates(CNode node, CNode main, Map<Long, Set<Long>> parents, Map<Long, Pair<Long, MutableInt>> inUse, Set<Long> inUse2, int count) { + IDSequence buff = new IDSequence(true, count-1); //zero based + inUse.clear(); inUse2.clear(); + node.resetVisitStatus(); + return rIsValidNumVectorIntermediates(node, main, parents, inUse, inUse2, buff); + } + + public static boolean rIsValidNumVectorIntermediates(CNode node, CNode main, Map<Long, Set<Long>> parents, + Map<Long, Pair<Long, MutableInt>> inUse, Set<Long> inUse2, IDSequence buff) { + if( node.isVisited() ) + return true; + //recursively process inputs + for( CNode c : node.getInput() ) + if( !rIsValidNumVectorIntermediates(c, main, parents, inUse, inUse2, buff) ) + return false; + // add current node consumers for vectors + if( !node.getDataType().isScalar() && parents.containsKey(node.getID()) && node != main ) { + long vectID = buff.getNextID(); + if( inUse2.contains(vectID) ) + return false; //CONFLICT detected + inUse.put(node.getID(), Pair.of(vectID, + new MutableInt(parents.get(node.getID()).size()))); + inUse2.add(vectID); + } + //remove input dependencies + for( CNode c : node.getInput() ) { + Pair<Long, MutableInt> tmp = inUse.get(c.getID()); + if( tmp != null ) { + tmp.getValue().decrement(); + if( tmp.getValue().intValue() <= 0 ) { + inUse.remove(c.getID()); + inUse2.remove(tmp.getKey()); + } + } + } + node.setVisited(); + return true; + } + public static Map<Long, Set<Long>> getAllParents(CNode node) { Map<Long, Set<Long>> ret = new HashMap<>(); getAllParents(node, ret); http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java index cdbbc3f..e5bc741 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java @@ -28,32 +28,36 @@ import java.util.concurrent.atomic.AtomicLong; public class IDSequence { private final AtomicLong _current; - private final boolean _wrapAround; + private final boolean _cyclic; + private final long _cycleLen; public IDSequence() { - this(false); + this(false, -1); } - public IDSequence(boolean wrapAround) { + public IDSequence(boolean cyclic) { + this(cyclic, Long.MAX_VALUE); + } + + public IDSequence(boolean cyclic, long cycleLen) { _current = new AtomicLong(-1); - _wrapAround = wrapAround; + _cyclic = cyclic; + _cycleLen = cycleLen; } + /** * Creates the next ID, if overflow a RuntimeException is thrown. * * @return ID */ - public long getNextID() - { + public long getNextID() { long val = _current.incrementAndGet(); - - if( val == Long.MAX_VALUE ) { - if( !_wrapAround ) + if( val == _cycleLen ) { + if( !_cyclic ) throw new RuntimeException("WARNING: IDSequence will produced numeric overflow."); reset(); } - return val; } http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 48555ae..d0a417d 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -81,6 +81,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME42 = TEST_NAME+"42"; //X/rowSums(min(X, Y, Z)) private static final String TEST_NAME43 = TEST_NAME+"43"; //bias_add(X,B) + bias_mult(X,B) private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)) + 7; + private static final String TEST_NAME45 = TEST_NAME+"45"; //vector allocation; private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -92,7 +93,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=44; i++) + for(int i=1; i<=45; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -755,6 +756,21 @@ public class RowAggTmplTest extends AutomatedTestBase public void testCodegenRowAgg44SP() { testCodegenIntegration( TEST_NAME44, false, ExecType.SPARK ); } + + @Test + public void testCodegenRowAggRewrite45CP() { + testCodegenIntegration( TEST_NAME45, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg45CP() { + testCodegenIntegration( TEST_NAME45, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg45SP() { + testCodegenIntegration( TEST_NAME45, false, ExecType.SPARK ); + } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -799,7 +815,7 @@ public class RowAggTmplTest extends AutomatedTestBase Assert.assertTrue(!heavyHittersContainsSubString("uark+")); if( testname.equals(TEST_NAME17) ) Assert.assertTrue(!heavyHittersContainsSubString(RightIndex.OPCODE)); - if( testname.equals(TEST_NAME28) ) + if( testname.equals(TEST_NAME28) || testname.equals(TEST_NAME45) ) Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2) && !heavyHittersContainsSubString("sp_spoofRA", 2)); if( testname.equals(TEST_NAME30) ) http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/scripts/functions/codegen/rowAggPattern45.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern45.R b/src/test/scripts/functions/codegen/rowAggPattern45.R new file mode 100644 index 0000000..fee4bbd --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern45.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 1500, 100) * matrix(1,1500,1)%*%t(seq(1,100)); + +X0 = X - 0.5; +X1 = X / rowSums(X0)%*%matrix(1,1,100); +X2 = abs(X1 * 0.5); +X3 = X1 / rowSums(X2)%*%matrix(1,1,100); + +R = as.matrix(sum(X3)); + +writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/scripts/functions/codegen/rowAggPattern45.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern45.dml b/src/test/scripts/functions/codegen/rowAggPattern45.dml new file mode 100644 index 0000000..4f71886 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern45.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 1500, 100) * t(seq(1,100)); +while(FALSE){} + +X0 = X - 0.5; +X1 = X / rowSums(X0); +X2 = abs(X1 * 0.5); +X3 = X1 / rowSums(X2); + +while(FALSE){} +R = as.matrix(sum(X3)); + +write(R, $1)
