This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 2931f6e [SYSTEMDS-2550] Improved parameter server epoch timing/logging
2931f6e is described below
commit 2931f6ec82798e4e71281ac8ca9cc47a55381266
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Feb 20 18:13:41 2021 +0100
[SYSTEMDS-2550] Improved parameter server epoch timing/logging
Closes #1176.
---
.../ParameterizedBuiltinFunctionExpression.java | 4 +-
.../java/org/apache/sysds/parser/Statement.java | 5 +--
.../paramserv/FederatedPSControlThread.java | 44 +++++++++----------
.../controlprogram/paramserv/ParamServer.java | 37 +++++++++++++---
.../paramserv/dp/BalanceToAvgFederatedScheme.java | 4 +-
.../paramserv/dp/DataPartitionFederatedScheme.java | 14 +++----
.../dp/KeepDataOnWorkerFederatedScheme.java | 4 +-
.../dp/ReplicateToMaxFederatedScheme.java | 4 +-
.../paramserv/dp/ShuffleFederatedScheme.java | 4 +-
.../dp/SubsampleToMinFederatedScheme.java | 4 +-
.../cp/ParamservBuiltinCPInstruction.java | 49 ++++++++++++++--------
.../java/org/apache/sysds/utils/Statistics.java | 22 ++++++++--
.../paramserv/FederatedParamservTest.java | 40 +++++++++---------
.../scripts/functions/federated/paramserv/CNN.dml | 4 +-
.../federated/paramserv/FederatedParamservTest.dml | 4 +-
.../functions/federated/paramserv/TwoNN.dml | 4 +-
16 files changed, 150 insertions(+), 97 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 583c643..4d111b0 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -290,7 +290,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS,
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_VAL_FUN, Statement.PS_MODE,
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
- Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
+ Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
@@ -310,7 +310,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
checkDataValueType(true, fname, Statement.PS_PARALLELISM,
DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
checkStringParam(true, fname,
Statement.PS_FED_RUNTIME_BALANCING, conditional);
- checkStringParam(true, fname, Statement.PS_FED_WEIGHING,
conditional);
+ checkStringParam(true, fname, Statement.PS_FED_WEIGHTING,
conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS,
DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING,
conditional);
checkDataValueType(true, fname, Statement.PS_SEED,
DataType.SCALAR, ValueType.INT64, conditional);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index eb32865..d15fd44 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -89,10 +89,10 @@ public abstract class Statement implements ParseInfo
public enum PSFrequency {
BATCH, EPOCH
}
- public static final String PS_FED_WEIGHING = "weighing";
+ public static final String PS_FED_WEIGHTING = "weighting";
public static final String PS_FED_RUNTIME_BALANCING =
"runtime_balancing";
public enum PSRuntimeBalancing {
- NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
+ NONE, BASELINE, CYCLE_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
}
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
@@ -101,7 +101,6 @@ public abstract class Statement implements ParseInfo
public enum PSScheme {
DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM,
OVERLAP_RESHUFFLE
}
- public static final String PS_FED_SCHEME = "fed_scheme";
public enum FederatedPSScheme {
KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX,
SUBSAMPLE_TO_MIN, BALANCE_TO_AVG
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 13e029c..10fee56 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -78,19 +78,19 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
private final PSRuntimeBalancing _runtimeBalancing;
private int _numBatchesPerEpoch;
private int _possibleBatchesPerLocalEpoch;
- private final boolean _weighing;
- private double _weighingFactor = 1;
- private final boolean _cycleStartAt0 = false;
+ private final boolean _weighting;
+ private double _weightingFactor = 1;
+ private boolean _cycleStartAt0 = false;
public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq,
- PSRuntimeBalancing runtimeBalancing, boolean weighing, int
epochs, long batchSize,
+ PSRuntimeBalancing runtimeBalancing, boolean weighting, int
epochs, long batchSize,
int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer
ps)
{
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
- _weighing = weighing;
+ _weighting = weighting;
// generate the ID for the model
_modelVarID = FederationUtils.getNextFedDataID();
}
@@ -98,40 +98,42 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
/**
* Sets up the federated worker and control thread
*
- * @param weighingFactor Gradients from this worker will be multiplied
by this factor if weighing is enabled
+ * @param weightingFactor Gradients from this worker will be multiplied
by this factor if weighting is enabled
*/
- public void setup(double weighingFactor) {
+ public void setup(double weightingFactor) {
incWorkerNumber();
// prepare features and labels
_featuresData = (FederatedData)
_features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData)
_labels.getFedMapping().getMap().values().toArray()[0];
- // weighing factor is always set, but only used when weighing
is specified
- _weighingFactor = weighingFactor;
+ // weighting factor is always set, but only used when weighting
is specified
+ _weightingFactor = weightingFactor;
// different runtime balancing calculations
long dataSize = _features.getNumRows();
// calculate scaled batch size if balancing via batch size.
// In some cases there will be some cycling
- if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+ if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH)
_batchSize = (int) Math.ceil((double) dataSize /
_numBatchesPerEpoch);
- }
// Calculate possible batches with batch size
_possibleBatchesPerLocalEpoch = (int) Math.ceil((double)
dataSize / _batchSize);
// If no runtime balancing is specified, just run possible
number of batches
// WARNING: Will get stuck on miss match
- if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+ if(_runtimeBalancing == PSRuntimeBalancing.NONE)
_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
- }
+
+ // If running in baseline mode set cycle to false
+ if(_runtimeBalancing == PSRuntimeBalancing.BASELINE)
+ _cycleStartAt0 = true;
if( LOG.isInfoEnabled() ) {
LOG.info("Setup config for worker " +
this.getWorkerName());
LOG.info("Batch size: " + _batchSize + " possible
batches: " + _possibleBatchesPerLocalEpoch
- + " batches to run: " +
_numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+ + " batches to run: " +
_numBatchesPerEpoch + " weighting factor: " + _weightingFactor);
}
// serialize program
@@ -321,16 +323,16 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
protected void weighAndPushGradients(ListObject gradients) {
// scale gradients - must only include MatrixObjects
- if(_weighing && _weighingFactor != 1) {
- Timing tWeighing = DMLScript.STATISTICS ? new
Timing(true) : null;
+ if(_weighting && _weightingFactor != 1) {
+ Timing tWeighting = DMLScript.STATISTICS ? new
Timing(true) : null;
gradients.getData().parallelStream().forEach((matrix)
-> {
MatrixObject matrixObject = (MatrixObject)
matrix;
MatrixBlock input =
matrixObject.acquireReadAndRelease().scalarOperations(
- new
RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new
MatrixBlock());
+ new
RightScalarOperator(Multiply.getMultiplyFnObject(), _weightingFactor), new
MatrixBlock());
matrixObject.acquireModify(input);
matrixObject.release();
});
- accFedPSGradientWeighingTime(tWeighing);
+ accFedPSGradientWeightingTime(tWeighting);
}
// Push the gradients to ps
@@ -342,7 +344,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
/**
- * Computes all epochs and updates after each batch
+ * Computes all epochs and updates after each batch
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
@@ -557,9 +559,9 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
// Statistics methods
- protected void accFedPSGradientWeighingTime(Timing time) {
+ protected void accFedPSGradientWeightingTime(Timing time) {
if (DMLScript.STATISTICS && time != null)
- Statistics.accFedPSGradientWeighingTime((long)
time.stop());
+ Statistics.accFedPSGradientWeightingTime((long)
time.stop());
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 6315ef9..4fe072c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -98,7 +98,7 @@ public abstract class ParamServer
_finishedStates = new boolean[workerNum];
setupAggFunc(_ec, aggFunc);
- if(valFunc != null && numBatchesPerEpoch > 0) {
+ if(valFunc != null && numBatchesPerEpoch > 0 && valFeatures !=
null && valLabels != null) {
setupValFunc(_ec, valFunc, valFeatures, valLabels);
}
_numBatchesPerEpoch = numBatchesPerEpoch;
@@ -204,12 +204,15 @@ public abstract class ParamServer
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
- if ((_freq ==
Statement.PSFrequency.EPOCH ||
+ if (_numBatchesPerEpoch != -1 &&
+ (_freq ==
Statement.PSFrequency.EPOCH ||
(_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
+ time_epoch();
+
if(_validationPossible)
validate();
@@ -229,12 +232,15 @@ public abstract class ParamServer
updateGlobalModel(gradients);
// This if works similarly to the one
for BSP, but divides the sync couter through the number of workers,
// creating "Pseudo Epochs"
- if ((_freq ==
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
- (_freq ==
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float)
_numBatchesPerEpoch == 0)) {
+ if (_numBatchesPerEpoch != -1 &&
+ ((_freq ==
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
+ (_freq ==
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float)
_numBatchesPerEpoch == 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
+ time_epoch();
+
if(_validationPossible)
validate();
@@ -321,9 +327,28 @@ public abstract class ParamServer
}
/**
+ * Prints the time the epoch took to complete
+ */
+ private void time_epoch() {
+ if (DMLScript.STATISTICS) {
+ //TODO double check correctness with multiple,
potentially concurrent paramserv invocation
+ Statistics.accPSExecutionTime((long)
Statistics.getPSExecutionTimer().stop());
+ double current_total_execution_time =
Statistics.getPSExecutionTime();
+ double current_total_validation_time =
Statistics.getPSValidationTime();
+ double time_to_epoch = current_total_execution_time -
current_total_validation_time;
+
+ if (LOG.isInfoEnabled())
+ if(_validationPossible)
+ LOG.info("[+] PARAMSERV: epoch timer
(excl. validation): " + time_to_epoch / 1000 + " secs.");
+ else
+ LOG.info("[+] PARAMSERV: epoch timer: "
+ time_to_epoch / 1000 + " secs.");
+ }
+ }
+
+ /**
* Checks the current model against the validation set
*/
- private synchronized void validate() {
+ private void validate() {
Timing tValidate = DMLScript.STATISTICS ? new Timing(true) :
null;
_ec.setVariable(Statement.PS_MODEL, _model);
@@ -338,7 +363,7 @@ public abstract class ParamServer
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
// Log validation results
- if(LOG.isInfoEnabled())
+ if (LOG.isInfoEnabled())
LOG.info("[+] PARAMSERV: validation-loss: " + loss + "
validation-accuracy: " + accuracy);
if(tValidate != null)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index e3daf60..9c90767 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -52,7 +52,7 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
BalanceMetrics balanceMetricsBefore =
getBalanceMetrics(pFeatures);
- List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetricsBefore);
+ List<Double> weightingFactors = getWeightingFactors(pFeatures,
balanceMetricsBefore);
int average_num_rows = (int) balanceMetricsBefore._avgRows;
@@ -79,7 +79,7 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weightingFactors);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index e00923e..c6429b4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,15 +45,15 @@ public abstract class DataPartitionFederatedScheme {
public final List<MatrixObject> _pLabels;
public final int _workerNum;
public final BalanceMetrics _balanceMetrics;
- public final List<Double> _weighingFactors;
+ public final List<Double> _weightingFactors;
- public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double>
weighingFactors) {
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double>
weightingFactors) {
_pFeatures = pFeatures;
_pLabels = pLabels;
_workerNum = workerNum;
_balanceMetrics = balanceMetrics;
- _weighingFactors = weighingFactors;
+ _weightingFactors = weightingFactors;
}
}
@@ -125,12 +125,12 @@ public abstract class DataPartitionFederatedScheme {
return new BalanceMetrics(minRows, sum / slices.size(),
maxRows);
}
- static List<Double> getWeighingFactors(List<MatrixObject> pFeatures,
BalanceMetrics balanceMetrics) {
- List<Double> weighingFactors = new ArrayList<>();
+ static List<Double> getWeightingFactors(List<MatrixObject> pFeatures,
BalanceMetrics balanceMetrics) {
+ List<Double> weightingFactors = new ArrayList<>();
pFeatures.forEach((feature) -> {
- weighingFactors.add((double) feature.getNumRows() /
balanceMetrics._avgRows);
+ weightingFactors.add((double) feature.getNumRows() /
balanceMetrics._avgRows);
});
- return weighingFactors;
+ return weightingFactors;
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index afbaf4d..ae8d874 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -35,7 +35,7 @@ public class KeepDataOnWorkerFederatedScheme extends
DataPartitionFederatedSchem
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
- List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetrics);
- return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weighingFactors);
+ List<Double> weightingFactors = getWeightingFactors(pFeatures,
balanceMetrics);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weightingFactors);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index e9c1b50..77b2287 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -52,7 +52,7 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- List<Double> weighingFactors = getWeighingFactors(pFeatures,
getBalanceMetrics(pFeatures));
+ List<Double> weightingFactors = getWeightingFactors(pFeatures,
getBalanceMetrics(pFeatures));
int max_rows = 0;
for (MatrixObject pFeature : pFeatures) {
@@ -82,7 +82,7 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weightingFactors);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 365554d..af95270 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -51,7 +51,7 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
- List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetrics);
+ List<Double> weightingFactors = getWeightingFactors(pFeatures,
balanceMetrics);
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
@@ -71,7 +71,7 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
}
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weighingFactors);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weightingFactors);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index e55b92e..369b3dd 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -52,7 +52,7 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- List<Double> weighingFactors = getWeighingFactors(pFeatures,
getBalanceMetrics(pFeatures));
+ List<Double> weightingFactors = getWeightingFactors(pFeatures,
getBalanceMetrics(pFeatures));
int min_rows = Integer.MAX_VALUE;
for (MatrixObject pFeature : pFeatures) {
@@ -82,7 +82,7 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weightingFactors);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a99e8ee..e64fdf8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -41,11 +41,10 @@ import static org.apache.sysds.parser.Statement.PS_MODE;
import static org.apache.sysds.parser.Statement.PS_MODEL;
import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
import static org.apache.sysds.parser.Statement.PS_SCHEME;
-import static org.apache.sysds.parser.Statement.PS_FED_SCHEME;
import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
-import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING;
import static org.apache.sysds.parser.Statement.PS_SEED;
import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
@@ -127,7 +126,9 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private void runFederated(ExecutionContext ec) {
- Timing tExecutionTime = DMLScript.STATISTICS ? new Timing(true)
: null;
+ if(DMLScript.STATISTICS)
+ Statistics.getPSExecutionTimer().start();
+
Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
LOG.info("PARAMETER SERVER");
LOG.info("[+] Running in federated mode");
@@ -135,12 +136,11 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// get inputs
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
- String valFunc = getValFunction();
PSUpdateType updateType = getUpdateType();
PSFrequency freq = getFrequency();
FederatedPSScheme federatedPSScheme = getFederatedScheme();
PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
- boolean weighing = getWeighing();
+ boolean weighting = getWeighting();
int seed = getSeed();
if( LOG.isInfoEnabled() ) {
@@ -148,7 +148,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
LOG.info("[+] Frequency: " + freq);
LOG.info("[+] Data Partitioning: " + federatedPSScheme);
LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
- LOG.info("[+] Weighing: " + weighing);
+ LOG.info("[+] Weighting: " + weighting);
LOG.info("[+] Seed: " + seed);
}
if (tSetup != null)
@@ -179,12 +179,14 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ExecutionContext aggServiceEC =
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
// Create the parameter server
ListObject model = ec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, freq, workerNum, model, aggServiceEC, valFunc,
- getNumBatchesPerEpoch(runtimeBalancing,
result._balanceMetrics), ec.getMatrixObject(getParam(PS_VAL_FEATURES)),
ec.getMatrixObject(getParam(PS_VAL_LABELS)));
+ MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
+ MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+ ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
+ getNumBatchesPerEpoch(runtimeBalancing,
result._balanceMetrics), val_features, val_labels);
// Create the local workers
int finalNumBatchesPerEpoch =
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0,
workerNum)
- .mapToObj(i -> new FederatedPSControlThread(i, updFunc,
freq, runtimeBalancing, weighing,
+ .mapToObj(i -> new FederatedPSControlThread(i, updFunc,
freq, runtimeBalancing, weighting,
getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
.collect(Collectors.toList());
if(workerNum != threads.size()) {
@@ -194,7 +196,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
for (int i = 0; i < threads.size(); i++) {
threads.get(i).setFeatures(result._pFeatures.get(i));
threads.get(i).setLabels(result._pLabels.get(i));
- threads.get(i).setup(result._weighingFactors.get(i));
+ threads.get(i).setup(result._weightingFactors.get(i));
}
if (DMLScript.STATISTICS)
Statistics.accPSSetupTime((long) tSetup.stop());
@@ -206,7 +208,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Fetch the final model from ps
ec.setVariable(output.getName(), ps.getResult());
if (DMLScript.STATISTICS)
- Statistics.accPSExecutionTime((long)
tExecutionTime.stop());
+ Statistics.accPSExecutionTime((long)
Statistics.getPSExecutionTimer().stop());
} catch (InterruptedException | ExecutionException e) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
} finally {
@@ -293,6 +295,9 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private void runLocally(ExecutionContext ec, PSModeType mode) {
+ if(DMLScript.STATISTICS)
+ Statistics.getPSExecutionTimer().start();
+
Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
int workerNum = getWorkerNum(mode);
BasicThreadFactory factory = new BasicThreadFactory.Builder()
@@ -314,9 +319,15 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
PSFrequency freq = getFrequency();
PSUpdateType updateType = getUpdateType();
+ double rows_per_worker = Math.ceil((float)
ec.getMatrixObject(getParam(PS_FEATURES)).getNumRows() / workerNum);
+ int num_batches_per_epoch = (int) Math.ceil(rows_per_worker /
getBatchSize());
+
// Create the parameter server
ListObject model = ec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC);
+ MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
+ MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+ ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC, getValFunction(),
+ num_batches_per_epoch, val_features,
val_labels);
// Create the local workers
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -344,6 +355,8 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ret.get(); //error handling
// Fetch the final model from ps
ec.setVariable(output.getName(), ps.getResult());
+ if (DMLScript.STATISTICS)
+ Statistics.accPSExecutionTime((long)
Statistics.getPSExecutionTimer().stop());
} catch (InterruptedException | ExecutionException e) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
} finally {
@@ -529,11 +542,11 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
private FederatedPSScheme getFederatedScheme() {
FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME;
- if (getParameterMap().containsKey(PS_FED_SCHEME)) {
+ if (getParameterMap().containsKey(PS_SCHEME)) {
try {
- federated_scheme =
FederatedPSScheme.valueOf(getParam(PS_FED_SCHEME));
+ federated_scheme =
FederatedPSScheme.valueOf(getParam(PS_SCHEME));
} catch (IllegalArgumentException e) {
- throw new
DMLRuntimeException(String.format("Paramserv function in federated mode: not
support data partition scheme '%s'", getParam(PS_FED_SCHEME)));
+ throw new
DMLRuntimeException(String.format("Paramserv function in federated mode: not
support data partition scheme '%s'", getParam(PS_SCHEME)));
}
}
return federated_scheme;
@@ -548,7 +561,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
*/
private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing,
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
int numBatchesPerEpoch;
- if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+ if(runtimeBalancing == PSRuntimeBalancing.CYCLE_MIN ||
runtimeBalancing == PSRuntimeBalancing.BASELINE) {
numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
|| runtimeBalancing ==
PSRuntimeBalancing.SCALE_BATCH) {
@@ -561,8 +574,8 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
return numBatchesPerEpoch;
}
- private boolean getWeighing() {
- return getParameterMap().containsKey(PS_FED_WEIGHING) &&
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+ private boolean getWeighting() {
+ return getParameterMap().containsKey(PS_FED_WEIGHTING) &&
Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
}
private String getValFunction() {
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 320f610..8fcdf02 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -38,6 +38,7 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
@@ -117,6 +118,7 @@ public class Statistics
private static final LongAdder sparkBroadcastCount = new LongAdder();
// Paramserv function stats (time is in milli sec)
+ private static final Timing psExecutionTimer = new Timing(false);
private static final LongAdder psExecutionTime = new LongAdder();
private static final LongAdder psNumWorkers = new LongAdder();
private static final LongAdder psSetupTime = new LongAdder();
@@ -130,7 +132,7 @@ public class Statistics
// Federated parameter server specifics (time is in milli sec)
private static final LongAdder fedPSDataPartitioningTime = new
LongAdder();
private static final LongAdder fedPSWorkerComputingTime = new
LongAdder();
- private static final LongAdder fedPSGradientWeighingTime = new
LongAdder();
+ private static final LongAdder fedPSGradientWeightingTime = new
LongAdder();
private static final LongAdder fedPSCommunicationTime = new LongAdder();
//PARFOR optimization stats (low frequency updates)
@@ -571,6 +573,14 @@ public class Statistics
psNumWorkers.add(n);
}
+ public static Timing getPSExecutionTimer() {
+ return psExecutionTimer;
+ }
+
+ public static double getPSExecutionTime() {
+ return psExecutionTime.doubleValue();
+ }
+
public static void accPSExecutionTime(long n) {
psExecutionTime.add(n);
}
@@ -603,6 +613,10 @@ public class Statistics
psRpcRequestTime.add(t);
}
+ public static double getPSValidationTime() {
+ return psValidationTime.doubleValue();
+ }
+
public static void accPSValidationTime(long t) {
psValidationTime.add(t);
}
@@ -615,8 +629,8 @@ public class Statistics
fedPSWorkerComputingTime.add(t);
}
- public static void accFedPSGradientWeighingTime(long t) {
- fedPSGradientWeighingTime.add(t);
+ public static void accFedPSGradientWeightingTime(long t) {
+ fedPSGradientWeightingTime.add(t);
}
public static void accFedPSCommunicationTime(long t) {
fedPSCommunicationTime.add(t);}
@@ -1049,7 +1063,7 @@ public class Statistics
sb.append(String.format("PS fed data
partitioning time:\t%.3f secs.\n", fedPSDataPartitioningTime.doubleValue() /
1000));
sb.append(String.format("PS fed comm
time (cum):\t\t%.3f secs.\n", fedPSCommunicationTime.doubleValue() / 1000));
sb.append(String.format("PS fed worker
comp time (cum):\t%.3f secs.\n", fedPSWorkerComputingTime.doubleValue() /
1000));
- sb.append(String.format("PS fed grad
weigh time (cum):\t%.3f secs.\n", fedPSGradientWeighingTime.doubleValue() /
1000));
+ sb.append(String.format("PS fed grad.
weigh. time (cum):\t%.3f secs.\n", fedPSGradientWeightingTime.doubleValue() /
1000));
sb.append(String.format("PS fed global
model agg time:\t%.3f secs.\n", psAggregationTime.doubleValue() / 1000));
}
else {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 5d7c7e2..9221a53 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,7 +54,7 @@ public class FederatedParamservTest extends AutomatedTestBase
{
private final String _freq;
private final String _scheme;
private final String _runtime_balancing;
- private final String _weighing;
+ private final String _weighting;
private final String _data_distribution;
private final int _seed;
@@ -66,35 +66,35 @@ public class FederatedParamservTest extends
AutomatedTestBase {
// basic functionality
//{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
- {"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "NONE" ,
"true", "IMBALANCED", 200},
- {"CNN", 2, 4, 1, 4, 0.01, "ASP",
"BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "true", "IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 4, 0.01, "ASP",
"EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "true", "IMBALANCED",
200},
- {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "BASELINE", "true", "IMBALANCED",
200},
+ {"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "NONE",
"true", "IMBALANCED", 200},
+ {"CNN", 2, 4, 1, 4, 0.01, "ASP",
"BATCH", "REPLICATE_TO_MAX", "CYCLE_MIN", "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "ASP",
"EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX", "true", "IMBALANCED",
200},
+ {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED",
200},
/*
// runtime balancing
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED",
200},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED",
200},
// data partitioning
- {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true",
"IMBALANCED", 200},
- {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SHUFFLE", "CYCLE_AVG", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "REPLICATE_TO_MAX", "NONE", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "BALANCE_TO_AVG", "NONE", "true",
"IMBALANCED", 200},
// balanced tests
- {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200}
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED",
200}
*/
});
}
public FederatedParamservTest(String networkType, int
numFederatedWorkers, int dataSetSize, int batch_size,
- int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String weighing, String data_distribution,
int seed) {
+ int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String weighting, String data_distribution,
int seed) {
_networkType = networkType;
_numFederatedWorkers = numFederatedWorkers;
@@ -106,7 +106,7 @@ public class FederatedParamservTest extends
AutomatedTestBase {
_freq = freq;
_scheme = scheme;
_runtime_balancing = runtime_balancing;
- _weighing = weighing;
+ _weighting = weighting;
_data_distribution = data_distribution;
_seed = seed;
}
@@ -192,7 +192,7 @@ public class FederatedParamservTest extends
AutomatedTestBase {
"freq=" + _freq,
"scheme=" + _scheme,
"runtime_balancing=" +
_runtime_balancing,
- "weighing=" + _weighing,
+ "weighting=" + _weighting,
"network_type=" + _networkType,
"channels=" + C,
"hin=" + Hin,
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 79628ef..6663ca6 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -161,7 +161,7 @@ train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val,
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
string utype, string freq, int batch_size, string scheme, string
runtime_balancing,
- string weighing, double eta, int C, int Hin, int Win, int seed = -1)
+ string weighting, double eta, int C, int Hin, int Win, int seed = -1)
return (list[unknown] model)
{
N = nrow(X)
@@ -208,7 +208,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
val="./src/test/scripts/functions/federated/paramserv/CNN.dml::validate",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed)
}
/*
diff --git
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index c7ad305..7efd588 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -27,13 +27,13 @@ features = read($features)
labels = read($labels)
if($network_type == "TwoNN") {
- model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed)
print("Test results:")
[loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784),
matrix(0, rows=100, cols=10), model, list())
print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test +
"\n")
}
else {
- model = CNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin,
$win, $seed)
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin,
$win, $seed)
print("Test results:")
hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
[loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784),
matrix(0, rows=100, cols=10), model, hyperparams)
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index e7fc6d9..f42cdca 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y,
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighing,
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighting,
double eta, int seed = -1)
return (list[unknown] model) {
@@ -156,7 +156,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed)
}
/*