This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5dec562  [SYSTEMDS-2550] Federated paramserv balancing and data 
partitioning
5dec562 is described below

commit 5dec5627398a7eb58facb7be2be0973a2428b345
Author: Tobias Rieger <[email protected]>
AuthorDate: Sun Dec 20 19:44:18 2020 +0100

    [SYSTEMDS-2550] Federated paramserv balancing and data partitioning
    
    Closes #1131.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   7 +-
 .../java/org/apache/sysds/parser/Statement.java    |   8 +-
 .../sysds/runtime/controlprogram/ProgramBlock.java |   2 +-
 .../controlprogram/federated/FederationMap.java    |   2 +-
 .../paramserv/FederatedPSControlThread.java        | 342 +++++++++------------
 .../controlprogram/paramserv/ParamservUtils.java   |  33 +-
 .../paramserv/dp/BalanceToAvgFederatedScheme.java  | 106 +++++++
 .../paramserv/dp/DataPartitionFederatedScheme.java |  98 +++++-
 .../paramserv/dp/FederatedDataPartitioner.java     |   9 +
 .../dp/KeepDataOnWorkerFederatedScheme.java        |   2 +-
 .../dp/ReplicateToMaxFederatedScheme.java          | 103 +++++++
 .../paramserv/dp/ShuffleFederatedScheme.java       |  56 +++-
 .../dp/SubsampleToMinFederatedScheme.java          | 102 ++++++
 .../cp/ParamservBuiltinCPInstruction.java          |  56 +++-
 .../fed/MatrixIndexingFEDInstruction.java          |   2 +-
 .../instructions/fed/VariableFEDInstruction.java   |   4 +-
 .../sysds/runtime/io/ReaderWriterFederated.java    |   2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  33 +-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  74 +++++
 .../paramserv/FederatedParamservTest.java          | 145 +++++----
 .../scripts/functions/federated/paramserv/CNN.dml  |  91 +++---
 .../federated/paramserv/FederatedParamservTest.dml |  32 +-
 .../functions/federated/paramserv/TwoNN.dml        |  74 ++---
 23 files changed, 965 insertions(+), 418 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 2e33676..5171f21 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -286,7 +286,11 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        raiseValidateError("Should provide more arguments for 
function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS);
                }
                //check for invalid parameters
-               Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, 
Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, 
Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, 
Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, 
Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+               Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, 
Statement.PS_FEATURES, Statement.PS_LABELS,
+                       Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, 
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
+                       Statement.PS_MODE, Statement.PS_UPDATE_TYPE, 
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
+                       Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
+                       Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
                // check existence and correctness of parameters
@@ -304,6 +308,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, 
DataType.SCALAR, ValueType.INT64, conditional);
                checkDataValueType(true, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT64, conditional);
                checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
+               checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, 
conditional);
                checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, 
DataType.LIST, ValueType.UNKNOWN, conditional);
                checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
conditional);
 
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index b61b0d6..6767d85 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -87,6 +87,10 @@ public abstract class Statement implements ParseInfo
        public enum PSFrequency {
                BATCH, EPOCH
        }
+       public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+       public enum PSRuntimeBalancing {
+               NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, 
SCALE_BATCH_AND_WEIGH
+       }
        public static final String PS_EPOCHS = "epochs";
        public static final String PS_BATCH_SIZE = "batchsize";
        public static final String PS_PARALLELISM = "k";
@@ -95,7 +99,7 @@ public abstract class Statement implements ParseInfo
                DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, 
OVERLAP_RESHUFFLE
        }
        public enum FederatedPSScheme {
-               KEEP_DATA_ON_WORKER, SHUFFLE
+               KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX, 
SUBSAMPLE_TO_MIN, BALANCE_TO_AVG
        }
        public static final String PS_HYPER_PARAMS = "hyperparams";
        public static final String PS_CHECKPOINTING = "checkpointing";
@@ -107,7 +111,7 @@ public abstract class Statement implements ParseInfo
        // prefixed with code: "1701-NCC-" to not overwrite anything
        public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
        public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
-       public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
+       public static final String PS_FED_POSS_BATCHES_LOCAL = 
"1701-NCC-poss_batches_local";
        public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
        public static final String PS_FED_GRADIENTS_FNAME = 
"1701-NCC-gradients_fname";
        public static final String PS_FED_AGGREGATION_FNAME = 
"1701-NCC-aggregation_fname";
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 263ecf4..a555a5d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -376,7 +376,7 @@ public abstract class ProgramBlock implements ParseInfo
                        
                        CacheableData<?> mo = (CacheableData<?>)dat;
                        if( mo.isFederated() ) {
-                               if( 
mo.getFedMapping().getFedMapping().isEmpty() )
+                               if( mo.getFedMapping().getMap().isEmpty() )
                                        throw new DMLRuntimeException("Invalid 
empty FederationMap for: "+mo);
                        }
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 9590510..482ade7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -114,7 +114,7 @@ public class FederationMap {
                return _fedMap.keySet().toArray(new FederatedRange[0]);
        }
 
-       public Map<FederatedRange, FederatedData> getFedMapping() {
+       public Map<FederatedRange, FederatedData> getMap() {
                return _fedMap;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 80418ac..393b131 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -29,6 +32,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
@@ -53,17 +57,24 @@ import static 
org.apache.sysds.runtime.util.ProgramConverter.*;
 
 public class FederatedPSControlThread extends PSWorker implements 
Callable<Void> {
        private static final long serialVersionUID = 6846648059569648791L;
+       protected static final Log LOG = 
LogFactory.getLog(ParamServer.class.getName());
+       
+       Statement.PSRuntimeBalancing _runtimeBalancing;
        FederatedData _featuresData;
        FederatedData _labelsData;
-       final long _batchCounterVarID;
+       final long _localStartBatchNumVarID;
        final long _modelVarID;
-       int _totalNumBatches;
+       int _numBatchesPerGlobalEpoch;
+       int _possibleBatchesPerLocalEpoch;
+       boolean _cycleStartAt0 = false;
 
-       public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, 
ParamServer ps) {
+       public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int 
epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, 
ParamServer ps) {
                super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
-               
+
+               _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+               _runtimeBalancing = runtimeBalancing;
                // generate the IDs for model and batch counter. These get 
overwritten on the federated worker each time
-               _batchCounterVarID = FederationUtils.getNextFedDataID();
+               _localStartBatchNumVarID = FederationUtils.getNextFedDataID();
                _modelVarID = FederationUtils.getNextFedDataID();
        }
 
@@ -72,18 +83,22 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
         */
        public void setup() {
                // prepare features and labels
-               _features.getFedMapping().forEachParallel((range, data) -> {
-                       _featuresData = data;
-                       return null;
-               });
-               _labels.getFedMapping().forEachParallel((range, data) -> {
-                       _labelsData = data;
-                       return null;
-               });
+               _featuresData = (FederatedData) 
_features.getFedMapping().getMap().values().toArray()[0];
+               _labelsData = (FederatedData) 
_labels.getFedMapping().getMap().values().toArray()[0];
 
                // calculate number of batches and get data size
                long dataSize = _features.getNumRows();
-               _totalNumBatches = (int) Math.ceil((double) dataSize / 
_batchSize);
+               _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) 
dataSize / _batchSize);
+               if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN 
+                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.CYCLE_AVG 
+                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.CYCLE_MAX)) {
+                       _numBatchesPerGlobalEpoch = 
_possibleBatchesPerLocalEpoch;
+               }
+
+               if(_runtimeBalancing == 
Statement.PSRuntimeBalancing.SCALE_BATCH 
+                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
+                       throw new NotImplementedException();
+               }
 
                // serialize program
                // create program blocks for the instruction filtering
@@ -112,17 +127,17 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                programSerialized = sb.toString();
 
                // write program and meta data to worker
-               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                               _featuresData.getVarID(),
-                               new setupFederatedWorker(_batchSize,
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
+                       new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
+                               new SetupFederatedWorker(_batchSize,
                                                dataSize,
-                                               _totalNumBatches,
+                                               _possibleBatchesPerLocalEpoch,
                                                programSerialized,
                                                _inst.getNamespace(),
                                                _inst.getFunctionName(),
                                                
_ps.getAggInst().getFunctionName(),
                                                
_ec.getListObject("hyperparams"),
-                                               _batchCounterVarID,
+                                               _localStartBatchNumVarID,
                                                _modelVarID
                                )
                ));
@@ -140,24 +155,27 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        /**
         * Setup UDF executed on the federated worker
         */
