This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5dec562 [SYSTEMDS-2550] Federated paramserv balancing and data
partitioning
5dec562 is described below
commit 5dec5627398a7eb58facb7be2be0973a2428b345
Author: Tobias Rieger <[email protected]>
AuthorDate: Sun Dec 20 19:44:18 2020 +0100
[SYSTEMDS-2550] Federated paramserv balancing and data partitioning
Closes #1131.
---
.../ParameterizedBuiltinFunctionExpression.java | 7 +-
.../java/org/apache/sysds/parser/Statement.java | 8 +-
.../sysds/runtime/controlprogram/ProgramBlock.java | 2 +-
.../controlprogram/federated/FederationMap.java | 2 +-
.../paramserv/FederatedPSControlThread.java | 342 +++++++++------------
.../controlprogram/paramserv/ParamservUtils.java | 33 +-
.../paramserv/dp/BalanceToAvgFederatedScheme.java | 106 +++++++
.../paramserv/dp/DataPartitionFederatedScheme.java | 98 +++++-
.../paramserv/dp/FederatedDataPartitioner.java | 9 +
.../dp/KeepDataOnWorkerFederatedScheme.java | 2 +-
.../dp/ReplicateToMaxFederatedScheme.java | 103 +++++++
.../paramserv/dp/ShuffleFederatedScheme.java | 56 +++-
.../dp/SubsampleToMinFederatedScheme.java | 102 ++++++
.../cp/ParamservBuiltinCPInstruction.java | 56 +++-
.../fed/MatrixIndexingFEDInstruction.java | 2 +-
.../instructions/fed/VariableFEDInstruction.java | 4 +-
.../sysds/runtime/io/ReaderWriterFederated.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 33 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 74 +++++
.../paramserv/FederatedParamservTest.java | 145 +++++----
.../scripts/functions/federated/paramserv/CNN.dml | 91 +++---
.../federated/paramserv/FederatedParamservTest.dml | 32 +-
.../functions/federated/paramserv/TwoNN.dml | 74 ++---
23 files changed, 965 insertions(+), 418 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 2e33676..5171f21 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -286,7 +286,11 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
raiseValidateError("Should provide more arguments for
function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS);
}
//check for invalid parameters
- Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL,
Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES,
Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY,
Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+ Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL,
Statement.PS_FEATURES, Statement.PS_LABELS,
+ Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS,
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
+ Statement.PS_MODE, Statement.PS_UPDATE_TYPE,
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
+ Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
+ Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
@@ -304,6 +308,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE,
DataType.SCALAR, ValueType.INT64, conditional);
checkDataValueType(true, fname, Statement.PS_PARALLELISM,
DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
+ checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING,
conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS,
DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING,
conditional);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index b61b0d6..6767d85 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -87,6 +87,10 @@ public abstract class Statement implements ParseInfo
public enum PSFrequency {
BATCH, EPOCH
}
+ public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+ public enum PSRuntimeBalancing {
+ NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH,
SCALE_BATCH_AND_WEIGH
+ }
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
public static final String PS_PARALLELISM = "k";
@@ -95,7 +99,7 @@ public abstract class Statement implements ParseInfo
DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM,
OVERLAP_RESHUFFLE
}
public enum FederatedPSScheme {
- KEEP_DATA_ON_WORKER, SHUFFLE
+ KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX,
SUBSAMPLE_TO_MIN, BALANCE_TO_AVG
}
public static final String PS_HYPER_PARAMS = "hyperparams";
public static final String PS_CHECKPOINTING = "checkpointing";
@@ -107,7 +111,7 @@ public abstract class Statement implements ParseInfo
// prefixed with code: "1701-NCC-" to not overwrite anything
public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
- public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
+ public static final String PS_FED_POSS_BATCHES_LOCAL =
"1701-NCC-poss_batches_local";
public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
public static final String PS_FED_GRADIENTS_FNAME =
"1701-NCC-gradients_fname";
public static final String PS_FED_AGGREGATION_FNAME =
"1701-NCC-aggregation_fname";
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 263ecf4..a555a5d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -376,7 +376,7 @@ public abstract class ProgramBlock implements ParseInfo
CacheableData<?> mo = (CacheableData<?>)dat;
if( mo.isFederated() ) {
- if(
mo.getFedMapping().getFedMapping().isEmpty() )
+ if( mo.getFedMapping().getMap().isEmpty() )
throw new DMLRuntimeException("Invalid
empty FederationMap for: "+mo);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 9590510..482ade7 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -114,7 +114,7 @@ public class FederationMap {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}
- public Map<FederatedRange, FederatedData> getFedMapping() {
+ public Map<FederatedRange, FederatedData> getMap() {
return _fedMap;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 80418ac..393b131 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.controlprogram.paramserv;
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -29,6 +32,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
@@ -53,17 +57,24 @@ import static
org.apache.sysds.runtime.util.ProgramConverter.*;
public class FederatedPSControlThread extends PSWorker implements
Callable<Void> {
private static final long serialVersionUID = 6846648059569648791L;
+ protected static final Log LOG =
LogFactory.getLog(ParamServer.class.getName());
+
+ Statement.PSRuntimeBalancing _runtimeBalancing;
FederatedData _featuresData;
FederatedData _labelsData;
- final long _batchCounterVarID;
+ final long _localStartBatchNumVarID;
final long _modelVarID;
- int _totalNumBatches;
+ int _numBatchesPerGlobalEpoch;
+ int _possibleBatchesPerLocalEpoch;
+ boolean _cycleStartAt0 = false;
- public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec,
ParamServer ps) {
+ public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int
epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec,
ParamServer ps) {
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
-
+
+ _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+ _runtimeBalancing = runtimeBalancing;
// generate the IDs for model and batch counter. These get
overwritten on the federated worker each time
- _batchCounterVarID = FederationUtils.getNextFedDataID();
+ _localStartBatchNumVarID = FederationUtils.getNextFedDataID();
_modelVarID = FederationUtils.getNextFedDataID();
}
@@ -72,18 +83,22 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
*/
public void setup() {
// prepare features and labels
- _features.getFedMapping().forEachParallel((range, data) -> {
- _featuresData = data;
- return null;
- });
- _labels.getFedMapping().forEachParallel((range, data) -> {
- _labelsData = data;
- return null;
- });
+ _featuresData = (FederatedData)
_features.getFedMapping().getMap().values().toArray()[0];
+ _labelsData = (FederatedData)
_labels.getFedMapping().getMap().values().toArray()[0];
// calculate number of batches and get data size
long dataSize = _features.getNumRows();
- _totalNumBatches = (int) Math.ceil((double) dataSize /
_batchSize);
+ _possibleBatchesPerLocalEpoch = (int) Math.ceil((double)
dataSize / _batchSize);
+ if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN
+ || _runtimeBalancing ==
Statement.PSRuntimeBalancing.CYCLE_AVG
+ || _runtimeBalancing ==
Statement.PSRuntimeBalancing.CYCLE_MAX)) {
+ _numBatchesPerGlobalEpoch =
_possibleBatchesPerLocalEpoch;
+ }
+
+ if(_runtimeBalancing ==
Statement.PSRuntimeBalancing.SCALE_BATCH
+ || _runtimeBalancing ==
Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
+ throw new NotImplementedException();
+ }
// serialize program
// create program blocks for the instruction filtering
@@ -112,17 +127,17 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
programSerialized = sb.toString();
// write program and meta data to worker
- Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- _featuresData.getVarID(),
- new setupFederatedWorker(_batchSize,
+ Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
+ new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
+ new SetupFederatedWorker(_batchSize,
dataSize,
- _totalNumBatches,
+ _possibleBatchesPerLocalEpoch,
programSerialized,
_inst.getNamespace(),
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
- _batchCounterVarID,
+ _localStartBatchNumVarID,
_modelVarID
)
));
@@ -140,24 +155,27 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
/**
* Setup UDF executed on the federated worker
*/
- private static class setupFederatedWorker extends FederatedUDF {
+ private static class SetupFederatedWorker extends FederatedUDF {
private static final long serialVersionUID =
-3148991224792675607L;
- long _batchSize;
- long _dataSize;
- long _numBatches;
- String _programString;
- String _namespace;
- String _gradientsFunctionName;
- String _aggregationFunctionName;
- ListObject _hyperParams;
- long _batchCounterVarID;
- long _modelVarID;
-
- protected setupFederatedWorker(long batchSize, long dataSize,
long numBatches, String programString, String namespace, String
gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams,
long batchCounterVarID, long modelVarID) {
+ private final long _batchSize;
+ private final long _dataSize;
+ private final int _possibleBatchesPerLocalEpoch;
+ private final String _programString;
+ private final String _namespace;
+ private final String _gradientsFunctionName;
+ private final String _aggregationFunctionName;
+ private final ListObject _hyperParams;
+ private final long _batchCounterVarID;
+ private final long _modelVarID;
+
+ protected SetupFederatedWorker(long batchSize, long dataSize,
int possibleBatchesPerLocalEpoch,
+ String programString, String namespace, String
gradientsFunctionName, String aggregationFunctionName,
+ ListObject hyperParams, long batchCounterVarID, long
modelVarID)
+ {
super(new long[]{});
_batchSize = batchSize;
_dataSize = dataSize;
- _numBatches = numBatches;
+ _possibleBatchesPerLocalEpoch =
possibleBatchesPerLocalEpoch;
_programString = programString;
_namespace = namespace;
_gradientsFunctionName = gradientsFunctionName;
@@ -175,7 +193,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// set variables to ec
ec.setVariable(Statement.PS_FED_BATCH_SIZE, new
IntObject(_batchSize));
ec.setVariable(Statement.PS_FED_DATA_SIZE, new
IntObject(_dataSize));
- ec.setVariable(Statement.PS_FED_NUM_BATCHES, new
IntObject(_numBatches));
+ ec.setVariable(Statement.PS_FED_POSS_BATCHES_LOCAL, new
IntObject(_possibleBatchesPerLocalEpoch));
ec.setVariable(Statement.PS_FED_NAMESPACE, new
StringObject(_namespace));
ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new
StringObject(_gradientsFunctionName));
ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new
StringObject(_aggregationFunctionName));
@@ -192,9 +210,9 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
*/
public void teardown() {
// write program and meta data to worker
- Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- _featuresData.getVarID(),
- new teardownFederatedWorker()
+ Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
+ new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
+ new TeardownFederatedWorker()
));
try {
@@ -210,10 +228,10 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
/**
* Teardown UDF executed on the federated worker
*/
- private static class teardownFederatedWorker extends FederatedUDF {
+ private static class TeardownFederatedWorker extends FederatedUDF {
private static final long serialVersionUID =
-153650281873318969L;
- protected teardownFederatedWorker() {
+ protected TeardownFederatedWorker() {
super(new long[]{});
}
@@ -222,14 +240,13 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// remove variables from ec
ec.removeVariable(Statement.PS_FED_BATCH_SIZE);
ec.removeVariable(Statement.PS_FED_DATA_SIZE);
- ec.removeVariable(Statement.PS_FED_NUM_BATCHES);
+ ec.removeVariable(Statement.PS_FED_POSS_BATCHES_LOCAL);
ec.removeVariable(Statement.PS_FED_NAMESPACE);
ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
ec.removeVariable(Statement.PS_FED_MODEL_VARID);
ParamservUtils.cleanupListObject(ec,
Statement.PS_HYPER_PARAMS);
- ParamservUtils.cleanupListObject(ec,
Statement.PS_GRADIENTS);
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
@@ -246,10 +263,13 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
try {
switch (_freq) {
case BATCH:
- computeBatch(_totalNumBatches);
+ computeWithBatchUpdates();
break;
+ /*case NBATCH:
+ computeWithNBatchUpdates();
+ break; */
case EPOCH:
- computeEpoch();
+ computeWithEpochUpdates();
break;
default:
throw new
DMLRuntimeException(String.format("%s not support update frequency %s",
getWorkerName(), _freq));
@@ -271,154 +291,82 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_ps.push(_workerID, gradients);
}
+ static protected int getNextLocalBatchNum(int currentLocalBatchNumber,
int possibleBatchesPerLocalEpoch) {
+ return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
+ }
+
/**
- * Computes all epochs and synchronizes after each batch
- *
- * @param numBatches the number of batches per epoch
+ * Computes all epochs and updates after each batch
*/
- protected void computeBatch(int numBatches) {
+ protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
- for (int batchCounter = 0; batchCounter < numBatches;
batchCounter++) {
+ int currentLocalBatchNumber = (_cycleStartAt0) ? 0 :
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+
+ for (int batchCounter = 0; batchCounter <
_numBatchesPerGlobalEpoch; batchCounter++) {
+ int localStartBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
- ListObject gradients =
computeBatchGradients(model, batchCounter);
+ ListObject gradients =
computeGradientsForNBatches(model, 1, localStartBatchNum);
pushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
- System.out.println("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
+ if( LOG.isInfoEnabled() )
+ LOG.info("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
}
}
/**
- * Computes a single specified batch on the federated worker
- *
- * @param model the current model from the parameter server
- * @param batchCounter the current batch number needed for slicing the
features and labels
- * @return the gradient vector
+ * Computes all epochs and updates after N batches
*/
- protected ListObject computeBatchGradients(ListObject model, int
batchCounter) {
- // put batch counter on federated worker
- Future<FederatedResponse> putBatchCounterResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new
IntObject(batchCounter)));
-
- // put current model on federated worker
- Future<FederatedResponse> putParamsResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
-
- try {
- if(!putParamsResponse.get().isSuccessful() ||
!putBatchCounterResponse.get().isSuccessful())
- throw new
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
- }
- catch(Exception e) {
- throw new DMLRuntimeException("FederatedLocalPSThread:
failed to execute put" + e.getMessage());
- }
-
- // create and execute the udf on the remote worker
- Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- _featuresData.getVarID(),
- new federatedComputeBatchGradients(new
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID,
_modelVarID})
- ));
-
- try {
- Object[] responseData = udfResponse.get().getData();
- return (ListObject) responseData[0];
- }
- catch(Exception e) {
- throw new DMLRuntimeException("FederatedLocalPSThread:
failed to execute UDF" + e.getMessage());
- }
+ protected void computeWithNBatchUpdates() {
+ throw new NotImplementedException();
}
/**
- * This is the code that will be executed on the federated Worker when
computing a single batch
+ * Computes all epochs and updates after each epoch
*/
- private static class federatedComputeBatchGradients extends
FederatedUDF {
- private static final long serialVersionUID =
-3652112393963053475L;
-
- protected federatedComputeBatchGradients(long[] inIDs) {
- super(inIDs);
- }
-
- @Override
- public FederatedResponse execute(ExecutionContext ec, Data...
data) {
- // read in data by varid
- MatrixObject features = (MatrixObject) data[0];
- MatrixObject labels = (MatrixObject) data[1];
- long batchCounter = ((IntObject)
data[2]).getLongValue();
- ListObject model = (ListObject) data[3];
-
- // get data from execution context
- long batchSize = ((IntObject)
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
- long dataSize = ((IntObject)
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
- String namespace = ((StringObject)
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
- String gradientsFunctionName = ((StringObject)
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
-
- // slice batch from feature and label matrix
- long begin = batchCounter * batchSize + 1;
- long end = Math.min((batchCounter + 1) * batchSize,
dataSize);
- MatrixObject bFeatures =
ParamservUtils.sliceMatrix(features, begin, end);
- MatrixObject bLabels =
ParamservUtils.sliceMatrix(labels, begin, end);
-
- // prepare execution context
- ec.setVariable(Statement.PS_MODEL, model);
- ec.setVariable(Statement.PS_FEATURES, bFeatures);
- ec.setVariable(Statement.PS_LABELS, bLabels);
-
- // recreate gradient instruction and output
- FunctionProgramBlock func =
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName,
false);
- ArrayList<DataIdentifier> inputs =
func.getInputParams();
- ArrayList<DataIdentifier> outputs =
func.getOutputParams();
- CPOperand[] boundInputs = inputs.stream()
- .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
- ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
-
.collect(Collectors.toCollection(ArrayList::new));
- Instruction gradientsInstruction = new
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
- func.getInputParamNames(), outputNames,
"gradient function");
- DataIdentifier gradientsOutput = outputs.get(0);
-
- // calculate and gradients
- gradientsInstruction.processInstruction(ec);
- ListObject gradients =
ec.getListObject(gradientsOutput.getName());
-
- // clean up sliced batch
-
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
- ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
- ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
-
- // model clean up - doing this twice is not an issue
- ParamservUtils.cleanupListObject(ec,
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
- ParamservUtils.cleanupListObject(ec,
Statement.PS_MODEL);
-
- // return
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients);
- }
- }
-
- /**
- * Computes all epochs and synchronizes after each one
- */
- protected void computeEpoch() {
+ protected void computeWithEpochUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
+ int localStartBatchNum = (_cycleStartAt0) ? 0 :
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+
// Pull the global parameters from ps
ListObject model = pullModel();
- ListObject gradients = computeEpochGradients(model);
+ ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch,
localStartBatchNum, true);
pushGradients(gradients);
- System.out.println("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
+
+ if( LOG.isInfoEnabled() )
+ LOG.info("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
}
+ protected ListObject computeGradientsForNBatches(ListObject model, int
numBatchesToCompute, int localStartBatchNum) {
+ return computeGradientsForNBatches(model, numBatchesToCompute,
localStartBatchNum, false);
+ }
+
/**
- * Computes one epoch on the federated worker and updates the model
local
+ * Computes the gradients of n batches on the federated worker and is
able to update the model local.
+ * Returns the gradients.
*
* @param model the current model from the parameter server
+ * @param localStartBatchNum the batch to start from
+ * @param localUpdate whether to update the model locally
+ *
* @return the gradient vector
*/
- protected ListObject computeEpochGradients(ListObject model) {
+ protected ListObject computeGradientsForNBatches(ListObject model,
+ int numBatchesToCompute, int localStartBatchNum, boolean
localUpdate)
+ {
+ // put local start batch num on federated worker
+ Future<FederatedResponse> putBatchCounterResponse =
_featuresData.executeFederatedOperation(
+ new FederatedRequest(RequestType.PUT_VAR,
_localStartBatchNumVarID, new IntObject(localStartBatchNum)));
// put current model on federated worker
- Future<FederatedResponse> putParamsResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+ Future<FederatedResponse> putParamsResponse =
_featuresData.executeFederatedOperation(
+ new FederatedRequest(RequestType.PUT_VAR, _modelVarID,
model));
try {
- if(!putParamsResponse.get().isSuccessful())
+ if(!putParamsResponse.get().isSuccessful() ||
!putBatchCounterResponse.get().isSuccessful())
throw new
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
}
catch(Exception e) {
@@ -426,9 +374,10 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
// create and execute the udf on the remote worker
- Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- _featuresData.getVarID(),
- new federatedComputeEpochGradients(new
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID})
+ Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
+ new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
+ new federatedComputeGradientsForNBatches(new
long[]{_featuresData.getVarID(), _labelsData.getVarID(),
+ _localStartBatchNumVarID, _modelVarID},
numBatchesToCompute,localUpdate)
));
try {
@@ -441,13 +390,17 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
/**
- * This is the code that will be executed on the federated Worker when
computing one epoch
+ * This is the code that will be executed on the federated Worker when
computing one gradients for n batches
*/
- private static class federatedComputeEpochGradients extends
FederatedUDF {
+ private static class federatedComputeGradientsForNBatches extends
FederatedUDF {
private static final long serialVersionUID =
-3075901536748794832L;
+ int _numBatchesToCompute;
+ boolean _localUpdate;
- protected federatedComputeEpochGradients(long[] inIDs) {
+ protected federatedComputeGradientsForNBatches(long[] inIDs,
int numBatchesToCompute, boolean localUpdate) {
super(inIDs);
+ _numBatchesToCompute = numBatchesToCompute;
+ _localUpdate = localUpdate;
}
@Override
@@ -455,12 +408,13 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// read in data by varid
MatrixObject features = (MatrixObject) data[0];
MatrixObject labels = (MatrixObject) data[1];
- ListObject model = (ListObject) data[2];
+ int localStartBatchNum = (int) ((IntObject)
data[2]).getLongValue();
+ ListObject model = (ListObject) data[3];
// get data from execution context
long batchSize = ((IntObject)
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
long dataSize = ((IntObject)
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
- long numBatches = ((IntObject)
ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue();
+ int possibleBatchesPerLocalEpoch = (int) ((IntObject)
ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue();
String namespace = ((StringObject)
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
String gradientsFunctionName = ((StringObject)
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
String aggregationFuctionName = ((StringObject)
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
@@ -478,61 +432,71 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
func.getInputParamNames(), outputNames,
"gradient function");
DataIdentifier gradientsOutput = outputs.get(0);
- // recreate aggregation instruction and output
- func =
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName,
false);
- inputs = func.getInputParams();
- outputs = func.getOutputParams();
- boundInputs = inputs.stream()
- .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
- outputNames =
outputs.stream().map(DataIdentifier::getName)
-
.collect(Collectors.toCollection(ArrayList::new));
- Instruction aggregationInstruction = new
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
- func.getInputParamNames(), outputNames,
"aggregation function");
- DataIdentifier aggregationOutput = outputs.get(0);
-
+ // recreate aggregation instruction and output if needed
+ Instruction aggregationInstruction = null;
+ DataIdentifier aggregationOutput = null;
+ if(_localUpdate && _numBatchesToCompute > 1) {
+ func =
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName,
false);
+ inputs = func.getInputParams();
+ outputs = func.getOutputParams();
+ boundInputs = inputs.stream()
+ .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ outputNames =
outputs.stream().map(DataIdentifier::getName)
+
.collect(Collectors.toCollection(ArrayList::new));
+ aggregationInstruction = new
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
+ func.getInputParamNames(),
outputNames, "aggregation function");
+ aggregationOutput = outputs.get(0);
+ }
ListObject accGradients = null;
+ int currentLocalBatchNumber = localStartBatchNum;
// prepare execution context
ec.setVariable(Statement.PS_MODEL, model);
- for (int batchCounter = 0; batchCounter < numBatches;
batchCounter++) {
+ for (int batchCounter = 0; batchCounter <
_numBatchesToCompute; batchCounter++) {
+ int localBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch);
+
// slice batch from feature and label matrix
- long begin = batchCounter * batchSize + 1;
- long end = Math.min((batchCounter + 1) *
batchSize, dataSize);
+ long begin = localBatchNum * batchSize + 1;
+ long end = Math.min((localBatchNum + 1) *
batchSize, dataSize);
MatrixObject bFeatures =
ParamservUtils.sliceMatrix(features, begin, end);
MatrixObject bLabels =
ParamservUtils.sliceMatrix(labels, begin, end);
// prepare execution context
ec.setVariable(Statement.PS_FEATURES,
bFeatures);
ec.setVariable(Statement.PS_LABELS, bLabels);
- boolean localUpdate = batchCounter < numBatches
- 1;
- // calculate intermediate gradients
+ // calculate gradients for batch
gradientsInstruction.processInstruction(ec);
ListObject gradients =
ec.getListObject(gradientsOutput.getName());
- // TODO: is this equivalent for momentum based
and AMS prob?
+ // accrue the computed gradients - In the
single batch case this is just a list copy
+ // is this equivalent for momentum based and
AMS prob?
accGradients =
ParamservUtils.accrueGradients(accGradients, gradients, false);
- // Update the local model with gradients
- if(localUpdate) {
+ // update the local model with gradients if
needed
+ if(_localUpdate && batchCounter <
_numBatchesToCompute - 1) {
// Invoke the aggregate function
+ assert aggregationInstruction != null;
aggregationInstruction.processInstruction(ec);
// Get the new model
model =
ec.getListObject(aggregationOutput.getName());
// Set new model in execution context
ec.setVariable(Statement.PS_MODEL,
model);
// clean up gradients and result
- ParamservUtils.cleanupListObject(ec,
Statement.PS_GRADIENTS);
ParamservUtils.cleanupListObject(ec,
aggregationOutput.getName());
}
- // clean up sliced batch
+ // clean up
+ ParamservUtils.cleanupListObject(ec,
gradientsOutput.getName());
ParamservUtils.cleanupData(ec,
Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec,
Statement.PS_LABELS);
+
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
+ if( LOG.isInfoEnabled() )
+ LOG.info("[+]" + " completed batch " +
localBatchNum);
}
- // model clean up - doing this twice is not an issue
+ // model clean up
ParamservUtils.cleanupListObject(ec,
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
ParamservUtils.cleanupListObject(ec,
Statement.PS_MODEL);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index e63fb14..51600d2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -210,12 +210,39 @@ public class ParamservUtils {
MatrixBlock sample = MatrixBlock.sampleOperations(numEntries,
numEntries, false, seed);
// Combine the sequence and sample as a table
- return seq.ctableSeqOperations(sample, 1.0,
- new MatrixBlock(numEntries, numEntries, true));
+ return seq.ctableSeqOperations(sample, 1.0, new
MatrixBlock(numEntries, numEntries, true));
+ }
+
+ /**
+ * Generates a matrix which when left multiplied with the input matrix
will subsample
+ * @param nsamples number of samples
+ * @param nrows number of rows in input matrix
+ * @param seed seed used to generate random number
+ * @return subsample matrix
+ */
+ public static MatrixBlock generateSubsampleMatrix(int nsamples, int
nrows, long seed) {
+ MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
+ // No replacement to preserve as much of the original data as
possible
+ MatrixBlock sample = MatrixBlock.sampleOperations(nrows,
nsamples, false, seed);
+ return seq.ctableSeqOperations(sample, 1.0, new
MatrixBlock(nsamples, nrows, true), false);
+ }
+
+ /**
+ * Generates a matrix which when left multiplied with the input matrix
will replicate n data rows
+ * @param nsamples number of samples
+ * @param nrows number of rows in input matrix
+ * @param seed seed used to generate random number
+ * @return replication matrix
+ */
+ public static MatrixBlock generateReplicationMatrix(int nsamples, int
nrows, long seed) {
+ MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
+ // Replacement set to true to provide random replication
+ MatrixBlock sample = MatrixBlock.sampleOperations(nrows,
nsamples, true, seed);
+ return seq.ctableSeqOperations(sample, 1.0, new
MatrixBlock(nsamples, nrows, true), false);
}
public static ExecutionContext createExecutionContext(ExecutionContext
ec,
- LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+ LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
{
return createExecutionContext(ec, varsMap, updFunc, aggFunc, k,
false);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
new file mode 100644
index 0000000..460faba
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+ int average_num_rows = (int)
Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+
+ for(int i = 0; i < pFeatures.size(); i++) {
+ // Works, because the map contains a single entry
+ FederatedData featuresData = (FederatedData)
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+ Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ featuresData.getVarID(), new
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, average_num_rows)));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new
DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: balance
UDF returned fail");
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing
balance UDF failed" + e.getMessage());
+ }
+
+ DataCharacteristics update =
pFeatures.get(i).getDataCharacteristics().setRows(average_num_rows);
+ pFeatures.get(i).updateDataCharacteristics(update);
+ update =
pLabels.get(i).getDataCharacteristics().setRows(average_num_rows);
+ pLabels.get(i).updateDataCharacteristics(update);
+ }
+
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ }
+
+ /**
+ * Balance UDF executed on the federated worker
+ */
+ private static class balanceDataOnFederatedWorker extends FederatedUDF {
+ private static final long serialVersionUID =
6631958250346625546L;
+ private final int _average_num_rows;
+
+ protected balanceDataOnFederatedWorker(long[] inIDs, int
average_num_rows) {
+ super(inIDs);
+ _average_num_rows = average_num_rows;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+
+ if(features.getNumRows() > _average_num_rows) {
+ // generate subsampling matrix
+ MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_average_num_rows,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ subsampleTo(features, subsampleMatrixBlock);
+ subsampleTo(labels, subsampleMatrixBlock);
+ }
+ else if(features.getNumRows() < _average_num_rows) {
+ int num_rows_needed = _average_num_rows -
Math.toIntExact(features.getNumRows());
+ // generate replication matrix
+ MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ replicateTo(features, replicateMatrixBlock);
+ replicateTo(labels, replicateMatrixBlock);
+ }
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index 4183372..f5c9638 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -26,6 +26,10 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
@@ -37,14 +41,16 @@ import java.util.List;
public abstract class DataPartitionFederatedScheme {
public static final class Result {
- public final List<MatrixObject> pFeatures;
- public final List<MatrixObject> pLabels;
- public final int workerNum;
-
- public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum) {
- this.pFeatures = pFeatures;
- this.pLabels = pLabels;
- this.workerNum = workerNum;
+ public final List<MatrixObject> _pFeatures;
+ public final List<MatrixObject> _pLabels;
+ public final int _workerNum;
+ public final BalanceMetrics _balanceMetrics;
+
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum, BalanceMetrics balanceMetrics) {
+ this._pFeatures = pFeatures;
+ this._pLabels = pLabels;
+ this._workerNum = workerNum;
+ this._balanceMetrics = balanceMetrics;
}
}
@@ -57,12 +63,10 @@ public abstract class DataPartitionFederatedScheme {
*/
static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
-
List<MatrixObject> slices =
Collections.synchronizedList(new ArrayList<>());
fedMatrix.getFedMapping().forEachParallel((range, data)
-> {
// Create sliced matrix object
MatrixObject slice = new
MatrixObject(fedMatrix.getValueType(),
Dag.getNextUniqueVarname(Types.DataType.MATRIX));
- // Warning needs MetaDataFormat instead of
MetaData
slice.setMetaData(new MetaDataFormat(
new
MatrixCharacteristics(range.getSize(0), range.getSize(1)),
Types.FileFormat.BINARY)
@@ -85,4 +89,78 @@ public abstract class DataPartitionFederatedScheme {
"currently only supports row federated
data");
}
}
+
+ static BalanceMetrics getBalanceMetrics(List<MatrixObject> slices) {
+ if (slices == null || slices.size() == 0)
+ return new BalanceMetrics(0, 0, 0);
+
+ long minRows = slices.get(0).getNumRows();
+ long maxRows = minRows;
+ long sum = 0;
+
+ for (MatrixObject slice : slices) {
+ if (slice.getNumRows() < minRows)
+ minRows = slice.getNumRows();
+ else if (slice.getNumRows() > maxRows)
+ maxRows = slice.getNumRows();
+
+ sum += slice.getNumRows();
+ }
+
+ return new BalanceMetrics(minRows, sum / slices.size(),
maxRows);
+ }
+
+ public static final class BalanceMetrics {
+ public final long _minRows;
+ public final long _avgRows;
+ public final long _maxRows;
+
+ public BalanceMetrics(long minRows, long avgRows, long maxRows)
{
+ this._minRows = minRows;
+ this._avgRows = avgRows;
+ this._maxRows = maxRows;
+ }
+ }
+
+ /**
+ * Just a mat multiply used to shuffle with a provided shuffle
matrixBlock
+ *
+ * @param m the input matrix object
+ * @param P the permutation matrix for shuffling
+ */
+ static void shuffle(MatrixObject m, MatrixBlock P) {
+ int k = InfrastructureAnalyzer.getLocalParallelism();
+ AggregateBinaryOperator mm =
InstructionUtils.getMatMultOperator(k);
+ MatrixBlock out = P.aggregateBinaryOperations(P,
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+ m.acquireModify(out);
+ m.release();
+ }
+
+ /**
+ * Takes a MatrixObjects and extends it to the chosen number of rows by
random replication
+ *
+ * @param m the input matrix object
+ * @param R the permutation matrix for replication
+ */
+ static void replicateTo(MatrixObject m, MatrixBlock R) {
+ int k = InfrastructureAnalyzer.getLocalParallelism();
+ AggregateBinaryOperator mm =
InstructionUtils.getMatMultOperator(k);
+ MatrixBlock out = R.aggregateBinaryOperations(R,
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+ m.acquireModify(m.acquireReadAndRelease().append(out, new
MatrixBlock(), false));
+ m.release();
+ }
+
+ /**
+ * Just a mat multiply used to subsample with a provided subsample
matrixBlock
+ *
+ * @param m the input matrix object
+ * @param R the permutation matrix for subsampling
+ */
+ static void subsampleTo(MatrixObject m, MatrixBlock S) {
+ int k = InfrastructureAnalyzer.getLocalParallelism();
+ AggregateBinaryOperator mm =
InstructionUtils.getMatMultOperator(k);
+ MatrixBlock out = S.aggregateBinaryOperations(S,
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+ m.acquireModify(out);
+ m.release();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index 4cdfb95..d1ebb6c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -35,6 +35,15 @@ public class FederatedDataPartitioner {
case SHUFFLE:
_scheme = new ShuffleFederatedScheme();
break;
+ case REPLICATE_TO_MAX:
+ _scheme = new ReplicateToMaxFederatedScheme();
+ break;
+ case SUBSAMPLE_TO_MIN:
+ _scheme = new SubsampleToMinFederatedScheme();
+ break;
+ case BALANCE_TO_AVG:
+ _scheme = new BalanceToAvgFederatedScheme();
+ break;
default:
throw new
DMLRuntimeException(String.format("FederatedDataPartitioner: not support data
partition scheme '%s'", scheme));
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index 06feded..e306f25 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -27,6 +27,6 @@ public class KeepDataOnWorkerFederatedScheme extends
DataPartitionFederatedSchem
public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- return new Result(pFeatures, pLabels, pFeatures.size());
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
new file mode 100644
index 0000000..068cfa9
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+ int max_rows = 0;
+ for (MatrixObject pFeature : pFeatures) {
+ max_rows = (pFeature.getNumRows() > max_rows) ?
Math.toIntExact(pFeature.getNumRows()) : max_rows;
+ }
+
+ for(int i = 0; i < pFeatures.size(); i++) {
+ // Works, because the map contains a single entry
+ FederatedData featuresData = (FederatedData)
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+ Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ featuresData.getVarID(), new
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, max_rows)));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new
DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme:
replicate UDF returned fail");
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme:
executing replicate UDF failed" + e.getMessage());
+ }
+
+ DataCharacteristics update =
pFeatures.get(i).getDataCharacteristics().setRows(max_rows);
+ pFeatures.get(i).updateDataCharacteristics(update);
+ update =
pLabels.get(i).getDataCharacteristics().setRows(max_rows);
+ pLabels.get(i).updateDataCharacteristics(update);
+ }
+
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ }
+
+ /**
+ * Replicate UDF executed on the federated worker
+ */
+ private static class replicateDataOnFederatedWorker extends
FederatedUDF {
+ private static final long serialVersionUID =
-6930898456315100587L;
+ private final int _max_rows;
+
+ protected replicateDataOnFederatedWorker(long[] inIDs, int
max_rows) {
+ super(inIDs);
+ _max_rows = max_rows;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+
+ // replicate up to the max
+ if(features.getNumRows() < _max_rows) {
+ int num_rows_needed = _max_rows -
Math.toIntExact(features.getNumRows());
+ // generate replication matrix
+ MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ replicateTo(features, replicateMatrixBlock);
+ replicateTo(labels, replicateMatrixBlock);
+ }
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index d6d8cfc..65ef69d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -19,15 +19,67 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.List;
+import java.util.concurrent.Future;
public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- return new Result(pFeatures, pLabels, pFeatures.size());
+
+ for(int i = 0; i < pFeatures.size(); i++) {
+ // Works, because the map contains a single entry
+ FederatedData featuresData = (FederatedData)
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+ Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ featuresData.getVarID(), new
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()})));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle
UDF returned fail");
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing
shuffle UDF failed" + e.getMessage());
+ }
+ }
+
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ }
+
+ /**
+ * Shuffle UDF executed on the federated worker
+ */
+ private static class shuffleDataOnFederatedWorker extends FederatedUDF {
+ private static final long serialVersionUID =
3228664618781333325L;
+
+ protected shuffleDataOnFederatedWorker(long[] inIDs) {
+ super(inIDs);
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+
+ // generate permutation matrix
+ MatrixBlock permutationMatrixBlock =
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()),
System.currentTimeMillis());
+ shuffle(features, permutationMatrixBlock);
+ shuffle(labels, permutationMatrixBlock);
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
}
-}
+}
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
new file mode 100644
index 0000000..9b62cc8
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+ int min_rows = Integer.MAX_VALUE;
+ for (MatrixObject pFeature : pFeatures) {
+ min_rows = (pFeature.getNumRows() < min_rows) ?
Math.toIntExact(pFeature.getNumRows()) : min_rows;
+ }
+
+ for(int i = 0; i < pFeatures.size(); i++) {
+ // Works, because the map contains a single entry
+ FederatedData featuresData = (FederatedData)
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+ Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ featuresData.getVarID(), new
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, min_rows)));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new
DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme:
subsample UDF returned fail");
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme:
executing subsample UDF failed" + e.getMessage());
+ }
+
+ DataCharacteristics update =
pFeatures.get(i).getDataCharacteristics().setRows(min_rows);
+ pFeatures.get(i).updateDataCharacteristics(update);
+ update =
pLabels.get(i).getDataCharacteristics().setRows(min_rows);
+ pLabels.get(i).updateDataCharacteristics(update);
+ }
+
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ }
+
+ /**
+ * Subsample UDF executed on the federated worker
+ */
+ private static class subsampleDataOnFederatedWorker extends
FederatedUDF {
+ private static final long serialVersionUID =
2213790859544004286L;
+ private final int _min_rows;
+
+ protected subsampleDataOnFederatedWorker(long[] inIDs, int
min_rows) {
+ super(inIDs);
+ _min_rows = min_rows;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+
+ // subsample down to minimum
+ if(features.getNumRows() > _min_rows) {
+ // generate subsampling matrix
+ MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_min_rows,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ subsampleTo(features, subsampleMatrixBlock);
+ subsampleTo(labels, subsampleMatrixBlock);
+ }
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 7a285f6..a2b8d9f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -32,6 +32,7 @@ import static
org.apache.sysds.parser.Statement.PS_PARALLELISM;
import static org.apache.sysds.parser.Statement.PS_SCHEME;
import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
+import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
import java.util.HashMap;
import java.util.HashSet;
@@ -52,11 +53,12 @@ import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.LopProperties;
-import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.Statement.PSFrequency;
import org.apache.sysds.parser.Statement.PSModeType;
import org.apache.sysds.parser.Statement.PSScheme;
+import org.apache.sysds.parser.Statement.FederatedPSScheme;
import org.apache.sysds.parser.Statement.PSUpdateType;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -86,6 +88,8 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
private static final int DEFAULT_BATCH_SIZE = 64;
private static final PSFrequency DEFAULT_UPDATE_FREQUENCY =
PSFrequency.EPOCH;
private static final PSScheme DEFAULT_SCHEME =
PSScheme.DISJOINT_CONTIGUOUS;
+ private static final PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING =
PSRuntimeBalancing.NONE;
+ private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME =
FederatedPSScheme.KEEP_DATA_ON_WORKER;
private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
@@ -96,7 +100,6 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
@Override
public void processInstruction(ExecutionContext ec) {
// check if the input is federated
-
if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
runFederated(ec);
@@ -124,15 +127,27 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// get inputs
PSFrequency freq = getFrequency();
PSUpdateType updateType = getUpdateType();
+ PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+ FederatedPSScheme federatedPSScheme = getFederatedScheme();
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
// partition federated data
- DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+ DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(federatedPSScheme)
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
- List<MatrixObject> pFeatures = result.pFeatures;
- List<MatrixObject> pLabels = result.pLabels;
- int workerNum = result.workerNum;
+ List<MatrixObject> pFeatures = result._pFeatures;
+ List<MatrixObject> pLabels = result._pLabels;
+ int workerNum = result._workerNum;
+
+ // calculate runtime balancing
+ int numBatchesPerEpoch = 0;
+ if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+ numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
+ numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+ numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
+ }
// setup threading
BasicThreadFactory factory = new BasicThreadFactory.Builder()
@@ -141,8 +156,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Get the compiled execution context
LocalVariableMap newVarsMap = createVarsMap(ec);
- // Level of par is 1 because one worker will be launched per
task
- // TODO: Fix recompilation
+ // Level of par is -1 so each federated worker can scale to its
cpu cores
ExecutionContext newEC =
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, -1,
true);
// Create workers' execution context
List<ExecutionContext> federatedWorkerECs =
ParamservUtils.copyExecutionContext(newEC, workerNum);
@@ -152,8 +166,9 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ListObject model = ec.getListObject(getParam(PS_MODEL));
ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, workerNum, model, aggServiceEC);
// Create the local workers
+ int finalNumBatchesPerEpoch = numBatchesPerEpoch;
List<FederatedPSControlThread> threads = IntStream.range(0,
workerNum)
- .mapToObj(i -> new FederatedPSControlThread(i,
updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps))
+ .mapToObj(i -> new FederatedPSControlThread(i,
updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
.collect(Collectors.toList());
if(workerNum != threads.size()) {
@@ -379,6 +394,18 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
}
+ private PSRuntimeBalancing getRuntimeBalancing() {
+ if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+ return DEFAULT_RUNTIME_BALANCING;
+ }
+ try {
+ return
PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+ } catch (IllegalArgumentException e) {
+ throw new DMLRuntimeException(String.format("Paramserv
function: "
+ + "not support '%s' runtime
balancing.", getParam(PS_RUNTIME_BALANCING)));
+ }
+ }
+
private static int getRemainingCores() {
return InfrastructureAnalyzer.getLocalParallelism();
}
@@ -469,4 +496,15 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
return scheme;
}
+ private FederatedPSScheme getFederatedScheme() {
+ FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME;
+ if (getParameterMap().containsKey(PS_SCHEME)) {
+ try {
+ federated_scheme =
FederatedPSScheme.valueOf(getParam(PS_SCHEME));
+ } catch (IllegalArgumentException e) {
+ throw new
DMLRuntimeException(String.format("Paramserv function in federated mode: not
support data partition scheme '%s'", getParam(PS_SCHEME)));
+ }
+ }
+ return federated_scheme;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
index 477379c..4de60b4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
@@ -61,7 +61,7 @@ public final class MatrixIndexingFEDInstruction extends
IndexingFEDInstruction {
//modify federated ranges in place
Map<FederatedRange, IndexRange> ixs = new HashMap<>();
- for(FederatedRange range : fedMap.getFedMapping().keySet()) {
+ for(FederatedRange range : fedMap.getMap().keySet()) {
long rs = range.getBeginDims()[0], re =
range.getEndDims()[0],
cs = range.getBeginDims()[1], ce =
range.getEndDims()[1];
long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart
- rs) : 0;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index da25122..a479be7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -104,7 +104,7 @@ public class VariableFEDInstruction extends FEDInstruction
implements LineageTra
MatrixObject out = ec.getMatrixObject(_in.getOutput());
FederationMap outMap =
mo1.getFedMapping().copyWithNewID(fr1.getID());
Map<FederatedRange, FederatedData> newMap = new HashMap<>();
- for(Map.Entry<FederatedRange, FederatedData> pair :
outMap.getFedMapping().entrySet()) {
+ for(Map.Entry<FederatedRange, FederatedData> pair :
outMap.getMap().entrySet()) {
FederatedData om = pair.getValue();
FederatedData nf = new
FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(),
om.getVarID());
@@ -131,7 +131,7 @@ public class VariableFEDInstruction extends FEDInstruction
implements LineageTra
out.getDataCharacteristics().set(mo1.getNumRows(),
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
FederationMap outMap =
mo1.getFedMapping().copyWithNewID(fr1.getID());
Map<FederatedRange, FederatedData> newMap = new HashMap<>();
- for(Map.Entry<FederatedRange, FederatedData> pair :
outMap.getFedMapping().entrySet()) {
+ for(Map.Entry<FederatedRange, FederatedData> pair :
outMap.getMap().entrySet()) {
FederatedData om = pair.getValue();
FederatedData nf = new
FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(),
om.getVarID());
diff --git
a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
index 0527d23..5252c5e 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
@@ -107,7 +107,7 @@ public class ReaderWriterFederated {
FileSystem fs = IOUtilFunctions.getFileSystem(path,
job);
DataOutputStream out = fs.create(path, true);
ObjectMapper mapper = new ObjectMapper();
- FederatedDataAddress[] outObjects =
parseMap(fedMap.getFedMapping());
+ FederatedDataAddress[] outObjects =
parseMap(fedMap.getMap());
try(BufferedWriter pw = new BufferedWriter(new
OutputStreamWriter(out))) {
mapper.writeValue(pw, outObjects);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 4cea827..9161410 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5326,21 +5326,15 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
if( resultBlock!=null )
resultBlock.recomputeNonZeros();
}
-
+
/**
- * D = ctable(seq,A,w)
- * this <- seq; thatMatrix <- A; thatScalar <- w; result
<- D
- *
- * (i1,j1,v1) from input1 (this)
- * (i1,j1,v2) from input2 (that)
- * (w) from scalar_input3 (scalarThat2)
- *
* @param thatMatrix matrix value
* @param thatScalar scalar double
* @param resultBlock result matrix block
+ * @param updateClen when this matrix already has the desired number of
columns updateClen can be set to false
* @return resultBlock
*/
- public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double
thatScalar, MatrixBlock resultBlock) {
+ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double
thatScalar, MatrixBlock resultBlock, boolean updateClen) {
MatrixBlock that = checkType(thatMatrix);
CTable ctable = CTable.getCTableFnObject();
double w = thatScalar;
@@ -5357,9 +5351,28 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
//update meta data (initially unknown number of columns)
//note: nnz maintained in ctable (via quickset)
- resultBlock.clen = maxCol;
+ if(updateClen) {
+ resultBlock.clen = maxCol;
+ }
return resultBlock;
}
+
+ /**
+ * D = ctable(seq,A,w)
+ * this <- seq; thatMatrix <- A; thatScalar <- w; result
<- D
+ *
+ * (i1,j1,v1) from input1 (this)
+ * (i1,j1,v2) from input2 (that)
+ * (w) from scalar_input3 (scalarThat2)
+ *
+ * @param thatMatrix matrix value
+ * @param thatScalar scalar double
+ * @param resultBlock result matrix block
+ * @return resultBlock
+ */
+ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double
thatScalar, MatrixBlock resultBlock) {
+ return ctableSeqOperations(thatMatrix, thatScalar, resultBlock,
true);
+ }
/**
* D = ctable(A,B,W)
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4f08e88..9af4fcb 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test;
+import static java.lang.Math.ceil;
import static java.lang.Thread.sleep;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -27,10 +28,12 @@ import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Properties;
@@ -43,6 +46,7 @@ import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession.Builder;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
@@ -52,13 +56,17 @@ import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.ParseException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FrameReader;
import org.apache.sysds.runtime.io.FrameReaderFactory;
@@ -67,6 +75,7 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
@@ -588,6 +597,71 @@ public abstract class AutomatedTestBase {
/**
* <p>
+ * Takes a matrix (double[][]) and writes it in parts locally. Then it
creates a federated MatrixObject
+ * containing the local paths and the given ports. This federated MO is
also written to disk with the provided name
+ * When running federated workers locally on the specified ports this
federated Matrix can then be used
+ * for testing purposes. Just use read on input(name)
+ * </p>
+ *
+ * @param name name of the matrix when writing to disk
+ * @param matrix two dimensional matrix
+ * @param numFederatedWorkers the number of federated workers
+ * @param ports a list of port the length of the number of federated
workers
+ * @param ranges an array containing arrays of length to with the upper
and lower bound (rows) for the slices
+ */
+ protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name,
+ double[][] matrix, int numFederatedWorkers, List<Integer>
ports, double[][] ranges)
+ {
+ // check matrix non empty
+ if(matrix.length == 0 || matrix[0].length == 0)
+ return;
+
+ int nrows = matrix.length;
+ int ncol = matrix[0].length;
+
+ // create federated MatrixObject
+ MatrixObject federatedMatrixObject = new
MatrixObject(ValueType.FP64,
+ Dag.getNextUniqueVarname(Types.DataType.MATRIX));
+ federatedMatrixObject.setMetaData(new MetaDataFormat(
+ new MatrixCharacteristics(nrows, ncol),
Types.FileFormat.BINARY));
+
+ // write parts and generate FederationMap
+ HashMap<FederatedRange, FederatedData> fedHashMap = new
HashMap<>();
+ for(int i = 0; i < numFederatedWorkers; i++) {
+ double lowerBound = ranges[i][0];
+ double upperBound = ranges[i][1];
+ double examplesForWorkerI = upperBound - lowerBound;
+ String path = name + "_" + (i + 1);
+
+ // write slice
+ writeInputMatrixWithMTD(path,
Arrays.copyOfRange(matrix, (int)lowerBound, (int)upperBound),
+ false, new MatrixCharacteristics((long)
examplesForWorkerI, ncol,
+ OptimizerUtils.DEFAULT_BLOCKSIZE, (long)
examplesForWorkerI * ncol));
+
+ // generate fedmap entry
+ FederatedRange range = new FederatedRange(new
long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol});
+ FederatedData data = new FederatedData(DataType.MATRIX,
new InetSocketAddress(ports.get(i)), input(path));
+ fedHashMap.put(range, data);
+ }
+
+ federatedMatrixObject.setFedMapping(new
FederationMap(FederationUtils.getNextFedDataID(), fedHashMap));
+
federatedMatrixObject.getFedMapping().setType(FederationMap.FType.ROW);
+
+ writeInputFederatedWithMTD(name, federatedMatrixObject, null);
+ }
+
+ protected double[][] generateBalancedFederatedRowRanges(int
numFederatedWorkers, int dataSetSize) {
+ double[][] ranges = new double[numFederatedWorkers][2];
+ double examplesPerWorker = ceil( (double) dataSetSize /
(double) numFederatedWorkers);
+ for(int i = 0; i < numFederatedWorkers; i++) {
+ ranges[i][0] = examplesPerWorker * i;
+ ranges[i][1] = Math.min(examplesPerWorker * (i + 1),
dataSetSize);
+ }
+ return ranges;
+ }
+
+ /**
+ * <p>
* Adds a matrix to the input path and writes it to a file.
* </p>
*
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index cc0af07..6a52fc4 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -27,7 +27,6 @@ import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -44,48 +43,66 @@ public class FederatedParamservTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/paramserv/";
private final static String TEST_NAME = "FederatedParamservTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedParamservTest.class.getSimpleName() + "/";
- private final static int _blocksize = 1024;
private final String _networkType;
private final int _numFederatedWorkers;
- private final int _examplesPerWorker;
+ private final int _dataSetSize;
private final int _epochs;
private final int _batch_size;
private final double _eta;
private final String _utype;
private final String _freq;
+ private final String _scheme;
+ private final String _runtime_balancing;
+ private final String _data_distribution;
// parameters
@Parameterized.Parameters
public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[][] {
- // Network type, number of federated workers, examples
per worker, batch size, epochs, learning rate, update
- // type, update frequency
- {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"TwoNN",
2, 2, 1, 5, 0.01, "ASP", "BATCH"},
- {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"TwoNN",
2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
- {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"CNN", 2,
2, 1, 5, 0.01, "ASP", "BATCH"},
- {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"CNN", 2,
2, 1, 5, 0.01, "ASP", "EPOCH"},
- {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"},
- // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"},
- // {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"},
- // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"},
- // {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"},
- // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"},
- {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"},
- // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"}
+ // Network type, number of federated workers, data set
size, batch size, epochs, learning rate, update type, update frequency
+
+ // basic functionality
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
+ {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH",
"SHUFFLE", "NONE" , "IMBALANCED"},
+ {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH",
"REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH",
"BALANCE_TO_AVG", "CYCLE_MAX", "IMBALANCED"},
+ {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"},
+
+ /*
+ // runtime balancing
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
+
+ // data partitioning
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"},
+
+ // balanced tests
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}
+ */
});
}
- public FederatedParamservTest(String networkType, int
numFederatedWorkers, int examplesPerWorker, int batch_size,
- int epochs, double eta, String utype, String freq) {
+ public FederatedParamservTest(String networkType, int
numFederatedWorkers, int dataSetSize, int batch_size,
+ int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String data_distribution) {
_networkType = networkType;
_numFederatedWorkers = numFederatedWorkers;
- _examplesPerWorker = examplesPerWorker;
+ _dataSetSize = dataSetSize;
_batch_size = batch_size;
_epochs = epochs;
_eta = eta;
_utype = utype;
_freq = freq;
+ _scheme = scheme;
+ _runtime_balancing = runtime_balancing;
+ _data_distribution = data_distribution;
}
@Override
@@ -111,68 +128,74 @@ public class FederatedParamservTest extends
AutomatedTestBase {
setOutputBuffering(true);
int C = 1, Hin = 28, Win = 28;
- int numFeatures = C * Hin * Win;
int numLabels = 10;
ExecMode platformOld = setExecMode(mode);
try {
-
- // dml name
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- // generate program args
- List<String> programArgsList = new
ArrayList<>(Arrays.asList("-stats",
- "-nvargs",
- "examples_per_worker=" + _examplesPerWorker,
- "num_features=" + numFeatures,
- "num_labels=" + numLabels,
- "epochs=" + _epochs,
- "batch_size=" + _batch_size,
- "eta=" + _eta,
- "utype=" + _utype,
- "freq=" + _freq,
- "network_type=" + _networkType,
- "channels=" + C,
- "hin=" + Hin,
- "win=" + Win));
-
- // for each worker
+ // start threads
List<Integer> ports = new ArrayList<>();
List<Thread> threads = new ArrayList<>();
for(int i = 0; i < _numFederatedWorkers; i++) {
- // write row partitioned features to disk
- writeInputMatrixWithMTD("X" + i,
-
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win),
- false,
- new
MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize,
- _examplesPerWorker *
numFeatures));
- // write row partitioned labels to disk
- writeInputMatrixWithMTD("y" + i,
-
generateDummyMNISTLabels(_examplesPerWorker, numLabels),
- false,
- new
MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize,
- _examplesPerWorker *
numLabels));
-
- // start worker
ports.add(getRandomAvailablePort());
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+ }
- // add worker to program args
- programArgsList.add("X" + i + "=" +
TestUtils.federatedAddress(ports.get(i), input("X" + i)));
- programArgsList.add("y" + i + "=" +
TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+ // generate test data
+ double[][] features =
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+ double[][] labels =
generateDummyMNISTLabels(_dataSetSize, numLabels);
+ String featuresName = "";
+ String labelsName = "";
+
+ // federate test data balanced or imbalanced
+ if(_data_distribution.equals("IMBALANCED")) {
+ featuresName = "X_IMBALANCED_" +
_numFederatedWorkers;
+ labelsName = "y_IMBALANCED_" +
_numFederatedWorkers;
+ double[][] ranges = {{0,1}, {1,4}};
+
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features,
_numFederatedWorkers, ports, ranges);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
+ }
+ else {
+ featuresName = "X_BALANCED_" +
_numFederatedWorkers;
+ labelsName = "y_BALANCED_" +
_numFederatedWorkers;
+ double[][] ranges =
generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features,
_numFederatedWorkers, ports, ranges);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
}
+
try {
- Thread.sleep(1000);
+ //wait for all workers to be setup
+ Thread.sleep(FED_WORKER_WAIT);
}
catch(InterruptedException e) {
e.printStackTrace();
}
-
+
+ // dml name
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ // generate program args
+ List<String> programArgsList = new
ArrayList<>(Arrays.asList("-stats",
+ "-nvargs",
+ "features=" + input(featuresName),
+ "labels=" + input(labelsName),
+ "epochs=" + _epochs,
+ "batch_size=" + _batch_size,
+ "eta=" + _eta,
+ "utype=" + _utype,
+ "freq=" + _freq,
+ "scheme=" + _scheme,
+ "runtime_balancing=" +
_runtime_balancing,
+ "network_type=" + _networkType,
+ "channels=" + C,
+ "hin=" + Hin,
+ "win=" + Win,
+ "seed=" + 25));
+
programArgs = programArgsList.toArray(new String[0]);
LOG.debug(runTest(null));
Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
- // cleanup
+ // shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
TestUtils.shutdownThreads(threads.get(i));
}
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index d622c13..69c7e76 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -67,8 +67,10 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
*/
train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int C, int Hin, int Win, int epochs, int batch_size, double
learning_rate)
- return (list[unknown] model_trained) {
+ int epochs, int batch_size, double eta,
+ int C, int Hin, int Win,
+ int seed = -1)
+ return (list[unknown] model) {
N = nrow(X)
K = ncol(y)
@@ -84,74 +86,45 @@ train = function(matrix[double] X, matrix[double] y,
N3 = 512 # num nodes in affine3
# Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
- [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win)
- [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N,
F1*(Hin/2)*(Win/2))
- [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N,
F2*(Hin/2/2)*(Win/2/2))
- [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3)
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs:
(N, F2*(Hin/2/2)*(Win/2/2))
+ lseed = ifelse(seed==-1, -1, seed + 3);
+ [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3)
W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
# Initialize SGD w/ Nesterov momentum optimizer
- learning_rate = learning_rate # learning rate
mu = 0.9 # momentum
decay = 0.95 # learning rate decay constant
vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2,
vb3, vb4)
+
# Regularization
lambda = 5e-04
# Create the hyper parameter list
- hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C,
Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1,
F2=F2, N3=N3)
+ hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin,
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2,
N3=N3)
# Calculate iterations
iters = ceil(N / batch_size)
- print_interval = floor(iters / 25)
-
- print("[+] Starting optimization")
- print("[+] Learning rate: " + learning_rate)
- print("[+] Batch size: " + batch_size)
- print("[+] Iterations per epoch: " + iters + "\n")
for (e in 1:epochs) {
- print("[+] Starting epoch: " + e)
- print("|")
for(i in 1:iters) {
- # Create the model list
- model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
-
# Get next batch
beg = ((i-1) * batch_size) %% N + 1
end = min(N, beg + batch_size - 1)
X_batch = X[beg:end,]
y_batch = y[beg:end,]
- gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
- model_updated = aggregation(model_list, hyperparams, gradients_list)
-
- W1 = as.matrix(model_updated[1])
- W2 = as.matrix(model_updated[2])
- W3 = as.matrix(model_updated[3])
- W4 = as.matrix(model_updated[4])
- b1 = as.matrix(model_updated[5])
- b2 = as.matrix(model_updated[6])
- b3 = as.matrix(model_updated[7])
- b4 = as.matrix(model_updated[8])
- vW1 = as.matrix(model_updated[9])
- vW2 = as.matrix(model_updated[10])
- vW3 = as.matrix(model_updated[11])
- vW4 = as.matrix(model_updated[12])
- vb1 = as.matrix(model_updated[13])
- vb2 = as.matrix(model_updated[14])
- vb3 = as.matrix(model_updated[15])
- vb4 = as.matrix(model_updated[16])
- if((i %% print_interval) == 0) {
- print("█")
- }
+ gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+ model = aggregation(model, hyperparams, gradients_list)
}
- print("|")
}
-
- model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
}
/*
@@ -190,9 +163,10 @@ train = function(matrix[double] X, matrix[double] y,
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int C, int Hin, int Win, int epochs, int workers,
- string utype, string freq, int batch_size, string scheme,
string mode, double learning_rate)
- return (list[unknown] model_trained) {
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing,
+ double eta, int C, int Hin, int Win,
+ int seed = -1)
+ return (list[unknown] model) {
N = nrow(X)
K = ncol(y)
@@ -208,14 +182,17 @@ train_paramserv = function(matrix[double] X,
matrix[double] y,
N3 = 512 # num nodes in affine3
# Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
- [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win)
- [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N,
F1*(Hin/2)*(Win/2))
- [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N,
F2*(Hin/2/2)*(Win/2/2))
- [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3)
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs:
(N, F2*(Hin/2/2)*(Win/2/2))
+ lseed = ifelse(seed==-1, -1, seed + 3);
+ [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3)
W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
# Initialize SGD w/ Nesterov momentum optimizer
- learning_rate = learning_rate # learning rate
+ learning_rate = eta # learning rate
mu = 0.9 # momentum
decay = 0.95 # learning rate decay constant
vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
@@ -225,12 +202,16 @@ train_paramserv = function(matrix[double] X,
matrix[double] y,
# Regularization
lambda = 5e-04
# Create the model list
- model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1,
vb2, vb3, vb4)
+ model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2,
vb3, vb4)
# Create the hyper parameter list
- params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin,
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2,
N3=N3)
+ hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin,
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2,
N3=N3)
# Use paramserv function
- model_trained = paramserv(model=model_list, features=X, labels=y,
val_features=X_val, val_labels=y_val,
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+ model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
+ upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
+
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
+ k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+ scheme=scheme, runtime_balancing=runtime_balancing,
hyperparams=hyperparams)
}
/*
diff --git
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 16c72c4..10d2cc7 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -23,35 +23,13 @@
source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
# create federated input matrices
-features = federated(addresses=list($X0, $X1),
- ranges=list(list(0, 0), list($examples_per_worker, $num_features),
- list($examples_per_worker, 0), list($examples_per_worker * 2,
$num_features)))
+features = read($features)
+labels = read($labels)
-labels = federated(addresses=list($y0, $y1),
- ranges=list(list(0, 0), list($examples_per_worker, $num_labels),
- list($examples_per_worker, 0), list($examples_per_worker * 2,
$num_labels)))
-
-epochs = $epochs
-batch_size = $batch_size
-learning_rate = $eta
-utype = $utype
-freq = $freq
-network_type = $network_type
-
-# currently ignored parameters
-workers = 1
-scheme = "DISJOINT_CONTIGUOUS"
-paramserv_mode = "LOCAL"
-
-# config for the cnn
-channels = $channels
-hin = $hin
-win = $win
-
-if(network_type == "TwoNN") {
- model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme,
paramserv_mode, learning_rate)
+if($network_type == "TwoNN") {
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $eta, $seed)
}
else {
- model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq,
batch_size, scheme, paramserv_mode, learning_rate)
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $eta, $channels, $hin, $win, $seed)
}
print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 31e889a..9bd49d8 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -57,8 +57,9 @@ source("nn/optim/sgd.dml") as sgd
*/
train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int epochs, int batch_size, double learning_rate)
- return (list[unknown] model_trained) {
+ int epochs, int batch_size, double eta,
+ int seed = -1)
+ return (list[unknown] model) {
N = nrow(X) # num examples
D = ncol(X) # num features
@@ -66,53 +67,31 @@ train = function(matrix[double] X, matrix[double] y,
# Create the network:
## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
- [W1, b1] = affine::init(D, 200, -1)
- [W2, b2] = affine::init(200, 200, -1)
- [W3, b3] = affine::init(200, K, -1)
+ [W1, b1] = affine::init(D, 200, seed = seed)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = affine::init(200, 200, seed = lseed)
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(200, K, seed = lseed)
W3 = W3 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+ model = list(W1, W2, W3, b1, b2, b3)
# Create the hyper parameter list
- hyperparams = list(learning_rate=learning_rate)
+ hyperparams = list(learning_rate=eta)
# Calculate iterations
iters = ceil(N / batch_size)
- print_interval = floor(iters / 25)
-
- print("[+] Starting optimization")
- print("[+] Learning rate: " + learning_rate)
- print("[+] Batch size: " + batch_size)
- print("[+] Iterations per epoch: " + iters + "\n")
for (e in 1:epochs) {
- print("[+] Starting epoch: " + e)
- print("|")
for(i in 1:iters) {
- # Create the model list
- model_list = list(W1, W2, W3, b1, b2, b3)
-
# Get next batch
beg = ((i-1) * batch_size) %% N + 1
end = min(N, beg + batch_size - 1)
X_batch = X[beg:end,]
y_batch = y[beg:end,]
- gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
- model_updated = aggregation(model_list, hyperparams, gradients_list)
-
- W1 = as.matrix(model_updated[1])
- W2 = as.matrix(model_updated[2])
- W3 = as.matrix(model_updated[3])
- b1 = as.matrix(model_updated[4])
- b2 = as.matrix(model_updated[5])
- b3 = as.matrix(model_updated[6])
-
- if((i %% print_interval) == 0) {
- print("█")
- }
+ gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+ model = aggregation(model, hyperparams, gradients_list)
}
- print("|")
}
-
- model_trained = list(W1, W2, W3, b1, b2, b3)
}
/*
@@ -146,9 +125,9 @@ train = function(matrix[double] X, matrix[double] y,
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int epochs, int workers,
- string utype, string freq, int batch_size, string scheme,
string mode, double learning_rate)
- return (list[unknown] model_trained) {
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing,
+ double eta, int seed = -1)
+ return (list[unknown] model) {
N = nrow(X) # num examples
D = ncol(X) # num features
@@ -156,16 +135,27 @@ train_paramserv = function(matrix[double] X,
matrix[double] y,
# Create the network:
## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
- [W1, b1] = affine::init(D, 200, -1)
- [W2, b2] = affine::init(200, 200, -1)
- [W3, b3] = affine::init(200, K, -1)
+ [W1, b1] = affine::init(D, 200, seed = seed)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = affine::init(200, 200, seed = lseed)
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(200, K, seed = lseed)
+ # W3 = W3 / sqrt(2) # different initialization, since being fed into
softmax, instead of relu
+
+ # [W1, b1] = affine::init(D, 200)
+ # [W2, b2] = affine::init(200, 200)
+ # [W3, b3] = affine::init(200, K)
# Create the model list
- model_list = list(W1, W2, W3, b1, b2, b3)
+ model = list(W1, W2, W3, b1, b2, b3)
# Create the hyper parameter list
- params = list(learning_rate=learning_rate)
+ hyperparams = list(learning_rate=eta)
# Use paramserv function
- model_trained = paramserv(model=model_list, features=X, labels=y,
val_features=X_val, val_labels=y_val,
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+ model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
+
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
+
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
+ k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+ scheme=scheme, runtime_balancing=runtime_balancing,
hyperparams=hyperparams)
}
/*