Repository: systemml
Updated Branches:
  refs/heads/master 9ea71d534 -> 945f3ccf9


[SYSTEMML-1313] Runtime support for remote parfor broadcast inputs

Closes #759.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/945f3ccf
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/945f3ccf
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/945f3ccf

Branch: refs/heads/master
Commit: 945f3ccf9f167ee43c5946b2ea871af5b56a72be
Parents: 9ea71d5
Author: EdgarLGB <[email protected]>
Authored: Sat Apr 28 16:02:37 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Apr 28 16:02:41 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/ParForProgramBlock.java      |  11 +-
 .../context/ExecutionContext.java               |   5 +
 .../context/SparkExecutionContext.java          | 121 +++++++++++++------
 .../parfor/CachedReuseVariables.java            |  33 ++++-
 .../parfor/RemoteParForSpark.java               |  47 ++++++-
 .../parfor/RemoteParForSparkWorker.java         |  14 ++-
 .../spark/data/BroadcastObject.java             |  71 +++++++----
 7 files changed, 227 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index c46e4dd..3d17e2f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -308,15 +308,16 @@ public class ParForProgramBlock extends ForProgramBlock
        public static final boolean FORCE_CP_ON_REMOTE_MR       = true; // 
compile body to CP if exec type forced to MR
        public static final boolean LIVEVAR_AWARE_EXPORT        = true; // 
export only read variables according to live variable analysis
        public static final boolean RESET_RECOMPILATION_FLAGs   = true;
-       
-       public static final String PARFOR_FNAME_PREFIX          = "/parfor/"; 
+       public static final boolean ALLOW_BROADCAST_INPUTS      = false; // 
enables to broadcast inputs for remote_spark
+       
+       public static final String PARFOR_FNAME_PREFIX          = "/parfor/"; 
        public static final String PARFOR_MR_TASKS_TMP_FNAME    = 
PARFOR_FNAME_PREFIX + "%ID%_MR_taskfile"; 
        public static final String PARFOR_MR_RESULT_TMP_FNAME   = 
PARFOR_FNAME_PREFIX + "%ID%_MR_results"; 
        public static final String PARFOR_MR_RESULTMERGE_FNAME  = 
PARFOR_FNAME_PREFIX + "%ID%_resultmerge%VAR%"; 
        public static final String PARFOR_DATAPARTITIONS_FNAME  = 
PARFOR_FNAME_PREFIX + "%ID%_datapartitions%VAR%"; 
        
        public static final String PARFOR_COUNTER_GROUP_NAME    = "SystemML 
ParFOR Counters";
-       
+
        // static ID generator sequences
        private final static IDSequence _pfIDSeq = new IDSequence();
        private final static IDSequence _pwIDSeq = new IDSequence();
@@ -1042,8 +1043,8 @@ public class ParForProgramBlock extends ForProgramBlock
                exportMatricesToHDFS(ec);
                
                // Step 3) submit Spark parfor job (no lazy evaluation, since 
collect on result)
-               //MatrixObject colocatedDPMatrixObj = 
(_colocatedDPMatrix!=null)? (MatrixObject)ec.getVariable(_colocatedDPMatrix) : 
null;
-               RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, 
program, clsMap, tasks, ec, _enableCPCaching, _numThreads);
+               RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, 
program,
+                       clsMap, tasks, ec, _resultVars, _enableCPCaching, 
_numThreads);
                
                if( _monitor ) 
                        StatisticMonitor.putPFStat(_ID, 
Stat.PARFOR_WAIT_EXEC_T, time.stop());

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index dfc3dc5..f339efb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -195,6 +195,11 @@ public class ExecutionContext {
                
                return (MatrixObject) dat;
        }
+
+       public boolean isFrameObject(String varname) {
+               Data dat = getVariable(varname);
+               return (dat!= null && dat instanceof FrameObject);
+       }
        
        public FrameObject getFrameObject(CPOperand input) {
                return getFrameObject(input.getName());

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 4bb1f3a..8e7d888 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -505,10 +505,48 @@ public class SparkExecutionContext extends 
ExecutionContext
 
                return rdd;
        }