-       private static class setupFederatedWorker extends FederatedUDF {
+       private static class SetupFederatedWorker extends FederatedUDF {
                private static final long serialVersionUID = 
-3148991224792675607L;
-               long _batchSize;
-               long _dataSize;
-               long _numBatches;
-               String _programString;
-               String _namespace;
-               String _gradientsFunctionName;
-               String _aggregationFunctionName;
-               ListObject _hyperParams;
-               long _batchCounterVarID;
-               long _modelVarID;
-
-               protected setupFederatedWorker(long batchSize, long dataSize, 
long numBatches, String programString, String namespace, String 
gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, 
long batchCounterVarID, long modelVarID) {
+               private final long _batchSize;
+               private final long _dataSize;
+               private final int _possibleBatchesPerLocalEpoch;
+               private final String _programString;
+               private final String _namespace;
+               private final String _gradientsFunctionName;
+               private final String _aggregationFunctionName;
+               private final ListObject _hyperParams;
+               private final long _batchCounterVarID;
+               private final long _modelVarID;
+
+               protected SetupFederatedWorker(long batchSize, long dataSize, 
int possibleBatchesPerLocalEpoch,
+                       String programString, String namespace, String 
gradientsFunctionName, String aggregationFunctionName,
+                       ListObject hyperParams, long batchCounterVarID, long 
modelVarID)
+               {
                        super(new long[]{});
                        _batchSize = batchSize;
                        _dataSize = dataSize;
-                       _numBatches = numBatches;
+                       _possibleBatchesPerLocalEpoch = 
possibleBatchesPerLocalEpoch;
                        _programString = programString;
                        _namespace = namespace;
                        _gradientsFunctionName = gradientsFunctionName;
@@ -175,7 +193,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // set variables to ec
                        ec.setVariable(Statement.PS_FED_BATCH_SIZE, new 
IntObject(_batchSize));
                        ec.setVariable(Statement.PS_FED_DATA_SIZE, new 
IntObject(_dataSize));
-                       ec.setVariable(Statement.PS_FED_NUM_BATCHES, new 
IntObject(_numBatches));
+                       ec.setVariable(Statement.PS_FED_POSS_BATCHES_LOCAL, new 
IntObject(_possibleBatchesPerLocalEpoch));
                        ec.setVariable(Statement.PS_FED_NAMESPACE, new 
StringObject(_namespace));
                        ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new 
StringObject(_gradientsFunctionName));
                        ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new 
StringObject(_aggregationFunctionName));
@@ -192,9 +210,9 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
         */
        public void teardown() {
                // write program and meta data to worker
-               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                               _featuresData.getVarID(),
-                               new teardownFederatedWorker()
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
+                       new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
+                       new TeardownFederatedWorker()
                ));
 
                try {
@@ -210,10 +228,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        /**
         * Teardown UDF executed on the federated worker
         */
-       private static class teardownFederatedWorker extends FederatedUDF {
+       private static class TeardownFederatedWorker extends FederatedUDF {
                private static final long serialVersionUID = 
-153650281873318969L;
 
-               protected teardownFederatedWorker() {
+               protected TeardownFederatedWorker() {
                        super(new long[]{});
                }
 
@@ -222,14 +240,13 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // remove variables from ec
                        ec.removeVariable(Statement.PS_FED_BATCH_SIZE);
                        ec.removeVariable(Statement.PS_FED_DATA_SIZE);
-                       ec.removeVariable(Statement.PS_FED_NUM_BATCHES);
+                       ec.removeVariable(Statement.PS_FED_POSS_BATCHES_LOCAL);
                        ec.removeVariable(Statement.PS_FED_NAMESPACE);
                        ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
                        ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
                        ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
                        ec.removeVariable(Statement.PS_FED_MODEL_VARID);
                        ParamservUtils.cleanupListObject(ec, 
Statement.PS_HYPER_PARAMS);
-                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
                        
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
@@ -246,10 +263,13 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                try {
                        switch (_freq) {
                                case BATCH:
-                                       computeBatch(_totalNumBatches);
+                                       computeWithBatchUpdates();
                                        break;
+                               /*case NBATCH:
+                                       computeWithNBatchUpdates();
+                                       break; */
                                case EPOCH:
-                                       computeEpoch();
+                                       computeWithEpochUpdates();
                                        break;
                                default:
                                        throw new 
DMLRuntimeException(String.format("%s not support update frequency %s", 
getWorkerName(), _freq));
@@ -271,154 +291,82 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                _ps.push(_workerID, gradients);
        }
 
+       static protected int getNextLocalBatchNum(int currentLocalBatchNumber, 
int possibleBatchesPerLocalEpoch) {
+               return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
+       }
+
        /**
-        * Computes all epochs and synchronizes after each batch
-        *
-        * @param numBatches the number of batches per epoch
+        * Computes all epochs and updates after each batch
         */
-       protected void computeBatch(int numBatches) {
+       protected void computeWithBatchUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
-                       for (int batchCounter = 0; batchCounter < numBatches; 
batchCounter++) {
+                       int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : 
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+
+                       for (int batchCounter = 0; batchCounter < 
_numBatchesPerGlobalEpoch; batchCounter++) {
+                               int localStartBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
                                ListObject model = pullModel();
-                               ListObject gradients = 
computeBatchGradients(model, batchCounter);
+                               ListObject gradients = 
computeGradientsForNBatches(model, 1, localStartBatchNum);
                                pushGradients(gradients);
                                ParamservUtils.cleanupListObject(model);
                                ParamservUtils.cleanupListObject(gradients);
                        }
-                       System.out.println("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+                       if( LOG.isInfoEnabled() )
+                               LOG.info("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
                }
        }
 
        /**
-        * Computes a single specified batch on the federated worker
-        *
-        * @param model the current model from the parameter server
-        * @param batchCounter the current batch number needed for slicing the 
features and labels
-        * @return the gradient vector
+        * Computes all epochs and updates after N batches
         */
-       protected ListObject computeBatchGradients(ListObject model, int 
batchCounter) {
-               // put batch counter on federated worker
-               Future<FederatedResponse> putBatchCounterResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new 
IntObject(batchCounter)));
-
-               // put current model on federated worker
-               Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
-
-               try {
-                       if(!putParamsResponse.get().isSuccessful() || 
!putBatchCounterResponse.get().isSuccessful())
-                               throw new 
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
-               }
-               catch(Exception e) {
-                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute put" + e.getMessage());
-               }
-
-               // create and execute the udf on the remote worker
-               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                               _featuresData.getVarID(),
-                               new federatedComputeBatchGradients(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID, 
_modelVarID})
-               ));
-
-               try {
-                       Object[] responseData = udfResponse.get().getData();
-                       return (ListObject) responseData[0];
-               }
-               catch(Exception e) {
-                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute UDF" + e.getMessage());
-               }
+       protected void computeWithNBatchUpdates() {
+               throw new NotImplementedException();
        }
 
        /**
-        * This is the code that will be executed on the federated Worker when 
computing a single batch
+        * Computes all epochs and updates after each epoch
         */
-       private static class federatedComputeBatchGradients extends 
FederatedUDF {
-               private static final long serialVersionUID = 
-3652112393963053475L;
-
-               protected federatedComputeBatchGradients(long[] inIDs) {
-                       super(inIDs);
-               }
-
-               @Override
-               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
-                       // read in data by varid
-                       MatrixObject features = (MatrixObject) data[0];
-                       MatrixObject labels = (MatrixObject) data[1];
-                       long batchCounter = ((IntObject) 
data[2]).getLongValue();
-                       ListObject model = (ListObject) data[3];
-
-                       // get data from execution context
-                       long batchSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
-                       long dataSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
-                       String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
-                       String gradientsFunctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
-
-                       // slice batch from feature and label matrix
-                       long begin = batchCounter * batchSize + 1;
-                       long end = Math.min((batchCounter + 1) * batchSize, 
dataSize);
-                       MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(features, begin, end);
-                       MatrixObject bLabels = 
ParamservUtils.sliceMatrix(labels, begin, end);
-
-                       // prepare execution context
-                       ec.setVariable(Statement.PS_MODEL, model);
-                       ec.setVariable(Statement.PS_FEATURES, bFeatures);
-                       ec.setVariable(Statement.PS_LABELS, bLabels);
-
-                       // recreate gradient instruction and output
-                       FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, 
false);
-                       ArrayList<DataIdentifier> inputs = 
func.getInputParams();
-                       ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
-                       CPOperand[] boundInputs = inputs.stream()
-                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-                                       .toArray(CPOperand[]::new);
-                       ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
-                                       
.collect(Collectors.toCollection(ArrayList::new));
-                       Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
-                                       func.getInputParamNames(), outputNames, 
"gradient function");
-                       DataIdentifier gradientsOutput = outputs.get(0);
-
-                       // calculate and gradients
-                       gradientsInstruction.processInstruction(ec);
-                       ListObject gradients = 
ec.getListObject(gradientsOutput.getName());
-
-                       // clean up sliced batch
-                       
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
-                       ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
-                       ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
-
-                       // model clean up - doing this twice is not an issue
-                       ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
-                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
-
-                       // return
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients);
-               }
-       }
-
-       /**
-        * Computes all epochs and synchronizes after each one
-        */
-       protected void computeEpoch() {
+       protected void computeWithEpochUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
+                       int localStartBatchNum = (_cycleStartAt0) ? 0 : 
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+
                        // Pull the global parameters from ps
                        ListObject model = pullModel();
-                       ListObject gradients = computeEpochGradients(model);
+                       ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, 
localStartBatchNum, true);
                        pushGradients(gradients);
