This is an automated email from the ASF dual-hosted git repository. kinnerebner pushed a commit to branch paramserv_spark_fix in repository https://gitbox.apache.org/repos/asf/systemds.git
commit f586eaa8b95aefc7c67eea379b69405463632447 Author: Kevin Innerebner <[email protected]> AuthorDate: Mon Jul 11 22:47:01 2022 +0200 [MINOR] Fix Spark ParameterServer This patch fixes the Spark execution mode for the parameter server. In commit 28ff18fca2a9258168db7397d56236a5e0d9564b the handling of functions was changed, leading to the parameter server in Spark mode, not finding or sending the functions to the workers properly. Closes #1662 --- .../runtime/controlprogram/paramserv/ParamServer.java | 18 ++++++++++-------- .../controlprogram/paramserv/ParamservUtils.java | 3 +++ .../controlprogram/paramserv/SparkPSWorker.java | 3 +++ .../instructions/cp/ParamservBuiltinCPInstruction.java | 4 ++-- .../test/functions/paramserv/ParamservSparkNNTest.java | 5 ++++- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java index 3957965988..e88a19d964 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java @@ -78,7 +78,8 @@ public abstract class ParamServer private int _numWorkers; private int _numBackupWorkers; - private boolean[] _discardWorkerRes; + // number of updates the respective worker is straggling behind + private int[] _numUpdatesStraggling; private boolean _modelAvg; private ListObject _accModels = null; @@ -109,7 +110,7 @@ public abstract class ParamServer _numBatchesPerEpoch = numBatchesPerEpoch; _numWorkers = workerNum; _numBackupWorkers = numBackupWorkers; - _discardWorkerRes = new boolean[workerNum]; + _numUpdatesStraggling = new int[workerNum]; _modelAvg = modelAvg; // broadcast initial model @@ -118,6 +119,8 @@ public abstract class ParamServer protected void setupAggFunc(ExecutionContext ec, String aggFunc) { String[] cfn = DMLProgram.splitFunctionKey(aggFunc); + if(cfn.length == 1) + cfn = new String[] {null, cfn[0]}; String ns = cfn[0]; String fname = cfn[1]; boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, fname, false); @@ -240,10 +243,10 @@ public abstract class ParamServer break; } case SBP: { - if(_discardWorkerRes[workerID]) { + if(_numUpdatesStraggling[workerID] > 0) { LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID); broadcastModel(workerID); - _discardWorkerRes[workerID] = false; + _numUpdatesStraggling[workerID]--; break; } setFinishedState(workerID); @@ -255,7 +258,6 @@ public abstract class ParamServer updateGlobalModel(gradients); if(enoughFinished()) { - // set flags to throwaway backup worker results tagStragglers(); performGlobalGradientUpdate(); } @@ -300,7 +302,7 @@ public abstract class ParamServer private void tagStragglers() { for(int i = 0; i < _finishedStates.length; ++i) { if(!_finishedStates[i]) - _discardWorkerRes[i] = true; + _numUpdatesStraggling[i]++; } } @@ -371,10 +373,10 @@ public abstract class ParamServer case SBP: { // first weight the models based on number of workers ListObject weightParams = weightModels(model, _numWorkers - _numBackupWorkers); - if(_discardWorkerRes[workerID]) { + if(_numUpdatesStraggling[workerID] > 0) { LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID); broadcastModel(workerID); - _discardWorkerRes[workerID] = false; + _numUpdatesStraggling[workerID]--; break; } setFinishedState(workerID); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java index cfc3a200a5..2a6877d89e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java @@ -268,7 +268,10 @@ public class ParamservUtils { String[] parts = DMLProgram.splitFunctionKey(e.getKey()); FunctionProgramBlock fpb = ProgramConverter .createDeepCopyFunctionProgramBlock(e.getValue(), new HashSet<>(), new HashSet<>()); + fpb._namespace = parts[0]; + fpb._functionName = parts[1]; newProg.addFunctionProgramBlock(parts[0], parts[1], fpb, opt); + newProg.addProgramBlock(fpb); } return newProg; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java index 9e96b45a5b..7823d8811c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java @@ -76,6 +76,9 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< _nEpochs = aEpochs; _nbatches = nbatches; _modelAvg = modelAvg; + + // make SparkPSWorker serializable + _tpool = null; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 1fa83b2a8d..ef45a9c2b3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -661,10 +661,10 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private int getNumBackupWorkers() { if(!getParameterMap().containsKey(PS_NUM_BACKUP_WORKERS)) { - if (!getUpdateType().isSBP()) - LOG.warn("Specifying number of backup-workers without SBP mode has no effect"); return DEFAULT_NUM_BACKUP_WORKERS; } + if (!getUpdateType().isSBP()) + LOG.warn("Specifying number of backup-workers without SBP mode has no effect"); return Integer.parseInt(getParam(PS_NUM_BACKUP_WORKERS)); } diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java index c7f0e39dff..630c3c1ebd 100644 --- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java +++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java @@ -29,7 +29,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; @net.jcip.annotations.NotThreadSafe -@Ignore public class ParamservSparkNNTest extends AutomatedTestBase { private static final String TEST_NAME1 = "paramserv-test"; @@ -77,12 +76,16 @@ public class ParamservSparkNNTest extends AutomatedTestBase { } @Test + @Ignore public void testParamservWorkerFailed() { + // FIXME: `aggregation` function can't be found (optimized away?) runDMLTest(TEST_NAME2, true, DMLRuntimeException.class, "Invalid indexing by name in unnamed list: worker_err."); } @Test + @Ignore public void testParamservAggServiceFailed() { + // FIXME: `aggregation` function can't be found (optimized away?) runDMLTest(TEST_NAME3, true, DMLRuntimeException.class, "Invalid indexing by name in unnamed list: agg_service_err."); }
