Repository: systemml
Updated Branches:
  refs/heads/master 31610e36d -> 51154f17b


[SYSTEMML-2480] Fix register allocation codegen row template

This patch fixes potential correctness issues in codegen row templates
cause by too aggressive reuse of vectors. Specifically, so far we
determined the number of required live tmp vectors which is only valid
for dynamically managed vector pools (w/ alloc/free). However, for
efficiency reasons we use a static array buffer where reuse does not
only depend on the number of live objects but also their sequence.

This issue did not show up before, because in case of row template with
different multiple vector sizes, we conservatively allocate number of
types times num registers. However, in special cases such as the
attached test, it led to incorrect results.

We now use a proof-based algorithm for determining the number of
register. Starting from a heuristic allocation (guaranteed upper bound),
we systematically decrease the count and verify that no invalid reuse
occurs. In attached test, the heuristic allocated 5 vectors (per
thread), the incorrect live objects is 2, and our new algorithm finds
the true valid minimum of 3.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef50f8a5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef50f8a5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef50f8a5

Branch: refs/heads/master
Commit: ef50f8a5f99a63524e9ce86c6ef6a6a70e988634
Parents: 31610e3
Author: Matthias Boehm <[email protected]>
Authored: Wed Aug 1 22:21:11 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Aug 1 22:21:11 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  7 ++-
 .../hops/codegen/template/TemplateUtils.java    | 63 ++++++++++++++++++--
 .../controlprogram/parfor/util/IDSequence.java  | 24 ++++----
 .../functions/codegen/RowAggTmplTest.java       | 20 ++++++-
 .../scripts/functions/codegen/rowAggPattern45.R | 35 +++++++++++
 .../functions/codegen/rowAggPattern45.dml       | 33 ++++++++++
 6 files changed, 163 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index fd012ec..9fc9ac4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -110,7 +110,7 @@ public class SpoofCompiler
        public static final boolean PRUNE_REDUNDANT_PLANS  = true;
        public static PlanCachePolicy PLAN_CACHE_POLICY    = 
PlanCachePolicy.CSLH;
        public static final int PLAN_CACHE_SIZE            = 1024; //max 1K 
classes
-       public static final RegisterAlloc REG_ALLOC_POLICY = 
RegisterAlloc.EXACT;
+       public static final RegisterAlloc REG_ALLOC_POLICY = 
RegisterAlloc.EXACT_STATIC_BUFF;
        
        public enum CompilerType {
                AUTO,
@@ -150,8 +150,9 @@ public class SpoofCompiler
        }
        
        public enum RegisterAlloc {
-               HEURISTIC,
-               EXACT,
+               HEURISTIC,           //max vector intermediates, special 
handling pipelines (always safe)
+               EXACT_DYNAMIC_BUFF,  //min number of live vector intermediates, 
assuming dynamic pooling
+               EXACT_STATIC_BUFF,   //min number of live vector intermediates, 
assuming static array ring buffer
        }
        
        static {

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 438eb56..f2ad7e6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -26,6 +26,7 @@ import java.util.Map;
 import java.util.Set;
 
 import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
@@ -59,6 +60,7 @@ import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
 import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
 import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class TemplateUtils 
@@ -371,18 +373,34 @@ public class TemplateUtils
                node.resetVisitStatus();
                int count = -1;
                switch( SpoofCompiler.REG_ALLOC_POLICY ) {
-                       case HEURISTIC:
+                       case HEURISTIC: {
                                boolean unaryPipe = 
isUnaryOperatorPipeline(node);
                                node.resetVisitStatus();
                                count = unaryPipe ? 
getMaxVectorIntermediates(node) :
                                        countVectorIntermediates(node);
                                break;
-                       case EXACT:
+                       }
+                       case EXACT_DYNAMIC_BUFF: {
                                Map<Long, Set<Long>> parents = 
getAllParents(node);
                                node.resetVisitStatus();
                                count = getMaxLiveVectorIntermediates(
                                        node, main, parents, new HashSet<>());
                                break;
+                       }
+                       case EXACT_STATIC_BUFF: {
+                               //init with basic heuristic
+                               boolean unaryPipe = 
isUnaryOperatorPipeline(node);
+                               node.resetVisitStatus();
+                               count = unaryPipe ? 
getMaxVectorIntermediates(node) :
+                                       countVectorIntermediates(node);
+                               //reduce count and proof validity
+                               Map<Long, Set<Long>> parents = 
getAllParents(node);
+                               Map<Long, Pair<Long, MutableInt>> inUse = new 
HashMap<>(); //node ID, vector ID, num Refs
+                               Set<Long> inUse2 = new HashSet<>(); //for fast 
probes
+                               while( count > 0 && 
isValidNumVectorIntermediates(node, main, parents, inUse, inUse2, count-1) )
+                                       count--;
+                               break;
+                       }
                }
                node.resetVisitStatus();
                return count;
@@ -437,8 +455,6 @@ public class TemplateUtils
                return ret + cntBin + cntUn + cntTn;
        }
        