-                       System.out.println("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+                       
+                       if( LOG.isInfoEnabled() )
+                               LOG.info("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
                        ParamservUtils.cleanupListObject(model);
                        ParamservUtils.cleanupListObject(gradients);
                }
        }
 
+       protected ListObject computeGradientsForNBatches(ListObject model, int 
numBatchesToCompute, int localStartBatchNum) {
+               return computeGradientsForNBatches(model, numBatchesToCompute, 
localStartBatchNum, false);
+       }
+
        /**
-        * Computes one epoch on the federated worker and updates the model 
local
+        * Computes the gradients of n batches on the federated worker and is 
able to update the model local.
+        * Returns the gradients.
         *
         * @param model the current model from the parameter server
+        * @param localStartBatchNum the batch to start from
+        * @param localUpdate whether to update the model locally
+        *
         * @return the gradient vector
         */
-       protected ListObject computeEpochGradients(ListObject model) {
+       protected ListObject computeGradientsForNBatches(ListObject model,
+               int numBatchesToCompute, int localStartBatchNum, boolean 
localUpdate)
+       {
+               // put local start batch num on federated worker
+               Future<FederatedResponse> putBatchCounterResponse = 
_featuresData.executeFederatedOperation(
+                       new FederatedRequest(RequestType.PUT_VAR, 
_localStartBatchNumVarID, new IntObject(localStartBatchNum)));
                // put current model on federated worker
-               Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+               Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(
+                       new FederatedRequest(RequestType.PUT_VAR, _modelVarID, 
model));
 
                try {
-                       if(!putParamsResponse.get().isSuccessful())
+                       if(!putParamsResponse.get().isSuccessful() || 
!putBatchCounterResponse.get().isSuccessful())
                                throw new 
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
                }
                catch(Exception e) {
@@ -426,9 +374,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                }
 
                // create and execute the udf on the remote worker
-               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                               _featuresData.getVarID(),
-                               new federatedComputeEpochGradients(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID})
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
+                       new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
+                               new federatedComputeGradientsForNBatches(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(),
+                               _localStartBatchNumVarID, _modelVarID}, 
numBatchesToCompute,localUpdate)
                ));
 
                try {
@@ -441,13 +390,17 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        }
 
        /**
-        * This is the code that will be executed on the federated Worker when 
computing one epoch
+        * This is the code that will be executed on the federated Worker when 
computing one gradients for n batches
         */
-       private static class federatedComputeEpochGradients extends 
FederatedUDF {
+       private static class federatedComputeGradientsForNBatches extends 
FederatedUDF {
                private static final long serialVersionUID = 
-3075901536748794832L;
+               int _numBatchesToCompute;
+               boolean _localUpdate;
 
-               protected federatedComputeEpochGradients(long[] inIDs) {
+               protected federatedComputeGradientsForNBatches(long[] inIDs, 
int numBatchesToCompute, boolean localUpdate) {
                        super(inIDs);
+                       _numBatchesToCompute = numBatchesToCompute;
+                       _localUpdate = localUpdate;
                }
 
                @Override
@@ -455,12 +408,13 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // read in data by varid
                        MatrixObject features = (MatrixObject) data[0];
                        MatrixObject labels = (MatrixObject) data[1];
-                       ListObject model = (ListObject) data[2];
+                       int localStartBatchNum = (int) ((IntObject) 
data[2]).getLongValue();
+                       ListObject model = (ListObject) data[3];
 
                        // get data from execution context
                        long batchSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
                        long dataSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
-                       long numBatches = ((IntObject) 
ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue();
+                       int possibleBatchesPerLocalEpoch = (int) ((IntObject) 
ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue();
                        String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
                        String gradientsFunctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
                        String aggregationFuctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
@@ -478,61 +432,71 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                        func.getInputParamNames(), outputNames, 
"gradient function");
                        DataIdentifier gradientsOutput = outputs.get(0);
 
-                       // recreate aggregation instruction and output
-                       func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, 
false);
-                       inputs = func.getInputParams();
-                       outputs = func.getOutputParams();
-                       boundInputs = inputs.stream()
-                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-                                       .toArray(CPOperand[]::new);
-                       outputNames = 
outputs.stream().map(DataIdentifier::getName)
-                                       
.collect(Collectors.toCollection(ArrayList::new));
-                       Instruction aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
-                                       func.getInputParamNames(), outputNames, 
"aggregation function");
-                       DataIdentifier aggregationOutput = outputs.get(0);
-
+                       // recreate aggregation instruction and output if needed
+                       Instruction aggregationInstruction = null;
+                       DataIdentifier aggregationOutput = null;
+                       if(_localUpdate && _numBatchesToCompute > 1) {
+                               func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, 
false);
+                               inputs = func.getInputParams();
+                               outputs = func.getOutputParams();
+                               boundInputs = inputs.stream()
+                                               .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                               .toArray(CPOperand[]::new);
+                               outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                                               
.collect(Collectors.toCollection(ArrayList::new));
+                               aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
+                                               func.getInputParamNames(), 
outputNames, "aggregation function");
+                               aggregationOutput = outputs.get(0);
+                       }
 
                        ListObject accGradients = null;
+                       int currentLocalBatchNumber = localStartBatchNum;
                        // prepare execution context
                        ec.setVariable(Statement.PS_MODEL, model);
