This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new b6640d9 [SYSTEMDS-2550] Extended parameter server (validation
function, stats)
b6640d9 is described below
commit b6640d93011e9bd1fa986c3e92ca7b2a9d8a276b
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Jan 30 22:08:50 2021 +0100
[SYSTEMDS-2550] Extended parameter server (validation function, stats)
Closes #1154.
---
.../ParameterizedBuiltinFunctionExpression.java | 3 +-
.../java/org/apache/sysds/parser/Statement.java | 2 +-
.../paramserv/FederatedPSControlThread.java | 94 +++++++++-------
.../controlprogram/paramserv/LocalParamServer.java | 16 ++-
.../runtime/controlprogram/paramserv/PSWorker.java | 1 -
.../controlprogram/paramserv/ParamServer.java | 125 +++++++++++++++++++--
.../runtime/controlprogram/parfor/stat/Timing.java | 23 ++--
.../cp/ParamservBuiltinCPInstruction.java | 57 ++++++++--
.../java/org/apache/sysds/utils/Statistics.java | 66 ++++++++---
.../paramserv/FederatedParamservTest.java | 40 ++++---
.../scripts/functions/federated/paramserv/CNN.dml | 48 +++++---
.../federated/paramserv/FederatedParamservTest.dml | 14 ++-
.../functions/federated/paramserv/TwoNN.dml | 18 ++-
13 files changed, 372 insertions(+), 135 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 05bfc48..583c643 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -288,7 +288,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
//check for invalid parameters
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL,
Statement.PS_FEATURES, Statement.PS_LABELS,
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS,
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
- Statement.PS_MODE, Statement.PS_UPDATE_TYPE,
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
+ Statement.PS_VAL_FUN, Statement.PS_MODE,
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
@@ -301,6 +301,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
checkDataValueType(true, fname, Statement.PS_VAL_LABELS,
DataType.MATRIX, ValueType.FP64, conditional);
checkDataValueType(false, fname, Statement.PS_UPDATE_FUN,
DataType.SCALAR, ValueType.STRING, conditional);
checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN,
DataType.SCALAR, ValueType.STRING, conditional);
+ checkDataValueType(true, fname, Statement.PS_VAL_FUN,
DataType.SCALAR, ValueType.STRING, conditional);
checkStringParam(true, fname, Statement.PS_MODE, conditional);
checkStringParam(true, fname, Statement.PS_UPDATE_TYPE,
conditional);
checkStringParam(true, fname, Statement.PS_FREQUENCY,
conditional);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index 9104246..38d16cd 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -66,6 +66,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_LABELS = "labels";
public static final String PS_VAL_FEATURES = "val_features";
public static final String PS_VAL_LABELS = "val_labels";
+ public static final String PS_VAL_FUN = "val";
public static final String PS_UPDATE_FUN = "upd";
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
@@ -117,7 +118,6 @@ public abstract class Statement implements ParseInfo
public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
public static final String PS_FED_GRADIENTS_FNAME =
"1701-NCC-gradients_fname";
public static final String PS_FED_AGGREGATION_FNAME =
"1701-NCC-aggregation_fname";
- public static final String PS_FED_BATCHCOUNTER_VARID =
"1701-NCC-batchcounter_varid";
public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 98bc91a..13e029c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -23,6 +23,7 @@ import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.Statement.PSFrequency;
@@ -45,6 +46,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
@@ -53,9 +55,10 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.util.ProgramConverter;
+import org.apache.sysds.utils.Statistics;
import java.util.ArrayList;
-import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
@@ -69,16 +72,15 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
private FederatedData _featuresData;
private FederatedData _labelsData;
- private final long _localStartBatchNumVarID;
private final long _modelVarID;
// runtime balancing
- private PSRuntimeBalancing _runtimeBalancing;
+ private final PSRuntimeBalancing _runtimeBalancing;
private int _numBatchesPerEpoch;
private int _possibleBatchesPerLocalEpoch;
- private boolean _weighing;
+ private final boolean _weighing;
private double _weighingFactor = 1;
- private boolean _cycleStartAt0 = false;
+ private final boolean _cycleStartAt0 = false;
public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq,
PSRuntimeBalancing runtimeBalancing, boolean weighing, int
epochs, long batchSize,
@@ -89,8 +91,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
_weighing = weighing;
- // generate the IDs for model and batch counter. These get
overwritten on the federated worker each time
- _localStartBatchNumVarID = FederationUtils.getNextFedDataID();
+ // generate the ID for the model
_modelVarID = FederationUtils.getNextFedDataID();
}
@@ -100,6 +101,8 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
* @param weighingFactor Gradients from this worker will be multiplied
by this factor if weighing is enabled
*/
public void setup(double weighingFactor) {
+ incWorkerNumber();
+
// prepare features and labels
_featuresData = (FederatedData)
_features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData)
_labels.getFedMapping().getMap().values().toArray()[0];
@@ -125,9 +128,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
}
- LOG.info("Setup config for worker " + this.getWorkerName());
- LOG.info("Batch size: " + _batchSize + " possible batches: " +
_possibleBatchesPerLocalEpoch
- + " batches to run: " + _numBatchesPerEpoch + "
weighing factor: " + _weighingFactor);
+ if( LOG.isInfoEnabled() ) {
+ LOG.info("Setup config for worker " +
this.getWorkerName());
+ LOG.info("Batch size: " + _batchSize + " possible
batches: " + _possibleBatchesPerLocalEpoch
+ + " batches to run: " +
_numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+ }
// serialize program
// create program blocks for the instruction filtering
@@ -135,12 +140,12 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ArrayList<ProgramBlock> pbs = new ArrayList<>();
BasicProgramBlock gradientProgramBlock = new
BasicProgramBlock(_ec.getProgram());
- gradientProgramBlock.setInstructions(new
ArrayList<>(Arrays.asList(_inst)));
+ gradientProgramBlock.setInstructions(new
ArrayList<>(Collections.singletonList(_inst)));
pbs.add(gradientProgramBlock);
if(_freq == PSFrequency.EPOCH) {
BasicProgramBlock aggProgramBlock = new
BasicProgramBlock(_ec.getProgram());
- aggProgramBlock.setInstructions(new
ArrayList<>(Arrays.asList(_ps.getAggInst())));
+ aggProgramBlock.setInstructions(new
ArrayList<>(Collections.singletonList(_ps.getAggInst())));
pbs.add(aggProgramBlock);
}
@@ -160,7 +165,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
- _localStartBatchNumVarID,
_modelVarID
)
));
@@ -188,12 +192,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
private final String _gradientsFunctionName;
private final String _aggregationFunctionName;
private final ListObject _hyperParams;
- private final long _batchCounterVarID;
private final long _modelVarID;
protected SetupFederatedWorker(long batchSize, long dataSize,
int possibleBatchesPerLocalEpoch,
String programString, String namespace, String
gradientsFunctionName, String aggregationFunctionName,
- ListObject hyperParams, long batchCounterVarID, long
modelVarID)
+ ListObject hyperParams, long modelVarID)
{
super(new long[]{});
_batchSize = batchSize;
@@ -204,7 +207,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_gradientsFunctionName = gradientsFunctionName;
_aggregationFunctionName = aggregationFunctionName;
_hyperParams = hyperParams;
- _batchCounterVarID = batchCounterVarID;
_modelVarID = modelVarID;
}
@@ -221,7 +223,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new
StringObject(_gradientsFunctionName));
ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new
StringObject(_aggregationFunctionName));
ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
- ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new
IntObject(_batchCounterVarID));
ec.setVariable(Statement.PS_FED_MODEL_VARID, new
IntObject(_modelVarID));
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
@@ -272,7 +273,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ec.removeVariable(Statement.PS_FED_NAMESPACE);
ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
- ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
ec.removeVariable(Statement.PS_FED_MODEL_VARID);
ParamservUtils.cleanupListObject(ec,
Statement.PS_HYPER_PARAMS);
@@ -319,9 +319,10 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return _ps.pull(_workerID);
}
- protected void scaleAndPushGradients(ListObject gradients) {
+ protected void weighAndPushGradients(ListObject gradients) {
// scale gradients - must only include MatrixObjects
if(_weighing && _weighingFactor != 1) {
+ Timing tWeighing = DMLScript.STATISTICS ? new
Timing(true) : null;
gradients.getData().parallelStream().forEach((matrix)
-> {
MatrixObject matrixObject = (MatrixObject)
matrix;
MatrixBlock input =
matrixObject.acquireReadAndRelease().scalarOperations(
@@ -329,6 +330,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
matrixObject.acquireModify(input);
matrixObject.release();
});
+ accFedPSGradientWeighingTime(tWeighing);
}
// Push the gradients to ps
@@ -350,12 +352,10 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
int localStartBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, 1, localStartBatchNum);
- scaleAndPushGradients(gradients);
+ weighAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
- LOG.info("[+] " + this.getWorkerName() + "
completed BATCH " + localStartBatchNum);
}
- LOG.info("[+] " + this.getWorkerName() + " ---
completed EPOCH " + epochCounter);
}
}
@@ -376,9 +376,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// Pull the global parameters from ps
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum,
true);
- scaleAndPushGradients(gradients);
-
- LOG.info("[+] " + this.getWorkerName() + " ---
completed EPOCH " + epochCounter);
+ weighAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
@@ -401,15 +399,13 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
protected ListObject computeGradientsForNBatches(ListObject model,
int numBatchesToCompute, int localStartBatchNum, boolean
localUpdate)
{
- // put local start batch num on federated worker
- Future<FederatedResponse> putBatchCounterResponse =
_featuresData.executeFederatedOperation(
- new FederatedRequest(RequestType.PUT_VAR,
_localStartBatchNumVarID, new IntObject(localStartBatchNum)));
+ Timing tFedCommunication = DMLScript.STATISTICS ? new
Timing(true) : null;
// put current model on federated worker
Future<FederatedResponse> putParamsResponse =
_featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.PUT_VAR, _modelVarID,
model));
try {
- if(!putParamsResponse.get().isSuccessful() ||
!putBatchCounterResponse.get().isSuccessful())
+ if(!putParamsResponse.get().isSuccessful())
throw new
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
}
catch(Exception e) {
@@ -420,14 +416,22 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
new federatedComputeGradientsForNBatches(new
long[]{_featuresData.getVarID(), _labelsData.getVarID(),
- _localStartBatchNumVarID, _modelVarID},
numBatchesToCompute,localUpdate)
+ _modelVarID}, numBatchesToCompute, localUpdate,
localStartBatchNum)
));
try {
Object[] responseData = udfResponse.get().getData();
+ if(DMLScript.STATISTICS) {
+ long total = (long) tFedCommunication.stop();
+ long workerComputing = ((DoubleObject)
responseData[1]).getLongValue();
+
Statistics.accFedPSWorkerComputing(workerComputing);
+ Statistics.accFedPSCommunicationTime(total -
workerComputing);
+ }
return (ListObject) responseData[0];
}
catch(Exception e) {
+ if(DMLScript.STATISTICS)
+ tFedCommunication.stop();
throw new DMLRuntimeException("FederatedLocalPSThread:
failed to execute UDF" + e.getMessage());
}
}
@@ -439,20 +443,22 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
private static final long serialVersionUID =
-3075901536748794832L;
int _numBatchesToCompute;
boolean _localUpdate;
+ int _localStartBatchNum;
- protected federatedComputeGradientsForNBatches(long[] inIDs,
int numBatchesToCompute, boolean localUpdate) {
+ protected federatedComputeGradientsForNBatches(long[] inIDs,
int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) {
super(inIDs);
_numBatchesToCompute = numBatchesToCompute;
_localUpdate = localUpdate;
+ _localStartBatchNum = localStartBatchNum;
}
@Override
public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ Timing tGradients = new Timing(true);
// read in data by varid
MatrixObject features = (MatrixObject) data[0];
MatrixObject labels = (MatrixObject) data[1];
- int localStartBatchNum = (int) ((IntObject)
data[2]).getLongValue();
- ListObject model = (ListObject) data[3];
+ ListObject model = (ListObject) data[2];
// get data from execution context
long batchSize = ((IntObject)
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
@@ -493,7 +499,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
ListObject accGradients = null;
- int currentLocalBatchNumber = localStartBatchNum;
+ int currentLocalBatchNumber = _localStartBatchNum;
// prepare execution context
ec.setVariable(Statement.PS_MODEL, model);
for (int batchCounter = 0; batchCounter <
_numBatchesToCompute; batchCounter++) {
@@ -534,14 +540,14 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ParamservUtils.cleanupListObject(ec,
gradientsOutput.getName());
ParamservUtils.cleanupData(ec,
Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec,
Statement.PS_LABELS);
-
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
}
// model clean up
ParamservUtils.cleanupListObject(ec,
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
ParamservUtils.cleanupListObject(ec,
Statement.PS_MODEL);
-
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
+ // stop timing
+ DoubleObject gradientsTime = new
DoubleObject(tGradients.stop());
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new
Object[]{accGradients, gradientsTime});
}
@Override
@@ -551,6 +557,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
// Statistics methods
+ protected void accFedPSGradientWeighingTime(Timing time) {
+ if (DMLScript.STATISTICS && time != null)
+ Statistics.accFedPSGradientWeighingTime((long)
time.stop());
+ }
+
@Override
public String getWorkerName() {
return String.format("Federated worker_%d", _workerID);
@@ -558,21 +569,22 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
@Override
protected void incWorkerNumber() {
-
+ if (DMLScript.STATISTICS)
+ Statistics.incWorkerNumber();
}
@Override
protected void accLocalModelUpdateTime(Timing time) {
-
+ throw new NotImplementedException();
}
@Override
protected void accBatchIndexingTime(Timing time) {
-
+ throw new NotImplementedException();
}
@Override
protected void accGradientComputeTime(Timing time) {
-
+ throw new NotImplementedException();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 29193b0..7bd96f2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.ListObject;
@@ -30,12 +31,19 @@ public class LocalParamServer extends ParamServer {
super();
}
- public static LocalParamServer create(ListObject model, String aggFunc,
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
- return new LocalParamServer(model, aggFunc, updateType, ec,
workerNum);
+ public static LocalParamServer create(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
+ Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
+ MatrixObject valFeatures, MatrixObject valLabels)
+ {
+ return new LocalParamServer(model, aggFunc, updateType, freq,
ec,
+ workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels);
}
- private LocalParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
- super(model, aggFunc, updateType, ec, workerNum);
+ private LocalParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
+ Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
+ MatrixObject valFeatures, MatrixObject valLabels)
+ {
+ super(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
numBatchesPerEpoch, valFeatures, valLabels);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 701e45c..c0389f3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -35,7 +35,6 @@ import
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-// TODO use the validate features and labels to calculate the model precision
when training
public abstract class PSWorker implements Serializable
{
private static final long serialVersionUID = -3510485051178200118L;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index e420ed8..6315ef9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -38,9 +38,11 @@ import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.utils.Statistics;
@@ -57,15 +59,30 @@ public abstract class ParamServer
//aggregation service
protected ExecutionContext _ec;
private Statement.PSUpdateType _updateType;
+ private Statement.PSFrequency _freq;
private FunctionCallCPInstruction _inst;
private String _outputName;
private boolean[] _finishedStates; // Workers' finished states
private ListObject _accGradients = null;
+ private boolean _validationPossible;
+ private FunctionCallCPInstruction _valInst;
+ private String _lossOutput;
+ private String _accuracyOutput;
+
+ private int _syncCounter = 0;
+ private int _epochCounter = 0 ;
+ private int _numBatchesPerEpoch;
+
+ private int _numWorkers;
+
protected ParamServer() {}
- protected ParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
+ protected ParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
+ Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc,
+ int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ {
// init worker queues and global model
_modelMap = new HashMap<>(workerNum);
IntStream.range(0, workerNum).forEach(i -> {
@@ -77,8 +94,15 @@ public abstract class ParamServer
// init aggregation service
_ec = ec;
_updateType = updateType;
+ _freq = freq;
_finishedStates = new boolean[workerNum];
setupAggFunc(_ec, aggFunc);
+
+ if(valFunc != null && numBatchesPerEpoch > 0) {
+ setupValFunc(_ec, valFunc, valFeatures, valLabels);
+ }
+ _numBatchesPerEpoch = numBatchesPerEpoch;
+ _numWorkers = workerNum;
// broadcast initial model
broadcastModel(true);
@@ -110,6 +134,39 @@ public abstract class ParamServer
func.getInputParamNames(), outputNames, "aggregate
function");
}
+ protected void setupValFunc(ExecutionContext ec, String valFunc,
MatrixObject valFeatures, MatrixObject valLabels) {
+ String[] cfn = DMLProgram.splitFunctionKey(valFunc);
+ String ns = cfn[0];
+ String fname = cfn[1];
+ FunctionProgramBlock func =
ec.getProgram().getFunctionProgramBlock(ns, fname, false);
+ ArrayList<DataIdentifier> inputs = func.getInputParams();
+ ArrayList<DataIdentifier> outputs = func.getOutputParams();
+
+ // Check the output of the validate function
+ if (outputs.size() != 2) {
+ throw new DMLRuntimeException(String.format("The output
of the '%s' function should provide the loss and the accuracy in that order",
valFunc));
+ }
+ if (outputs.get(0).getDataType() != DataType.SCALAR ||
outputs.get(1).getDataType() != DataType.SCALAR) {
+ throw new DMLRuntimeException(String.format("The
outputs of the '%s' function should both be scalars", valFunc));
+ }
+ _lossOutput = outputs.get(0).getName();
+ _accuracyOutput = outputs.get(1).getName();
+
+ CPOperand[] boundInputs = inputs.stream()
+ .map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ _valInst = new FunctionCallCPInstruction(ns, fname, false,
boundInputs,
+ func.getInputParamNames(), outputNames, "validate
function");
+
+ // write validation data to execution context. hyper params are
already in ec
+ _ec.setVariable(Statement.PS_VAL_FEATURES, valFeatures);
+ _ec.setVariable(Statement.PS_VAL_LABELS, valLabels);
+
+ _validationPossible = true;
+ }
+
public abstract void push(int workerID, ListObject value);
public abstract ListObject pull(int workerID);
@@ -119,7 +176,7 @@ public abstract class ParamServer
// so we could return directly the result model
return _model;
}
-
+
protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
try {
if (LOG.isDebugEnabled()) {
@@ -143,6 +200,22 @@ public abstract class ParamServer
updateGlobalModel(_accGradients);
_accGradients = null;
}
+
+ // This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
+ // In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
+ // BSP epoch case every time
+ if ((_freq ==
Statement.PSFrequency.EPOCH ||
+ (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+
+ if(LOG.isInfoEnabled())
+ LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
+
+ if(_validationPossible)
+ validate();
+
+ _epochCounter++;
+ _syncCounter = 0;
+ }
// Broadcast the updated model
resetFinishedStates();
@@ -154,6 +227,21 @@ public abstract class ParamServer
}
case ASP: {
updateGlobalModel(gradients);
+ // This if works similarly to the one
for BSP, but divides the sync couter through the number of workers,
+ // creating "Pseudo Epochs"
+ if ((_freq ==
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
+ (_freq ==
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float)
_numBatchesPerEpoch == 0)) {
+
+ if(LOG.isInfoEnabled())
+ LOG.info("[+]
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
+
+ if(_validationPossible)
+ validate();
+
+ _epochCounter++;
+ _syncCounter = 0;
+ }
+
broadcastModel(workerID);
break;
}
@@ -162,14 +250,14 @@ public abstract class ParamServer
}
}
catch (Exception e) {
- throw new DMLRuntimeException("Aggregation service
failed: ", e);
+ throw new DMLRuntimeException("Aggregation or
validation service failed: ", e);
}
}
private void updateGlobalModel(ListObject gradients) {
Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
_model = updateLocalModel(_ec, gradients, _model);
- if (DMLScript.STATISTICS)
+ if (DMLScript.STATISTICS && tAgg != null)
Statistics.accPSAggregationTime((long) tAgg.stop());
}
@@ -226,14 +314,37 @@ public abstract class ParamServer
private void broadcastModel(int workerID) throws InterruptedException {
Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
-
//broadcast copy of model to specific worker, cleaned up by
worker
_modelMap.get(workerID).put(ParamservUtils.copyList(_model,
false));
-
- if (DMLScript.STATISTICS)
+ if (DMLScript.STATISTICS && tBroad != null)
Statistics.accPSModelBroadcastTime((long)
tBroad.stop());
}
+ /**
+ * Checks the current model against the validation set
+ */
+ private synchronized void validate() {
+ Timing tValidate = DMLScript.STATISTICS ? new Timing(true) :
null;
+ _ec.setVariable(Statement.PS_MODEL, _model);
+
+ // Invoke the validation function
+ _valInst.processInstruction(_ec);
+
+ // Get the validation results
+ double loss = ((DoubleObject)
_ec.getVariable(_lossOutput)).getDoubleValue();
+ double accuracy = ((DoubleObject)
_ec.getVariable(_accuracyOutput)).getDoubleValue();
+
+ // cleanup
+ ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
+
+ // Log validation results
+ if(LOG.isInfoEnabled())
+ LOG.info("[+] PARAMSERV: validation-loss: " + loss + "
validation-accuracy: " + accuracy);
+
+ if(tValidate != null)
+ Statistics.accPSValidationTime((long) tValidate.stop());
+ }
+
public FunctionCallCPInstruction getAggInst() {
return _inst;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
index aec117a..79830ce 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
@@ -26,27 +26,22 @@ package org.apache.sysds.runtime.controlprogram.parfor.stat;
*/
public class Timing
{
-
-
private long _start = -1;
public Timing() {
//default constructor
}
- public Timing(boolean start)
- {
+ public Timing(boolean start) {
//init and start the timer
- if( start ){
+ if( start )
start();
- }
}
/**
* Starts the time measurement.
*/
- public void start()
- {
+ public void start() {
_start = System.nanoTime();
}
@@ -56,16 +51,15 @@ public class Timing
*
* @return duration between start and stop
*/
- public double stop()
- {
+ public double stop() {
if( _start == -1 )
throw new RuntimeException("Stop time measurement
without prior start is invalid.");
- long end = System.nanoTime();
+ long end = System.nanoTime();
double duration = ((double)(end-_start))/1000000;
//carry end time over
- _start = end;
+ _start = end;
return duration;
}
@@ -73,11 +67,8 @@ public class Timing
* Measures and returns the time since the last start() or stop()
invocation,
* restarts the measurement, and prints the last measurement to STDOUT.
*/
- public void stopAndPrint()
- {
+ public void stopAndPrint() {
double tmp = stop();
-
System.out.println("PARFOR: time = "+tmp+"ms");
}
-
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a66e039..785915d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -46,6 +46,9 @@ import static
org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
import static org.apache.sysds.parser.Statement.PS_SEED;
+import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
+import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
+import static org.apache.sysds.parser.Statement.PS_VAL_FUN;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
@@ -123,12 +126,15 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private void runFederated(ExecutionContext ec) {
+ Timing tExecutionTime = DMLScript.STATISTICS ? new Timing(true)
: null;
+ Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
LOG.info("PARAMETER SERVER");
LOG.info("[+] Running in federated mode");
// get inputs
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
+ String valFunc = getValFunction();
PSUpdateType updateType = getUpdateType();
PSFrequency freq = getFrequency();
FederatedPSScheme federatedPSScheme = getFederatedScheme();
@@ -144,17 +150,24 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
LOG.info("[+] Weighing: " + weighing);
LOG.info("[+] Seed: " + seed);
}
-
+ if (tSetup != null)
+ Statistics.accPSSetupTime((long) tSetup.stop());
+
// partition federated data
+ Timing tDataPartitioning = DMLScript.STATISTICS ? new
Timing(true) : null;
DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(federatedPSScheme, seed)
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
int workerNum = result._workerNum;
+ if (DMLScript.STATISTICS)
+ Statistics.accFedPSDataPartitioningTime((long)
tDataPartitioning.stop());
+
+ if (DMLScript.STATISTICS)
+ tSetup.start();
// setup threading
BasicThreadFactory factory = new BasicThreadFactory.Builder()
.namingPattern("workers-pool-thread-%d").build();
ExecutorService es = Executors.newFixedThreadPool(workerNum,
factory);
-
// Get the compiled execution context
LocalVariableMap newVarsMap = createVarsMap(ec);
// Level of par is -1 so each federated worker can scale to its
cpu cores
@@ -165,24 +178,25 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ExecutionContext aggServiceEC =
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
// Create the parameter server
ListObject model = ec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, workerNum, model, aggServiceEC);
+ ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, freq, workerNum, model, aggServiceEC, valFunc,
+ getNumBatchesPerEpoch(runtimeBalancing,
result._balanceMetrics), ec.getMatrixObject(getParam(PS_VAL_FEATURES)),
ec.getMatrixObject(getParam(PS_VAL_LABELS)));
// Create the local workers
int finalNumBatchesPerEpoch =
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0,
workerNum)
.mapToObj(i -> new FederatedPSControlThread(i, updFunc,
freq, runtimeBalancing, weighing,
getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
.collect(Collectors.toList());
-
if(workerNum != threads.size()) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning
does not match threads!");
}
-
// Set features and lables for the control threads and write
the program and instructions and hyperparams to the federated workers
for (int i = 0; i < threads.size(); i++) {
threads.get(i).setFeatures(result._pFeatures.get(i));
threads.get(i).setLabels(result._pLabels.get(i));
threads.get(i).setup(result._weighingFactors.get(i));
}
+ if (DMLScript.STATISTICS)
+ Statistics.accPSSetupTime((long) tSetup.stop());
try {
// Launch the worker threads and wait for completion
@@ -190,6 +204,8 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ret.get(); //error handling
// Fetch the final model from ps
ec.setVariable(output.getName(), ps.getResult());
+ if (DMLScript.STATISTICS)
+ Statistics.accPSExecutionTime((long)
tExecutionTime.stop());
} catch (InterruptedException | ExecutionException e) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
} finally {
@@ -215,7 +231,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Create the parameter server
ListObject model = sec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(mode, aggFunc, getUpdateType(),
workerNum, model, aggServiceEC);
+ ParamServer ps = createPS(mode, aggFunc, getUpdateType(),
getFrequency(), workerNum, model, aggServiceEC);
// Get driver host
String host =
sec.getSparkContext().getConf().get("spark.driver.host");
@@ -299,7 +315,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Create the parameter server
ListObject model = ec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(mode, aggFunc, updateType, workerNum,
model, aggServiceEC);
+ ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC);
// Create the local workers
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -436,14 +452,24 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
*
* @return parameter server
*/
- private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
+ private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
+ PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec)
+ {
+ return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null);
+ }
+
+ // When this creation is used the parameter server is able to validate
after each epoch
+ private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
+ PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec, String valFunc,
+ int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ {
switch (mode) {
case FEDERATED:
case LOCAL:
case REMOTE_SPARK:
- return LocalParamServer.create(model, aggFunc,
updateType, ec, workerNum);
+ return LocalParamServer.create(model, aggFunc,
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels);
default:
- throw new DMLRuntimeException("Unsupported
parameter server: "+mode.name());
+ throw new DMLRuntimeException("Unsupported
parameter server: " + mode.name());
}
}
@@ -518,7 +544,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
* @return numBatchesPerEpoch
*/
private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing,
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
- int numBatchesPerEpoch = 0;
+ int numBatchesPerEpoch;
if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
@@ -526,6 +552,8 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+ } else {
+ numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
}
return numBatchesPerEpoch;
}
@@ -534,6 +562,13 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
return getParameterMap().containsKey(PS_FED_WEIGHING) &&
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
}
+ private String getValFunction() {
+ if (getParameterMap().containsKey(PS_VAL_FUN)) {
+ return getParam(PS_VAL_FUN);
+ }
+ return null;
+ }
+
private int getSeed() {
return (getParameterMap().containsKey(PS_SEED)) ?
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index b9059d9..b40e905 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -117,6 +117,7 @@ public class Statistics
private static final LongAdder sparkBroadcastCount = new LongAdder();
// Paramserv function stats (time is in milli sec)
+ private static final LongAdder psExecutionTime = new LongAdder();
private static final LongAdder psNumWorkers = new LongAdder();
private static final LongAdder psSetupTime = new LongAdder();
private static final LongAdder psGradientComputeTime = new LongAdder();
@@ -125,6 +126,12 @@ public class Statistics
private static final LongAdder psModelBroadcastTime = new LongAdder();
private static final LongAdder psBatchIndexTime = new LongAdder();
private static final LongAdder psRpcRequestTime = new LongAdder();
+ private static final LongAdder psValidationTime = new LongAdder();
+ // Federated parameter server specifics (time is in milli sec)
+ private static final LongAdder fedPSDataPartitioningTime = new
LongAdder();
+ private static final LongAdder fedPSWorkerComputingTime = new
LongAdder();
+ private static final LongAdder fedPSGradientWeighingTime = new
LongAdder();
+ private static final LongAdder fedPSCommunicationTime = new LongAdder();
//PARFOR optimization stats (low frequency updates)
private static long parforOptTime = 0; //in milli sec
@@ -562,6 +569,10 @@ public class Statistics
psNumWorkers.add(n);
}
+ public static void accPSExecutionTime(long n) {
+ psExecutionTime.add(n);
+ }
+
public static void accPSSetupTime(long t) {
psSetupTime.add(t);
}
@@ -590,6 +601,24 @@ public class Statistics
psRpcRequestTime.add(t);
}
+ public static void accPSValidationTime(long t) {
+ psValidationTime.add(t);
+ }
+
+ public static void accFedPSDataPartitioningTime(long t) {
+ fedPSDataPartitioningTime.add(t);
+ }
+
+ public static void accFedPSWorkerComputing(long t) {
+ fedPSWorkerComputingTime.add(t);
+ }
+
+ public static void accFedPSGradientWeighingTime(long t) {
+ fedPSGradientWeighingTime.add(t);
+ }
+
+ public static void accFedPSCommunicationTime(long t) {
fedPSCommunicationTime.add(t);}
+
public static String getCPHeavyHitterCode( Instruction inst )
{
String opcode = null;
@@ -758,13 +787,13 @@ public class Statistics
if(wrapIter == 0) {
// Display instruction count
sb.append(String.format(
- " %" + maxNumLen + "d
%-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxCountLen + "d",
- (i + 1), instStr,
timeSString, count));
+ " %" + maxNumLen + "d %-" +
maxInstLen + "s %" + maxTimeSLen + "s %" + maxCountLen + "d",
+ (i + 1), instStr, timeSString,
count));
}
else {
sb.append(String.format(
- " %" + maxNumLen + "s
%-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxCountLen + "s",
- "", instStr, "", ""));
+ " %" + maxNumLen + "s %-" +
maxInstLen + "s %" + maxTimeSLen + "s %" + maxCountLen + "s",
+ "", instStr, "", ""));
}
sb.append("\n");
}
@@ -795,8 +824,8 @@ public class Statistics
maxNameLength = Math.max(maxNameLength, "Object".length());
StringBuilder res = new StringBuilder();
- res.append(String.format(" %-" + numPadLen + "s" + " %-" +
maxNameLength + "s" + " %s\n",
- "#", "Object", "Memory"));
+ res.append(String.format(" %-" + numPadLen + "s" + " %-"
+ + maxNameLength + "s" + " %s\n", "#", "Object",
"Memory"));
for (int ix = 1; ix <= numHittersToDisplay; ix++) {
String objName = entries[ix-1].getKey();
@@ -831,8 +860,7 @@ public class Statistics
public static long getJITCompileTime(){
long ret = -1; //unsupported
CompilationMXBean cmx =
ManagementFactory.getCompilationMXBean();
- if( cmx.isCompilationTimeMonitoringSupported() )
- {
+ if( cmx.isCompilationTimeMonitoringSupported() ) {
ret = cmx.getTotalCompilationTime();
ret += jitCompileTime; //add from remote processes
}
@@ -1011,14 +1039,26 @@ public class Statistics
sparkCollect.longValue()*1e-9));
}
if (psNumWorkers.longValue() > 0) {
+ sb.append(String.format("Paramserv total
execution time:\t%.3f secs.\n", psExecutionTime.doubleValue() / 1000));
sb.append(String.format("Paramserv total num
workers:\t%d.\n", psNumWorkers.longValue()));
sb.append(String.format("Paramserv setup
time:\t\t%.3f secs.\n", psSetupTime.doubleValue() / 1000));
- sb.append(String.format("Paramserv grad compute
time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
- sb.append(String.format("Paramserv model update
time:\t%.3f/%.3f secs.\n",
+
+ if(fedPSDataPartitioningTime.doubleValue() > 0)
{ //if data partitioning happens this is the federated case
+ sb.append(String.format("PS fed data
partitioning time:\t%.3f secs.\n", fedPSDataPartitioningTime.doubleValue() /
1000));
+ sb.append(String.format("PS fed comm
time (cum):\t\t%.3f secs.\n", fedPSCommunicationTime.doubleValue() / 1000));
+ sb.append(String.format("PS fed worker
comp time (cum):\t%.3f secs.\n", fedPSWorkerComputingTime.doubleValue() /
1000));
+ sb.append(String.format("PS fed grad
weigh time (cum):\t%.3f secs.\n", fedPSGradientWeighingTime.doubleValue() /
1000));
+ sb.append(String.format("PS fed global
model agg time:\t%.3f secs.\n", psAggregationTime.doubleValue() / 1000));
+ }
+ else {
+ sb.append(String.format("Paramserv grad
compute time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
+ sb.append(String.format("Paramserv
model update time:\t%.3f/%.3f secs.\n",
psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() /
1000));
- sb.append(String.format("Paramserv model
broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
- sb.append(String.format("Paramserv batch slice
time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
- sb.append(String.format("Paramserv RPC request
time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
+ sb.append(String.format("Paramserv
model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() /
1000));
+ sb.append(String.format("Paramserv
batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+ sb.append(String.format("Paramserv RPC
request time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
+ }
+ sb.append(String.format("Paramserv valdiation
time:\t%.3f secs.\n", psValidationTime.doubleValue() / 1000));
}
if( parforOptCount>0 ){
sb.append("ParFor loops optimized:\t\t" +
getParforOptCount() + ".\n");
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index a00e8dc..5d7c7e2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -64,29 +64,32 @@ public class FederatedParamservTest extends
AutomatedTestBase {
return Arrays.asList(new Object[][] {
// Network type, number of federated workers, data set
size, batch size, epochs, learning rate, update type, update frequency
// basic functionality
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ //{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED",
200},
+
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
{"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "NONE" ,
"true", "IMBALANCED", 200},
{"CNN", 2, 4, 1, 4, 0.01, "ASP",
"BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "true", "IMBALANCED", 200},
{"TwoNN", 2, 4, 1, 4, 0.01, "ASP",
"EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "true", "IMBALANCED",
200},
{"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200},
- /* // runtime balancing
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
-
- // data partitioning
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SHUFFLE", "CYCLE_AVG" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "REPLICATE_TO_MAX", "NONE" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "BALANCE_TO_AVG", "NONE" , "true",
"IMBALANCED", 200},
-
- // balanced tests
- {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH",
"KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200} */
-
+ /*
+ // runtime balancing
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED",
200},
+
+ // data partitioning
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "true",
"IMBALANCED", 200},
+
+ // balanced tests
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200}
+ */
});
}
@@ -125,6 +128,7 @@ public class FederatedParamservTest extends
AutomatedTestBase {
}
private void federatedParamserv(ExecMode mode) {
+ // Warning Statistics accumulate in unit test
// config
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 0f9ae63..79628ef 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -65,13 +65,10 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
* - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
* - b4: 4th layer biases vector, of shape (1, K)
*/
-train = function(matrix[double] X, matrix[double] y,
- matrix[double] X_val, matrix[double] y_val,
- int epochs, int batch_size, double eta,
- int C, int Hin, int Win,
- int seed = -1)
- return (list[unknown] model) {
-
+train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
+ matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
+ int Win, int seed = -1) return (list[unknown] model)
+{
N = nrow(X)
K = ncol(y)
@@ -162,12 +159,11 @@ train = function(matrix[double] X, matrix[double] y,
* - b4: 4th layer biases vector, of shape (1, K)
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
- matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighing,
- double eta, int C, int Hin, int Win,
- int seed = -1)
- return (list[unknown] model) {
-
+ matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
+ string utype, string freq, int batch_size, string scheme, string
runtime_balancing,
+ string weighing, double eta, int C, int Hin, int Win, int seed = -1)
+ return (list[unknown] model)
+{
N = nrow(X)
K = ncol(y)
@@ -210,6 +206,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
+ val="./src/test/scripts/functions/federated/paramserv/CNN.dml::validate",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
}
@@ -267,7 +264,7 @@ predict = function(matrix[double] X, int C, int Hin, int
Win, int batch_size, li
# Compute predictions over mini-batches
probs = matrix(0, rows=N, cols=K)
iters = ceil(N / batch_size)
- parfor(i in 1:iters, check=0) {
+ for(i in 1:iters, check=0) {
# Get next batch
beg = ((i-1) * batch_size) %% N + 1
end = min(N, beg + batch_size - 1)
@@ -320,6 +317,25 @@ eval = function(matrix[double] probs, matrix[double] y)
accuracy = mean(correct_pred)
}
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels,
+ list[unknown] model, list[unknown] hyperparams)
+ return (double loss, double accuracy)
+{
+ [loss, accuracy] = eval(predict(val_features,
as.integer(as.scalar(hyperparams["C"])),
+ as.integer(as.scalar(hyperparams["Hin"])),
as.integer(as.scalar(hyperparams["Win"])),
+ 32, model), val_labels)
+}
+
# Should always use 'features' (batch features), 'labels' (batch labels),
# 'hyperparams', 'model' as the arguments
# and return the gradients of type list
@@ -371,7 +387,7 @@ gradients = function(list[unknown] model,
# Compute loss & accuracy for training data
loss = cross_entropy_loss::forward(probs, labels)
accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
- print("[+] Completed forward pass on batch: train loss: " + loss + ", train
accuracy: " + accuracy)
+ # print("[+] Completed forward pass on batch: train loss: " + loss + ",
train accuracy: " + accuracy)
# Compute data backward pass
## loss
@@ -452,4 +468,4 @@ aggregation = function(list[unknown] model,
[b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
-}
\ No newline at end of file
+}
diff --git
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 5176cca..c7ad305 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,12 +26,16 @@
source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
features = read($features)
labels = read($labels)
-print($weighing)
-
if($network_type == "TwoNN") {
- model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $weighing, $eta, $seed)
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
+ print("Test results:")
+ [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784),
matrix(0, rows=100, cols=10), model, list())
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test +
"\n")
}
else {
- model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin,
$win, $seed)
+ print("Test results:")
+ hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+ [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784),
matrix(0, rows=100, cols=10), model, hyperparams)
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test +
"\n")
}
-print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index a6dc6f2..e7fc6d9 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -154,6 +154,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
+ val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
}
@@ -214,6 +215,21 @@ eval = function(matrix[double] probs, matrix[double] y)
accuracy = mean(correct_pred)
}
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels,
list[unknown] model, list[unknown] hyperparams)
+ return (double loss, double accuracy) {
+ [loss, accuracy] = eval(predict(val_features, model), val_labels)
+}
+
# Should always use 'features' (batch features), 'labels' (batch labels),
# 'hyperparams', 'model' as the arguments
# and return the gradients of type list
@@ -242,7 +258,7 @@ gradients = function(list[unknown] model,
# Compute loss & accuracy for training data
loss = cross_entropy_loss::forward(probs, labels)
accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
- print("[+] Completed forward pass on batch: train loss: " + loss + ", train
accuracy: " + accuracy)
+ # print("[+] Completed forward pass on batch: train loss: " + loss + ",
train accuracy: " + accuracy)
# Compute data backward pass
dprobs = cross_entropy_loss::backward(probs, labels)