-       
-       
        public static int getMaxLiveVectorIntermediates(CNode node, CNode main, 
Map<Long, Set<Long>> parents, Set<Pair<Long, Long>> stack) {
                if( node.isVisited() )
                        return -1;
@@ -462,6 +478,45 @@ public class TemplateUtils
                return max;
        }
        
+       public static boolean isValidNumVectorIntermediates(CNode node, CNode 
main, Map<Long, Set<Long>> parents, Map<Long, Pair<Long, MutableInt>> inUse, 
Set<Long> inUse2, int count) {
+               IDSequence buff = new IDSequence(true, count-1); //zero based
+               inUse.clear(); inUse2.clear();
+               node.resetVisitStatus();
+               return rIsValidNumVectorIntermediates(node, main, parents, 
inUse, inUse2, buff);
+       }
+       
+       public static boolean rIsValidNumVectorIntermediates(CNode node, CNode 
main, Map<Long, Set<Long>> parents,
+                       Map<Long, Pair<Long, MutableInt>> inUse, Set<Long> 
inUse2, IDSequence buff) {
+               if( node.isVisited() )
+                       return true;
+               //recursively process inputs
+               for( CNode c : node.getInput() )
+                       if( !rIsValidNumVectorIntermediates(c, main, parents, 
inUse, inUse2, buff) )
+                               return false;
+               // add current node consumers for vectors
+               if( !node.getDataType().isScalar() && 
parents.containsKey(node.getID()) && node != main ) {
+                       long vectID = buff.getNextID();
+                       if( inUse2.contains(vectID) )
+                               return false; //CONFLICT detected
+                       inUse.put(node.getID(), Pair.of(vectID,
+                               new 
MutableInt(parents.get(node.getID()).size())));
+                       inUse2.add(vectID);
+               }
+               //remove input dependencies
+               for( CNode c : node.getInput() ) {
+                       Pair<Long, MutableInt> tmp = inUse.get(c.getID());
+                       if( tmp != null ) {
+                               tmp.getValue().decrement();
+                               if( tmp.getValue().intValue() <= 0 ) {
+                                       inUse.remove(c.getID());
+                                       inUse2.remove(tmp.getKey());
+                               }
+                       }
+               }
+               node.setVisited();
+               return true;
+       }
+       
        public static Map<Long, Set<Long>> getAllParents(CNode node) {
                Map<Long, Set<Long>> ret = new HashMap<>();
                getAllParents(node, ret);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java
index cdbbc3f..e5bc741 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/IDSequence.java
@@ -28,32 +28,36 @@ import java.util.concurrent.atomic.AtomicLong;
 public class IDSequence 
 {
        private final AtomicLong _current;
-       private final boolean _wrapAround;
+       private final boolean _cyclic;
+       private final long _cycleLen;
        
        public IDSequence() {
-               this(false);
+               this(false, -1);
        }
        
-       public IDSequence(boolean wrapAround) {
+       public IDSequence(boolean cyclic) {
+               this(cyclic, Long.MAX_VALUE);
+       }
+       
+       public IDSequence(boolean cyclic, long cycleLen) {
                _current = new AtomicLong(-1);
-               _wrapAround = wrapAround;
+               _cyclic = cyclic;
+               _cycleLen = cycleLen;
        }
        
+       
        /**
         * Creates the next ID, if overflow a RuntimeException is thrown.
         * 
         * @return ID
         */
-       public long getNextID()
-       {
+       public long getNextID() {
                long val = _current.incrementAndGet();
-               
-               if( val == Long.MAX_VALUE ) {
-                       if( !_wrapAround )
+               if( val == _cycleLen ) {
+                       if( !_cyclic )
                                throw new RuntimeException("WARNING: IDSequence 
will produced numeric overflow.");
                        reset();
                }
-               
                return val;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 48555ae..d0a417d 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -81,6 +81,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME42 = TEST_NAME+"42"; 
//X/rowSums(min(X, Y, Z))
        private static final String TEST_NAME43 = TEST_NAME+"43"; 
//bias_add(X,B) + bias_mult(X,B)
        private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - 
mean(X)) + 7;
+       private static final String TEST_NAME45 = TEST_NAME+"45"; //vector 
allocation;
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -92,7 +93,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=44; i++)
+               for(int i=1; i<=45; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -755,6 +756,21 @@ public class RowAggTmplTest extends AutomatedTestBase
        public void testCodegenRowAgg44SP() {
                testCodegenIntegration( TEST_NAME44, false, ExecType.SPARK );
        }
+       
+       @Test
+       public void testCodegenRowAggRewrite45CP() {
+               testCodegenIntegration( TEST_NAME45, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg45CP() {
+               testCodegenIntegration( TEST_NAME45, false, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg45SP() {
+               testCodegenIntegration( TEST_NAME45, false, ExecType.SPARK );
+       }
 
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {
@@ -799,7 +815,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                                
Assert.assertTrue(!heavyHittersContainsSubString("uark+"));
                        if( testname.equals(TEST_NAME17) )
                                
Assert.assertTrue(!heavyHittersContainsSubString(RightIndex.OPCODE));
-                       if( testname.equals(TEST_NAME28) )
+                       if( testname.equals(TEST_NAME28) || 
testname.equals(TEST_NAME45) )
                                
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)
                                        && 
!heavyHittersContainsSubString("sp_spoofRA", 2));
                        if( testname.equals(TEST_NAME30) )

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/scripts/functions/codegen/rowAggPattern45.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern45.R 
b/src/test/scripts/functions/codegen/rowAggPattern45.R
new file mode 100644
index 0000000..fee4bbd
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern45.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(1, 1500, 100) * matrix(1,1500,1)%*%t(seq(1,100));
+
+X0 = X - 0.5;
+X1 = X / rowSums(X0)%*%matrix(1,1,100);
+X2 = abs(X1 * 0.5);
+X3 = X1 / rowSums(X2)%*%matrix(1,1,100);
+
+R = as.matrix(sum(X3));
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef50f8a5/src/test/scripts/functions/codegen/rowAggPattern45.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern45.dml 
b/src/test/scripts/functions/codegen/rowAggPattern45.dml
new file mode 100644
index 0000000..4f71886
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern45.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(1, 1500, 100) * t(seq(1,100));
+while(FALSE){}
+
+X0 = X - 0.5;
+X1 = X / rowSums(X0);
+X2 = abs(X1 * 0.5);
+X3 = X1 / rowSums(X2);
+
+while(FALSE){}
+R = as.matrix(sum(X3));
+
+write(R, $1)

Reply via email to