-                       for (int batchCounter = 0; batchCounter < numBatches; 
batchCounter++) {
+                       for (int batchCounter = 0; batchCounter < 
_numBatchesToCompute; batchCounter++) {
+                               int localBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch);
+
                                // slice batch from feature and label matrix
-                               long begin = batchCounter * batchSize + 1;
-                               long end = Math.min((batchCounter + 1) * 
batchSize, dataSize);
+                               long begin = localBatchNum * batchSize + 1;
+                               long end = Math.min((localBatchNum + 1) * 
batchSize, dataSize);
                                MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(features, begin, end);
                                MatrixObject bLabels = 
ParamservUtils.sliceMatrix(labels, begin, end);
 
                                // prepare execution context
                                ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
                                ec.setVariable(Statement.PS_LABELS, bLabels);
-                               boolean localUpdate = batchCounter < numBatches 
- 1;
 
-                               // calculate intermediate gradients
+                               // calculate gradients for batch
                                gradientsInstruction.processInstruction(ec);
                                ListObject gradients = 
ec.getListObject(gradientsOutput.getName());
 
-                               // TODO: is this equivalent for momentum based 
and AMS prob?
+                               // accrue the computed gradients - In the 
single batch case this is just a list copy
+                               // is this equivalent for momentum based and 
AMS prob?
                                accGradients = 
ParamservUtils.accrueGradients(accGradients, gradients, false);
 
-                               // Update the local model with gradients
-                               if(localUpdate) {
+                               // update the local model with gradients if 
needed
+                               if(_localUpdate && batchCounter < 
_numBatchesToCompute - 1) {
                                        // Invoke the aggregate function
+                                       assert aggregationInstruction != null;
                                        
aggregationInstruction.processInstruction(ec);
                                        // Get the new model
                                        model = 
ec.getListObject(aggregationOutput.getName());
                                        // Set new model in execution context
                                        ec.setVariable(Statement.PS_MODEL, 
model);
                                        // clean up gradients and result
-                                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
                                        ParamservUtils.cleanupListObject(ec, 
aggregationOutput.getName());
                                }
 
-                               // clean up sliced batch
+                               // clean up
+                               ParamservUtils.cleanupListObject(ec, 
gradientsOutput.getName());
                                ParamservUtils.cleanupData(ec, 
Statement.PS_FEATURES);
                                ParamservUtils.cleanupData(ec, 
Statement.PS_LABELS);
+                               
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
+                               if( LOG.isInfoEnabled() )
+                                       LOG.info("[+]" + " completed batch " + 
localBatchNum);
                        }
 
-                       // model clean up - doing this twice is not an issue
+                       // model clean up
                        ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
                        ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index e63fb14..51600d2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -210,12 +210,39 @@ public class ParamservUtils {
                MatrixBlock sample = MatrixBlock.sampleOperations(numEntries, 
numEntries, false, seed);
 
                // Combine the sequence and sample as a table
-               return seq.ctableSeqOperations(sample, 1.0,
-                       new MatrixBlock(numEntries, numEntries, true));
+               return seq.ctableSeqOperations(sample, 1.0, new 
MatrixBlock(numEntries, numEntries, true));
+       }
+
+       /**
+        * Generates a matrix which when left multiplied with the input matrix 
will subsample
+        * @param nsamples number of samples
+        * @param nrows number of rows in input matrix
+        * @param seed seed used to generate random number
+        * @return subsample matrix
+        */
+       public static MatrixBlock generateSubsampleMatrix(int nsamples, int 
nrows, long seed) {
+               MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
+               // No replacement to preserve as much of the original data as 
possible
+               MatrixBlock sample = MatrixBlock.sampleOperations(nrows, 
nsamples, false, seed);
+               return seq.ctableSeqOperations(sample, 1.0, new 
MatrixBlock(nsamples, nrows, true), false);
+       }
+
+       /**
+        * Generates a matrix which when left multiplied with the input matrix 
will replicate n data rows
+        * @param nsamples number of samples
+        * @param nrows number of rows in input matrix
+        * @param seed seed used to generate random number
+        * @return replication matrix
+        */
+       public static MatrixBlock generateReplicationMatrix(int nsamples, int 
nrows, long seed) {
+               MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
+               // Replacement set to true to provide random replication
+               MatrixBlock sample = MatrixBlock.sampleOperations(nrows, 
nsamples, true, seed);
+               return seq.ctableSeqOperations(sample, 1.0, new 
MatrixBlock(nsamples, nrows, true), false);
        }
 
        public static ExecutionContext createExecutionContext(ExecutionContext 
ec,
-               LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+               LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
        {
                return createExecutionContext(ec, varsMap, updFunc, aggFunc, k, 
false);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
new file mode 100644
index 0000000..460faba
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
+       @Override
+       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+               List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+               List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+               int average_num_rows = (int) 
Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+
+               for(int i = 0; i < pFeatures.size(); i++) {
+                       // Works, because the map contains a single entry
+                       FederatedData featuresData = (FederatedData) 
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+                       FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+                       Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       featuresData.getVarID(), new 
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, average_num_rows)));
+
+                       try {
+                               FederatedResponse response = udfResponse.get();
+                               if(!response.isSuccessful())
+                                       throw new 
DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: balance 
UDF returned fail");
+                       }
+                       catch(Exception e) {
+                               throw new 
DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing 
balance UDF failed" + e.getMessage());
+                       }
+
+                       DataCharacteristics update = 
pFeatures.get(i).getDataCharacteristics().setRows(average_num_rows);
+                       pFeatures.get(i).updateDataCharacteristics(update);
+                       update = 
pLabels.get(i).getDataCharacteristics().setRows(average_num_rows);
+                       pLabels.get(i).updateDataCharacteristics(update);
+               }
+
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+       }
+
+       /**
+        * Balance UDF executed on the federated worker
+        */
+       private static class balanceDataOnFederatedWorker extends FederatedUDF {
+               private static final long serialVersionUID = 
6631958250346625546L;
+               private final int _average_num_rows;
+               
+               protected balanceDataOnFederatedWorker(long[] inIDs, int 
average_num_rows) {
+                       super(inIDs);
+                       _average_num_rows = average_num_rows;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+
+                       if(features.getNumRows() > _average_num_rows) {
+                               // generate subsampling matrix
+                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_average_num_rows, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               subsampleTo(features, subsampleMatrixBlock);
+                               subsampleTo(labels, subsampleMatrixBlock);
+                       }
+                       else if(features.getNumRows() < _average_num_rows) {
+                               int num_rows_needed = _average_num_rows - 
Math.toIntExact(features.getNumRows());
+                               // generate replication matrix
+                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               replicateTo(features, replicateMatrixBlock);
+                               replicateTo(labels, replicateMatrixBlock);
+                       }
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index 4183372..f5c9638 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -26,6 +26,10 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 
@@ -37,14 +41,16 @@ import java.util.List;
 public abstract class DataPartitionFederatedScheme {
 
        public static final class Result {
-               public final List<MatrixObject> pFeatures;
-               public final List<MatrixObject> pLabels;
-               public final int workerNum;
-
-               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum) {
-                       this.pFeatures = pFeatures;
-                       this.pLabels = pLabels;
-                       this.workerNum = workerNum;
+               public final List<MatrixObject> _pFeatures;
+               public final List<MatrixObject> _pLabels;
+               public final int _workerNum;
+               public final BalanceMetrics _balanceMetrics;
+
+               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum, BalanceMetrics balanceMetrics) {
+                       this._pFeatures = pFeatures;
+                       this._pLabels = pLabels;
+                       this._workerNum = workerNum;
+                       this._balanceMetrics = balanceMetrics;
                }
        }
 
@@ -57,12 +63,10 @@ public abstract class DataPartitionFederatedScheme {
         */
        static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
                if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
-
                        List<MatrixObject> slices = 
Collections.synchronizedList(new ArrayList<>());
                        fedMatrix.getFedMapping().forEachParallel((range, data) 
-> {
                                // Create sliced matrix object
                                MatrixObject slice = new 
MatrixObject(fedMatrix.getValueType(), 
Dag.getNextUniqueVarname(Types.DataType.MATRIX));
-                               // Warning needs MetaDataFormat instead of 
MetaData
                                slice.setMetaData(new MetaDataFormat(
                                                new 
MatrixCharacteristics(range.getSize(0), range.getSize(1)),
                                                Types.FileFormat.BINARY)
@@ -85,4 +89,78 @@ public abstract class DataPartitionFederatedScheme {
                                        "currently only supports row federated 
data");
                }
        }
+
+       static BalanceMetrics getBalanceMetrics(List<MatrixObject> slices) {
+               if (slices == null || slices.size() == 0)
+                       return new BalanceMetrics(0, 0, 0);
+
+               long minRows = slices.get(0).getNumRows();
+               long maxRows = minRows;
+               long sum = 0;
+
+               for (MatrixObject slice : slices) {
+                       if (slice.getNumRows() < minRows)
+                               minRows = slice.getNumRows();
+                       else if (slice.getNumRows() > maxRows)
+                               maxRows = slice.getNumRows();
+
+                       sum += slice.getNumRows();
+               }
+
+               return new BalanceMetrics(minRows, sum / slices.size(), 
maxRows);
+       }
+
+       public static final class BalanceMetrics {
+               public final long _minRows;
+               public final long _avgRows;
+               public final long _maxRows;
+
+               public BalanceMetrics(long minRows, long avgRows, long maxRows) 
{
+                       this._minRows = minRows;
+                       this._avgRows = avgRows;
+                       this._maxRows = maxRows;
+               }
+       }
+
+       /**
+        * Just a mat multiply used to shuffle with a provided shuffle 
matrixBlock
+        *
+        * @param m the input matrix object
+        * @param P the permutation matrix for shuffling
+        */
+       static void shuffle(MatrixObject m, MatrixBlock P) {
+               int k = InfrastructureAnalyzer.getLocalParallelism();
+               AggregateBinaryOperator mm = 
InstructionUtils.getMatMultOperator(k);
+               MatrixBlock out = P.aggregateBinaryOperations(P, 
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+               m.acquireModify(out);
+               m.release();
+       }
+
+       /**
+        * Takes a MatrixObjects and extends it to the chosen number of rows by 
random replication
+        *
+        * @param m the input matrix object
+        * @param R the permutation matrix for replication
+        */
+       static void replicateTo(MatrixObject m, MatrixBlock R) {
+               int k = InfrastructureAnalyzer.getLocalParallelism();
+               AggregateBinaryOperator mm = 
InstructionUtils.getMatMultOperator(k);
+               MatrixBlock out = R.aggregateBinaryOperations(R, 
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+               m.acquireModify(m.acquireReadAndRelease().append(out, new 
MatrixBlock(), false));
+               m.release();
+       }
+
+       /**
+        * Just a mat multiply used to subsample with a provided subsample 
matrixBlock
+        *
+        * @param m the input matrix object
+        * @param R the permutation matrix for subsampling
+        */
+       static void subsampleTo(MatrixObject m, MatrixBlock S) {
+               int k = InfrastructureAnalyzer.getLocalParallelism();
+               AggregateBinaryOperator mm = 
InstructionUtils.getMatMultOperator(k);
+               MatrixBlock out = S.aggregateBinaryOperations(S, 
m.acquireReadAndRelease(), new MatrixBlock(), mm);
+               m.acquireModify(out);
+               m.release();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index 4cdfb95..d1ebb6c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -35,6 +35,15 @@ public class FederatedDataPartitioner {
                        case SHUFFLE:
                                _scheme = new ShuffleFederatedScheme();
                                break;
+                       case REPLICATE_TO_MAX:
+                               _scheme = new ReplicateToMaxFederatedScheme();
+                               break;
+                       case SUBSAMPLE_TO_MIN:
+                               _scheme = new SubsampleToMinFederatedScheme();
+                               break;
+                       case BALANCE_TO_AVG:
+                               _scheme = new BalanceToAvgFederatedScheme();
+                               break;
                        default:
                                throw new 
DMLRuntimeException(String.format("FederatedDataPartitioner: not support data 
partition scheme '%s'", scheme));
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index 06feded..e306f25 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -27,6 +27,6 @@ public class KeepDataOnWorkerFederatedScheme extends 
DataPartitionFederatedSchem
        public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-               return new Result(pFeatures, pLabels, pFeatures.size());
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
new file mode 100644
index 0000000..068cfa9
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme {
+       @Override
+       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+               List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+               List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+               int max_rows = 0;
+               for (MatrixObject pFeature : pFeatures) {
+                       max_rows = (pFeature.getNumRows() > max_rows) ? 
Math.toIntExact(pFeature.getNumRows()) : max_rows;
+               }
+
+               for(int i = 0; i < pFeatures.size(); i++) {
+                       // Works, because the map contains a single entry
+                       FederatedData featuresData = (FederatedData) 
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+                       FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+                       Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       featuresData.getVarID(), new 
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, max_rows)));
+
+                       try {
+                               FederatedResponse response = udfResponse.get();
+                               if(!response.isSuccessful())
+                                       throw new 
DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: 
replicate UDF returned fail");
+                       }
+                       catch(Exception e) {
+                               throw new 
DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: 
executing replicate UDF failed" + e.getMessage());
+                       }
+
+                       DataCharacteristics update = 
pFeatures.get(i).getDataCharacteristics().setRows(max_rows);
+                       pFeatures.get(i).updateDataCharacteristics(update);
+                       update = 
pLabels.get(i).getDataCharacteristics().setRows(max_rows);
+                       pLabels.get(i).updateDataCharacteristics(update);
+               }
+
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+       }
+
+       /**
+        * Replicate UDF executed on the federated worker
+        */
+       private static class replicateDataOnFederatedWorker extends 
FederatedUDF {
+               private static final long serialVersionUID = 
-6930898456315100587L;
+               private final int _max_rows;
+               
+               protected replicateDataOnFederatedWorker(long[] inIDs, int 
max_rows) {
+                       super(inIDs);
+                       _max_rows = max_rows;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+
+                       // replicate up to the max
+                       if(features.getNumRows() < _max_rows) {
+                               int num_rows_needed = _max_rows - 
Math.toIntExact(features.getNumRows());
+                               // generate replication matrix
+                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               replicateTo(features, replicateMatrixBlock);
+                               replicateTo(labels, replicateMatrixBlock);
+                       }
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index d6d8cfc..65ef69d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -19,15 +19,67 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 import java.util.List;
+import java.util.concurrent.Future;
 
 public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
        @Override
        public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-               return new Result(pFeatures, pLabels, pFeatures.size());
+
+               for(int i = 0; i < pFeatures.size(); i++) {
+                       // Works, because the map contains a single entry
+                       FederatedData featuresData = (FederatedData) 
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+                       FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+                       Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       featuresData.getVarID(), new 
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()})));
+
+                       try {
+                               FederatedResponse response = udfResponse.get();
+                               if(!response.isSuccessful())
+                                       throw new 
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle 
UDF returned fail");
+                       }
+                       catch(Exception e) {
+                               throw new 
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing 
shuffle UDF failed" + e.getMessage());
+                       }
+               }
+
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+       }
+
+       /**
+        * Shuffle UDF executed on the federated worker
+        */
+       private static class shuffleDataOnFederatedWorker extends FederatedUDF {
+               private static final long serialVersionUID = 
3228664618781333325L;
+
+               protected shuffleDataOnFederatedWorker(long[] inIDs) {
+                       super(inIDs);
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+
+                       // generate permutation matrix
+                       MatrixBlock permutationMatrixBlock = 
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), 
System.currentTimeMillis());
+                       shuffle(features, permutationMatrixBlock);
+                       shuffle(labels, permutationMatrixBlock);
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
        }
-}
+}
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
new file mode 100644
index 0000000..9b62cc8
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme {
+       @Override
+       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+               List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+               List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+
+               int min_rows = Integer.MAX_VALUE;
+               for (MatrixObject pFeature : pFeatures) {
+                       min_rows = (pFeature.getNumRows() < min_rows) ? 
Math.toIntExact(pFeature.getNumRows()) : min_rows;
+               }
+
+               for(int i = 0; i < pFeatures.size(); i++) {
+                       // Works, because the map contains a single entry
+                       FederatedData featuresData = (FederatedData) 
pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
+                       FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+
+                       Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       featuresData.getVarID(), new 
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, min_rows)));
+
+                       try {
+                               FederatedResponse response = udfResponse.get();
+                               if(!response.isSuccessful())
+                                       throw new 
DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: 
subsample UDF returned fail");
+                       }
+                       catch(Exception e) {
+                               throw new 
DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: 
executing subsample UDF failed" + e.getMessage());
+                       }
+
+                       DataCharacteristics update = 
pFeatures.get(i).getDataCharacteristics().setRows(min_rows);
+                       pFeatures.get(i).updateDataCharacteristics(update);
+                       update = 
pLabels.get(i).getDataCharacteristics().setRows(min_rows);
+                       pLabels.get(i).updateDataCharacteristics(update);
+               }
+
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+       }
+
+       /**
+        * Subsample UDF executed on the federated worker
+        */
+       private static class subsampleDataOnFederatedWorker extends 
FederatedUDF {
+               private static final long serialVersionUID = 
2213790859544004286L;
+               private final int _min_rows;
+               
+               protected subsampleDataOnFederatedWorker(long[] inIDs, int 
min_rows) {
+                       super(inIDs);
+                       _min_rows = min_rows;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+
+                       // subsample down to minimum
+                       if(features.getNumRows() > _min_rows) {
+                               // generate subsampling matrix
+                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_min_rows, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               subsampleTo(features, subsampleMatrixBlock);
+                               subsampleTo(labels, subsampleMatrixBlock);
+                       }
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 7a285f6..a2b8d9f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -32,6 +32,7 @@ import static 
org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
+import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
 
 import java.util.HashMap;
 import java.util.HashSet;
@@ -52,11 +53,12 @@ import org.apache.spark.util.LongAccumulator;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.lops.LopProperties;
-import org.apache.sysds.parser.Statement;
 import org.apache.sysds.parser.Statement.PSFrequency;
 import org.apache.sysds.parser.Statement.PSModeType;
 import org.apache.sysds.parser.Statement.PSScheme;
+import org.apache.sysds.parser.Statement.FederatedPSScheme;
 import org.apache.sysds.parser.Statement.PSUpdateType;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -86,6 +88,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        private static final int DEFAULT_BATCH_SIZE = 64;
        private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.EPOCH;
        private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
+       private static final PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING = 
PSRuntimeBalancing.NONE;
+       private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME = 
FederatedPSScheme.KEEP_DATA_ON_WORKER;
        private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
        private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
 
@@ -96,7 +100,6 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        @Override
        public void processInstruction(ExecutionContext ec) {
                // check if the input is federated
-               
                if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
                                
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
                        runFederated(ec);
@@ -124,15 +127,27 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // get inputs
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
+               PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+               FederatedPSScheme federatedPSScheme = getFederatedScheme();
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
 
                // partition federated data
-               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(federatedPSScheme)
                                
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
-               List<MatrixObject> pFeatures = result.pFeatures;
-               List<MatrixObject> pLabels = result.pLabels;
-               int workerNum = result.workerNum;
+               List<MatrixObject> pFeatures = result._pFeatures;
+               List<MatrixObject> pLabels = result._pLabels;
+               int workerNum = result._workerNum;
+
+               // calculate runtime balancing
+               int numBatchesPerEpoch = 0;
+               if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
+               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
+               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
+               }
 
                // setup threading
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
@@ -141,8 +156,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(ec);
-               // Level of par is 1 because one worker will be launched per 
task
-               // TODO: Fix recompilation
+               // Level of par is -1 so each federated worker can scale to its 
cpu cores
                ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, -1, 
true);
                // Create workers' execution context
                List<ExecutionContext> federatedWorkerECs = 
ParamservUtils.copyExecutionContext(newEC, workerNum);
@@ -152,8 +166,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ListObject model = ec.getListObject(getParam(PS_MODEL));
                ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, workerNum, model, aggServiceEC);
                // Create the local workers
