[SYSTEMML-1430] Robust broadcast memory handling (track pinned sizes)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a429e2df
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a429e2df
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a429e2df

Branch: refs/heads/master
Commit: a429e2df9287b709edd245c6a3211d62ecbf9517
Parents: f380b52
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 23 13:37:44 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 23 13:39:39 2017 -0700

----------------------------------------------------------------------
 .../controlprogram/caching/CacheableData.java   | 16 +++++++++++++++
 .../controlprogram/caching/MatrixObject.java    |  2 +-
 .../context/SparkExecutionContext.java          | 21 ++++++++++++++++----
 .../spark/data/BroadcastObject.java             | 13 ++++++++----
 4 files changed, 43 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a429e2df/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index c36b8ca..054b333 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.controlprogram.caching;
 import java.io.File;
 import java.io.IOException;
 import java.lang.ref.SoftReference;
+import java.util.concurrent.atomic.AtomicLong;
 
 import org.apache.commons.lang.mutable.MutableBoolean;
 import org.apache.commons.logging.Log;
@@ -115,6 +116,12 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        private static ThreadLocal<Long> sizePinned = new ThreadLocal<Long>() {
         @Override protected Long initialValue() { return 0L; }
     };
+
+       //current size of live broadcast objects (because Spark's 
ContextCleaner maintains 
+       //a buffer with references to prevent eager cleanup by GC); note that 
this is an 
+       //overestimate, because we maintain partitioned broadcasts as soft 
references, which 
+       //might be collected by the GC and subsequently cleaned up by Spark's 
ContextCleaner.
+       private static AtomicLong _refBCs = new AtomicLong(0);  
     
        static {
                _seq = new IDSequence();
@@ -1213,6 +1220,14 @@ public abstract class CacheableData<T extends 
CacheBlock> extends Data
                return sizePinned.get();
        }
        
+       public static void addBroadcastSize(long size) {
+               _refBCs.addAndGet(size);
+       }
+       
+       public static long getBroadcastSize() {
+               return _refBCs.longValue();
+       }
+       
        // --------- STATIC CACHE INIT/CLEANUP OPERATIONS ----------
 
        public synchronized static void cleanupCacheDir() {
@@ -1285,6 +1300,7 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        
                //init write-ahead buffer
                LazyWriteBuffer.init();
+               _refBCs.set(0);
                
                _activeFlag = true; //turn on caching
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a429e2df/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
index dc3cfd1..4e560c8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
@@ -482,7 +482,7 @@ public class MatrixObject extends CacheableData<MatrixBlock>
                        
                        //guarded rdd collect 
                        if( ii == InputInfo.BinaryBlockInputInfo && //guarded 
collect not for binary cell
-                               
!OptimizerUtils.checkSparkCollectMemoryBudget(rlen, clen, brlen, bclen, nnz, 
getPinnedSize()) ) {
+                               
!OptimizerUtils.checkSparkCollectMemoryBudget(mc, 
getPinnedSize()+getBroadcastSize()) ) {
                                //write RDD to hdfs and read to prevent invalid 
collect mem consumption 
                                //note: lazy, partition-at-a-time collect 
(toLocalIterator) was significantly slower
                                if( 
!MapReduceTool.existsFileOnHDFS(_hdfsFileName) ) { //prevent overwrite existing 
file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a429e2df/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 106acc0..c2e3dd0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -526,6 +526,10 @@ public class SparkExecutionContext extends ExecutionContext
                //create new broadcast handle (never created, evicted)
                if( bret == null ) 
                {
+                       //account for overwritten invalid broadcast (e.g., 
evicted)
+                       if( mo.getBroadcastHandle()!=null )
+                               
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
+                       
                        //obtain meta data for matrix 
                        int brlen = (int) mo.getNumRowsPerBlock();
                        int bclen = (int) mo.getNumColumnsPerBlock();
@@ -550,12 +554,14 @@ public class SparkExecutionContext extends 
ExecutionContext
                                }
                        }
                        else { //single partition
-                               ret[0] = getSparkContext().broadcast( pmb);
+                               ret[0] = getSparkContext().broadcast(pmb);
                        }
                
                        bret = new PartitionedBroadcast<MatrixBlock>(ret);
-                       BroadcastObject<MatrixBlock> bchandle = new 
BroadcastObject<MatrixBlock>(bret, varname);
+                       BroadcastObject<MatrixBlock> bchandle = new 
BroadcastObject<MatrixBlock>(bret, varname, 
+                                       
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics()));
                        mo.setBroadcastHandle(bchandle);
+                       CacheableData.addBroadcastSize(bchandle.getSize());
                }
                
                if (DMLScript.STATISTICS) {
@@ -586,6 +592,10 @@ public class SparkExecutionContext extends ExecutionContext
                //create new broadcast handle (never created, evicted)
                if( bret == null ) 
                {
+                       //account for overwritten invalid broadcast (e.g., 
evicted)
+                       if( fo.getBroadcastHandle()!=null )
+                               
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize());
+                       
                        //obtain meta data for frame 
                        int bclen = (int) fo.getNumColumns();
                        int brlen = OptimizerUtils.getDefaultFrameSize();
@@ -610,12 +620,14 @@ public class SparkExecutionContext extends 
ExecutionContext
                                }
                        }
                        else { //single partition
-                               ret[0] = getSparkContext().broadcast( pmb);
+                               ret[0] = getSparkContext().broadcast(pmb);
                        }
                
                        bret = new PartitionedBroadcast<FrameBlock>(ret);
-                       BroadcastObject<FrameBlock> bchandle = new 
BroadcastObject<FrameBlock>(bret, varname);
+                       BroadcastObject<FrameBlock> bchandle = new 
BroadcastObject<FrameBlock>(bret, varname,  
+                                       
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics()));
                        fo.setBroadcastHandle(bchandle);
+                       CacheableData.addBroadcastSize(bchandle.getSize());
                }
                
                if (DMLScript.STATISTICS) {
@@ -1136,6 +1148,7 @@ public class SparkExecutionContext extends 
ExecutionContext
                        if( pbm != null ) //robustness for evictions
                                for( Broadcast<PartitionedBlock> bc : 
pbm.getBroadcasts() )
                                        cleanupBroadcastVariable(bc);
+                       
CacheableData.addBroadcastSize(-((BroadcastObject)lob).getSize());
                }
        
                //recursively process lineage children

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a429e2df/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
index 4316b15..bdd34e4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
@@ -27,18 +27,23 @@ import 
org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
 public class BroadcastObject<T extends CacheBlock> extends LineageObject
 {
        //soft reference storage for graceful cleanup in case of memory pressure
-       protected SoftReference<PartitionedBroadcast<T>> _bcHandle = null;
+       protected final SoftReference<PartitionedBroadcast<T>> _bcHandle;
+       private final long _size;
        
-       public BroadcastObject( PartitionedBroadcast<T> bvar, String varName ) {
+       public BroadcastObject( PartitionedBroadcast<T> bvar, String varName, 
long size ) {
                super(varName);
                _bcHandle = new SoftReference<PartitionedBroadcast<T>>(bvar);
+               _size = size;
        }
 
        @SuppressWarnings("rawtypes")
-       public PartitionedBroadcast getBroadcast()
-       {
+       public PartitionedBroadcast getBroadcast() {
                return _bcHandle.get();
        }
+       
+       public long getSize() {
+               return _size;
+       }
 
        public boolean isValid() 
        {

Reply via email to