You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2023/08/08 12:51:14 UTC
[systemds] branch main updated: [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 68c2c17db5 [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
68c2c17db5 is described below
commit 68c2c17db56c8c9ea5b1985bbf3b525d4ac8d022
Author: baunsgaard <ba...@tu-berlin.de>
AuthorDate: Mon Aug 7 13:40:42 2023 +0200
[SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
This commit allows reuse of custom size thread pools, but only for the
main thread. Only allowing the main thread to reuse a pool avoids
problems with parfor spawning threads that use the same shared pool.
I tried using ThreadLocal to solve this problem initially, but this
did not work with our testing framework while it did work in practice.
This implementation is a compromise to work with the test framework,
while not introducing to much code.
Closes #1873
---
src/main/java/org/apache/sysds/api/DMLScript.java | 2 +-
.../runtime/controlprogram/ParForProgramBlock.java | 23 +-
.../instructions/cp/BroadcastCPInstruction.java | 7 +-
.../instructions/cp/PrefetchCPInstruction.java | 6 +-
.../spark/AggregateUnarySPInstruction.java | 12 +-
.../spark/CheckpointSPInstruction.java | 6 +-
.../instructions/spark/CpmmSPInstruction.java | 16 +-
.../instructions/spark/MapmmSPInstruction.java | 7 +-
.../instructions/spark/TsmmSPInstruction.java | 12 +-
.../instructions/spark/ZipmmSPInstruction.java | 14 +-
.../runtime/lineage/LineageSparkCacheEviction.java | 13 +-
.../sysds/runtime/util/CommonThreadPool.java | 181 ++++++---
.../sysds/test/component/misc/ThreadPool.java | 408 +++++++++++++++++++++
13 files changed, 593 insertions(+), 114 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index ff70b330ae..ddc5ee2517 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -573,7 +573,7 @@ public class DMLScript
FederatedData.clearFederatedWorkers();
//0) shutdown prefetch/broadcast thread pool if necessary
- CommonThreadPool.shutdownAsyncRDDPool();
+ CommonThreadPool.shutdownAsyncPools();
//1) cleanup scratch space (everything for current uuid)
//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 2fc12c4c26..94bbaf2545 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.controlprogram;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
@@ -89,6 +91,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.ParForStatistics;
@@ -118,8 +121,8 @@ import java.util.stream.Stream;
* TODO: papply(A,1:2,FUN) language construct (compiled to ParFOR) via DML function repository => modules OK, but second-order functions required
*
*/
-public class ParForProgramBlock extends ForProgramBlock
-{
+public class ParForProgramBlock extends ForProgramBlock {
+ protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
// execution modes
public enum PExecMode {
LOCAL, //local (master) multi-core execution mode
@@ -759,7 +762,7 @@ public class ParForProgramBlock extends ForProgramBlock
LocalTaskQueue<Task> queue = new LocalTaskQueue<>();
Thread[] threads = new Thread[_numThreads];
LocalParWorker[] workers = new LocalParWorker[_numThreads];
- IntStream.range(0, _numThreads).parallel().forEach(i -> {
+ IntStream.range(0, _numThreads).forEach(i -> {
workers[i] = createParallelWorker( _pwIDs[i], queue, ec, i);
threads[i] = new Thread( workers[i] );
threads[i].setPriority(Thread.MAX_PRIORITY);
@@ -1430,9 +1433,14 @@ public class ParForProgramBlock extends ForProgramBlock
}
}
- private void consolidateAndCheckResults(ExecutionContext ec, long expIters, long expTasks, long numIters, long numTasks, LocalVariableMap [] results)
- {
+ private void consolidateAndCheckResults(ExecutionContext ec, final long expIters, final long expTasks,
+ final long numIters, final long numTasks, LocalVariableMap[] results) {
Timing time = new Timing(true);
+
+ //check expected counters
+ if( numTasks != expTasks || numIters !=expIters ) //consistency check
+ throw new DMLRuntimeException("PARFOR: Number of executed tasks does not match the number of created tasks: tasks "+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
+
//result merge
if( checkParallelRemoteResultMerge() )
@@ -1531,10 +1539,7 @@ public class ParForProgramBlock extends ForProgramBlock
if( CREATE_UNSCOPED_RESULTVARS && sb != null && ec.getVariables() != null ) //sb might be null for nested parallelism
createEmptyUnscopedVariables( ec.getVariables(), sb );
- //check expected counters
- if( numTasks != expTasks || numIters !=expIters ) //consistency check
- throw new DMLRuntimeException("PARFOR: Number of executed tasks does not match the number of created tasks: tasks "+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
-
+
if( DMLScript.STATISTICS )
ParForStatistics.incrementMergeTime((long) time.stop());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index aa0be7daec..58378bf7e4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.instructions.cp;
-import java.util.concurrent.Executors;
-
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -43,9 +41,6 @@ public class BroadcastCPInstruction extends UnaryCPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
+ CommonThreadPool.getDynamicPool().submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 96e4b7afe2..233509a5b8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.instructions.cp;
-import java.util.concurrent.Executors;
-
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
@@ -53,10 +51,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction {
// If the next instruction which takes this output as an input comes before
// the prefetch thread triggers, that instruction will start the operations.
// In that case this Prefetch instruction will act like a NOOP.
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
// Saving the lineage item inside the matrix object will replace the pre-attached
// lineage item (e.g. mapmm). Hence, passing separately.
- CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerPrefetchTask(ec.getMatrixObject(output), li));
+ CommonThreadPool.getDynamicPool().submit(new TriggerPrefetchTask(ec.getMatrixObject(output), li));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 50816aefe4..32b80a2360 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -49,11 +52,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
public class AggregateUnarySPInstruction extends UnarySPInstruction {
private SparkAggType _aggtype = null;
@@ -115,10 +115,8 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction {
//Trigger the chain of Spark operations and maintain a future to the result
//TODO: Make memory for the future matrix block
try {
- if(CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
RDDAggregateTask task = new RDDAggregateTask(_optr, _aop, in, mc);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 2be663bdbc..4ee56a7eca 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -51,8 +51,6 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
-import java.util.concurrent.Executors;
-
public class CheckpointSPInstruction extends UnarySPInstruction {
// default storage level
private StorageLevel _level = null;
@@ -86,9 +84,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction {
// TODO: Synchronize. Avoid double execution
ec.setVariable(output.getName(), ec.getCacheableData(input1));
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerCheckpointTask(ec.getMatrixObject(output)));
+ CommonThreadPool.getDynamicPool().submit(new TriggerCheckpointTask(ec.getMatrixObject(output)));
return;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 6425613583..602d74a275 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -47,11 +50,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
/**
* Cpmm: cross-product matrix multiplication operation (distributed matrix multiply
@@ -112,10 +112,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
{
if (ConfigurationManager.isMaxPrallelizeEnabled()) {
try {
- if(CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
CpmmMatrixVectorTask task = new CpmmMatrixVectorTask(in1, in2);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
@@ -147,10 +145,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
{
if (ConfigurationManager.isMaxPrallelizeEnabled()) {
try {
- if(CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
CpmmMatrixMatrixTask task = new CpmmMatrixMatrixTask(in1, in2, numPartJoin);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
sec.setMatrixOutput(output.getName(), future_out);
}
catch(Exception ex) { throw new DMLRuntimeException(ex); }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index b0285b1bba..080de52d23 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.spark;
import java.util.Iterator;
import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
@@ -59,8 +58,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-
import org.apache.sysds.runtime.util.CommonThreadPool;
+
import scala.Tuple2;
public class MapmmSPInstruction extends AggregateBinarySPInstruction {
@@ -144,10 +143,8 @@ public class MapmmSPInstruction extends AggregateBinarySPInstruction {
{
if (ConfigurationManager.isMaxPrallelizeEnabled()) {
try {
- if(CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
RDDMapmmTask task = new RDDMapmmTask(in1, in2, type);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index acba784bf2..dd6ddb526d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -20,6 +20,9 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -37,11 +40,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
public class TsmmSPInstruction extends UnarySPInstruction {
private MMTSJType _type = null;
@@ -72,10 +72,8 @@ public class TsmmSPInstruction extends UnarySPInstruction {
if (ConfigurationManager.isMaxPrallelizeEnabled()) {
try {
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
TsmmTask task = new TsmmTask(in, _type);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
index de7922e25a..18d88178a3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -42,11 +45,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
public class ZipmmSPInstruction extends BinarySPInstruction {
// internal flag to apply left-transpose rewrite or not
@@ -86,10 +86,8 @@ public class ZipmmSPInstruction extends BinarySPInstruction {
if (ConfigurationManager.isMaxPrallelizeEnabled()) {
try {
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
- ZipmmTask task = new ZipmmTask(in1, in2, _tRewrite);
- Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ ZipmmTask task = new ZipmmTask(in1, in2, _tRewrite);
+ Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
index 65648508f8..84bb8598c0 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
@@ -19,6 +19,10 @@
package org.apache.sysds.runtime.lineage;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.TreeSet;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -27,11 +31,6 @@ import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
import org.apache.sysds.runtime.util.CommonThreadPool;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.TreeSet;
-import java.util.concurrent.Executors;
-
public class LineageSparkCacheEviction
{
private static long SPARK_STORAGE_LIMIT = 0; //60% (upper limit of Spark unified memory)
@@ -212,9 +211,7 @@ public class LineageSparkCacheEviction
int localHitCount = RDDHitCountLocal.get(e._key);
if (localHitCount > 3) {
RDDHitCountLocal.remove(e._key);
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerRemoteTask(e.getRDDObject().getRDD()));
+ CommonThreadPool.getDynamicPool().submit(new TriggerRemoteTask(e.getRDDObject().getRDD()));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index f96c4cc4af..cc6483d258 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -30,44 +30,118 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
/**
- * This common thread pool provides an abstraction to obtain a shared
- * thread pool, specifically the ForkJoinPool.commonPool, for all requests
- * of the maximum degree of parallelism. If pools of different size are
- * requested, we create new pool instances of FixedThreadPool.
+ * This common thread pool provides an abstraction to obtain a shared thread pool.
+ *
+ * If the number of logical cores is specified a ForkJoinPool.commonPool is returned on all requests.
+ *
+ * If pools of different size are requested, we create new pool instances of FixedThreadPool, Unless we currently are on
+ * the main thread, Then we return a shared instance of the first requested number of cores.
+ *
+ * Alternatively the class also contain a dynamic threadPool, that is intended for asynchronous long running tasks with
+ * low compute overhead, such as broadcast and collect from federated workers.
*/
-public class CommonThreadPool implements ExecutorService
-{
- //shared thread pool used system-wide, potentially by concurrent parfor workers
- //we use the ForkJoinPool.commonPool() to avoid explicit cleanup, including
- //unnecessary initialization (e.g., problematic in jmlc) and because this commonPool
- //resulted in better performance than a dedicated fixed thread pool.
+public class CommonThreadPool implements ExecutorService {
+ /** Log object */
+ protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
+
+ /** The number of threads of the machine */
private static final int size = InfrastructureAnalyzer.getLocalParallelism();
+ /**
+ * Shared thread pool used system-wide, potentially by concurrent parfor workers
+ *
+ * we use the ForkJoinPool.commonPool() to avoid explicit cleanup, including unnecessary initialization (e.g.,
+ * problematic in jmlc) and because this commonPool resulted in better performance than a dedicated fixed thread
+ * pool.
+ */
private static final ExecutorService shared = ForkJoinPool.commonPool();
+ /** A secondary thread local executor that use a custom number of threads */
+ private static CommonThreadPool shared2 = null;
+ /** The number of threads used in the custom secondary executor */
+ private static int shared2K = -1;
+ /** Dynamic thread pool, that dynamically allocate threads as tasks come in. */
+ private static ExecutorService asyncPool = null;
+ /** This common thread pool */
private final ExecutorService _pool;
- public static ExecutorService triggerRemoteOPsPool = null;
+ /**
+ * Constructor of the threadPool.
+ * This is intended not to be used except for tests.
+ * Please use the static constructors.
+ *
+ * @param pool The thread pool instance to use.
+ */
public CommonThreadPool(ExecutorService pool) {
- _pool = pool;
- }
-
+ this._pool = pool;
+ }
+
+ /**
+ * Get the shared Executor thread pool, that have the number of threads of the host system
+ *
+ * @return An ExecutorService
+ */
+ public static ExecutorService get() {
+ return shared;
+ }
+
+ /**
+ * Get a Executor thread pool, that have the number of threads specified in k.
+ *
+ * The thread pool can be reused by other processes in the same host thread requesting another pool of the same
+ * number of threads. The executor that is guaranteed ThreadLocal except if it is number of host logical cores.
+ *
+ *
+ * @param k The number of threads wanted
+ * @return The executor with specified parallelism
+ */
public static ExecutorService get(int k) {
- return new CommonThreadPool( (size==k) ?
- shared : Executors.newFixedThreadPool(k));
- }
-
+ if(size == k)
+ return shared;
+ else if(Thread.currentThread().getName().equals("main")) {
+ if(shared2 != null && shared2K == k)
+ return shared2;
+ else if(shared2 == null) {
+ shared2 = new CommonThreadPool(Executors.newFixedThreadPool(k));
+ shared2K = k;
+ return shared2;
+ }
+ else
+ return new CommonThreadPool(Executors.newFixedThreadPool(k));
+ }
+ else
+ return new CommonThreadPool(Executors.newFixedThreadPool(k));
+ }
+
+ /**
+ * Get if there is a current thread pool that have the given parallelism locally.
+ *
+ * @param k the parallelism
+ * @return If we have a cached thread pool.
+ */
+ public static boolean isSharedTPThreads(int k) {
+ return size == k || shared2K == k || shared2K == -1;
+ }
+
+ /**
+ * Invoke the collection of tasks and shutdown the pool upon job termination.
+ *
+ * @param <T> The type of class to return from the job
+ * @param pool The pool to execute in
+ * @param tasks The tasks to execute
+ */
public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? extends Callable<T>> tasks) {
try {
- //execute tasks
+ // execute tasks
List<Future<T>> ret = pool.invokeAll(tasks);
- //check for errors and exceptions
- for( Future<T> r : ret )
+ // check for errors and exceptions
+ for(Future<T> r : ret)
r.get();
- //shutdown pool
+ // shutdown pool
pool.shutdown();
}
catch(Exception ex) {
@@ -75,28 +149,51 @@ public class CommonThreadPool implements ExecutorService
}
}
- public static void shutdownShared() {
- shared.shutdownNow();
+ /**
+ * Get a dynamic thread pool that allocate threads as the requests are made. This pool is intended for async remote
+ * calls that does not depend on local compute.
+ *
+ * @return A dynamic thread pool.
+ */
+ public static ExecutorService getDynamicPool() {
+ if(asyncPool != null)
+ return asyncPool;
+ else {
+ asyncPool = Executors.newCachedThreadPool();
+ return asyncPool;
+ }
}
- public static void shutdownAsyncRDDPool() {
- if (triggerRemoteOPsPool != null) {
- //shutdown prefetch/broadcast thread pool
- triggerRemoteOPsPool.shutdown();
- triggerRemoteOPsPool = null;
+ /**
+ * Shutdown the cached thread pools.
+ */
+ public static void shutdownAsyncPools() {
+ if(asyncPool != null) {
+ // shutdown prefetch/broadcast thread pool
+ asyncPool.shutdown();
+ asyncPool = null;
+ }
+ if(shared2 != null) {
+ // shutdown shared custom thread count pool
+ shared2.shutdown();
+ shared2 = null;
+ shared2K = -1;
}
}
+ public final boolean isCached() {
+ return _pool.equals(shared) || this.equals(shared2);
+ }
+
@Override
public void shutdown() {
- if( _pool != shared )
+ if(!isCached())
_pool.shutdown();
}
@Override
public List<Runnable> shutdownNow() {
- return ( _pool != shared ) ?
- _pool.shutdownNow() : null;
+ return !isCached() ? _pool.shutdownNow() : null;
}
@Override
@@ -106,10 +203,10 @@ public class CommonThreadPool implements ExecutorService
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
- throws InterruptedException {
+ throws InterruptedException {
return _pool.invokeAll(tasks, timeout, unit);
}
-
+
@Override
public void execute(Runnable command) {
_pool.execute(command);
@@ -130,31 +227,29 @@ public class CommonThreadPool implements ExecutorService
return _pool.submit(task);
}
-
- //unnecessary methods required for API compliance
@Override
public boolean isShutdown() {
- throw new NotImplementedException();
+ return isCached() || _pool.isShutdown();
}
@Override
public boolean isTerminated() {
- throw new NotImplementedException();
+ return isCached() || _pool.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
- throw new NotImplementedException();
+ return isCached() || _pool.awaitTermination(timeout, unit);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
- throw new NotImplementedException();
+ return _pool.invokeAny(tasks);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
- throws InterruptedException, ExecutionException, TimeoutException {
- throw new NotImplementedException();
+ throws InterruptedException, ExecutionException, TimeoutException {
+ return _pool.invokeAny(tasks);
}
}
diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
new file mode 100644
index 0000000000..ca79e8800b
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.misc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.junit.Test;
+
+public class ThreadPool {
+ protected static final Log LOG = LogFactory.getLog(ThreadPool.class.getName());
+
+ @Test
+ public void testGetTheSame() {
+ CommonThreadPool.shutdownAsyncPools();
+ ExecutorService x = CommonThreadPool.get();
+ ExecutorService y = CommonThreadPool.get();
+ x.shutdown();
+ y.shutdown();
+
+ assertEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ CommonThreadPool.shutdownAsyncPools();
+
+ }
+
+ @Test
+ public void testGetSameCustomThreadCount() {
+ CommonThreadPool.shutdownAsyncPools();
+ // choosing 7 because the machine is unlikely to have 7 logical cores.
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ ExecutorService x = CommonThreadPool.get(7);
+ ExecutorService y = CommonThreadPool.get(7);
+ x.shutdown();
+ y.shutdown();
+
+ Thread.currentThread().setName(name);
+ assertEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ CommonThreadPool.shutdownAsyncPools();
+
+ }
+
+ @Test
+ public void testGetSameCustomThreadCountExecute() throws InterruptedException, ExecutionException {
+ // choosing 7 because the machine is unlikely to have 7 logical cores.
+ CommonThreadPool.shutdownAsyncPools();
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ ExecutorService x = CommonThreadPool.get(7);
+ ExecutorService y = CommonThreadPool.get(7);
+ assertEquals(x, y);
+ int v = x.submit(() -> 5).get();
+ x.shutdown();
+ int v2 = y.submit(() -> 5).get();
+ y.shutdown();
+
+ Thread.currentThread().setName(name);
+ assertEquals(x, y);
+ assertEquals(v, v2);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void testGetSameCustomThreadCountExecuteV2() throws InterruptedException, ExecutionException {
+ // choosing 7 because the machine is unlikely to have 7 logical cores.
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ ExecutorService x = CommonThreadPool.get(7);
+ ExecutorService y = CommonThreadPool.get(7);
+ assertEquals(x, y);
+ int v = x.submit(() -> 5).get();
+ int v2 = y.submit(() -> 5).get();
+ x.shutdown();
+ y.shutdown();
+
+ Thread.currentThread().setName(name);
+ assertEquals(x, y);
+ assertEquals(v, v2);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void testGetSameCustomThreadCountExecuteV3() throws InterruptedException, ExecutionException {
+ // choosing 7 because the machine is unlikely to have 7 logical cores.
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ ExecutorService x = CommonThreadPool.get(7);
+ ExecutorService y = CommonThreadPool.get(7);
+ assertEquals(x, y);
+ x.shutdown();
+ y.shutdown();
+ int v = x.submit(() -> 5).get();
+ int v2 = y.submit(() -> 5).get();
+
+ Thread.currentThread().setName(name);
+ assertEquals(x, y);
+ assertEquals(v, v2);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void testGetSameCustomThreadCountExecuteV4() throws InterruptedException, ExecutionException {
+ // choosing 7 because the machine is unlikely to have 7 logical cores.
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ CommonThreadPool.shutdownAsyncPools();
+ ExecutorService x = CommonThreadPool.get(5);
+ ExecutorService y = CommonThreadPool.get(7);
+ assertNotEquals(x, y);
+ x.shutdown();
+ int v = x.submit(() -> 5).get();
+ int v2 = y.submit(() -> 5).get();
+ y.shutdown();
+
+ Thread.currentThread().setName(name);
+ assertEquals(v, v2);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void testFromOtherThread() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ ExecutorService x = CommonThreadPool.get(5);
+ Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(5));
+ ExecutorService y = a.get();
+ assertNotEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void testFromOtherThreadInfrastructureParallelism() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ final int k = InfrastructureAnalyzer.getLocalParallelism();
+ ExecutorService x = CommonThreadPool.get(k);
+ Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(k));
+ ExecutorService y = a.get();
+ assertEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void dynamic() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ final int k = InfrastructureAnalyzer.getLocalParallelism();
+ ExecutorService x = CommonThreadPool.getDynamicPool();
+ Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(k));
+ ExecutorService y = a.get();
+ assertNotEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void dynamicSame() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ ExecutorService x = CommonThreadPool.getDynamicPool();
+ ExecutorService y = CommonThreadPool.getDynamicPool();
+ assertEquals(x, y);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void isSharedTPThreads() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ for(int i = 0; i < 10; i++)
+ assertTrue(CommonThreadPool.isSharedTPThreads(i));
+
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void isSharedTPThreadsCommonSize() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ assertTrue(CommonThreadPool.isSharedTPThreads(InfrastructureAnalyzer.getLocalParallelism()));
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void isSharedTPThreadsFalse() throws InterruptedException, ExecutionException {
+ CommonThreadPool.shutdownAsyncPools();
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ CommonThreadPool.get(18);
+ for(int i = 1; i < 10; i++)
+ if(i != InfrastructureAnalyzer.getLocalParallelism())
+ assertFalse("" + i, CommonThreadPool.isSharedTPThreads(i));
+ assertTrue(CommonThreadPool.isSharedTPThreads(18));
+ assertFalse(CommonThreadPool.isSharedTPThreads(19));
+
+ Thread.currentThread().setName(name);
+ CommonThreadPool.shutdownAsyncPools();
+ }
+
+ @Test
+ public void justWorks() throws InterruptedException, ExecutionException {
+
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ for(int j = 0; j < 2; j++) {
+ for(int i = 4; i < 17; i++) {
+ ExecutorService p = CommonThreadPool.get(i);
+ final Integer l = i;
+ assertEquals(l, p.submit(() -> l).get());
+ p.shutdown();
+ }
+ }
+ Thread.currentThread().setName(name);
+ }
+
+ @Test
+ public void justWorksNotMain() throws InterruptedException, ExecutionException {
+
+ for(int j = 0; j < 2; j++) {
+
+ for(int i = 4; i < 10; i++) {
+ ExecutorService p = CommonThreadPool.get(i);
+ final Integer l = i;
+ assertEquals(l, p.submit(() -> l).get());
+ p.shutdown();
+
+ }
+ }
+ }
+
+ @Test
+ public void justWorksShutdownNow() throws InterruptedException, ExecutionException {
+
+ String name = Thread.currentThread().getName();
+ Thread.currentThread().setName("main");
+ for(int j = 0; j < 2; j++) {
+
+ for(int i = 4; i < 16; i++) {
+ ExecutorService p = CommonThreadPool.get(i);
+ final Integer l = i;
+ assertEquals(l, p.submit(() -> l).get());
+ p.shutdownNow();
+
+ }
+ }
+ Thread.currentThread().setName(name);
+ }
+
+ @Test
+ public void justWorksShutdownNowNotMain() throws InterruptedException, ExecutionException {
+
+ for(int j = 0; j < 2; j++) {
+
+ for(int i = 4; i < 16; i++) {
+ ExecutorService p = CommonThreadPool.get(i);
+ final Integer l = i;
+ assertEquals(l, p.submit(() -> l).get());
+ p.shutdownNow();
+
+ }
+ }
+ }
+
+ @Test
+ public void mock1() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException,
+ InterruptedException, ExecutionException, TimeoutException {
+
+ ExecutorService p = mock(ExecutorService.class);
+ ExecutorService c = new CommonThreadPool(p);
+
+ when(p.shutdownNow()).thenReturn(null);
+ assertNull(c.shutdownNow());
+
+ Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+ when(p.invokeAll(cc)).thenReturn(null);
+ assertNull(c.invokeAll(cc));
+ when(p.invokeAll(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+ assertNull(c.invokeAll(cc, 1, TimeUnit.DAYS));
+ doNothing().when(p).execute((Runnable) null);
+ c.execute((Runnable) null);
+
+ when(p.submit((Callable<Integer>) null)).thenReturn(null);
+ assertNull(c.submit((Callable<Integer>) null));
+
+ when(p.submit((Runnable) null, null)).thenReturn(null);
+ assertNull(c.submit((Runnable) null, null));
+ // when(tp.pool()).thenReturn(p);
+
+ when(p.submit((Runnable) null)).thenReturn(null);
+ assertNull(c.submit((Runnable) null));
+
+ when(p.isShutdown()).thenReturn(false);
+ assertFalse(c.isShutdown());
+ when(p.isShutdown()).thenReturn(true);
+ assertTrue(c.isShutdown());
+
+ when(p.isTerminated()).thenReturn(false);
+ assertFalse(c.isTerminated());
+ when(p.isTerminated()).thenReturn(true);
+ assertTrue(c.isTerminated());
+
+ when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(false);
+ assertFalse(c.awaitTermination(10, TimeUnit.DAYS));
+ when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(true);
+ assertTrue(c.awaitTermination(10, TimeUnit.DAYS));
+
+ when(p.invokeAny(cc)).thenReturn(null);
+ assertNull(c.invokeAny(cc));
+ when(p.invokeAny(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+ assertNull(c.invokeAny(cc, 1, TimeUnit.DAYS));
+ doNothing().when(p).execute((Runnable) null);
+ c.execute((Runnable) null);
+
+ }
+
+ @Test
+ public void mock2() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException,
+ InterruptedException, ExecutionException, TimeoutException {
+
+ CommonThreadPool p = mock(CommonThreadPool.class);
+ when(p.isShutdown()).thenCallRealMethod();
+ when(p.isTerminated()).thenCallRealMethod();
+ when(p.awaitTermination(10, TimeUnit.DAYS)).thenCallRealMethod();
+ when(p.isCached()).thenReturn(true);
+ assertTrue(p.isShutdown());
+ assertTrue(p.isTerminated());
+ assertTrue(p.awaitTermination(10, TimeUnit.DAYS));
+ }
+
+ @Test
+ public void coverEdge() {
+ ExecutorService a = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+ assertTrue(new CommonThreadPool(a).isCached());
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void invokeAndShutdownException() throws InterruptedException {
+ ExecutorService p = mock(ExecutorService.class);
+ ExecutorService c = new CommonThreadPool(p);
+
+ when(p.invokeAll(null)).thenThrow(new RuntimeException("Test"));
+
+ CommonThreadPool.invokeAndShutdown(p, null);
+
+ }
+
+ @Test
+ public void invokeAndShutdown() throws InterruptedException {
+
+ ExecutorService p = mock(ExecutorService.class);
+ ExecutorService c = new CommonThreadPool(p);
+
+ Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+ when(p.invokeAll(cc)).thenReturn(new ArrayList<Future<Integer>>());
+
+ CommonThreadPool.invokeAndShutdown(c, null);
+
+ }
+
+ @Test
+ @SuppressWarnings("all")
+ public void invokeAndShutdownV2() throws InterruptedException{
+
+ ExecutorService p = mock(ExecutorService.class);
+ ExecutorService c = new CommonThreadPool(p);
+
+ Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+ List<Future<Integer>> f = new ArrayList<Future<Integer>>();
+ f.add(mock(Future.class));
+ when(p.invokeAll(cc)).thenReturn(f );
+
+ CommonThreadPool.invokeAndShutdown(c, null);
+
+ }
+}