+               int finalNumBatchesPerEpoch = numBatchesPerEpoch;
                List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
-                               .mapToObj(i -> new FederatedPSControlThread(i, 
updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps))
+                               .mapToObj(i -> new FederatedPSControlThread(i, 
updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
                                .collect(Collectors.toList());
 
                if(workerNum != threads.size()) {
@@ -379,6 +394,18 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
        }
 
+       private PSRuntimeBalancing getRuntimeBalancing() {
+               if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+                       return DEFAULT_RUNTIME_BALANCING;
+               }
+               try {
+                       return 
PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+               } catch (IllegalArgumentException e) {
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: "
+                                       + "not support '%s' runtime 
balancing.", getParam(PS_RUNTIME_BALANCING)));
+               }
+       }
+
        private static int getRemainingCores() {
                return InfrastructureAnalyzer.getLocalParallelism();
        }
@@ -469,4 +496,15 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                return scheme;
        }
 
+       private FederatedPSScheme getFederatedScheme() {
+               FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME;
+               if (getParameterMap().containsKey(PS_SCHEME)) {
+                       try {
+                               federated_scheme = 
FederatedPSScheme.valueOf(getParam(PS_SCHEME));
+                       } catch (IllegalArgumentException e) {
+                               throw new 
DMLRuntimeException(String.format("Paramserv function in federated mode: not 
support data partition scheme '%s'", getParam(PS_SCHEME)));
+                       }
+               }
+               return federated_scheme;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
index 477379c..4de60b4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
@@ -61,7 +61,7 @@ public final class MatrixIndexingFEDInstruction extends 
IndexingFEDInstruction {
                
                //modify federated ranges in place
                Map<FederatedRange, IndexRange> ixs = new HashMap<>();
-               for(FederatedRange range : fedMap.getFedMapping().keySet()) {
+               for(FederatedRange range : fedMap.getMap().keySet()) {
                        long rs = range.getBeginDims()[0], re = 
range.getEndDims()[0],
                                cs = range.getBeginDims()[1], ce = 
range.getEndDims()[1];
                        long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart 
- rs) : 0;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index da25122..a479be7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -104,7 +104,7 @@ public class VariableFEDInstruction extends FEDInstruction 
implements LineageTra
                MatrixObject out = ec.getMatrixObject(_in.getOutput());
                FederationMap outMap = 
mo1.getFedMapping().copyWithNewID(fr1.getID());
                Map<FederatedRange, FederatedData> newMap = new HashMap<>();
-               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getFedMapping().entrySet()) {
+               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getMap().entrySet()) {
                        FederatedData om = pair.getValue();
                        FederatedData nf = new 
FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(),
                                om.getVarID());
@@ -131,7 +131,7 @@ public class VariableFEDInstruction extends FEDInstruction 
implements LineageTra
                out.getDataCharacteristics().set(mo1.getNumRows(), 
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
                FederationMap outMap = 
mo1.getFedMapping().copyWithNewID(fr1.getID());
                Map<FederatedRange, FederatedData> newMap = new HashMap<>();
-               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getFedMapping().entrySet()) {
+               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getMap().entrySet()) {
                        FederatedData om = pair.getValue();
                        FederatedData nf = new 
FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(),
                                om.getVarID());
diff --git 
a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java 
b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
index 0527d23..5252c5e 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
@@ -107,7 +107,7 @@ public class ReaderWriterFederated {
                        FileSystem fs = IOUtilFunctions.getFileSystem(path, 
job);
                        DataOutputStream out = fs.create(path, true);
                        ObjectMapper mapper = new ObjectMapper();
-                       FederatedDataAddress[] outObjects = 
parseMap(fedMap.getFedMapping());
+                       FederatedDataAddress[] outObjects = 
parseMap(fedMap.getMap());
                        try(BufferedWriter pw = new BufferedWriter(new 
OutputStreamWriter(out))) {
                                mapper.writeValue(pw, outObjects);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 4cea827..9161410 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5326,21 +5326,15 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                if( resultBlock!=null )
                        resultBlock.recomputeNonZeros();
        }
-       
+
        /**
-        *  D = ctable(seq,A,w)
-        *  this &lt;- seq; thatMatrix &lt;- A; thatScalar &lt;- w; result 
&lt;- D
-        *  
-        * (i1,j1,v1) from input1 (this)
-        * (i1,j1,v2) from input2 (that)
-        * (w)  from scalar_input3 (scalarThat2)
-        * 
         * @param thatMatrix matrix value
         * @param thatScalar scalar double
         * @param resultBlock result matrix block
+        * @param updateClen when this matrix already has the desired number of 
columns updateClen can be set to false
         * @return resultBlock
         */
-       public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double 
thatScalar, MatrixBlock resultBlock) {
+       public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double 
thatScalar, MatrixBlock resultBlock, boolean updateClen) {
                MatrixBlock that = checkType(thatMatrix);
                CTable ctable = CTable.getCTableFnObject();
                double w = thatScalar;
@@ -5357,9 +5351,28 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                
                //update meta data (initially unknown number of columns)
                //note: nnz maintained in ctable (via quickset)
-               resultBlock.clen = maxCol;
+               if(updateClen) {
+                       resultBlock.clen = maxCol;
+               }
                return resultBlock;
        }
+
+       /**
+        *  D = ctable(seq,A,w)
+        *  this &lt;- seq; thatMatrix &lt;- A; thatScalar &lt;- w; result 
&lt;- D
+        *
+        * (i1,j1,v1) from input1 (this)
+        * (i1,j1,v2) from input2 (that)
+        * (w)  from scalar_input3 (scalarThat2)
+        *
+        * @param thatMatrix matrix value
+        * @param thatScalar scalar double
+        * @param resultBlock result matrix block
+        * @return resultBlock
+        */
+       public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double 
thatScalar, MatrixBlock resultBlock) {
+               return ctableSeqOperations(thatMatrix, thatScalar, resultBlock, 
true);
+       }
        
        /**
         *  D = ctable(A,B,W)
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4f08e88..9af4fcb 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test;
 
+import static java.lang.Math.ceil;
 import static java.lang.Thread.sleep;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -27,10 +28,12 @@ import java.io.ByteArrayOutputStream;
 import java.io.File;
 import java.io.IOException;
 import java.io.PrintStream;
+import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 
@@ -43,6 +46,7 @@ import org.apache.hadoop.util.GenericOptionsParser;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.SparkSession.Builder;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
@@ -52,13 +56,17 @@ import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.parser.ParseException;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.DMLScriptException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.FrameReader;
 import org.apache.sysds.runtime.io.FrameReaderFactory;
@@ -67,6 +75,7 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
@@ -588,6 +597,71 @@ public abstract class AutomatedTestBase {
 
        /**
         * <p>
+        * Takes a matrix (double[][]) and writes it in parts locally. Then it 
creates a federated MatrixObject
+        * containing the local paths and the given ports. This federated MO is 
also written to disk with the provided name
+        * When running federated workers locally on the specified ports this 
federated Matrix can then be used
+        * for testing purposes. Just use read on input(name)
+        * </p>
+        *
+        * @param name name of the matrix when writing to disk
+        * @param matrix two dimensional matrix
+        * @param numFederatedWorkers the number of federated workers
+        * @param ports a list of port the length of the number of federated 
workers
+        * @param ranges an array containing arrays of length to with the upper 
and lower bound (rows) for the slices
+        */
+       protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name,
+               double[][] matrix, int numFederatedWorkers, List<Integer> 
ports, double[][] ranges)
+       {
+               // check matrix non empty
+               if(matrix.length == 0 || matrix[0].length == 0)
+                       return;
+
+               int nrows = matrix.length;
+               int ncol = matrix[0].length;
+
+               // create federated MatrixObject
+               MatrixObject federatedMatrixObject = new 
MatrixObject(ValueType.FP64, 
+                       Dag.getNextUniqueVarname(Types.DataType.MATRIX));
+               federatedMatrixObject.setMetaData(new MetaDataFormat(
+                       new MatrixCharacteristics(nrows, ncol), 
Types.FileFormat.BINARY));
+
+               // write parts and generate FederationMap
+               HashMap<FederatedRange, FederatedData> fedHashMap = new 
HashMap<>();
+               for(int i = 0; i < numFederatedWorkers; i++) {
+                       double lowerBound = ranges[i][0];
+                       double upperBound = ranges[i][1];
+                       double examplesForWorkerI = upperBound - lowerBound;
+                       String path = name + "_" + (i + 1);
+
+                       // write slice
+                       writeInputMatrixWithMTD(path, 
Arrays.copyOfRange(matrix, (int)lowerBound, (int)upperBound),
+                               false, new MatrixCharacteristics((long) 
examplesForWorkerI, ncol,
+                               OptimizerUtils.DEFAULT_BLOCKSIZE, (long) 
examplesForWorkerI * ncol));
+
+                       // generate fedmap entry
+                       FederatedRange range = new FederatedRange(new 
long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol});
+                       FederatedData data = new FederatedData(DataType.MATRIX, 
new InetSocketAddress(ports.get(i)), input(path));
+                       fedHashMap.put(range, data);
+               }
+               
+               federatedMatrixObject.setFedMapping(new 
FederationMap(FederationUtils.getNextFedDataID(), fedHashMap));
+               
federatedMatrixObject.getFedMapping().setType(FederationMap.FType.ROW);
+
+               writeInputFederatedWithMTD(name, federatedMatrixObject, null);
+       }
+
+       protected double[][] generateBalancedFederatedRowRanges(int 
numFederatedWorkers, int dataSetSize) {
+               double[][] ranges = new double[numFederatedWorkers][2];
+               double examplesPerWorker = ceil( (double) dataSetSize / 
(double) numFederatedWorkers);
+               for(int i = 0; i < numFederatedWorkers; i++) {
+                       ranges[i][0] = examplesPerWorker * i;
+                       ranges[i][1] = Math.min(examplesPerWorker * (i + 1), 
dataSetSize);
+               }
+               return ranges;
+       }
+
+       /**
+        * <p>
         * Adds a matrix to the input path and writes it to a file.
         * </p>
         *
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index cc0af07..6a52fc4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -27,7 +27,6 @@ import java.util.List;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -44,48 +43,66 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/paramserv/";
        private final static String TEST_NAME = "FederatedParamservTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedParamservTest.class.getSimpleName() + "/";
-       private final static int _blocksize = 1024;
 
        private final String _networkType;
        private final int _numFederatedWorkers;
-       private final int _examplesPerWorker;
+       private final int _dataSetSize;
        private final int _epochs;
        private final int _batch_size;
        private final double _eta;
        private final String _utype;
        private final String _freq;
+       private final String _scheme;
+       private final String _runtime_balancing;
+       private final String _data_distribution;
 
        // parameters
        @Parameterized.Parameters
        public static Collection<Object[]> parameters() {
                return Arrays.asList(new Object[][] {
-                       // Network type, number of federated workers, examples 
per worker, batch size, epochs, learning rate, update
-                       // type, update frequency
-                       {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"TwoNN", 
2, 2, 1, 5, 0.01, "ASP", "BATCH"},
-                       {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"TwoNN", 
2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
-                       {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"CNN", 2, 
2, 1, 5, 0.01, "ASP", "BATCH"},
-                       {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"CNN", 2, 
2, 1, 5, 0.01, "ASP", "EPOCH"},
-                       {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"},
-                       // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"},
-                       // {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"},
-                       // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"},
-                       // {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"},
-                       // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"},
-                       {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"},
-                       // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"}
+                       // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
+
+                       // basic functionality
+                       {"TwoNN", 2, 4, 1, 4, 0.01,       "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
+                       {"CNN",   2, 4, 1, 4, 0.01,       "BSP", "EPOCH", 
"SHUFFLE",             "NONE" ,     "IMBALANCED"},
+                       {"CNN",   2, 4, 1, 4, 0.01,       "ASP", "BATCH", 
"REPLICATE_TO_MAX",    "RUN_MIN" ,  "IMBALANCED"},
+                       {"TwoNN", 2, 4, 1, 4, 0.01,       "ASP", "EPOCH", 
"BALANCE_TO_AVG",      "CYCLE_MAX", "IMBALANCED"},
+                       {"TwoNN", 5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER", "NONE" ,     "BALANCED"},
+
+                       /*
+                               // runtime balancing
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "IMBALANCED"},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "IMBALANCED"},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "IMBALANCED"},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "IMBALANCED"},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "IMBALANCED"},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "IMBALANCED"},
+
+                               // data partitioning
+                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "IMBALANCED"},
+                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "IMBALANCED"},
+                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "IMBALANCED"},
+                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "BALANCE_TO_AVG",               "NONE" ,                "IMBALANCED"},
+
+                               // balanced tests
+                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE" ,                "BALANCED"}
+                        */
                });
        }
 