-       
+
+       public Broadcast<CacheBlock> 
broadcastVariable(CacheableData<CacheBlock> cd) {
+               long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+               Broadcast<CacheBlock> brBlock = null;
+
+               // reuse existing non partitioned broadcast handle
+               if (cd.getBroadcastHandle() != null && 
cd.getBroadcastHandle().isNonPartitionedBroadcastValid()) {
+                       brBlock = 
cd.getBroadcastHandle().getNonPartitionedBroadcast();
+               }
+
+               if (brBlock == null) {
+                       //create new broadcast handle (never created, evicted)
+                       // account for overwritten invalid broadcast (e.g., 
evicted)
+                       if (cd.getBroadcastHandle() != null)
+                               
CacheableData.addBroadcastSize(-cd.getBroadcastHandle().getNonPartitionedBroadcastSize());
+
+                       // read the matrix block
+                       CacheBlock cb = cd.acquireRead();
+                       cd.release();
+
+                       // broadcast a non-empty frame whose size is smaller 
than 2G
+                       if (cb.getExactSerializedSize() > 0 && 
cb.getExactSerializedSize() <= Integer.MAX_VALUE) {
+                               brBlock = getSparkContext().broadcast(cb);
+                               // create the broadcast handle if the matrix or 
frame has never been broadcasted
+                               if (cd.getBroadcastHandle() == null) {
+                                       cd.setBroadcastHandle(new 
BroadcastObject<>());
+                               }
+                               
cd.getBroadcastHandle().setNonPartitionedBroadcast(brBlock,
+                                       
OptimizerUtils.estimateSize(cd.getMatrixCharacteristics()));
+                               
CacheableData.addBroadcastSize(cd.getBroadcastHandle().getNonPartitionedBroadcastSize());
+
+                               if (DMLScript.STATISTICS) {
+                                       
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+                                       Statistics.incSparkBroadcastCount(1);
+                               }
+                       }
+               }
+               return brBlock;
+       }
+
        @SuppressWarnings("unchecked")
-       public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable( 
String varname )
-       {
+       public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable(String 
varname) {
                //NOTE: The memory consumption of this method is the in-memory 
size of the 
                //matrix object plus the partitioned size in 1k-1k blocks. 
Since the call
                //to broadcast happens after the matrix object has been 
released, the memory
@@ -524,18 +562,15 @@ public class SparkExecutionContext extends 
ExecutionContext
                PartitionedBroadcast<MatrixBlock> bret = null;
 
                //reuse existing broadcast handle
-               if( mo.getBroadcastHandle()!=null
-                       && mo.getBroadcastHandle().isValid() )
-               {
-                       bret = mo.getBroadcastHandle().getBroadcast();
+               if (mo.getBroadcastHandle() != null && 
mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
+                       bret = 
mo.getBroadcastHandle().getPartitionedBroadcast();
                }
 
                //create new broadcast handle (never created, evicted)
-               if( bret == null )
-               {
+               if (bret == null) {
                        //account for overwritten invalid broadcast (e.g., 
evicted)
-                       if( mo.getBroadcastHandle()!=null )
-                               
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
+                       if (mo.getBroadcastHandle() != null)
+                               
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getPartitionedBroadcastSize());
 
                        //obtain meta data for matrix
                        int brlen = (int) mo.getNumRowsPerBlock();
@@ -548,24 +583,26 @@ public class SparkExecutionContext extends 
ExecutionContext
 
                        //determine coarse-grained partitioning
                        int numPerPart = 
PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), 
mo.getNumColumns(), brlen, bclen);
-                       int numParts = (int) 
Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
+                       int numParts = (int) Math.ceil((double) 
pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
                        Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new 
Broadcast[numParts];
 
                        //create coarse-grained partitioned broadcasts
                        if (numParts > 1) {
                                Arrays.parallelSetAll(ret, i -> 
createPartitionedBroadcast(pmb, numPerPart, i));
-                       } 
-                       else { //single partition
+                       } else { //single partition
                                ret[0] = getSparkContext().broadcast(pmb);
                                if (!isLocalMaster())
                                        pmb.clearBlocks();
                        }
                        
                        bret = new PartitionedBroadcast<>(ret, 
mo.getMatrixCharacteristics());
-                       BroadcastObject<MatrixBlock> bchandle = new 
BroadcastObject<>(bret,
-                                       
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics()));
-                       mo.setBroadcastHandle(bchandle);
-                       CacheableData.addBroadcastSize(bchandle.getSize());
+                       // create the broadcast handle if the matrix or frame 
has never been broadcasted
+                       if (mo.getBroadcastHandle() == null) {
+                               mo.setBroadcastHandle(new 
BroadcastObject<MatrixBlock>());
+                       }
+                       mo.getBroadcastHandle().setPartitionedBroadcast(bret,
+                               
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics()));
+                       
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getPartitionedBroadcastSize());
                }
 
                if (DMLScript.STATISTICS) {
@@ -577,8 +614,7 @@ public class SparkExecutionContext extends ExecutionContext
        }
 
        @SuppressWarnings("unchecked")
