This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 68c2c17db5 [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
68c2c17db5 is described below

commit 68c2c17db56c8c9ea5b1985bbf3b525d4ac8d022
Author: baunsgaard <[email protected]>
AuthorDate: Mon Aug 7 13:40:42 2023 +0200

    [SYSTEMDS-3572] CommonThreadPool Reuse ThreadLocal Pools
    
    This commit allows reuse of custom size thread pools, but only for the
    main thread. Only allowing the main thread to reuse a pool avoids
    problems with parfor spawning threads that use the same shared pool.
    I tried using ThreadLocal to solve this problem initially, but this
    did not work with our testing framework while it did work in practice.
    This implementation is a compromise to work with the test framework,
    while not introducing to much code.
    
    Closes #1873
---
 src/main/java/org/apache/sysds/api/DMLScript.java  |   2 +-
 .../runtime/controlprogram/ParForProgramBlock.java |  23 +-
 .../instructions/cp/BroadcastCPInstruction.java    |   7 +-
 .../instructions/cp/PrefetchCPInstruction.java     |   6 +-
 .../spark/AggregateUnarySPInstruction.java         |  12 +-
 .../spark/CheckpointSPInstruction.java             |   6 +-
 .../instructions/spark/CpmmSPInstruction.java      |  16 +-
 .../instructions/spark/MapmmSPInstruction.java     |   7 +-
 .../instructions/spark/TsmmSPInstruction.java      |  12 +-
 .../instructions/spark/ZipmmSPInstruction.java     |  14 +-
 .../runtime/lineage/LineageSparkCacheEviction.java |  13 +-
 .../sysds/runtime/util/CommonThreadPool.java       | 181 ++++++---
 .../sysds/test/component/misc/ThreadPool.java      | 408 +++++++++++++++++++++
 13 files changed, 593 insertions(+), 114 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index ff70b330ae..ddc5ee2517 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -573,7 +573,7 @@ public class DMLScript
                FederatedData.clearFederatedWorkers();
                
                //0) shutdown prefetch/broadcast thread pool if necessary
-               CommonThreadPool.shutdownAsyncRDDPool();
+               CommonThreadPool.shutdownAsyncPools();
 
                //1) cleanup scratch space (everything for current uuid)
                //(required otherwise export to hdfs would skip assumed 
unnecessary writes if same name)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 2fc12c4c26..94bbaf2545 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.runtime.controlprogram;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
@@ -89,6 +91,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.stats.ParForStatistics;
@@ -118,8 +121,8 @@ import java.util.stream.Stream;
  * TODO: papply(A,1:2,FUN) language construct (compiled to ParFOR) via DML 
function repository =&gt; modules OK, but second-order functions required
  *
  */