-       public FederatedParamservTest(String networkType, int 
numFederatedWorkers, int examplesPerWorker, int batch_size,
-               int epochs, double eta, String utype, String freq) {
+       public FederatedParamservTest(String networkType, int 
numFederatedWorkers, int dataSetSize, int batch_size,
+               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String data_distribution) {
                _networkType = networkType;
                _numFederatedWorkers = numFederatedWorkers;
-               _examplesPerWorker = examplesPerWorker;
+               _dataSetSize = dataSetSize;
                _batch_size = batch_size;
                _epochs = epochs;
                _eta = eta;
                _utype = utype;
                _freq = freq;
+               _scheme = scheme;
+               _runtime_balancing = runtime_balancing;
+               _data_distribution = data_distribution;
        }
 
        @Override
@@ -111,68 +128,74 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                setOutputBuffering(true);
 
                int C = 1, Hin = 28, Win = 28;
-               int numFeatures = C * Hin * Win;
                int numLabels = 10;
 
                ExecMode platformOld = setExecMode(mode);
 
                try {
-
-                       // dml name
-                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       // generate program args
-                       List<String> programArgsList = new 
ArrayList<>(Arrays.asList("-stats",
-                               "-nvargs",
-                               "examples_per_worker=" + _examplesPerWorker,
-                               "num_features=" + numFeatures,
-                               "num_labels=" + numLabels,
-                               "epochs=" + _epochs,
-                               "batch_size=" + _batch_size,
-                               "eta=" + _eta,
-                               "utype=" + _utype,
-                               "freq=" + _freq,
-                               "network_type=" + _networkType,
-                               "channels=" + C,
-                               "hin=" + Hin,
-                               "win=" + Win));
-
-                       // for each worker
+                       // start threads
                        List<Integer> ports = new ArrayList<>();
                        List<Thread> threads = new ArrayList<>();
                        for(int i = 0; i < _numFederatedWorkers; i++) {
-                               // write row partitioned features to disk
-                               writeInputMatrixWithMTD("X" + i,
-                                       
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win),
-                                       false,
-                                       new 
MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize,
-                                               _examplesPerWorker * 
numFeatures));
-                               // write row partitioned labels to disk
-                               writeInputMatrixWithMTD("y" + i,
-                                       
generateDummyMNISTLabels(_examplesPerWorker, numLabels),
-                                       false,
-                                       new 
MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize,
-                                               _examplesPerWorker * 
numLabels));
-
-                               // start worker
                                ports.add(getRandomAvailablePort());
                                
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+                       }
 