-       public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable( 
String varname)
-       {
+       public PartitionedBroadcast<FrameBlock> 
getBroadcastForFrameVariable(String varname) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 
                FrameObject fo = getFrameObject(varname);
@@ -586,18 +622,15 @@ public class SparkExecutionContext extends 
ExecutionContext
                PartitionedBroadcast<FrameBlock> bret = null;
 
                //reuse existing broadcast handle
-               if( fo.getBroadcastHandle()!=null
-                       && fo.getBroadcastHandle().isValid() )
-               {
-                       bret = fo.getBroadcastHandle().getBroadcast();
+               if (fo.getBroadcastHandle() != null && 
fo.getBroadcastHandle().isPartitionedBroadcastValid()) {
+                       bret = 
fo.getBroadcastHandle().getPartitionedBroadcast();
                }
 
                //create new broadcast handle (never created, evicted)
-               if( bret == null )
-               {
+               if (bret == null) {
                        //account for overwritten invalid broadcast (e.g., 
evicted)
-                       if( fo.getBroadcastHandle()!=null )
-                               
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize());
+                       if (fo.getBroadcastHandle() != null)
+                               
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getPartitionedBroadcastSize());
 
                        //obtain meta data for frame
                        int bclen = (int) fo.getNumColumns();
@@ -610,24 +643,25 @@ public class SparkExecutionContext extends 
ExecutionContext
 
                        //determine coarse-grained partitioning
                        int numPerPart = 
PartitionedBroadcast.computeBlocksPerPartition(fo.getNumRows(), 
fo.getNumColumns(), brlen, bclen);
-                       int numParts = (int) 
Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
+                       int numParts = (int) Math.ceil((double) 
pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
                        Broadcast<PartitionedBlock<FrameBlock>>[] ret = new 
Broadcast[numParts];
 
                        //create coarse-grained partitioned broadcasts
                        if (numParts > 1) {
                                Arrays.parallelSetAll(ret, i -> 
createPartitionedBroadcast(pmb, numPerPart, i));
-                       }
-                       else { //single partition
+                       } else { //single partition
                                ret[0] = getSparkContext().broadcast(pmb);
                                if (!isLocalMaster())
                                        pmb.clearBlocks();
                        }
 
                        bret = new PartitionedBroadcast<>(ret, 
fo.getMatrixCharacteristics());
-                       BroadcastObject<FrameBlock> bchandle = new 
BroadcastObject<>(bret,
-                                       
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics()));
-                       fo.setBroadcastHandle(bchandle);
-                       CacheableData.addBroadcastSize(bchandle.getSize());
+                       if (fo.getBroadcastHandle() == null) {
+                               fo.setBroadcastHandle(new 
BroadcastObject<FrameBlock>());
+                       }
+                       fo.getBroadcastHandle().setPartitionedBroadcast(bret,
+                               
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics()));
+                       
CacheableData.addBroadcastSize(fo.getBroadcastHandle().getPartitionedBroadcastSize());
                }
 
                if (DMLScript.STATISTICS) {
@@ -1124,11 +1158,20 @@ public class SparkExecutionContext extends 
ExecutionContext
                                _parRDDs.deregisterRDD(rddID);
                }
                else if( lob instanceof BroadcastObject ) {
-                       PartitionedBroadcast pbm = 
((BroadcastObject)lob).getBroadcast();
-                       if( pbm != null ) //robustness for evictions
-                               for( Broadcast<PartitionedBlock> bc : 
pbm.getBroadcasts() )
+                       BroadcastObject bob = (BroadcastObject) lob;
+                       // clean the partitioned broadcast
+                       if (bob.isPartitionedBroadcastValid()) {
+                               PartitionedBroadcast pbm = 
bob.getPartitionedBroadcast();
+                               if( pbm != null ) //robustness evictions
+                                       pbm.destroy();
+                       }
+                       // clean the non-partitioned broadcast
+                       if (((BroadcastObject) 
lob).isNonPartitionedBroadcastValid()) {
+                               Broadcast<CacheableData> bc = 
bob.getNonPartitionedBroadcast();
+                               if( bc != null ) //robustness evictions
                                        cleanupBroadcastVariable(bc);
-                       
CacheableData.addBroadcastSize(-((BroadcastObject)lob).getSize());
+                       }
+                       
CacheableData.addBroadcastSize(-bob.getNonPartitionedBroadcastSize());
                }
 
                //recursively process lineage children

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java
index 2532cbd..3db8e7a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java
@@ -23,8 +23,15 @@ import java.lang.ref.SoftReference;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
+import java.util.Map.Entry;
 
