You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2023/08/08 12:51:14 UTC

[systemds] branch main updated: [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 68c2c17db5 [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
68c2c17db5 is described below

commit 68c2c17db56c8c9ea5b1985bbf3b525d4ac8d022
Author: baunsgaard <ba...@tu-berlin.de>
AuthorDate: Mon Aug 7 13:40:42 2023 +0200

    [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
    
    This commit allows reuse of custom size thread pools, but only for the
    main thread. Only allowing the main thread to reuse a pool avoids
    problems with parfor spawning threads that use the same shared pool.
    I tried using ThreadLocal to solve this problem initially, but this
    did not work with our testing framework while it did work in practice.
    This implementation is a compromise to work with the test framework,
    while not introducing to much code.
    
    Closes #1873
---
 src/main/java/org/apache/sysds/api/DMLScript.java  |   2 +-
 .../runtime/controlprogram/ParForProgramBlock.java |  23 +-
 .../instructions/cp/BroadcastCPInstruction.java    |   7 +-
 .../instructions/cp/PrefetchCPInstruction.java     |   6 +-
 .../spark/AggregateUnarySPInstruction.java         |  12 +-
 .../spark/CheckpointSPInstruction.java             |   6 +-
 .../instructions/spark/CpmmSPInstruction.java      |  16 +-
 .../instructions/spark/MapmmSPInstruction.java     |   7 +-
 .../instructions/spark/TsmmSPInstruction.java      |  12 +-
 .../instructions/spark/ZipmmSPInstruction.java     |  14 +-
 .../runtime/lineage/LineageSparkCacheEviction.java |  13 +-
 .../sysds/runtime/util/CommonThreadPool.java       | 181 ++++++---
 .../sysds/test/component/misc/ThreadPool.java      | 408 +++++++++++++++++++++
 13 files changed, 593 insertions(+), 114 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index ff70b330ae..ddc5ee2517 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -573,7 +573,7 @@ public class DMLScript
 		FederatedData.clearFederatedWorkers();
 		
 		//0) shutdown prefetch/broadcast thread pool if necessary
-		CommonThreadPool.shutdownAsyncRDDPool();
+		CommonThreadPool.shutdownAsyncPools();
 
 		//1) cleanup scratch space (everything for current uuid)
 		//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 2fc12c4c26..94bbaf2545 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.runtime.controlprogram;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
@@ -89,6 +91,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.stats.ParForStatistics;
@@ -118,8 +121,8 @@ import java.util.stream.Stream;
  * TODO: papply(A,1:2,FUN) language construct (compiled to ParFOR) via DML function repository =&gt; modules OK, but second-order functions required
  *
  */
-public class ParForProgramBlock extends ForProgramBlock 
-{	
+public class ParForProgramBlock extends ForProgramBlock {	
+	protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
 	// execution modes
 	public enum PExecMode {
 		LOCAL,          //local (master) multi-core execution mode
@@ -759,7 +762,7 @@ public class ParForProgramBlock extends ForProgramBlock
 			LocalTaskQueue<Task> queue = new LocalTaskQueue<>();
 			Thread[] threads         = new Thread[_numThreads];
 			LocalParWorker[] workers = new LocalParWorker[_numThreads];
-			IntStream.range(0, _numThreads).parallel().forEach(i -> {
+			IntStream.range(0, _numThreads).forEach(i -> {
 				workers[i] = createParallelWorker( _pwIDs[i], queue, ec, i);
 				threads[i] = new Thread( workers[i] );
 				threads[i].setPriority(Thread.MAX_PRIORITY);
@@ -1430,9 +1433,14 @@ public class ParForProgramBlock extends ForProgramBlock
 		}
 	}
 
-	private void consolidateAndCheckResults(ExecutionContext ec, long expIters, long expTasks, long numIters, long numTasks, LocalVariableMap [] results) 
-	{
+	private void consolidateAndCheckResults(ExecutionContext ec, final long expIters, final long expTasks,
+		final long numIters, final long numTasks, LocalVariableMap[] results) {
 		Timing time = new Timing(true);
+
+		//check expected counters
+		if( numTasks != expTasks || numIters !=expIters ) //consistency check
+			throw new DMLRuntimeException("PARFOR: Number of executed tasks does not match the number of created tasks: tasks "+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
+	
 		
 		//result merge
 		if( checkParallelRemoteResultMerge() )
@@ -1531,10 +1539,7 @@ public class ParForProgramBlock extends ForProgramBlock
 		if( CREATE_UNSCOPED_RESULTVARS && sb != null && ec.getVariables() != null ) //sb might be null for nested parallelism
 			createEmptyUnscopedVariables( ec.getVariables(), sb );
 		
-		//check expected counters
-		if( numTasks != expTasks || numIters !=expIters ) //consistency check
-			throw new DMLRuntimeException("PARFOR: Number of executed tasks does not match the number of created tasks: tasks "+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
-	
+			
 		if( DMLScript.STATISTICS )
 			ParForStatistics.incrementMergeTime((long) time.stop());
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index aa0be7daec..58378bf7e4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.concurrent.Executors;
-
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -43,9 +41,6 @@ public class BroadcastCPInstruction extends UnaryCPInstruction {
 	@Override
 	public void processInstruction(ExecutionContext ec) {
 		ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-
-		if (CommonThreadPool.triggerRemoteOPsPool == null)
-			CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
-		CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
+		CommonThreadPool.getDynamicPool().submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 96e4b7afe2..233509a5b8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.concurrent.Executors;
-
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
@@ -53,10 +51,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction {
 		// If the next instruction which takes this output as an input comes before
 		// the prefetch thread triggers, that instruction will start the operations.
 		// In that case this Prefetch instruction will act like a NOOP. 
-		if (CommonThreadPool.triggerRemoteOPsPool == null)
-			CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 		// Saving the lineage item inside the matrix object will replace the pre-attached
 		// lineage item (e.g. mapmm). Hence, passing separately.
-		CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerPrefetchTask(ec.getMatrixObject(output), li));
+		CommonThreadPool.getDynamicPool().submit(new TriggerPrefetchTask(ec.getMatrixObject(output), li));
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 50816aefe4..32b80a2360 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -49,11 +52,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class AggregateUnarySPInstruction extends UnarySPInstruction {
 	private SparkAggType _aggtype = null;
@@ -115,10 +115,8 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction {
 				//Trigger the chain of Spark operations and maintain a future to the result
 				//TODO: Make memory for the future matrix block
 				try {
-					if(CommonThreadPool.triggerRemoteOPsPool == null)
-						CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 					RDDAggregateTask task = new RDDAggregateTask(_optr, _aop, in, mc);
-					Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+					Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 					LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
 					sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
 				}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 2be663bdbc..4ee56a7eca 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -51,8 +51,6 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.Statistics;
 
-import java.util.concurrent.Executors;
-
 public class CheckpointSPInstruction extends UnarySPInstruction {
 	// default storage level
 	private StorageLevel _level = null;
@@ -86,9 +84,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction {
 			// TODO: Synchronize. Avoid double execution
 			ec.setVariable(output.getName(), ec.getCacheableData(input1));
 
-			if (CommonThreadPool.triggerRemoteOPsPool == null)
-				CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
-			CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerCheckpointTask(ec.getMatrixObject(output)));
+			CommonThreadPool.getDynamicPool().submit(new TriggerCheckpointTask(ec.getMatrixObject(output)));
 			return;
 		}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 6425613583..602d74a275 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -47,11 +50,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 /**
  * Cpmm: cross-product matrix multiplication operation (distributed matrix multiply
@@ -112,10 +112,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
 		{
 			if (ConfigurationManager.isMaxPrallelizeEnabled()) {
 				try {
-					if(CommonThreadPool.triggerRemoteOPsPool == null)
-						CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 					CpmmMatrixVectorTask task = new CpmmMatrixVectorTask(in1, in2);
-					Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+					Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 					LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
 					sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
 				}
@@ -147,10 +145,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
 			{
 				if (ConfigurationManager.isMaxPrallelizeEnabled()) {
 					try {
-						if(CommonThreadPool.triggerRemoteOPsPool == null)
-							CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 						CpmmMatrixMatrixTask task = new CpmmMatrixMatrixTask(in1, in2, numPartJoin);
-						Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+						Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 						sec.setMatrixOutput(output.getName(), future_out);
 					}
 					catch(Exception ex) { throw new DMLRuntimeException(ex); }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index b0285b1bba..080de52d23 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.spark;
 
 import java.util.Iterator;
 import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
 
@@ -59,8 +58,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
-
 import org.apache.sysds.runtime.util.CommonThreadPool;
+
 import scala.Tuple2;
 
 public class MapmmSPInstruction extends AggregateBinarySPInstruction {
@@ -144,10 +143,8 @@ public class MapmmSPInstruction extends AggregateBinarySPInstruction {
 		{
 			if (ConfigurationManager.isMaxPrallelizeEnabled()) {
 				try {
-					if(CommonThreadPool.triggerRemoteOPsPool == null)
-						CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 					RDDMapmmTask task = new  RDDMapmmTask(in1, in2, type);
-					Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+					Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 					LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
 					sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
 				}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index acba784bf2..dd6ddb526d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -20,6 +20,9 @@
 package org.apache.sysds.runtime.instructions.spark;
 
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -37,11 +40,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class TsmmSPInstruction extends UnarySPInstruction {
 	private MMTSJType _type = null;
@@ -72,10 +72,8 @@ public class TsmmSPInstruction extends UnarySPInstruction {
 
 		if (ConfigurationManager.isMaxPrallelizeEnabled()) {
 			try {
-				if (CommonThreadPool.triggerRemoteOPsPool == null)
-					CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
 				TsmmTask task = new TsmmTask(in, _type);
-				Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+				Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 				LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
 				sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
index de7922e25a..18d88178a3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -42,11 +45,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class ZipmmSPInstruction extends BinarySPInstruction {
 	// internal flag to apply left-transpose rewrite or not
@@ -86,10 +86,8 @@ public class ZipmmSPInstruction extends BinarySPInstruction {
 
 		if (ConfigurationManager.isMaxPrallelizeEnabled()) {
 			try {
-				if (CommonThreadPool.triggerRemoteOPsPool == null)
-					CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
-				ZipmmTask task = new ZipmmTask(in1, in2, _tRewrite);
-				Future<MatrixBlock> future_out = CommonThreadPool.triggerRemoteOPsPool.submit(task);
+					ZipmmTask task = new ZipmmTask(in1, in2, _tRewrite);
+				Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
 				LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : null;
 				sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
index 65648508f8..84bb8598c0 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
@@ -19,6 +19,10 @@
 
 package org.apache.sysds.runtime.lineage;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.TreeSet;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -27,11 +31,6 @@ import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.TreeSet;
-import java.util.concurrent.Executors;
-
 public class LineageSparkCacheEviction
 {
 	private static long SPARK_STORAGE_LIMIT = 0; //60% (upper limit of Spark unified memory)
@@ -212,9 +211,7 @@ public class LineageSparkCacheEviction
 		int localHitCount = RDDHitCountLocal.get(e._key);
 		if (localHitCount > 3) {
 			RDDHitCountLocal.remove(e._key);
-			if (CommonThreadPool.triggerRemoteOPsPool == null)
-				CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
-			CommonThreadPool.triggerRemoteOPsPool.submit(new TriggerRemoteTask(e.getRDDObject().getRDD()));
+			CommonThreadPool.getDynamicPool().submit(new TriggerRemoteTask(e.getRDDObject().getRDD()));
 		}
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index f96c4cc4af..cc6483d258 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -30,44 +30,118 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
-import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 
 /**
- * This common thread pool provides an abstraction to obtain a shared
- * thread pool, specifically the ForkJoinPool.commonPool, for all requests
- * of the maximum degree of parallelism. If pools of different size are
- * requested, we create new pool instances of FixedThreadPool.
+ * This common thread pool provides an abstraction to obtain a shared thread pool.
+ * 
+ * If the number of logical cores is specified a ForkJoinPool.commonPool is returned on all requests.
+ * 
+ * If pools of different size are requested, we create new pool instances of FixedThreadPool, Unless we currently are on
+ * the main thread, Then we return a shared instance of the first requested number of cores.
+ * 
+ * Alternatively the class also contain a dynamic threadPool, that is intended for asynchronous long running tasks with
+ * low compute overhead, such as broadcast and collect from federated workers.
  */
-public class CommonThreadPool implements ExecutorService
-{
-	//shared thread pool used system-wide, potentially by concurrent parfor workers
-	//we use the ForkJoinPool.commonPool() to avoid explicit cleanup, including
-	//unnecessary initialization (e.g., problematic in jmlc) and because this commonPool
-	//resulted in better performance than a dedicated fixed thread pool.
+public class CommonThreadPool implements ExecutorService {
+	/** Log object */
+	protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
+
+	/** The number of threads of the machine */
 	private static final int size = InfrastructureAnalyzer.getLocalParallelism();
+	/**
+	 * Shared thread pool used system-wide, potentially by concurrent parfor workers
+	 * 
+	 * we use the ForkJoinPool.commonPool() to avoid explicit cleanup, including unnecessary initialization (e.g.,
+	 * problematic in jmlc) and because this commonPool resulted in better performance than a dedicated fixed thread
+	 * pool.
+	 */
 	private static final ExecutorService shared = ForkJoinPool.commonPool();
+	/** A secondary thread local executor that use a custom number of threads */
+	private static CommonThreadPool shared2 = null;
+	/** The number of threads used in the custom secondary executor */
+	private static int shared2K = -1;
+	/** Dynamic thread pool, that dynamically allocate threads as tasks come in. */
+	private static ExecutorService asyncPool = null;
+	/** This common thread pool */
 	private final ExecutorService _pool;
-	public static ExecutorService triggerRemoteOPsPool = null;
 
+	/**
+	 * Constructor of the threadPool.
+	 * This is intended not to be used except for tests.
+	 * Please use the static constructors.
+	 * 
+	 * @param pool The thread pool instance to use.
+	 */
 	public CommonThreadPool(ExecutorService pool) {
-		_pool = pool;
-	}
-
+		this._pool = pool;
+	}
+
+	/**
+	 * Get the shared Executor thread pool, that have the number of threads of the host system
+	 * 
+	 * @return An ExecutorService
+	 */
+	public static ExecutorService get() {
+		return shared;
+	}
+
+	/**
+	 * Get a Executor thread pool, that have the number of threads specified in k.
+	 * 
+	 * The thread pool can be reused by other processes in the same host thread requesting another pool of the same
+	 * number of threads. The executor that is guaranteed ThreadLocal except if it is number of host logical cores.
+	 * 
+	 * 
+	 * @param k The number of threads wanted
+	 * @return The executor with specified parallelism
+	 */
 	public static ExecutorService get(int k) {
-		return new CommonThreadPool( (size==k) ?
-			shared : Executors.newFixedThreadPool(k));
-	}
-	
+		if(size == k)
+			return shared;
+		else if(Thread.currentThread().getName().equals("main")) {
+			if(shared2 != null && shared2K == k)
+				return shared2;
+			else if(shared2 == null) {
+				shared2 = new CommonThreadPool(Executors.newFixedThreadPool(k));
+				shared2K = k;
+				return shared2;
+			}
+			else
+				return new CommonThreadPool(Executors.newFixedThreadPool(k));
+		}
+		else
+			return new CommonThreadPool(Executors.newFixedThreadPool(k));
+	}
+
+	/**
+	 * Get if there is a current thread pool that have the given parallelism locally.
+	 * 
+	 * @param k the parallelism
+	 * @return If we have a cached thread pool.
+	 */
+	public static boolean isSharedTPThreads(int k) {
+		return size == k || shared2K == k || shared2K == -1;
+	}
+
+	/**
+	 * Invoke the collection of tasks and shutdown the pool upon job termination.
+	 * 
+	 * @param <T>   The type of class to return from the job
+	 * @param pool  The pool to execute in
+	 * @param tasks The tasks to execute
+	 */
 	public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? extends Callable<T>> tasks) {
 		try {
-			//execute tasks
+			// execute tasks
 			List<Future<T>> ret = pool.invokeAll(tasks);
-			//check for errors and exceptions
-			for( Future<T> r : ret )
+			// check for errors and exceptions
+			for(Future<T> r : ret)
 				r.get();
-			//shutdown pool
+			// shutdown pool
 			pool.shutdown();
 		}
 		catch(Exception ex) {
@@ -75,28 +149,51 @@ public class CommonThreadPool implements ExecutorService
 		}
 	}
 
-	public static void shutdownShared() {
-		shared.shutdownNow();
+	/**
+	 * Get a dynamic thread pool that allocate threads as the requests are made. This pool is intended for async remote
+	 * calls that does not depend on local compute.
+	 * 
+	 * @return A dynamic thread pool.
+	 */
+	public static ExecutorService getDynamicPool() {
+		if(asyncPool != null)
+			return asyncPool;
+		else {
+			asyncPool = Executors.newCachedThreadPool();
+			return asyncPool;
+		}
 	}
 
-	public static void shutdownAsyncRDDPool() {
-		if (triggerRemoteOPsPool != null) {
-			//shutdown prefetch/broadcast thread pool
-			triggerRemoteOPsPool.shutdown();
-			triggerRemoteOPsPool = null;
+	/**
+	 * Shutdown the cached thread pools.
+	 */
+	public static void shutdownAsyncPools() {
+		if(asyncPool != null) {
+			// shutdown prefetch/broadcast thread pool
+			asyncPool.shutdown();
+			asyncPool = null;
+		}
+		if(shared2 != null) {
+			// shutdown shared custom thread count pool
+			shared2.shutdown();
+			shared2 = null;
+			shared2K = -1;
 		}
 	}
 
+	public final boolean isCached() {
+		return _pool.equals(shared) || this.equals(shared2);
+	}
+
 	@Override
 	public void shutdown() {
-		if( _pool != shared )
+		if(!isCached())
 			_pool.shutdown();
 	}
 
 	@Override
 	public List<Runnable> shutdownNow() {
-		return ( _pool != shared ) ?
-			_pool.shutdownNow() : null;
+		return !isCached() ? _pool.shutdownNow() : null;
 	}
 
 	@Override
@@ -106,10 +203,10 @@ public class CommonThreadPool implements ExecutorService
 
 	@Override
 	public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
-			throws InterruptedException {
+		throws InterruptedException {
 		return _pool.invokeAll(tasks, timeout, unit);
 	}
-	
+
 	@Override
 	public void execute(Runnable command) {
 		_pool.execute(command);
@@ -130,31 +227,29 @@ public class CommonThreadPool implements ExecutorService
 		return _pool.submit(task);
 	}
 
-	
-	//unnecessary methods required for API compliance
 	@Override
 	public boolean isShutdown() {
-		throw new NotImplementedException();
+		return isCached() || _pool.isShutdown();
 	}
 
 	@Override
 	public boolean isTerminated() {
-		throw new NotImplementedException();
+		return isCached() || _pool.isTerminated();
 	}
 
 	@Override
 	public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
-		throw new NotImplementedException();
+		return isCached() || _pool.awaitTermination(timeout, unit);
 	}
 
 	@Override
 	public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
-		throw new NotImplementedException();
+		return _pool.invokeAny(tasks);
 	}
 
 	@Override
 	public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
-			throws InterruptedException, ExecutionException, TimeoutException {
-		throw new NotImplementedException();
+		throws InterruptedException, ExecutionException, TimeoutException {
+		return _pool.invokeAny(tasks);
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
new file mode 100644
index 0000000000..ca79e8800b
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.misc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.junit.Test;
+
+public class ThreadPool {
+	protected static final Log LOG = LogFactory.getLog(ThreadPool.class.getName());
+
+	@Test
+	public void testGetTheSame() {
+		CommonThreadPool.shutdownAsyncPools();
+		ExecutorService x = CommonThreadPool.get();
+		ExecutorService y = CommonThreadPool.get();
+		x.shutdown();
+		y.shutdown();
+
+		assertEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+		CommonThreadPool.shutdownAsyncPools();
+
+	}
+
+	@Test
+	public void testGetSameCustomThreadCount() {
+		CommonThreadPool.shutdownAsyncPools();
+		// choosing 7 because the machine is unlikely to have 7 logical cores.
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		ExecutorService x = CommonThreadPool.get(7);
+		ExecutorService y = CommonThreadPool.get(7);
+		x.shutdown();
+		y.shutdown();
+
+		Thread.currentThread().setName(name);
+		assertEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+		CommonThreadPool.shutdownAsyncPools();
+
+	}
+
+	@Test
+	public void testGetSameCustomThreadCountExecute() throws InterruptedException, ExecutionException {
+		// choosing 7 because the machine is unlikely to have 7 logical cores.
+		CommonThreadPool.shutdownAsyncPools();
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		ExecutorService x = CommonThreadPool.get(7);
+		ExecutorService y = CommonThreadPool.get(7);
+		assertEquals(x, y);
+		int v = x.submit(() -> 5).get();
+		x.shutdown();
+		int v2 = y.submit(() -> 5).get();
+		y.shutdown();
+
+		Thread.currentThread().setName(name);
+		assertEquals(x, y);
+		assertEquals(v, v2);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void testGetSameCustomThreadCountExecuteV2() throws InterruptedException, ExecutionException {
+		// choosing 7 because the machine is unlikely to have 7 logical cores.
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		ExecutorService x = CommonThreadPool.get(7);
+		ExecutorService y = CommonThreadPool.get(7);
+		assertEquals(x, y);
+		int v = x.submit(() -> 5).get();
+		int v2 = y.submit(() -> 5).get();
+		x.shutdown();
+		y.shutdown();
+
+		Thread.currentThread().setName(name);
+		assertEquals(x, y);
+		assertEquals(v, v2);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void testGetSameCustomThreadCountExecuteV3() throws InterruptedException, ExecutionException {
+		// choosing 7 because the machine is unlikely to have 7 logical cores.
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		ExecutorService x = CommonThreadPool.get(7);
+		ExecutorService y = CommonThreadPool.get(7);
+		assertEquals(x, y);
+		x.shutdown();
+		y.shutdown();
+		int v = x.submit(() -> 5).get();
+		int v2 = y.submit(() -> 5).get();
+
+		Thread.currentThread().setName(name);
+		assertEquals(x, y);
+		assertEquals(v, v2);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void testGetSameCustomThreadCountExecuteV4() throws InterruptedException, ExecutionException {
+		// choosing 7 because the machine is unlikely to have 7 logical cores.
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		CommonThreadPool.shutdownAsyncPools();
+		ExecutorService x = CommonThreadPool.get(5);
+		ExecutorService y = CommonThreadPool.get(7);
+		assertNotEquals(x, y);
+		x.shutdown();
+		int v = x.submit(() -> 5).get();
+		int v2 = y.submit(() -> 5).get();
+		y.shutdown();
+
+		Thread.currentThread().setName(name);
+		assertEquals(v, v2);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void testFromOtherThread() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		ExecutorService x = CommonThreadPool.get(5);
+		Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(5));
+		ExecutorService y = a.get();
+		assertNotEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void testFromOtherThreadInfrastructureParallelism() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		final int k = InfrastructureAnalyzer.getLocalParallelism();
+		ExecutorService x = CommonThreadPool.get(k);
+		Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(k));
+		ExecutorService y = a.get();
+		assertEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void dynamic() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		final int k = InfrastructureAnalyzer.getLocalParallelism();
+		ExecutorService x = CommonThreadPool.getDynamicPool();
+		Future<ExecutorService> a = x.submit(() -> CommonThreadPool.get(k));
+		ExecutorService y = a.get();
+		assertNotEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void dynamicSame() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		ExecutorService x = CommonThreadPool.getDynamicPool();
+		ExecutorService y = CommonThreadPool.getDynamicPool();
+		assertEquals(x, y);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void isSharedTPThreads() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		for(int i = 0; i < 10; i++)
+			assertTrue(CommonThreadPool.isSharedTPThreads(i));
+
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void isSharedTPThreadsCommonSize() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		assertTrue(CommonThreadPool.isSharedTPThreads(InfrastructureAnalyzer.getLocalParallelism()));
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void isSharedTPThreadsFalse() throws InterruptedException, ExecutionException {
+		CommonThreadPool.shutdownAsyncPools();
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		CommonThreadPool.get(18);
+		for(int i = 1; i < 10; i++)
+			if(i != InfrastructureAnalyzer.getLocalParallelism())
+				assertFalse("" + i, CommonThreadPool.isSharedTPThreads(i));
+		assertTrue(CommonThreadPool.isSharedTPThreads(18));
+		assertFalse(CommonThreadPool.isSharedTPThreads(19));
+
+		Thread.currentThread().setName(name);
+		CommonThreadPool.shutdownAsyncPools();
+	}
+
+	@Test
+	public void justWorks() throws InterruptedException, ExecutionException {
+
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		for(int j = 0; j < 2; j++) {
+			for(int i = 4; i < 17; i++) {
+				ExecutorService p = CommonThreadPool.get(i);
+				final Integer l = i;
+				assertEquals(l, p.submit(() -> l).get());
+				p.shutdown();
+			}
+		}
+		Thread.currentThread().setName(name);
+	}
+
+	@Test
+	public void justWorksNotMain() throws InterruptedException, ExecutionException {
+
+		for(int j = 0; j < 2; j++) {
+
+			for(int i = 4; i < 10; i++) {
+				ExecutorService p = CommonThreadPool.get(i);
+				final Integer l = i;
+				assertEquals(l, p.submit(() -> l).get());
+				p.shutdown();
+
+			}
+		}
+	}
+
+	@Test
+	public void justWorksShutdownNow() throws InterruptedException, ExecutionException {
+
+		String name = Thread.currentThread().getName();
+		Thread.currentThread().setName("main");
+		for(int j = 0; j < 2; j++) {
+
+			for(int i = 4; i < 16; i++) {
+				ExecutorService p = CommonThreadPool.get(i);
+				final Integer l = i;
+				assertEquals(l, p.submit(() -> l).get());
+				p.shutdownNow();
+
+			}
+		}
+		Thread.currentThread().setName(name);
+	}
+
+	@Test
+	public void justWorksShutdownNowNotMain() throws InterruptedException, ExecutionException {
+
+		for(int j = 0; j < 2; j++) {
+
+			for(int i = 4; i < 16; i++) {
+				ExecutorService p = CommonThreadPool.get(i);
+				final Integer l = i;
+				assertEquals(l, p.submit(() -> l).get());
+				p.shutdownNow();
+
+			}
+		}
+	}
+
+	@Test
+	public void mock1() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException,
+		InterruptedException, ExecutionException, TimeoutException {
+
+		ExecutorService p = mock(ExecutorService.class);
+		ExecutorService c = new CommonThreadPool(p);
+
+		when(p.shutdownNow()).thenReturn(null);
+		assertNull(c.shutdownNow());
+
+		Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+		when(p.invokeAll(cc)).thenReturn(null);
+		assertNull(c.invokeAll(cc));
+		when(p.invokeAll(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+		assertNull(c.invokeAll(cc, 1, TimeUnit.DAYS));
+		doNothing().when(p).execute((Runnable) null);
+		c.execute((Runnable) null);
+
+		when(p.submit((Callable<Integer>) null)).thenReturn(null);
+		assertNull(c.submit((Callable<Integer>) null));
+
+		when(p.submit((Runnable) null, null)).thenReturn(null);
+		assertNull(c.submit((Runnable) null, null));
+		// when(tp.pool()).thenReturn(p);
+
+		when(p.submit((Runnable) null)).thenReturn(null);
+		assertNull(c.submit((Runnable) null));
+
+		when(p.isShutdown()).thenReturn(false);
+		assertFalse(c.isShutdown());
+		when(p.isShutdown()).thenReturn(true);
+		assertTrue(c.isShutdown());
+
+		when(p.isTerminated()).thenReturn(false);
+		assertFalse(c.isTerminated());
+		when(p.isTerminated()).thenReturn(true);
+		assertTrue(c.isTerminated());
+
+		when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(false);
+		assertFalse(c.awaitTermination(10, TimeUnit.DAYS));
+		when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(true);
+		assertTrue(c.awaitTermination(10, TimeUnit.DAYS));
+
+		when(p.invokeAny(cc)).thenReturn(null);
+		assertNull(c.invokeAny(cc));
+		when(p.invokeAny(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+		assertNull(c.invokeAny(cc, 1, TimeUnit.DAYS));
+		doNothing().when(p).execute((Runnable) null);
+		c.execute((Runnable) null);
+
+	}
+
+	@Test
+	public void mock2() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException,
+		InterruptedException, ExecutionException, TimeoutException {
+
+		CommonThreadPool p = mock(CommonThreadPool.class);
+		when(p.isShutdown()).thenCallRealMethod();
+		when(p.isTerminated()).thenCallRealMethod();
+		when(p.awaitTermination(10, TimeUnit.DAYS)).thenCallRealMethod();
+		when(p.isCached()).thenReturn(true);
+		assertTrue(p.isShutdown());
+		assertTrue(p.isTerminated());
+		assertTrue(p.awaitTermination(10, TimeUnit.DAYS));
+	}
+
+	@Test
+	public void coverEdge() {
+		ExecutorService a = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+		assertTrue(new CommonThreadPool(a).isCached());
+	}
+
+	@Test(expected = DMLRuntimeException.class)
+	public void invokeAndShutdownException() throws InterruptedException {
+		ExecutorService p = mock(ExecutorService.class);
+		ExecutorService c = new CommonThreadPool(p);
+
+		when(p.invokeAll(null)).thenThrow(new RuntimeException("Test"));
+
+		CommonThreadPool.invokeAndShutdown(p, null);
+
+	}
+
+	@Test
+	public void invokeAndShutdown() throws InterruptedException {
+
+		ExecutorService p = mock(ExecutorService.class);
+		ExecutorService c = new CommonThreadPool(p);
+
+		Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+		when(p.invokeAll(cc)).thenReturn(new ArrayList<Future<Integer>>());
+
+		CommonThreadPool.invokeAndShutdown(c, null);
+
+	}
+
+	@Test
+	@SuppressWarnings("all")
+	public void invokeAndShutdownV2() throws InterruptedException{
+		
+		ExecutorService p = mock(ExecutorService.class);
+		ExecutorService c = new CommonThreadPool(p);
+
+		Collection<Callable<Integer>> cc = (Collection<Callable<Integer>>) null;
+		List<Future<Integer>> f = new ArrayList<Future<Integer>>();
+		f.add(mock(Future.class));
+		when(p.invokeAll(cc)).thenReturn(f );
+
+		CommonThreadPool.invokeAndShutdown(c, null);
+
+	}
+}