-                               // add worker to program args
-                               programArgsList.add("X" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("X" + i)));
-                               programArgsList.add("y" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+                       // generate test data
+                       double[][] features = 
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+                       double[][] labels = 
generateDummyMNISTLabels(_dataSetSize, numLabels);
+                       String featuresName = "";
+                       String labelsName = "";
+
+                       // federate test data balanced or imbalanced
+                       if(_data_distribution.equals("IMBALANCED")) {
+                               featuresName = "X_IMBALANCED_" + 
_numFederatedWorkers;
+                               labelsName = "y_IMBALANCED_" + 
_numFederatedWorkers;
+                               double[][] ranges = {{0,1}, {1,4}};
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, 
_numFederatedWorkers, ports, ranges);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges);
+                       }
+                       else {
+                               featuresName = "X_BALANCED_" + 
_numFederatedWorkers;
+                               labelsName = "y_BALANCED_" + 
_numFederatedWorkers;
+                               double[][] ranges = 
generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, 
_numFederatedWorkers, ports, ranges);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges);
                        }
+
                        try {
-                               Thread.sleep(1000);
+                               //wait for all workers to be setup
+                               Thread.sleep(FED_WORKER_WAIT);
                        }
                        catch(InterruptedException e) {
                                e.printStackTrace();
                        }
-       
+
+                       // dml name
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       // generate program args
+                       List<String> programArgsList = new 
ArrayList<>(Arrays.asList("-stats",
+                                       "-nvargs",
+                                       "features=" + input(featuresName),
+                                       "labels=" + input(labelsName),
+                                       "epochs=" + _epochs,
+                                       "batch_size=" + _batch_size,
+                                       "eta=" + _eta,
+                                       "utype=" + _utype,
+                                       "freq=" + _freq,
+                                       "scheme=" + _scheme,
+                                       "runtime_balancing=" + 
_runtime_balancing,
+                                       "network_type=" + _networkType,
+                                       "channels=" + C,
+                                       "hin=" + Hin,
+                                       "win=" + Win,
+                                       "seed=" + 25));
+
                        programArgs = programArgsList.toArray(new String[0]);
                        LOG.debug(runTest(null));
                        Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                        
-                       // cleanup
+                       // shut down threads
                        for(int i = 0; i < _numFederatedWorkers; i++) {
                                TestUtils.shutdownThreads(threads.get(i));
                        }
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml 
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index d622c13..69c7e76 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -67,8 +67,10 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
  */
 train = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int C, int Hin, int Win, int epochs, int batch_size, double 
