You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2021/09/19 10:43:52 UTC
[systemds] branch master updated: [SYSTEMDS-3098] Add
synchronization to async. broadcast
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 05b474c [SYSTEMDS-3098] Add synchronization to async. broadcast
05b474c is described below
commit 05b474c74cb8d8bd1ee1680d92bc65b3ef176220
Author: arnabp <ar...@tugraz.at>
AuthorDate: Sun Sep 19 12:42:29 2021 +0200
[SYSTEMDS-3098] Add synchronization to async. broadcast
This patch wraps the creation of partitioned broadcast handle
code inside a synchronized block to remove redundant partitioning
by the CP or the new early-broadcast thread.
Moreover, this patch fixes a bug in broadcast count stat collection.
Closes #1393
---
.../context/SparkExecutionContext.java | 108 ++++++++++-----------
.../instructions/cp/TriggerBroadcastTask.java | 4 +-
.../java/org/apache/sysds/utils/Statistics.java | 2 +
3 files changed, 57 insertions(+), 57 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index bb95fe0..880f31f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -607,6 +607,7 @@ public class SparkExecutionContext extends ExecutionContext
brBlock = cd.getBroadcastHandle().getNonPartitionedBroadcast();
}
+ //TODO: synchronize
if (brBlock == null) {
//create new broadcast handle (never created, evicted)
// account for overwritten invalid broadcast (e.g., evicted)
@@ -651,54 +652,55 @@ public class SparkExecutionContext extends ExecutionContext
PartitionedBroadcast<MatrixBlock> bret = null;
- //reuse existing broadcast handle
- if (mo.getBroadcastHandle() != null && mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
- bret = mo.getBroadcastHandle().getPartitionedBroadcast();
- }
-
- //create new broadcast handle (never created, evicted)
- if (bret == null) {
- //account for overwritten invalid broadcast (e.g., evicted)
- if (mo.getBroadcastHandle() != null)
- CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
-
- //obtain meta data for matrix
- int blen = (int) mo.getBlocksize();
-
- //create partitioned matrix block and release memory consumed by input
- MatrixBlock mb = mo.acquireRead();
- PartitionedBlock<MatrixBlock> pmb = new PartitionedBlock<>(mb, blen);
- mo.release();
+ synchronized (mo) { //synchronize with the async. broadcast thread
+ //reuse existing broadcast handle
+ if (mo.getBroadcastHandle() != null && mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
+ bret = mo.getBroadcastHandle().getPartitionedBroadcast();
+ }
- //determine coarse-grained partitioning
- int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), mo.getNumColumns(), blen);
- int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
- Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new Broadcast[numParts];
+ //create new broadcast handle (never created, evicted)
+ if (bret == null) {
+ //account for overwritten invalid broadcast (e.g., evicted)
+ if (mo.getBroadcastHandle() != null)
+ CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
+
+ //obtain meta data for matrix
+ int blen = (int) mo.getBlocksize();
+
+ //create partitioned matrix block and release memory consumed by input
+ MatrixBlock mb = mo.acquireRead();
+ PartitionedBlock<MatrixBlock> pmb = new PartitionedBlock<>(mb, blen);
+ mo.release();
+
+ //determine coarse-grained partitioning
+ int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), mo.getNumColumns(), blen);
+ int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
+ Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new Broadcast[numParts];
+
+ //create coarse-grained partitioned broadcasts
+ if (numParts > 1) {
+ Arrays.parallelSetAll(ret, i -> createPartitionedBroadcast(pmb, numPerPart, i));
+ } else { //single partition
+ ret[0] = getSparkContext().broadcast(pmb);
+ if (!isLocalMaster())
+ pmb.clearBlocks();
+ }
+
+ bret = new PartitionedBroadcast<>(ret, mo.getDataCharacteristics());
+ // create the broadcast handle if the matrix or frame has never been broadcasted
+ if (mo.getBroadcastHandle() == null) {
+ mo.setBroadcastHandle(new BroadcastObject<MatrixBlock>());
+ }
+ mo.getBroadcastHandle().setPartitionedBroadcast(bret,
+ OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
+ CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
- //create coarse-grained partitioned broadcasts
- if (numParts > 1) {
- Arrays.parallelSetAll(ret, i -> createPartitionedBroadcast(pmb, numPerPart, i));
- } else { //single partition
- ret[0] = getSparkContext().broadcast(pmb);
- if (!isLocalMaster())
- pmb.clearBlocks();
- }
-
- bret = new PartitionedBroadcast<>(ret, mo.getDataCharacteristics());
- // create the broadcast handle if the matrix or frame has never been broadcasted
- if (mo.getBroadcastHandle() == null) {
- mo.setBroadcastHandle(new BroadcastObject<MatrixBlock>());
+ if (DMLScript.STATISTICS) {
+ Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
- mo.getBroadcastHandle().setPartitionedBroadcast(bret,
- OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
- CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
- }
-
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
- Statistics.incSparkBroadcastCount(1);
}
-
return bret;
}
@@ -753,13 +755,12 @@ public class SparkExecutionContext extends ExecutionContext
to.getBroadcastHandle().setPartitionedBroadcast(bret,
OptimizerUtils.estimatePartitionedSizeExactSparsity(to.getDataCharacteristics()));
CacheableData.addBroadcastSize(to.getBroadcastHandle().getSize());
- }
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
- Statistics.incSparkBroadcastCount(1);
+ if (DMLScript.STATISTICS) {
+ Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
-
return bret;
}
@@ -820,13 +821,12 @@ public class SparkExecutionContext extends ExecutionContext
fo.getBroadcastHandle().setPartitionedBroadcast(bret,
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getDataCharacteristics()));
CacheableData.addBroadcastSize(fo.getBroadcastHandle().getSize());
- }
- if (DMLScript.STATISTICS) {
- Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
- Statistics.incSparkBroadcastCount(1);
+ if (DMLScript.STATISTICS) {
+ Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
+ Statistics.incSparkBroadcastCount(1);
+ }
}
-
return bret;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
index cc1187b..122648e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
@@ -36,9 +36,6 @@ public class TriggerBroadcastTask implements Runnable {
@Override
public void run() {
- // TODO: Synchronization. Although it is harmless if to threads create separate
- // broadcast handles as only one will stay with the MatrixObject. However, redundant
- // partitioning increases untraced memory usage.
try {
SparkExecutionContext sec = (SparkExecutionContext)_ec;
sec.setBroadcastHandle(_broadcastMO);
@@ -47,6 +44,7 @@ public class TriggerBroadcastTask implements Runnable {
e.printStackTrace();
}
+ //TODO: Count only if successful (owned lock)
if (DMLScript.STATISTICS)
Statistics.incSparkAsyncBroadcastCount(1);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index d91d9c5..b97ae61 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -511,6 +511,8 @@ public class Statistics
parforMergeTime = 0;
sparkCtxCreateTime = 0;
+ sparkBroadcast.reset();
+ sparkBroadcastCount.reset();
sparkAsyncPrefetchCount.reset();
sparkAsyncBroadcastCount.reset();