-public class ParForProgramBlock extends ForProgramBlock 
-{      
+public class ParForProgramBlock extends ForProgramBlock {      
+       protected static final Log LOG = 
LogFactory.getLog(CommonThreadPool.class.getName());
        // execution modes
        public enum PExecMode {
                LOCAL,          //local (master) multi-core execution mode
@@ -759,7 +762,7 @@ public class ParForProgramBlock extends ForProgramBlock
                        LocalTaskQueue<Task> queue = new LocalTaskQueue<>();
                        Thread[] threads         = new Thread[_numThreads];
                        LocalParWorker[] workers = new 
LocalParWorker[_numThreads];
-                       IntStream.range(0, _numThreads).parallel().forEach(i -> 
{
+                       IntStream.range(0, _numThreads).forEach(i -> {
                                workers[i] = createParallelWorker( _pwIDs[i], 
queue, ec, i);
                                threads[i] = new Thread( workers[i] );
                                threads[i].setPriority(Thread.MAX_PRIORITY);
@@ -1430,9 +1433,14 @@ public class ParForProgramBlock extends ForProgramBlock
                }
        }
 
-       private void consolidateAndCheckResults(ExecutionContext ec, long 
expIters, long expTasks, long numIters, long numTasks, LocalVariableMap [] 
results) 
-       {
+       private void consolidateAndCheckResults(ExecutionContext ec, final long 
expIters, final long expTasks,
+               final long numIters, final long numTasks, LocalVariableMap[] 
results) {
                Timing time = new Timing(true);
+
+               //check expected counters
+               if( numTasks != expTasks || numIters !=expIters ) //consistency 
check
+                       throw new DMLRuntimeException("PARFOR: Number of 
executed tasks does not match the number of created tasks: tasks 
"+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
+       
                
                //result merge
                if( checkParallelRemoteResultMerge() )
@@ -1531,10 +1539,7 @@ public class ParForProgramBlock extends ForProgramBlock
                if( CREATE_UNSCOPED_RESULTVARS && sb != null && 
ec.getVariables() != null ) //sb might be null for nested parallelism
                        createEmptyUnscopedVariables( ec.getVariables(), sb );
                
-               //check expected counters
-               if( numTasks != expTasks || numIters !=expIters ) //consistency 
check
-                       throw new DMLRuntimeException("PARFOR: Number of 
executed tasks does not match the number of created tasks: tasks 
"+numTasks+"/"+expTasks+", iters "+numIters+"/"+expIters+".");
-       
+                       
                if( DMLScript.STATISTICS )
                        ParForStatistics.incrementMergeTime((long) time.stop());
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index aa0be7daec..58378bf7e4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.concurrent.Executors;
-
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -43,9 +41,6 @@ public class BroadcastCPInstruction extends 
UnaryCPInstruction {
        @Override
        public void processInstruction(ExecutionContext ec) {
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-
-               if (CommonThreadPool.triggerRemoteOPsPool == null)
-                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
+               CommonThreadPool.getDynamicPool().submit(new 
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 96e4b7afe2..233509a5b8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.concurrent.Executors;
-
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
@@ -53,10 +51,8 @@ public class PrefetchCPInstruction extends 
UnaryCPInstruction {
                // If the next instruction which takes this output as an input 
comes before
                // the prefetch thread triggers, that instruction will start 
the operations.
                // In that case this Prefetch instruction will act like a NOOP. 
-               if (CommonThreadPool.triggerRemoteOPsPool == null)
-                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
                // Saving the lineage item inside the matrix object will 
replace the pre-attached
                // lineage item (e.g. mapmm). Hence, passing separately.
-               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output), li));
+               CommonThreadPool.getDynamicPool().submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output), li));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 50816aefe4..32b80a2360 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -49,11 +52,8 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class AggregateUnarySPInstruction extends UnarySPInstruction {
        private SparkAggType _aggtype = null;
@@ -115,10 +115,8 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                                //Trigger the chain of Spark operations and 
maintain a future to the result
                                //TODO: Make memory for the future matrix block
                                try {
-                                       
if(CommonThreadPool.triggerRemoteOPsPool == null)
-                                               
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        RDDAggregateTask task = new 
RDDAggregateTask(_optr, _aop, in, mc);
-                                       Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       Future<MatrixBlock> future_out = 
CommonThreadPool.getDynamicPool().submit(task);
                                        LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
                                        
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 2be663bdbc..4ee56a7eca 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -51,8 +51,6 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.Statistics;
 
-import java.util.concurrent.Executors;
-
 public class CheckpointSPInstruction extends UnarySPInstruction {
        // default storage level
        private StorageLevel _level = null;
@@ -86,9 +84,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        // TODO: Synchronize. Avoid double execution
                        ec.setVariable(output.getName(), 
ec.getCacheableData(input1));
 
-                       if (CommonThreadPool.triggerRemoteOPsPool == null)
-                               CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-                       CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerCheckpointTask(ec.getMatrixObject(output)));
+                       CommonThreadPool.getDynamicPool().submit(new 
TriggerCheckpointTask(ec.getMatrixObject(output)));
                        return;
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 6425613583..602d74a275 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -47,11 +50,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 /**
  * Cpmm: cross-product matrix multiplication operation (distributed matrix 
multiply
@@ -112,10 +112,8 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                {
                        if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                                try {
-                                       
if(CommonThreadPool.triggerRemoteOPsPool == null)
-                                               
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        CpmmMatrixVectorTask task = new 
CpmmMatrixVectorTask(in1, in2);
-                                       Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       Future<MatrixBlock> future_out = 
CommonThreadPool.getDynamicPool().submit(task);
                                        LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
                                        
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
@@ -147,10 +145,8 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                        {
                                if 
(ConfigurationManager.isMaxPrallelizeEnabled()) {
                                        try {
-                                               
if(CommonThreadPool.triggerRemoteOPsPool == null)
-                                                       
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                                CpmmMatrixMatrixTask task = new 
CpmmMatrixMatrixTask(in1, in2, numPartJoin);
-                                               Future<MatrixBlock> future_out 
= CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                               Future<MatrixBlock> future_out 
= CommonThreadPool.getDynamicPool().submit(task);
                                                
sec.setMatrixOutput(output.getName(), future_out);
                                        }
                                        catch(Exception ex) { throw new 
DMLRuntimeException(ex); }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index b0285b1bba..080de52d23 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.spark;
 
 import java.util.Iterator;
 import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
 
@@ -59,8 +58,8 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
-
 import org.apache.sysds.runtime.util.CommonThreadPool;
+
 import scala.Tuple2;
 
 public class MapmmSPInstruction extends AggregateBinarySPInstruction {
@@ -144,10 +143,8 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                {
                        if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                                try {
-                                       
if(CommonThreadPool.triggerRemoteOPsPool == null)
-                                               
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        RDDMapmmTask task = new  
RDDMapmmTask(in1, in2, type);
-                                       Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       Future<MatrixBlock> future_out = 
CommonThreadPool.getDynamicPool().submit(task);
                                        LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
                                        
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index acba784bf2..dd6ddb526d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -20,6 +20,9 @@
 package org.apache.sysds.runtime.instructions.spark;
 
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -37,11 +40,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class TsmmSPInstruction extends UnarySPInstruction {
        private MMTSJType _type = null;
@@ -72,10 +72,8 @@ public class TsmmSPInstruction extends UnarySPInstruction {
 
                if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                        try {
-                               if (CommonThreadPool.triggerRemoteOPsPool == 
null)
-                                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
                                TsmmTask task = new TsmmTask(in, _type);
-                               Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                               Future<MatrixBlock> future_out = 
CommonThreadPool.getDynamicPool().submit(task);
                                LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
                                sec.setMatrixOutputAndLineage(output.getName(), 
future_out, li);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
index de7922e25a..18d88178a3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -42,11 +45,8 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
-import scala.Tuple2;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import scala.Tuple2;
 
 public class ZipmmSPInstruction extends BinarySPInstruction {
        // internal flag to apply left-transpose rewrite or not
@@ -86,10 +86,8 @@ public class ZipmmSPInstruction extends BinarySPInstruction {
 
                if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                        try {
-                               if (CommonThreadPool.triggerRemoteOPsPool == 
null)
-                                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-                               ZipmmTask task = new ZipmmTask(in1, in2, 
_tRewrite);
-                               Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       ZipmmTask task = new ZipmmTask(in1, 
in2, _tRewrite);
+                               Future<MatrixBlock> future_out = 
CommonThreadPool.getDynamicPool().submit(task);
                                LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
                                sec.setMatrixOutputAndLineage(output.getName(), 
future_out, li);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
index 65648508f8..84bb8598c0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
@@ -19,6 +19,10 @@
 
 package org.apache.sysds.runtime.lineage;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.TreeSet;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -27,11 +31,6 @@ import 
org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.TreeSet;
-import java.util.concurrent.Executors;
-
 public class LineageSparkCacheEviction
 {
        private static long SPARK_STORAGE_LIMIT = 0; //60% (upper limit of 
Spark unified memory)
@@ -212,9 +211,7 @@ public class LineageSparkCacheEviction
                int localHitCount = RDDHitCountLocal.get(e._key);
                if (localHitCount > 3) {
                        RDDHitCountLocal.remove(e._key);
-                       if (CommonThreadPool.triggerRemoteOPsPool == null)
-                               CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-                       CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteTask(e.getRDDObject().getRDD()));
+                       CommonThreadPool.getDynamicPool().submit(new 
TriggerRemoteTask(e.getRDDObject().getRDD()));
                }
        }
 
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index f96c4cc4af..cc6483d258 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -30,44 +30,118 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
-import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 
 /**
- * This common thread pool provides an abstraction to obtain a shared
- * thread pool, specifically the ForkJoinPool.commonPool, for all requests
- * of the maximum degree of parallelism. If pools of different size are
- * requested, we create new pool instances of FixedThreadPool.
+ * This common thread pool provides an abstraction to obtain a shared thread 
pool.
+ * 
+ * If the number of logical cores is specified a ForkJoinPool.commonPool is 
returned on all requests.
+ * 
+ * If pools of different size are requested, we create new pool instances of 
FixedThreadPool, Unless we currently are on
+ * the main thread, Then we return a shared instance of the first requested 
number of cores.
+ * 
+ * Alternatively the class also contain a dynamic threadPool, that is intended 
for asynchronous long running tasks with
+ * low compute overhead, such as broadcast and collect from federated workers.
  */
-public class CommonThreadPool implements ExecutorService
-{
-       //shared thread pool used system-wide, potentially by concurrent parfor 
workers
-       //we use the ForkJoinPool.commonPool() to avoid explicit cleanup, 
including
-       //unnecessary initialization (e.g., problematic in jmlc) and because 
this commonPool
-       //resulted in better performance than a dedicated fixed thread pool.
+public class CommonThreadPool implements ExecutorService {
+       /** Log object */
+       protected static final Log LOG = 
LogFactory.getLog(CommonThreadPool.class.getName());
+
+       /** The number of threads of the machine */
        private static final int size = 
InfrastructureAnalyzer.getLocalParallelism();
+       /**
+        * Shared thread pool used system-wide, potentially by concurrent 
parfor workers
+        * 
+        * we use the ForkJoinPool.commonPool() to avoid explicit cleanup, 
including unnecessary initialization (e.g.,
+        * problematic in jmlc) and because this commonPool resulted in better 
performance than a dedicated fixed thread
+        * pool.
+        */
        private static final ExecutorService shared = ForkJoinPool.commonPool();
+       /** A secondary thread local executor that use a custom number of 
threads */
+       private static CommonThreadPool shared2 = null;
+       /** The number of threads used in the custom secondary executor */
+       private static int shared2K = -1;
+       /** Dynamic thread pool, that dynamically allocate threads as tasks 
come in. */
+       private static ExecutorService asyncPool = null;
+       /** This common thread pool */
        private final ExecutorService _pool;
-       public static ExecutorService triggerRemoteOPsPool = null;
 
+       /**
+        * Constructor of the threadPool.
+        * This is intended not to be used except for tests.
+        * Please use the static constructors.
+        * 
+        * @param pool The thread pool instance to use.
+        */
        public CommonThreadPool(ExecutorService pool) {
-               _pool = pool;
-       }
-
+               this._pool = pool;
+       }
+
+       /**
+        * Get the shared Executor thread pool, that have the number of threads 
of the host system
+        * 
+        * @return An ExecutorService
+        */
+       public static ExecutorService get() {
+               return shared;
+       }
+
+       /**
+        * Get a Executor thread pool, that have the number of threads 
specified in k.
+        * 
+        * The thread pool can be reused by other processes in the same host 
thread requesting another pool of the same
+        * number of threads. The executor that is guaranteed ThreadLocal 
except if it is number of host logical cores.
+        * 
+        * 
+        * @param k The number of threads wanted
+        * @return The executor with specified parallelism
+        */
        public static ExecutorService get(int k) {
-               return new CommonThreadPool( (size==k) ?
-                       shared : Executors.newFixedThreadPool(k));
-       }
-       
+               if(size == k)
+                       return shared;
+               else if(Thread.currentThread().getName().equals("main")) {
+                       if(shared2 != null && shared2K == k)
+                               return shared2;
+                       else if(shared2 == null) {
+                               shared2 = new 
CommonThreadPool(Executors.newFixedThreadPool(k));
+                               shared2K = k;
+                               return shared2;
+                       }
+                       else
+                               return new 
CommonThreadPool(Executors.newFixedThreadPool(k));
+               }
+               else
+                       return new 
CommonThreadPool(Executors.newFixedThreadPool(k));
+       }
+
+       /**
+        * Get if there is a current thread pool that have the given 
parallelism locally.
+        * 
+        * @param k the parallelism
+        * @return If we have a cached thread pool.
+        */
+       public static boolean isSharedTPThreads(int k) {
+               return size == k || shared2K == k || shared2K == -1;
+       }
+
+       /**
+        * Invoke the collection of tasks and shutdown the pool upon job 
termination.
+        * 
+        * @param <T>   The type of class to return from the job
+        * @param pool  The pool to execute in
+        * @param tasks The tasks to execute
+        */
        public static <T> void invokeAndShutdown(ExecutorService pool, 
Collection<? extends Callable<T>> tasks) {
                try {
-                       //execute tasks
+                       // execute tasks
                        List<Future<T>> ret = pool.invokeAll(tasks);
-                       //check for errors and exceptions
-                       for( Future<T> r : ret )
+                       // check for errors and exceptions
+                       for(Future<T> r : ret)
                                r.get();
-                       //shutdown pool
+                       // shutdown pool
                        pool.shutdown();
                }
                catch(Exception ex) {
@@ -75,28 +149,51 @@ public class CommonThreadPool implements ExecutorService
                }
        }
 
-       public static void shutdownShared() {
-               shared.shutdownNow();
+       /**
+        * Get a dynamic thread pool that allocate threads as the requests are 
made. This pool is intended for async remote
+        * calls that does not depend on local compute.
+        * 
+        * @return A dynamic thread pool.
+        */
+       public static ExecutorService getDynamicPool() {
+               if(asyncPool != null)
+                       return asyncPool;
+               else {
+                       asyncPool = Executors.newCachedThreadPool();
+                       return asyncPool;
+               }
        }
 
-       public static void shutdownAsyncRDDPool() {
-               if (triggerRemoteOPsPool != null) {
-                       //shutdown prefetch/broadcast thread pool
-                       triggerRemoteOPsPool.shutdown();
-                       triggerRemoteOPsPool = null;
+       /**
+        * Shutdown the cached thread pools.
+        */
+       public static void shutdownAsyncPools() {
+               if(asyncPool != null) {
+                       // shutdown prefetch/broadcast thread pool
+                       asyncPool.shutdown();
+                       asyncPool = null;
+               }
+               if(shared2 != null) {
+                       // shutdown shared custom thread count pool
+                       shared2.shutdown();
+                       shared2 = null;
+                       shared2K = -1;
                }
        }
 
+       public final boolean isCached() {
+               return _pool.equals(shared) || this.equals(shared2);
+       }
+
        @Override
        public void shutdown() {
-               if( _pool != shared )
+               if(!isCached())
                        _pool.shutdown();
        }
 
        @Override
        public List<Runnable> shutdownNow() {
-               return ( _pool != shared ) ?
-                       _pool.shutdownNow() : null;
+               return !isCached() ? _pool.shutdownNow() : null;
        }
 
        @Override
@@ -106,10 +203,10 @@ public class CommonThreadPool implements ExecutorService
 
        @Override
        public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> 
tasks, long timeout, TimeUnit unit)
-                       throws InterruptedException {
+               throws InterruptedException {
                return _pool.invokeAll(tasks, timeout, unit);
        }
-       
+
        @Override
        public void execute(Runnable command) {
                _pool.execute(command);
@@ -130,31 +227,29 @@ public class CommonThreadPool implements ExecutorService
                return _pool.submit(task);
        }
 
-       
-       //unnecessary methods required for API compliance
        @Override
        public boolean isShutdown() {
-               throw new NotImplementedException();
+               return isCached() || _pool.isShutdown();
        }
 
        @Override
        public boolean isTerminated() {
-               throw new NotImplementedException();
+               return isCached() || _pool.isTerminated();
        }
 
        @Override
        public boolean awaitTermination(long timeout, TimeUnit unit) throws 
InterruptedException {
-               throw new NotImplementedException();
+               return isCached() || _pool.awaitTermination(timeout, unit);
        }
 
        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws 
InterruptedException, ExecutionException {
-               throw new NotImplementedException();
+               return _pool.invokeAny(tasks);
        }
 
        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long 
timeout, TimeUnit unit)
-                       throws InterruptedException, ExecutionException, 
TimeoutException {
-               throw new NotImplementedException();
+               throws InterruptedException, ExecutionException, 
TimeoutException {
+               return _pool.invokeAny(tasks);
        }
 }
diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java 
b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
new file mode 100644
index 0000000000..ca79e8800b
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.misc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.junit.Test;
+
+public class ThreadPool {
+       protected static final Log LOG = 
LogFactory.getLog(ThreadPool.class.getName());
+
+       @Test
+       public void testGetTheSame() {
+               CommonThreadPool.shutdownAsyncPools();
+               ExecutorService x = CommonThreadPool.get();
+               ExecutorService y = CommonThreadPool.get();
+               x.shutdown();
+               y.shutdown();
+
+               assertEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+               CommonThreadPool.shutdownAsyncPools();
+
+       }
+
+       @Test
+       public void testGetSameCustomThreadCount() {
+               CommonThreadPool.shutdownAsyncPools();
+               // choosing 7 because the machine is unlikely to have 7 logical 
cores.
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               ExecutorService x = CommonThreadPool.get(7);
+               ExecutorService y = CommonThreadPool.get(7);
+               x.shutdown();
+               y.shutdown();
+
+               Thread.currentThread().setName(name);
+               assertEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+               CommonThreadPool.shutdownAsyncPools();
+
+       }
+
+       @Test
+       public void testGetSameCustomThreadCountExecute() throws 
InterruptedException, ExecutionException {
+               // choosing 7 because the machine is unlikely to have 7 logical 
cores.
+               CommonThreadPool.shutdownAsyncPools();
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               ExecutorService x = CommonThreadPool.get(7);
+               ExecutorService y = CommonThreadPool.get(7);
+               assertEquals(x, y);
+               int v = x.submit(() -> 5).get();
+               x.shutdown();
+               int v2 = y.submit(() -> 5).get();
+               y.shutdown();
+
+               Thread.currentThread().setName(name);
+               assertEquals(x, y);
+               assertEquals(v, v2);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void testGetSameCustomThreadCountExecuteV2() throws 
InterruptedException, ExecutionException {
+               // choosing 7 because the machine is unlikely to have 7 logical 
cores.
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               ExecutorService x = CommonThreadPool.get(7);
+               ExecutorService y = CommonThreadPool.get(7);
+               assertEquals(x, y);
+               int v = x.submit(() -> 5).get();
+               int v2 = y.submit(() -> 5).get();
+               x.shutdown();
+               y.shutdown();
+
+               Thread.currentThread().setName(name);
+               assertEquals(x, y);
+               assertEquals(v, v2);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void testGetSameCustomThreadCountExecuteV3() throws 
InterruptedException, ExecutionException {
+               // choosing 7 because the machine is unlikely to have 7 logical 
cores.
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               ExecutorService x = CommonThreadPool.get(7);
+               ExecutorService y = CommonThreadPool.get(7);
+               assertEquals(x, y);
+               x.shutdown();
+               y.shutdown();
+               int v = x.submit(() -> 5).get();
+               int v2 = y.submit(() -> 5).get();
+
+               Thread.currentThread().setName(name);
+               assertEquals(x, y);
+               assertEquals(v, v2);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void testGetSameCustomThreadCountExecuteV4() throws 
InterruptedException, ExecutionException {
+               // choosing 7 because the machine is unlikely to have 7 logical 
cores.
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               CommonThreadPool.shutdownAsyncPools();
+               ExecutorService x = CommonThreadPool.get(5);
+               ExecutorService y = CommonThreadPool.get(7);
+               assertNotEquals(x, y);
+               x.shutdown();
+               int v = x.submit(() -> 5).get();
+               int v2 = y.submit(() -> 5).get();
+               y.shutdown();
+
+               Thread.currentThread().setName(name);
+               assertEquals(v, v2);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void testFromOtherThread() throws InterruptedException, 
ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               ExecutorService x = CommonThreadPool.get(5);
+               Future<ExecutorService> a = x.submit(() -> 
CommonThreadPool.get(5));
+               ExecutorService y = a.get();
+               assertNotEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void testFromOtherThreadInfrastructureParallelism() throws 
InterruptedException, ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               final int k = InfrastructureAnalyzer.getLocalParallelism();
+               ExecutorService x = CommonThreadPool.get(k);
+               Future<ExecutorService> a = x.submit(() -> 
CommonThreadPool.get(k));
+               ExecutorService y = a.get();
+               assertEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void dynamic() throws InterruptedException, ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               final int k = InfrastructureAnalyzer.getLocalParallelism();
+               ExecutorService x = CommonThreadPool.getDynamicPool();
+               Future<ExecutorService> a = x.submit(() -> 
CommonThreadPool.get(k));
+               ExecutorService y = a.get();
+               assertNotEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void dynamicSame() throws InterruptedException, 
ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               ExecutorService x = CommonThreadPool.getDynamicPool();
+               ExecutorService y = CommonThreadPool.getDynamicPool();
+               assertEquals(x, y);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void isSharedTPThreads() throws InterruptedException, 
ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               for(int i = 0; i < 10; i++)
+                       assertTrue(CommonThreadPool.isSharedTPThreads(i));
+
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void isSharedTPThreadsCommonSize() throws InterruptedException, 
ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               
assertTrue(CommonThreadPool.isSharedTPThreads(InfrastructureAnalyzer.getLocalParallelism()));
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void isSharedTPThreadsFalse() throws InterruptedException, 
ExecutionException {
+               CommonThreadPool.shutdownAsyncPools();
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               CommonThreadPool.get(18);
+               for(int i = 1; i < 10; i++)
+                       if(i != InfrastructureAnalyzer.getLocalParallelism())
+                               assertFalse("" + i, 
CommonThreadPool.isSharedTPThreads(i));
+               assertTrue(CommonThreadPool.isSharedTPThreads(18));
+               assertFalse(CommonThreadPool.isSharedTPThreads(19));
+
+               Thread.currentThread().setName(name);
+               CommonThreadPool.shutdownAsyncPools();
+       }
+
+       @Test
+       public void justWorks() throws InterruptedException, ExecutionException 
{
+
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               for(int j = 0; j < 2; j++) {
+                       for(int i = 4; i < 17; i++) {
+                               ExecutorService p = CommonThreadPool.get(i);
+                               final Integer l = i;
+                               assertEquals(l, p.submit(() -> l).get());
+                               p.shutdown();
+                       }
+               }
+               Thread.currentThread().setName(name);
+       }
+
+       @Test
+       public void justWorksNotMain() throws InterruptedException, 
ExecutionException {
+
+               for(int j = 0; j < 2; j++) {
+
+                       for(int i = 4; i < 10; i++) {
+                               ExecutorService p = CommonThreadPool.get(i);
+                               final Integer l = i;
+                               assertEquals(l, p.submit(() -> l).get());
+                               p.shutdown();
+
+                       }
+               }
+       }
+
+       @Test
+       public void justWorksShutdownNow() throws InterruptedException, 
ExecutionException {
+
+               String name = Thread.currentThread().getName();
+               Thread.currentThread().setName("main");
+               for(int j = 0; j < 2; j++) {
+
+                       for(int i = 4; i < 16; i++) {
+                               ExecutorService p = CommonThreadPool.get(i);
+                               final Integer l = i;
+                               assertEquals(l, p.submit(() -> l).get());
+                               p.shutdownNow();
+
+                       }
+               }
+               Thread.currentThread().setName(name);
+       }
+
+       @Test
+       public void justWorksShutdownNowNotMain() throws InterruptedException, 
ExecutionException {
+
+               for(int j = 0; j < 2; j++) {
+
+                       for(int i = 4; i < 16; i++) {
+                               ExecutorService p = CommonThreadPool.get(i);
+                               final Integer l = i;
+                               assertEquals(l, p.submit(() -> l).get());
+                               p.shutdownNow();
+
+                       }
+               }
+       }
+
+       @Test
+       public void mock1() throws NoSuchFieldException, SecurityException, 
IllegalArgumentException, IllegalAccessException,
+               InterruptedException, ExecutionException, TimeoutException {
+
+               ExecutorService p = mock(ExecutorService.class);
+               ExecutorService c = new CommonThreadPool(p);
+
+               when(p.shutdownNow()).thenReturn(null);
+               assertNull(c.shutdownNow());
+
+               Collection<Callable<Integer>> cc = 
(Collection<Callable<Integer>>) null;
+               when(p.invokeAll(cc)).thenReturn(null);
+               assertNull(c.invokeAll(cc));
+               when(p.invokeAll(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+               assertNull(c.invokeAll(cc, 1, TimeUnit.DAYS));
+               doNothing().when(p).execute((Runnable) null);
+               c.execute((Runnable) null);
+
+               when(p.submit((Callable<Integer>) null)).thenReturn(null);
+               assertNull(c.submit((Callable<Integer>) null));
+
+               when(p.submit((Runnable) null, null)).thenReturn(null);
+               assertNull(c.submit((Runnable) null, null));
+               // when(tp.pool()).thenReturn(p);
+
+               when(p.submit((Runnable) null)).thenReturn(null);
+               assertNull(c.submit((Runnable) null));
+
+               when(p.isShutdown()).thenReturn(false);
+               assertFalse(c.isShutdown());
+               when(p.isShutdown()).thenReturn(true);
+               assertTrue(c.isShutdown());
+
+               when(p.isTerminated()).thenReturn(false);
+               assertFalse(c.isTerminated());
+               when(p.isTerminated()).thenReturn(true);
+               assertTrue(c.isTerminated());
+
+               when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(false);
+               assertFalse(c.awaitTermination(10, TimeUnit.DAYS));
+               when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(true);
+               assertTrue(c.awaitTermination(10, TimeUnit.DAYS));
+
+               when(p.invokeAny(cc)).thenReturn(null);
+               assertNull(c.invokeAny(cc));
+               when(p.invokeAny(cc, 1L, TimeUnit.DAYS)).thenReturn(null);
+               assertNull(c.invokeAny(cc, 1, TimeUnit.DAYS));
+               doNothing().when(p).execute((Runnable) null);
+               c.execute((Runnable) null);
+
+       }
+
+       @Test
+       public void mock2() throws NoSuchFieldException, SecurityException, 
IllegalArgumentException, IllegalAccessException,
+               InterruptedException, ExecutionException, TimeoutException {
+
+               CommonThreadPool p = mock(CommonThreadPool.class);
+               when(p.isShutdown()).thenCallRealMethod();
+               when(p.isTerminated()).thenCallRealMethod();
+               when(p.awaitTermination(10, 
TimeUnit.DAYS)).thenCallRealMethod();
+               when(p.isCached()).thenReturn(true);
+               assertTrue(p.isShutdown());
+               assertTrue(p.isTerminated());
+               assertTrue(p.awaitTermination(10, TimeUnit.DAYS));
+       }
+
+       @Test
+       public void coverEdge() {
+               ExecutorService a = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+               assertTrue(new CommonThreadPool(a).isCached());
+       }
+
+       @Test(expected = DMLRuntimeException.class)
+       public void invokeAndShutdownException() throws InterruptedException {
+               ExecutorService p = mock(ExecutorService.class);
+               ExecutorService c = new CommonThreadPool(p);
+
+               when(p.invokeAll(null)).thenThrow(new RuntimeException("Test"));
+
+               CommonThreadPool.invokeAndShutdown(p, null);
+
+       }
+
+       @Test
+       public void invokeAndShutdown() throws InterruptedException {
+
+               ExecutorService p = mock(ExecutorService.class);
+               ExecutorService c = new CommonThreadPool(p);
+
+               Collection<Callable<Integer>> cc = 
(Collection<Callable<Integer>>) null;
+               when(p.invokeAll(cc)).thenReturn(new 
ArrayList<Future<Integer>>());
+
+               CommonThreadPool.invokeAndShutdown(c, null);
+
+       }
+
+       @Test
+       @SuppressWarnings("all")
+       public void invokeAndShutdownV2() throws InterruptedException{
+               
+               ExecutorService p = mock(ExecutorService.class);
+               ExecutorService c = new CommonThreadPool(p);
+
+               Collection<Callable<Integer>> cc = 
(Collection<Callable<Integer>>) null;
+               List<Future<Integer>> f = new ArrayList<Future<Integer>>();
+               f.add(mock(Future.class));
+               when(p.invokeAll(cc)).thenReturn(f );
+
+               CommonThreadPool.invokeAndShutdown(c, null);
+
+       }
+}


Reply via email to