+import org.apache.spark.broadcast.Broadcast;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.instructions.cp.Data;
 
 public class CachedReuseVariables
 {
@@ -33,11 +40,22 @@ public class CachedReuseVariables
        public CachedReuseVariables() {
                _data = new HashMap<>();
        }
+
+       public synchronized boolean containsVars(long pfid) {
+               return _data.containsKey(pfid);
+       }
        
-       public synchronized void reuseVariables(long pfid, LocalVariableMap 
vars, Collection<String> blacklist) {
+       @SuppressWarnings("unused")
+       public synchronized void reuseVariables(long pfid, LocalVariableMap 
vars, Collection<String> blacklist, Map<String, Broadcast<CacheBlock>> 
_brInputs) {
+
+               //fetch the broadcast variables
+               if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS && 
!containsVars(pfid)) {
+                       loadBroadcastVariables(vars, _brInputs);
+               }
+
                //check for existing reuse map
                LocalVariableMap tmp = null;
-               if( _data.containsKey(pfid) )
+               if( containsVars(pfid) )
                        tmp = _data.get(pfid).get();
                
                //build reuse map if not created yet or evicted
@@ -57,4 +75,15 @@ public class CachedReuseVariables
        public synchronized void clearVariables(long pfid) {
                _data.remove(pfid);
        }
+
+       @SuppressWarnings("unchecked")
+       private static void loadBroadcastVariables(LocalVariableMap variables, 
Map<String, Broadcast<CacheBlock>> brInputs) {
+               for( Entry<String, Broadcast<CacheBlock>> e : 
brInputs.entrySet() ) {
+                       Data d = variables.get(e.getKey());
+                       CacheableData<CacheBlock> cdcb = 
(CacheableData<CacheBlock>) d;
+                       cdcb.acquireModify(e.getValue().getValue());
+                       cdcb.setEmptyStatus(); // avoid eviction
+                       cdcb.refreshMetaData();
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java
index eee65f3..7d75cb6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java
@@ -19,14 +19,27 @@
 
 package org.apache.sysml.runtime.controlprogram.parfor;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
 
+import org.apache.sysml.parser.ParForStatementBlock;
+import org.apache.sysml.parser.ParForStatementBlock.ResultVar;
+import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import scala.Tuple2;
 
 import org.apache.sysml.api.DMLScript;
@@ -47,7 +60,6 @@ import org.apache.sysml.utils.Statistics;
  * pre-aggregation by overwriting partial task results with pre-paggregated 
results from subsequent
  * iterations)
  * 
- * TODO broadcast variables if possible
  * TODO reducebykey on variable names
  */
 public class RemoteParForSpark 
@@ -58,7 +70,7 @@ public class RemoteParForSpark
        private static final IDSequence _jobID = new IDSequence();
        
        public static RemoteParForJobReturn runJob(long pfid, String prog, 
HashMap<String, byte[]> clsMap, 
-                       List<Task> tasks, ExecutionContext ec, boolean 
cpCaching, int numMappers) 
+                       List<Task> tasks, ExecutionContext ec, 
ArrayList<ResultVar> resultVars, boolean cpCaching, int numMappers) 
        {
                String jobname = "ParFor-ESP";
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
@@ -74,11 +86,17 @@ public class RemoteParForSpark
                long jobid = _jobID.getNextID();
                if( InfrastructureAnalyzer.isLocalMode() )
                        RemoteParForSparkWorker.cleanupCachedVariables(jobid);
-               
+
+               // broadcast the inputs except the result variables
+               Map<String, Broadcast<CacheBlock>> brInputs = null;
+               if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS) {
+                       brInputs = broadcastInputs(sec, resultVars);
+               }
+
                //run remote_spark parfor job 
                //(w/o lazy evaluation to fit existing parfor framework, e.g., 
result merge)
                List<Tuple2<Long,String>> out = sc.parallelize(tasks, 
tasks.size()) //create rdd of parfor tasks
-                       .flatMapToPair(new RemoteParForSparkWorker(jobid, prog, 
clsMap, cpCaching, aTasks, aIters))
+                       .flatMapToPair(new RemoteParForSparkWorker(jobid, prog, 
clsMap, cpCaching, aTasks, aIters, brInputs))
                        .collect(); //execute and get output handles
                
                //de-serialize results
@@ -97,4 +115,25 @@ public class RemoteParForSpark
                
                return ret;
        }
