This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 05b474c [SYSTEMDS-3098] Add synchronization to async. broadcast
05b474c is described below
commit 05b474c74cb8d8bd1ee1680d92bc65b3ef176220
Author: arnabp <[email protected]>
AuthorDate: Sun Sep 19 12:42:29 2021 +0200
[SYSTEMDS-3098] Add synchronization to async. broadcast
This patch wraps the creation of partitioned broadcast handle
code inside a synchronized block to remove redundant partitioning
by the CP or the new early-broadcast thread.
Moreover, this patch fixes a bug in broadcast count stat collection.
Closes #1393
---
.../context/SparkExecutionContext.java | 108 ++++++++++-----------
.../instructions/cp/TriggerBroadcastTask.java | 4 +-
.../java/org/apache/sysds/utils/Statistics.java | 2 +
3 files changed, 57 insertions(+), 57 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index bb95fe0..880f31f 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -607,6 +607,7 @@ public class SparkExecutionContext extends ExecutionContext
brBlock =
cd.getBroadcastHandle().getNonPartitionedBroadcast();
}
+ //TODO: synchronize
if (brBlock == null) {
//create new broadcast handle (never created, evicted)
// account for overwritten invalid broadcast (e.g.,
evicted)
@@ -651,54 +652,55 @@ public class SparkExecutionContext extends
ExecutionContext
PartitionedBroadcast<MatrixBlock> bret = null;
- //reuse existing broadcast handle
- if (mo.getBroadcastHandle() != null &&
mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
- bret =
mo.getBroadcastHandle().getPartitionedBroadcast();
- }
-
- //create new broadcast handle (never created, evicted)
- if (bret == null) {
- //account for overwritten invalid broadcast (e.g.,
evicted)
- if (mo.getBroadcastHandle() != null)
-
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
-
- //obtain meta data for matrix
- int blen = (int) mo.getBlocksize();
-
- //create partitioned matrix block and release memory
consumed by input
- MatrixBlock mb = mo.acquireRead();
- PartitionedBlock<MatrixBlock> pmb = new
PartitionedBlock<>(mb, blen);
- mo.release();
+ synchronized (mo) { //synchronize with the async. broadcast
thread
+ //reuse existing broadcast handle
+ if (mo.getBroadcastHandle() != null &&
mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
+ bret =
mo.getBroadcastHandle().getPartitionedBroadcast();
+ }
- //determine coarse-grained partitioning
- int numPerPart =
PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(),
mo.getNumColumns(), blen);
- int numParts = (int) Math.ceil((double)
pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
- Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new
Broadcast[numParts];
+ //create new broadcast handle (never created, evicted)
+ if (bret == null) {
+ //account for overwritten invalid broadcast
(e.g., evicted)
+ if (mo.getBroadcastHandle() != null)
+
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
+
+ //obtain meta data for matrix
+ int blen = (int) mo.getBlocksize();
+
+ //create partitioned matrix block and release
memory consumed by input
+ MatrixBlock mb = mo.acquireRead();
+ PartitionedBlock<MatrixBlock> pmb = new
PartitionedBlock<>(mb, blen);
+ mo.release();
+
+ //determine coarse-grained partitioning
+ int numPerPart =
PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(),
mo.getNumColumns(), blen);
+ int numParts = (int) Math.ceil((double)
pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
+ Broadcast<PartitionedBlock<MatrixBlock>>[] ret
= new Broadcast[numParts];
+
+ //create coarse-grained partitioned broadcasts
+ if (numParts > 1) {
+ Arrays.parallelSetAll(ret, i ->
createPartitionedBroadcast(pmb, numPerPart, i));
+ } else { //single partition
+ ret[0] =
getSparkContext().broadcast(pmb);
+ if (!isLocalMaster())
+ pmb.clearBlocks();
+ }
+
+ bret = new PartitionedBroadcast<>(ret,
mo.getDataCharacteristics());
+ // create the broadcast handle if the matrix or
frame has never been broadcasted
+ if (mo.getBroadcastHandle() == null) {
+ mo.setBroadcastHandle(new
BroadcastObject<MatrixBlock>());
+ }
+
mo.getBroadcastHandle().setPartitionedBroadcast(bret,
+
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
+
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
- //create coarse-grained partitioned broadcasts
- if (numParts > 1) {
- Arrays.parallelSetAll(ret, i ->
createPartitionedBroadcast(pmb, numPerPart, i));
- } else { //single partition
- ret[0] = getSparkContext().broadcast(pmb);
- if (!isLocalMaster())
- pmb.clearBlocks();
- }
-
- bret = new PartitionedBroadcast<>(ret,
mo.getDataCharacteristics());
- // create the broadcast handle if the matrix or frame
has never been broadcasted
- if (mo.getBroadcastHandle() == null) {
- mo.setBroadcastHandle(new
BroadcastObject<MatrixBlock>());
+ if (DMLScript.STATISTICS) {
+
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
- mo.getBroadcastHandle().setPartitionedBroadcast(bret,
-
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
-
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
- }
-
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() -
t0);
- Statistics.incSparkBroadcastCount(1);
}
-
return bret;
}
@@ -753,13 +755,12 @@ public class SparkExecutionContext extends
ExecutionContext
to.getBroadcastHandle().setPartitionedBroadcast(bret,
OptimizerUtils.estimatePartitionedSizeExactSparsity(to.getDataCharacteristics()));
CacheableData.addBroadcastSize(to.getBroadcastHandle().getSize());
- }
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() -
t0);
- Statistics.incSparkBroadcastCount(1);
+ if (DMLScript.STATISTICS) {
+
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
-
return bret;
}
@@ -820,13 +821,12 @@ public class SparkExecutionContext extends
ExecutionContext
fo.getBroadcastHandle().setPartitionedBroadcast(bret,
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getDataCharacteristics()));
CacheableData.addBroadcastSize(fo.getBroadcastHandle().getSize());
- }
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() -
t0);
- Statistics.incSparkBroadcastCount(1);
+ if (DMLScript.STATISTICS) {
+
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
-
return bret;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
index cc1187b..122648e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
@@ -36,9 +36,6 @@ public class TriggerBroadcastTask implements Runnable {
@Override
public void run() {
- // TODO: Synchronization. Although it is harmless if to threads
create separate
- // broadcast handles as only one will stay with the
MatrixObject. However, redundant
- // partitioning increases untraced memory usage.
try {
SparkExecutionContext sec = (SparkExecutionContext)_ec;
sec.setBroadcastHandle(_broadcastMO);
@@ -47,6 +44,7 @@ public class TriggerBroadcastTask implements Runnable {
e.printStackTrace();
}
+ //TODO: Count only if successful (owned lock)
if (DMLScript.STATISTICS)
Statistics.incSparkAsyncBroadcastCount(1);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index d91d9c5..b97ae61 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -511,6 +511,8 @@ public class Statistics
parforMergeTime = 0;
sparkCtxCreateTime = 0;
+ sparkBroadcast.reset();
+ sparkBroadcastCount.reset();
sparkAsyncPrefetchCount.reset();
sparkAsyncBroadcastCount.reset();