learning_rate)
-    return (list[unknown] model_trained) {
+                 int epochs, int batch_size, double eta,
+                 int C, int Hin, int Win,
+                 int seed = -1)
+    return (list[unknown] model) {
 
   N = nrow(X)
   K = ncol(y)
@@ -84,74 +86,45 @@ train = function(matrix[double] X, matrix[double] y,
   N3 = 512  # num nodes in affine3
   # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
 
-  [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
-  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
-  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
-  [W4, b4] = affine::init(N3, K, -1)  # inputs: (N, N3)
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: 
(N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
   W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
 
   # Initialize SGD w/ Nesterov momentum optimizer
-  learning_rate = learning_rate  # learning rate
   mu = 0.9  # momentum
   decay = 0.95  # learning rate decay constant
   vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
   vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
   vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
   vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, 
vb3, vb4)
+
   # Regularization
   lambda = 5e-04
 
   # Create the hyper parameter list
-  hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, 
Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, 
F2=F2, N3=N3)
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
   # Calculate iterations
   iters = ceil(N / batch_size)
-  print_interval = floor(iters / 25)
-
-  print("[+] Starting optimization")
-  print("[+]  Learning rate: " + learning_rate)
-  print("[+]  Batch size: " + batch_size)
-  print("[+]  Iterations per epoch: " + iters + "\n")
 
   for (e in 1:epochs) {
-    print("[+] Starting epoch: " + e)
-    print("|")
     for(i in 1:iters) {
-      # Create the model list
-      model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
-
       # Get next batch
       beg = ((i-1) * batch_size) %% N + 1
       end = min(N, beg + batch_size - 1)
       X_batch = X[beg:end,]
       y_batch = y[beg:end,]
 
-      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
-      model_updated = aggregation(model_list, hyperparams, gradients_list)
-
-      W1 = as.matrix(model_updated[1])
-      W2 = as.matrix(model_updated[2])
-      W3 = as.matrix(model_updated[3])
-      W4 = as.matrix(model_updated[4])
-      b1 = as.matrix(model_updated[5])
-      b2 = as.matrix(model_updated[6])
-      b3 = as.matrix(model_updated[7])
-      b4 = as.matrix(model_updated[8])
-      vW1 = as.matrix(model_updated[9])
-      vW2 = as.matrix(model_updated[10])
-      vW3 = as.matrix(model_updated[11])
-      vW4 = as.matrix(model_updated[12])
-      vb1 = as.matrix(model_updated[13])
-      vb2 = as.matrix(model_updated[14])
-      vb3 = as.matrix(model_updated[15])
-      vb4 = as.matrix(model_updated[16])
-      if((i %% print_interval) == 0) {
-        print("█")
-      }
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
     }
-    print("|")
   }
-
-  model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
 }
 
 /*
@@ -190,9 +163,10 @@ train = function(matrix[double] X, matrix[double] y,
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int C, int Hin, int Win, int epochs, int workers,
-                 string utype, string freq, int batch_size, string scheme, 
string mode, double learning_rate)
-    return (list[unknown] model_trained) {
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing,
+                 double eta, int C, int Hin, int Win,
+                 int seed = -1)
+    return (list[unknown] model) {
 
   N = nrow(X)
   K = ncol(y)
@@ -208,14 +182,17 @@ train_paramserv = function(matrix[double] X, 
matrix[double] y,
   N3 = 512  # num nodes in affine3
   # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
 
-  [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
-  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
-  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
-  [W4, b4] = affine::init(N3, K, -1)  # inputs: (N, N3)
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: 
(N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
   W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
 
   # Initialize SGD w/ Nesterov momentum optimizer
-  learning_rate = learning_rate  # learning rate
+  learning_rate = eta  # learning rate
   mu = 0.9  # momentum
   decay = 0.95  # learning rate decay constant
   vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
@@ -225,12 +202,16 @@ train_paramserv = function(matrix[double] X, 
matrix[double] y,
   # Regularization
   lambda = 5e-04
   # Create the model list
-  model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, 
vb3, vb4)
   # Create the hyper parameter list
-  params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
 
   # Use paramserv function
-  model_trained = paramserv(model=model_list, features=X, labels=y, 
val_features=X_val, val_labels=y_val, 
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", 
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", 
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
+    upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
+    
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, 
hyperparams=hyperparams)
 }
 
 /*
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 16c72c4..10d2cc7 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -23,35 +23,13 @@ 
source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
 source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
 
 # create federated input matrices
-features = federated(addresses=list($X0, $X1),
-    ranges=list(list(0, 0), list($examples_per_worker, $num_features),
-                list($examples_per_worker, 0), list($examples_per_worker * 2, 
$num_features)))
+features = read($features)
+labels = read($labels)
 
-labels = federated(addresses=list($y0, $y1),
-    ranges=list(list(0, 0), list($examples_per_worker, $num_labels),
-                list($examples_per_worker, 0), list($examples_per_worker * 2, 
$num_labels)))
-
-epochs = $epochs
-batch_size = $batch_size
-learning_rate = $eta
-utype = $utype
-freq = $freq
-network_type = $network_type
-
-# currently ignored parameters
-workers = 1
-scheme = "DISJOINT_CONTIGUOUS"
-paramserv_mode = "LOCAL"
-
-# config for the cnn
-channels = $channels
-hin = $hin
-win = $win
-
-if(network_type == "TwoNN") {
-  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, 
paramserv_mode, learning_rate)
+if($network_type == "TwoNN") {
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $eta, $seed)
 }
 else {
-  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, 
batch_size, scheme, paramserv_mode, learning_rate)
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $eta, $channels, $hin, $win, $seed)
 }
 print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 31e889a..9bd49d8 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -57,8 +57,9 @@ source("nn/optim/sgd.dml") as sgd
  */
 train = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int epochs, int batch_size, double learning_rate)
-    return (list[unknown] model_trained) {
+                 int epochs, int batch_size, double eta,
+                 int seed = -1)
+    return (list[unknown] model) {
 
   N = nrow(X)  # num examples
   D = ncol(X)  # num features
@@ -66,53 +67,31 @@ train = function(matrix[double] X, matrix[double] y,
 
   # Create the network:
   ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
-  [W1, b1] = affine::init(D, 200, -1)
-  [W2, b2] = affine::init(200, 200, -1)
-  [W3, b3] = affine::init(200, K, -1)
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
   W3 = W3 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+  model = list(W1, W2, W3, b1, b2, b3)
 
   # Create the hyper parameter list
-  hyperparams = list(learning_rate=learning_rate)
+  hyperparams = list(learning_rate=eta)
   # Calculate iterations
   iters = ceil(N / batch_size)
-  print_interval = floor(iters / 25)
-
-  print("[+] Starting optimization")
-  print("[+]  Learning rate: " + learning_rate)
-  print("[+]  Batch size: " + batch_size)
-  print("[+]  Iterations per epoch: " + iters + "\n")
 
   for (e in 1:epochs) {
-    print("[+] Starting epoch: " + e)
-    print("|")
     for(i in 1:iters) {
-      # Create the model list
-      model_list = list(W1, W2, W3, b1, b2, b3)
-
       # Get next batch
       beg = ((i-1) * batch_size) %% N + 1
       end = min(N, beg + batch_size - 1)
       X_batch = X[beg:end,]
       y_batch = y[beg:end,]
 
-      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
-      model_updated = aggregation(model_list, hyperparams, gradients_list)
-
-      W1 = as.matrix(model_updated[1])
-      W2 = as.matrix(model_updated[2])
-      W3 = as.matrix(model_updated[3])
-      b1 = as.matrix(model_updated[4])
-      b2 = as.matrix(model_updated[5])
-      b3 = as.matrix(model_updated[6])
-
-      if((i %% print_interval) == 0) {
-        print("█")
-      }
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
     }
-    print("|")
   }
-
-  model_trained = list(W1, W2, W3, b1, b2, b3)
 }
 
 /*
@@ -146,9 +125,9 @@ train = function(matrix[double] X, matrix[double] y,
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int epochs, int workers,
-                 string utype, string freq, int batch_size, string scheme, 
string mode, double learning_rate)
-    return (list[unknown] model_trained) {
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing,
+                 double eta, int seed = -1)
+    return (list[unknown] model) {
 
   N = nrow(X)  # num examples
   D = ncol(X)  # num features
@@ -156,16 +135,27 @@ train_paramserv = function(matrix[double] X, 
matrix[double] y,
 
   # Create the network:
   ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
-  [W1, b1] = affine::init(D, 200, -1)
-  [W2, b2] = affine::init(200, 200, -1)
-  [W3, b3] = affine::init(200, K, -1)
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
+  # W3 = W3 / sqrt(2) # different initialization, since being fed into 
softmax, instead of relu
+
+  # [W1, b1] = affine::init(D, 200)
+  # [W2, b2] = affine::init(200, 200)
+  # [W3, b3] = affine::init(200, K)
 
   # Create the model list
-  model_list = list(W1, W2, W3, b1, b2, b3)
+  model = list(W1, W2, W3, b1, b2, b3)
   # Create the hyper parameter list
-  params = list(learning_rate=learning_rate)
+  hyperparams = list(learning_rate=eta)
   # Use paramserv function
-  model_trained = paramserv(model=model_list, features=X, labels=y, 
val_features=X_val, val_labels=y_val, 
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", 
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", 
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
+    
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
+    
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, 
hyperparams=hyperparams)
 }
 
 /*

Reply via email to