+
+       @SuppressWarnings("unchecked")
+       private static Map<String, Broadcast<CacheBlock>> 
broadcastInputs(SparkExecutionContext sec, 
ArrayList<ParForStatementBlock.ResultVar> resultVars) {
+               LocalVariableMap inputs = sec.getVariables();
+               // exclude the result variables
+               // TODO use optimizer-picked list of amenable objects (e.g., 
size constraints)
+               Set<String> retVars = resultVars.stream()
+                       .map(v -> v._name).collect(Collectors.toSet());
+               Set<String> brVars = inputs.keySet().stream()
+                       .filter(v -> 
!retVars.contains(v)).collect(Collectors.toSet());
+               
+               // construct broadcast objects
+               Map<String, Broadcast<CacheBlock>> result = new HashMap<>();
+               for (String key : brVars) {
+                       Data var = sec.getVariable(key);
+                       if ((var instanceof ScalarObject) || (var instanceof 
MatrixObject && ((MatrixObject) var).isPartitioned()))
+                               continue;
+                       result.put(key, 
sec.broadcastVariable((CacheableData<CacheBlock>) var));
+               }
+               return result;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
index 7485602..b22f48d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
@@ -23,13 +23,16 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.Map;
 import java.util.Map.Entry;
 import java.util.stream.Collectors;
 
 import org.apache.spark.TaskContext;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler;
@@ -52,8 +55,10 @@ public class RemoteParForSparkWorker extends ParWorker 
implements PairFlatMapFun
        
        private final LongAccumulator _aTasks;
        private final LongAccumulator _aIters;
+
+       private final Map<String, Broadcast<CacheBlock>> _brInputs;
        
-       public RemoteParForSparkWorker(long jobid, String program, 
HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, 
LongAccumulator aiters) {
+       public RemoteParForSparkWorker(long jobid, String program, 
HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, 
LongAccumulator aiters, Map<String, Broadcast<CacheBlock>> brInputs) {
                _jobid = jobid;
                _prog = program;
                _clsMap = clsMap;
@@ -62,6 +67,7 @@ public class RemoteParForSparkWorker extends ParWorker 
implements PairFlatMapFun
                //setup spark accumulators
                _aTasks = atasks;
                _aIters = aiters;
+               _brInputs = brInputs;
        }
        
        @Override 
@@ -102,13 +108,13 @@ public class RemoteParForSparkWorker extends ParWorker 
implements PairFlatMapFun
                _resultVars  = body.getResultVariables();
                _numTasks    = 0;
                _numIters    = 0;
-               
+
                //reuse shared inputs (to read shared inputs once per process 
instead of once per core; 
                //we reuse everything except result variables and partitioned 
input matrices)
                _ec.pinVariables(_ec.getVarList()); //avoid cleanup of shared 
