Repository: systemml Updated Branches: refs/heads/master 9ea71d534 -> 945f3ccf9
[SYSTEMML-1313] Runtime support for remote parfor broadcast inputs Closes #759. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/945f3ccf Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/945f3ccf Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/945f3ccf Branch: refs/heads/master Commit: 945f3ccf9f167ee43c5946b2ea871af5b56a72be Parents: 9ea71d5 Author: EdgarLGB <[email protected]> Authored: Sat Apr 28 16:02:37 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Apr 28 16:02:41 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/ParForProgramBlock.java | 11 +- .../context/ExecutionContext.java | 5 + .../context/SparkExecutionContext.java | 121 +++++++++++++------ .../parfor/CachedReuseVariables.java | 33 ++++- .../parfor/RemoteParForSpark.java | 47 ++++++- .../parfor/RemoteParForSparkWorker.java | 14 ++- .../spark/data/BroadcastObject.java | 71 +++++++---- 7 files changed, 227 insertions(+), 75 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index c46e4dd..3d17e2f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -308,15 +308,16 @@ public class ParForProgramBlock extends ForProgramBlock public static final boolean FORCE_CP_ON_REMOTE_MR = true; // compile body to CP if exec type forced to MR public static final boolean LIVEVAR_AWARE_EXPORT = true; // export only read variables according to live variable analysis public static final boolean RESET_RECOMPILATION_FLAGs = true; - - public static final String PARFOR_FNAME_PREFIX = "/parfor/"; + public static final boolean ALLOW_BROADCAST_INPUTS = false; // enables to broadcast inputs for remote_spark + + public static final String PARFOR_FNAME_PREFIX = "/parfor/"; public static final String PARFOR_MR_TASKS_TMP_FNAME = PARFOR_FNAME_PREFIX + "%ID%_MR_taskfile"; public static final String PARFOR_MR_RESULT_TMP_FNAME = PARFOR_FNAME_PREFIX + "%ID%_MR_results"; public static final String PARFOR_MR_RESULTMERGE_FNAME = PARFOR_FNAME_PREFIX + "%ID%_resultmerge%VAR%"; public static final String PARFOR_DATAPARTITIONS_FNAME = PARFOR_FNAME_PREFIX + "%ID%_datapartitions%VAR%"; public static final String PARFOR_COUNTER_GROUP_NAME = "SystemML ParFOR Counters"; - + // static ID generator sequences private final static IDSequence _pfIDSeq = new IDSequence(); private final static IDSequence _pwIDSeq = new IDSequence(); @@ -1042,8 +1043,8 @@ public class ParForProgramBlock extends ForProgramBlock exportMatricesToHDFS(ec); // Step 3) submit Spark parfor job (no lazy evaluation, since collect on result) - //MatrixObject colocatedDPMatrixObj = (_colocatedDPMatrix!=null)? (MatrixObject)ec.getVariable(_colocatedDPMatrix) : null; - RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, program, clsMap, tasks, ec, _enableCPCaching, _numThreads); + RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, program, + clsMap, tasks, ec, _resultVars, _enableCPCaching, _numThreads); if( _monitor ) StatisticMonitor.putPFStat(_ID, Stat.PARFOR_WAIT_EXEC_T, time.stop()); http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index dfc3dc5..f339efb 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -195,6 +195,11 @@ public class ExecutionContext { return (MatrixObject) dat; } + + public boolean isFrameObject(String varname) { + Data dat = getVariable(varname); + return (dat!= null && dat instanceof FrameObject); + } public FrameObject getFrameObject(CPOperand input) { return getFrameObject(input.getName()); http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 4bb1f3a..8e7d888 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -505,10 +505,48 @@ public class SparkExecutionContext extends ExecutionContext return rdd; } - + + public Broadcast<CacheBlock> broadcastVariable(CacheableData<CacheBlock> cd) { + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; + Broadcast<CacheBlock> brBlock = null; + + // reuse existing non partitioned broadcast handle + if (cd.getBroadcastHandle() != null && cd.getBroadcastHandle().isNonPartitionedBroadcastValid()) { + brBlock = cd.getBroadcastHandle().getNonPartitionedBroadcast(); + } + + if (brBlock == null) { + //create new broadcast handle (never created, evicted) + // account for overwritten invalid broadcast (e.g., evicted) + if (cd.getBroadcastHandle() != null) + CacheableData.addBroadcastSize(-cd.getBroadcastHandle().getNonPartitionedBroadcastSize()); + + // read the matrix block + CacheBlock cb = cd.acquireRead(); + cd.release(); + + // broadcast a non-empty frame whose size is smaller than 2G + if (cb.getExactSerializedSize() > 0 && cb.getExactSerializedSize() <= Integer.MAX_VALUE) { + brBlock = getSparkContext().broadcast(cb); + // create the broadcast handle if the matrix or frame has never been broadcasted + if (cd.getBroadcastHandle() == null) { + cd.setBroadcastHandle(new BroadcastObject<>()); + } + cd.getBroadcastHandle().setNonPartitionedBroadcast(brBlock, + OptimizerUtils.estimateSize(cd.getMatrixCharacteristics())); + CacheableData.addBroadcastSize(cd.getBroadcastHandle().getNonPartitionedBroadcastSize()); + + if (DMLScript.STATISTICS) { + Statistics.accSparkBroadCastTime(System.nanoTime() - t0); + Statistics.incSparkBroadcastCount(1); + } + } + } + return brBlock; + } + @SuppressWarnings("unchecked") - public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable( String varname ) - { + public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable(String varname) { //NOTE: The memory consumption of this method is the in-memory size of the //matrix object plus the partitioned size in 1k-1k blocks. Since the call //to broadcast happens after the matrix object has been released, the memory @@ -524,18 +562,15 @@ public class SparkExecutionContext extends ExecutionContext PartitionedBroadcast<MatrixBlock> bret = null; //reuse existing broadcast handle - if( mo.getBroadcastHandle()!=null - && mo.getBroadcastHandle().isValid() ) - { - bret = mo.getBroadcastHandle().getBroadcast(); + if (mo.getBroadcastHandle() != null && mo.getBroadcastHandle().isPartitionedBroadcastValid()) { + bret = mo.getBroadcastHandle().getPartitionedBroadcast(); } //create new broadcast handle (never created, evicted) - if( bret == null ) - { + if (bret == null) { //account for overwritten invalid broadcast (e.g., evicted) - if( mo.getBroadcastHandle()!=null ) - CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize()); + if (mo.getBroadcastHandle() != null) + CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getPartitionedBroadcastSize()); //obtain meta data for matrix int brlen = (int) mo.getNumRowsPerBlock(); @@ -548,24 +583,26 @@ public class SparkExecutionContext extends ExecutionContext //determine coarse-grained partitioning int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), mo.getNumColumns(), brlen, bclen); - int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart); + int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart); Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new Broadcast[numParts]; //create coarse-grained partitioned broadcasts if (numParts > 1) { Arrays.parallelSetAll(ret, i -> createPartitionedBroadcast(pmb, numPerPart, i)); - } - else { //single partition + } else { //single partition ret[0] = getSparkContext().broadcast(pmb); if (!isLocalMaster()) pmb.clearBlocks(); } bret = new PartitionedBroadcast<>(ret, mo.getMatrixCharacteristics()); - BroadcastObject<MatrixBlock> bchandle = new BroadcastObject<>(bret, - OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics())); - mo.setBroadcastHandle(bchandle); - CacheableData.addBroadcastSize(bchandle.getSize()); + // create the broadcast handle if the matrix or frame has never been broadcasted + if (mo.getBroadcastHandle() == null) { + mo.setBroadcastHandle(new BroadcastObject<MatrixBlock>()); + } + mo.getBroadcastHandle().setPartitionedBroadcast(bret, + OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics())); + CacheableData.addBroadcastSize(mo.getBroadcastHandle().getPartitionedBroadcastSize()); } if (DMLScript.STATISTICS) { @@ -577,8 +614,7 @@ public class SparkExecutionContext extends ExecutionContext } @SuppressWarnings("unchecked") - public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable( String varname) - { + public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable(String varname) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; FrameObject fo = getFrameObject(varname); @@ -586,18 +622,15 @@ public class SparkExecutionContext extends ExecutionContext PartitionedBroadcast<FrameBlock> bret = null; //reuse existing broadcast handle - if( fo.getBroadcastHandle()!=null - && fo.getBroadcastHandle().isValid() ) - { - bret = fo.getBroadcastHandle().getBroadcast(); + if (fo.getBroadcastHandle() != null && fo.getBroadcastHandle().isPartitionedBroadcastValid()) { + bret = fo.getBroadcastHandle().getPartitionedBroadcast(); } //create new broadcast handle (never created, evicted) - if( bret == null ) - { + if (bret == null) { //account for overwritten invalid broadcast (e.g., evicted) - if( fo.getBroadcastHandle()!=null ) - CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize()); + if (fo.getBroadcastHandle() != null) + CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getPartitionedBroadcastSize()); //obtain meta data for frame int bclen = (int) fo.getNumColumns(); @@ -610,24 +643,25 @@ public class SparkExecutionContext extends ExecutionContext //determine coarse-grained partitioning int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(fo.getNumRows(), fo.getNumColumns(), brlen, bclen); - int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart); + int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart); Broadcast<PartitionedBlock<FrameBlock>>[] ret = new Broadcast[numParts]; //create coarse-grained partitioned broadcasts if (numParts > 1) { Arrays.parallelSetAll(ret, i -> createPartitionedBroadcast(pmb, numPerPart, i)); - } - else { //single partition + } else { //single partition ret[0] = getSparkContext().broadcast(pmb); if (!isLocalMaster()) pmb.clearBlocks(); } bret = new PartitionedBroadcast<>(ret, fo.getMatrixCharacteristics()); - BroadcastObject<FrameBlock> bchandle = new BroadcastObject<>(bret, - OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics())); - fo.setBroadcastHandle(bchandle); - CacheableData.addBroadcastSize(bchandle.getSize()); + if (fo.getBroadcastHandle() == null) { + fo.setBroadcastHandle(new BroadcastObject<FrameBlock>()); + } + fo.getBroadcastHandle().setPartitionedBroadcast(bret, + OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics())); + CacheableData.addBroadcastSize(fo.getBroadcastHandle().getPartitionedBroadcastSize()); } if (DMLScript.STATISTICS) { @@ -1124,11 +1158,20 @@ public class SparkExecutionContext extends ExecutionContext _parRDDs.deregisterRDD(rddID); } else if( lob instanceof BroadcastObject ) { - PartitionedBroadcast pbm = ((BroadcastObject)lob).getBroadcast(); - if( pbm != null ) //robustness for evictions - for( Broadcast<PartitionedBlock> bc : pbm.getBroadcasts() ) + BroadcastObject bob = (BroadcastObject) lob; + // clean the partitioned broadcast + if (bob.isPartitionedBroadcastValid()) { + PartitionedBroadcast pbm = bob.getPartitionedBroadcast(); + if( pbm != null ) //robustness evictions + pbm.destroy(); + } + // clean the non-partitioned broadcast + if (((BroadcastObject) lob).isNonPartitionedBroadcastValid()) { + Broadcast<CacheableData> bc = bob.getNonPartitionedBroadcast(); + if( bc != null ) //robustness evictions cleanupBroadcastVariable(bc); - CacheableData.addBroadcastSize(-((BroadcastObject)lob).getSize()); + } + CacheableData.addBroadcastSize(-bob.getNonPartitionedBroadcastSize()); } //recursively process lineage children http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java index 2532cbd..3db8e7a 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/CachedReuseVariables.java @@ -23,8 +23,15 @@ import java.lang.ref.SoftReference; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.spark.broadcast.Broadcast; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; +import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.instructions.cp.Data; public class CachedReuseVariables { @@ -33,11 +40,22 @@ public class CachedReuseVariables public CachedReuseVariables() { _data = new HashMap<>(); } + + public synchronized boolean containsVars(long pfid) { + return _data.containsKey(pfid); + } - public synchronized void reuseVariables(long pfid, LocalVariableMap vars, Collection<String> blacklist) { + @SuppressWarnings("unused") + public synchronized void reuseVariables(long pfid, LocalVariableMap vars, Collection<String> blacklist, Map<String, Broadcast<CacheBlock>> _brInputs) { + + //fetch the broadcast variables + if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS && !containsVars(pfid)) { + loadBroadcastVariables(vars, _brInputs); + } + //check for existing reuse map LocalVariableMap tmp = null; - if( _data.containsKey(pfid) ) + if( containsVars(pfid) ) tmp = _data.get(pfid).get(); //build reuse map if not created yet or evicted @@ -57,4 +75,15 @@ public class CachedReuseVariables public synchronized void clearVariables(long pfid) { _data.remove(pfid); } + + @SuppressWarnings("unchecked") + private static void loadBroadcastVariables(LocalVariableMap variables, Map<String, Broadcast<CacheBlock>> brInputs) { + for( Entry<String, Broadcast<CacheBlock>> e : brInputs.entrySet() ) { + Data d = variables.get(e.getKey()); + CacheableData<CacheBlock> cdcb = (CacheableData<CacheBlock>) d; + cdcb.acquireModify(e.getValue().getValue()); + cdcb.setEmptyStatus(); // avoid eviction + cdcb.refreshMetaData(); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java index eee65f3..7d75cb6 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSpark.java @@ -19,14 +19,27 @@ package org.apache.sysml.runtime.controlprogram.parfor; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; +import org.apache.sysml.parser.ParForStatementBlock; +import org.apache.sysml.parser.ParForStatementBlock.ResultVar; +import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; +import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; import scala.Tuple2; import org.apache.sysml.api.DMLScript; @@ -47,7 +60,6 @@ import org.apache.sysml.utils.Statistics; * pre-aggregation by overwriting partial task results with pre-paggregated results from subsequent * iterations) * - * TODO broadcast variables if possible * TODO reducebykey on variable names */ public class RemoteParForSpark @@ -58,7 +70,7 @@ public class RemoteParForSpark private static final IDSequence _jobID = new IDSequence(); public static RemoteParForJobReturn runJob(long pfid, String prog, HashMap<String, byte[]> clsMap, - List<Task> tasks, ExecutionContext ec, boolean cpCaching, int numMappers) + List<Task> tasks, ExecutionContext ec, ArrayList<ResultVar> resultVars, boolean cpCaching, int numMappers) { String jobname = "ParFor-ESP"; long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; @@ -74,11 +86,17 @@ public class RemoteParForSpark long jobid = _jobID.getNextID(); if( InfrastructureAnalyzer.isLocalMode() ) RemoteParForSparkWorker.cleanupCachedVariables(jobid); - + + // broadcast the inputs except the result variables + Map<String, Broadcast<CacheBlock>> brInputs = null; + if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS) { + brInputs = broadcastInputs(sec, resultVars); + } + //run remote_spark parfor job //(w/o lazy evaluation to fit existing parfor framework, e.g., result merge) List<Tuple2<Long,String>> out = sc.parallelize(tasks, tasks.size()) //create rdd of parfor tasks - .flatMapToPair(new RemoteParForSparkWorker(jobid, prog, clsMap, cpCaching, aTasks, aIters)) + .flatMapToPair(new RemoteParForSparkWorker(jobid, prog, clsMap, cpCaching, aTasks, aIters, brInputs)) .collect(); //execute and get output handles //de-serialize results @@ -97,4 +115,25 @@ public class RemoteParForSpark return ret; } + + @SuppressWarnings("unchecked") + private static Map<String, Broadcast<CacheBlock>> broadcastInputs(SparkExecutionContext sec, ArrayList<ParForStatementBlock.ResultVar> resultVars) { + LocalVariableMap inputs = sec.getVariables(); + // exclude the result variables + // TODO use optimizer-picked list of amenable objects (e.g., size constraints) + Set<String> retVars = resultVars.stream() + .map(v -> v._name).collect(Collectors.toSet()); + Set<String> brVars = inputs.keySet().stream() + .filter(v -> !retVars.contains(v)).collect(Collectors.toSet()); + + // construct broadcast objects + Map<String, Broadcast<CacheBlock>> result = new HashMap<>(); + for (String key : brVars) { + Data var = sec.getVariable(key); + if ((var instanceof ScalarObject) || (var instanceof MatrixObject && ((MatrixObject) var).isPartitioned())) + continue; + result.put(key, sec.broadcastVariable((CacheableData<CacheBlock>) var)); + } + return result; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java index 7485602..b22f48d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java @@ -23,13 +23,16 @@ import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; +import java.util.Map; import java.util.Map.Entry; import java.util.stream.Collectors; import org.apache.spark.TaskContext; import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler; @@ -52,8 +55,10 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun private final LongAccumulator _aTasks; private final LongAccumulator _aIters; + + private final Map<String, Broadcast<CacheBlock>> _brInputs; - public RemoteParForSparkWorker(long jobid, String program, HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, LongAccumulator aiters) { + public RemoteParForSparkWorker(long jobid, String program, HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, LongAccumulator aiters, Map<String, Broadcast<CacheBlock>> brInputs) { _jobid = jobid; _prog = program; _clsMap = clsMap; @@ -62,6 +67,7 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun //setup spark accumulators _aTasks = atasks; _aIters = aiters; + _brInputs = brInputs; } @Override @@ -102,13 +108,13 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun _resultVars = body.getResultVariables(); _numTasks = 0; _numIters = 0; - + //reuse shared inputs (to read shared inputs once per process instead of once per core; //we reuse everything except result variables and partitioned input matrices) _ec.pinVariables(_ec.getVarList()); //avoid cleanup of shared inputs Collection<String> blacklist = UtilFunctions.asSet(_resultVars.stream() .map(v -> v._name).collect(Collectors.toList()), _ec.getVarListPartitioned()); - reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist); + reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist, _brInputs); //init and register-cleanup of buffer pool (in parfor spark, multiple tasks might //share the process-local, i.e., per executor, buffer pool; hence we synchronize @@ -137,7 +143,7 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun //mark as initialized _initialized = true; } - + public static void cleanupCachedVariables(long pfid) { reuseVars.clearVariables(pfid); } http://git-wip-us.apache.org/repos/asf/systemml/blob/945f3ccf/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java index ff1ac4f..0517e1b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java @@ -24,39 +24,68 @@ import java.lang.ref.SoftReference; import org.apache.spark.broadcast.Broadcast; import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; -public class BroadcastObject<T extends CacheBlock> extends LineageObject -{ +public class BroadcastObject<T extends CacheBlock> extends LineageObject { //soft reference storage for graceful cleanup in case of memory pressure - protected final SoftReference<PartitionedBroadcast<T>> _bcHandle; - private final long _size; - - public BroadcastObject( PartitionedBroadcast<T> bvar, long size ) { + private SoftReference<PartitionedBroadcast<T>> _pbcRef; // partitioned broadcast object reference + private SoftReference<Broadcast<T>> _npbcRef; // non partitioned broadcast object reference + + private long _pbcSize; // partitioned broadcast size + private long _npbcSize; // non-partitioned broadcast size + + public BroadcastObject() { super(); - _bcHandle = new SoftReference<>(bvar); - _size = size; + } + + public void setNonPartitionedBroadcast(Broadcast<T> bvar, long size) { + _npbcRef = new SoftReference<>(bvar); + _npbcSize = size; + } + + public void setPartitionedBroadcast(PartitionedBroadcast<T> bvar, long size) { + _pbcRef = new SoftReference<>(bvar); + _pbcSize = size; } @SuppressWarnings("rawtypes") - public PartitionedBroadcast getBroadcast() { - return _bcHandle.get(); + public PartitionedBroadcast getPartitionedBroadcast() { + return _pbcRef.get(); + } + + public Broadcast<T> getNonPartitionedBroadcast() { + return _npbcRef.get(); + } + + public long getPartitionedBroadcastSize() { + return _pbcSize; } - - public long getSize() { - return _size; + + public long getNonPartitionedBroadcastSize() { + return _npbcSize; } - public boolean isValid() - { + public boolean isPartitionedBroadcastValid() { + return _pbcRef != null && checkPartitionedBroadcastValid(); + } + + public boolean isNonPartitionedBroadcastValid() { + return _npbcRef != null && checkNonPartitionedBroadcastValid(); + } + + private boolean checkNonPartitionedBroadcastValid() { + return _npbcRef.get() != null; + } + + private boolean checkPartitionedBroadcastValid() { //check for evicted soft reference - PartitionedBroadcast<T> pbm = _bcHandle.get(); - if( pbm == null ) + PartitionedBroadcast<T> pbm = _pbcRef.get(); + if (pbm == null) return false; - + //check for validity of individual broadcasts Broadcast<PartitionedBlock<T>>[] tmp = pbm.getBroadcasts(); - for( Broadcast<PartitionedBlock<T>> bc : tmp ) - if( !bc.isValid() ) - return false; + for (Broadcast<PartitionedBlock<T>> bc : tmp) + if (!bc.isValid()) + return false; return true; } }