inputs
                Collection<String> blacklist = 
UtilFunctions.asSet(_resultVars.stream()
                        .map(v -> v._name).collect(Collectors.toList()), 
_ec.getVarListPartitioned());
-               reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist);
+               reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist, 
_brInputs);
                
                //init and register-cleanup of buffer pool (in parfor spark, 
multiple tasks might 
                //share the process-local, i.e., per executor, buffer pool; 
hence we synchronize 
@@ -137,7 +143,7 @@ public class RemoteParForSparkWorker extends ParWorker 
implements PairFlatMapFun
                //mark as initialized
                _initialized = true;
        }
-       
+
        public static void cleanupCachedVariables(long pfid) {
                reuseVars.clearVariables(pfid);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
index ff1ac4f..0517e1b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
@@ -24,39 +24,68 @@ import java.lang.ref.SoftReference;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
 
-public class BroadcastObject<T extends CacheBlock> extends LineageObject
-{
+public class BroadcastObject<T extends CacheBlock> extends LineageObject {
        //soft reference storage for graceful cleanup in case of memory pressure
-       protected final SoftReference<PartitionedBroadcast<T>> _bcHandle;
-       private final long _size;
-       
-       public BroadcastObject( PartitionedBroadcast<T> bvar, long size ) {
+       private SoftReference<PartitionedBroadcast<T>> _pbcRef; // partitioned 
broadcast object reference
+       private SoftReference<Broadcast<T>> _npbcRef; // non partitioned 
broadcast object reference
+
+       private long _pbcSize; // partitioned broadcast size
+       private long _npbcSize; // non-partitioned broadcast size
+
+       public BroadcastObject() {
                super();
-               _bcHandle = new SoftReference<>(bvar);
-               _size = size;
+       }
+
+       public void setNonPartitionedBroadcast(Broadcast<T> bvar, long size) {
+               _npbcRef = new SoftReference<>(bvar);
+               _npbcSize = size;
+       }
+
+       public void setPartitionedBroadcast(PartitionedBroadcast<T> bvar, long 
size) {
+               _pbcRef = new SoftReference<>(bvar);
+               _pbcSize = size;
        }
 
        @SuppressWarnings("rawtypes")
-       public PartitionedBroadcast getBroadcast() {
-               return _bcHandle.get();
+       public PartitionedBroadcast getPartitionedBroadcast() {
+               return _pbcRef.get();
+       }
+
+       public Broadcast<T> getNonPartitionedBroadcast() {
+               return _npbcRef.get();
+       }
+
+       public long getPartitionedBroadcastSize() {
+               return _pbcSize;
        }
-       
-       public long getSize() {
-               return _size;
+
+       public long getNonPartitionedBroadcastSize() {
+               return _npbcSize;
        }
 
-       public boolean isValid() 
-       {
+       public boolean isPartitionedBroadcastValid() {
+               return _pbcRef != null && checkPartitionedBroadcastValid();
+       }
+
+       public boolean isNonPartitionedBroadcastValid() {
+               return _npbcRef != null && checkNonPartitionedBroadcastValid();
+       }
+
+       private boolean checkNonPartitionedBroadcastValid() {
+               return _npbcRef.get() != null;
+       }
+
+       private boolean checkPartitionedBroadcastValid() {
                //check for evicted soft reference
-               PartitionedBroadcast<T> pbm = _bcHandle.get();
-               if( pbm == null )
+               PartitionedBroadcast<T> pbm = _pbcRef.get();
+               if (pbm == null)
                        return false;
-               
+
                //check for validity of individual broadcasts
                Broadcast<PartitionedBlock<T>>[] tmp = pbm.getBroadcasts();
-               for( Broadcast<PartitionedBlock<T>> bc : tmp )
-                       if( !bc.isValid() )
-                               return false;           
+               for (Broadcast<PartitionedBlock<T>> bc : tmp)
+                       if (!bc.isValid())
+                               return false;
                return true;
        }
 }

Reply via email to