Repository: systemml
Updated Branches:
  refs/heads/master 25be6a686 -> b06f390ec


[SYSTEMML-2392/8,2401/2/6] Paramserv statistics and various fixes

Closes #787.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b06f390e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b06f390e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b06f390e

Branch: refs/heads/master
Commit: b06f390ecf75dcf24e8143aafdba533440326861
Parents: 25be6a6
Author: EdgarLGB <[email protected]>
Authored: Sun Jun 17 18:21:48 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jun 17 18:21:48 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java |  78 ++++++--
 .../controlprogram/paramserv/PSWorker.java      |   4 +-
 .../controlprogram/paramserv/ParamServer.java   |  62 +++++--
 .../paramserv/ParamservUtils.java               | 132 +++++++++++++-
 .../cp/ParamservBuiltinCPInstruction.java       | 181 +++++--------------
 .../java/org/apache/sysml/utils/Statistics.java |  47 ++++-
 6 files changed, 327 insertions(+), 177 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index a64df72..beb5e45 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -20,14 +20,21 @@
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import java.util.concurrent.Callable;
+import java.util.stream.IntStream;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.utils.Statistics;
 
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
@@ -40,6 +47,9 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
        @Override
        public Void call() throws Exception {
+               if (DMLScript.STATISTICS)
+                       Statistics.incWorkerNumber();
+               
                try {
                        long dataSize = _features.getNumRows();
                        int totalIter = (int) Math.ceil((double) dataSize / 
_batchSize);
@@ -65,26 +75,28 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        private void computeEpoch(long dataSize, int totalIter) {
                for (int i = 0; i < _epochs; i++) {
                        // Pull the global parameters from ps
-                       ListObject globalParams = pullModel();
-
+                       ListObject params = pullModel();
+                       ListObject accGradients = null;
+                       
                        for (int j = 0; j < totalIter; j++) {
-                               _ec.setVariable(Statement.PS_MODEL, 
globalParams);
+                               _ec.setVariable(Statement.PS_MODEL, params);
 
                                ListObject gradients = 
computeGradients(dataSize, totalIter, i, j);
 
-                               if (j == totalIter - 1) {
-                                       // Push the gradients to ps
-                                       pushGradients(gradients);
-                                       ParamservUtils.cleanupListObject(_ec, 
globalParams);
-                               } else {
-                                       // Update the local model with gradients
-                                       globalParams = 
_ps.updateModel(gradients, globalParams);
-                                       if (LOG.isDebugEnabled()) {
-                                               LOG.debug(String.format("Local 
worker_%d: Local global parameter [size:%d kb] updated.",
-                                                       _workerID, 
globalParams.getDataSize()));
-                                       }
-                               }
+                               // Accumulate the intermediate gradients
+                               accGradients = (accGradients==null) ?
+                                       ParamservUtils.copyList(gradients) :
+                                       accrueGradients(accGradients, 
gradients);
+
+                               // Update the local model with gradients
+                               if( j < totalIter - 1 )
+                                       params = updateModel(params, gradients, 
i, j, totalIter);
                        }
+
+                       // Push the gradients to ps
+                       pushGradients(accGradients);
+                       ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
+
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
                        }
@@ -92,6 +104,22 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
        }
 
+       private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int totalIter) {
+               Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
+
+               globalParams = _ps.updateModel(gradients, globalParams);
+
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSLocalModelUpdateTime((long) 
tUpd.stop());
+               
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("Local worker_%d: Local global 
parameter [size:%d kb] updated. "
+                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
+                               _workerID, globalParams.getDataSize(), i + 1, 
_epochs, j + 1, totalIter));
+               }
+               return globalParams;
+       }
+
        private void computeBatch(long dataSize, int totalIter) {
                for (int i = 0; i < _epochs; i++) {
                        for (int j = 0; j < totalIter; j++) {
@@ -103,7 +131,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                                // Push the gradients to ps
                                pushGradients(gradients);
 
-                               ParamservUtils.cleanupListObject(_ec, 
globalParams);
+                               ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
                        }
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
@@ -135,8 +163,12 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                long end = Math.min((j + 1) * _batchSize, dataSize);
 
                // Get batch features and labels
+               Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null;
                MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, 
begin, end);
                MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, 
begin, end);
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSBatchIndexingTime((long) tSlic.stop());
+
                _ec.setVariable(Statement.PS_FEATURES, bFeatures);
                _ec.setVariable(Statement.PS_LABELS, bLabels);
 
@@ -148,7 +180,10 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                }
 
                // Invoke the update function
+               Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null;
                _inst.processInstruction(_ec);
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSGradientComputeTime((long) 
tGrad.stop());
 
                // Get the gradients
                ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
@@ -157,4 +192,15 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                ParamservUtils.cleanupData(bLabels);
                return gradients;
        }
+
+       private ListObject accrueGradients(ListObject accGradients, ListObject 
gradients) {
+               IntStream.range(0, accGradients.getLength()).forEach(i -> {
+                       MatrixBlock mb1 = ((MatrixObject) 
accGradients.getData().get(i)).acquireRead();
+                       MatrixBlock mb2 = ((MatrixObject) 
gradients.getData().get(i)).acquireRead();
+                       mb1.binaryOperationsInPlace(new 
BinaryOperator(Plus.getPlusFnObject()), mb2);
+                       ((MatrixObject) 
accGradients.getData().get(i)).release();
+                       ((MatrixObject) gradients.getData().get(i)).release();
+               });
+               return accGradients;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index d94831e..f46370d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.UPDATE_FUNC_PREFIX;
+
 import java.util.ArrayList;
 import java.util.stream.Collectors;
 
@@ -71,7 +73,7 @@ public abstract class PSWorker {
                        funcNS = keys[0];
                        funcName = keys[1];
                }
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(funcNS, UPDATE_FUNC_PREFIX + _workerID 
+ "_" + funcName);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
                CPOperand[] boundInputs = inputs.stream()

http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 7a39bec..0ddfb40 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.AGG_FUNC_PREFIX;
+
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -37,6 +39,7 @@ import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.DataIdentifier;
 import org.apache.sysml.parser.Expression;
@@ -44,13 +47,16 @@ import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.utils.Statistics;
 
 public abstract class ParamServer {
-       
+
        final BlockingQueue<Gradient> _gradientsQueue;
        final Map<Integer, BlockingQueue<ListObject>> _modelMap;
        private final AggregationService _aggService;
@@ -73,8 +79,7 @@ public abstract class ParamServer {
                        throw new DMLRuntimeException("Param server: failed to 
broadcast the initial model.", e);
                }
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
-                       .namingPattern("agg-service-pool-thread-%d")
-                       .build();
+                       .namingPattern("agg-service-pool-thread-%d").build();
                _es = Executors.newSingleThreadExecutor(factory);
        }
 
@@ -91,11 +96,17 @@ public abstract class ParamServer {
        }
 
        public ListObject getResult() {
+               // All the model updating work has terminated,
+               // so we could return directly the result model
                return _model;
        }
 
        public ListObject updateModel(ListObject gradients, ListObject model) {
-               return _aggService.updateModel(gradients, model);
+               //note: we use a new execution context to allow for concurrent 
execution of ASP local updates; 
+               //otherwise synchronized on the aggService instance would 
serialize those
+               ExecutionContext ec = 
ExecutionContextFactory.createContext(_aggService._ec.getProgram());
+               ec.setVariable(Statement.PS_HYPER_PARAMS, 
_aggService._ec.getVariable(Statement.PS_HYPER_PARAMS));
+               return _aggService.updateModel(ec, gradients, model);
        }
 
        public static class Gradient {
@@ -115,11 +126,11 @@ public abstract class ParamServer {
 
                protected final Log LOG = 
LogFactory.getLog(AggregationService.class.getName());
 
-               protected ExecutionContext _ec;
-               private Statement.PSUpdateType _updateType;
-               private FunctionCallCPInstruction _inst;
-               private DataIdentifier _output;
-               private boolean[] _finishedStates;  // Workers' finished states
+               protected final ExecutionContext _ec;
+               private final Statement.PSUpdateType _updateType;
+               private final FunctionCallCPInstruction _inst;
+               private final DataIdentifier _output;
+               private final boolean[] _finishedStates;  // Workers' finished 
states
 
                AggregationService(String aggFunc, Statement.PSUpdateType 
updateType, ExecutionContext ec, int workerNum) {
                        _ec = ec;
@@ -134,7 +145,7 @@ public abstract class ParamServer {
                                funcNS = keys[0];
                                funcName = keys[1];
                        }
-                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(funcNS, AGG_FUNC_PREFIX + funcName);
                        ArrayList<DataIdentifier> inputs = 
func.getInputParams();
                        ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
 
@@ -170,14 +181,24 @@ public abstract class ParamServer {
                }
 
                private void broadcastModel() throws InterruptedException {
+                       Timing tBroad = DMLScript.STATISTICS ? new Timing(true) 
: null;
+
                        //broadcast copy of the model to all workers, cleaned 
up by workers
                        for (BlockingQueue<ListObject> q : _modelMap.values())
                                q.put(ParamservUtils.copyList(_model));
+
+                       if (DMLScript.STATISTICS)
+                               Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
                }
-               
+
                private void broadcastModel(int workerID) throws 
InterruptedException {
+                       Timing tBroad = DMLScript.STATISTICS ? new Timing(true) 
: null;
+
                        //broadcast copy of model to specific worker, cleaned 
up by worker
                        
_modelMap.get(workerID).put(ParamservUtils.copyList(_model));
+
+                       if (DMLScript.STATISTICS)
+                               Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
                }
 
                @Override
@@ -195,7 +216,10 @@ public abstract class ParamServer {
                                }
 
                                // Update and redistribute the model
+                               Timing tAgg = DMLScript.STATISTICS ? new 
Timing(true) : null;
                                _model = updateModel(grad._gradients, _model);
+                               if (DMLScript.STATISTICS)
+                                       Statistics.accPSAggregationTime((long) 
tAgg.stop());
 
                                // Redistribute model according to update type
                                switch(_updateType) {
@@ -231,19 +255,23 @@ public abstract class ParamServer {
                 * @return A updated list object of model
                 */
                private synchronized ListObject updateModel(ListObject 
gradients, ListObject model) {
+                       return updateModel(_ec, gradients, model);
+               }
+               
+               private ListObject updateModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
                        // Populate the variables table with the gradients and 
model
-                       _ec.setVariable(Statement.PS_GRADIENTS, gradients);
-                       _ec.setVariable(Statement.PS_MODEL, model);
+                       ec.setVariable(Statement.PS_GRADIENTS, gradients);
+                       ec.setVariable(Statement.PS_MODEL, model);
 
                        // Invoke the aggregate function
-                       _inst.processInstruction(_ec);
+                       _inst.processInstruction(ec);
 
                        // Get the output
-                       ListObject newModel = (ListObject) 
_ec.getVariable(_output.getName());
+                       ListObject newModel = (ListObject) 
ec.getVariable(_output.getName());
 
                        // Update the model with the new output
-                       ParamservUtils.cleanupListObject(_ec, model);
-                       ParamservUtils.cleanupListObject(_ec, gradients);
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
                        return newModel;
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index 2b1dca5..c4a3d98 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -19,18 +19,35 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.MultiThreadedHop;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DMLTranslator;
 import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.StatementBlock;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.Program;
+import org.apache.sysml.runtime.controlprogram.ProgramBlock;
+import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -41,6 +58,9 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo;
 
 public class ParamservUtils {
 
+       public static final String UPDATE_FUNC_PREFIX = "_worker_";
+       public static final String AGG_FUNC_PREFIX = "_agg_";
+
        /**
         * Deep copy the list object
         *
@@ -65,8 +85,8 @@ public class ParamservUtils {
                return new ListObject(newData, lo.getNames());
        }
 
-       public static void cleanupListObject(ExecutionContext ec, ListObject 
lo) {
-               ec.getVariables().removeAllIn(new HashSet<>(lo.getNames()));
+       public static void cleanupListObject(ExecutionContext ec, String lName) 
{
+               ListObject lo = (ListObject) ec.removeVariable(lName);
                lo.getData().forEach(ParamservUtils::cleanupData);
        }
 
@@ -110,4 +130,112 @@ public class ParamservUtils {
                seq.ctableOperations(null, sample, 1.0, permutation);
                return permutation;
        }
+
+       public static ExecutionContext createExecutionContext(ExecutionContext 
ec, String updFunc, String aggFunc, int workerNum, int k) {
+               FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc);
+               FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc);
+
+               Program prog = ec.getProgram();
+
+               // 1. Recompile the internal program blocks
+               recompileProgramBlocks(k, prog.getProgramBlocks());
+               // 2. Recompile the imported function blocks
+               prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> 
recompileProgramBlocks(k, fvalue.getChildBlocks()));
+
+               // Copy function for workers
+               IntStream.range(0, workerNum).forEach(i -> 
copyFunction(updFunc, updPB, prog, UPDATE_FUNC_PREFIX + i + "_"));
+
+               // Copy function for agg service
+               copyFunction(aggFunc, aggPB, prog, AGG_FUNC_PREFIX);
+
+               return ExecutionContextFactory.createContext(prog);
+       }
+
+       private static void copyFunction(String funcName, FunctionProgramBlock 
updPB, Program prog, String prefix) {
+               String[] keys = DMLProgram.splitFunctionKey(funcName);
+               String namespace = null;
+               String func = keys[0];
+               if (keys.length == 2) {
+                       namespace = keys[0];
+                       func = keys[1];
+               }
+               FunctionProgramBlock copiedFunc = ProgramConverter
+                       .createDeepCopyFunctionProgramBlock(updPB, new 
HashSet<>(), new HashSet<>());
+               String fnameNew = prefix + func;
+               prog.addFunctionProgramBlock(namespace, fnameNew, copiedFunc);
+       }
+
+       private static void recompileProgramBlocks(int k, 
ArrayList<ProgramBlock> pbs) {
+               // Reset the visit status from root
+               for (ProgramBlock pb : pbs)
+                       
DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
+
+               // Should recursively assign the level of parallelism
+               // and recompile the program block
+               try {
+                       rAssignParallelism(pbs, k, false);
+               } catch (IOException e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
+
+       private static boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, 
int k, boolean recompiled) throws IOException {
+               for (ProgramBlock pb : pbs) {
+                       if (pb instanceof ParForProgramBlock) {
+                               ParForProgramBlock pfpb = (ParForProgramBlock) 
pb;
+                               pfpb.setDegreeOfParallelism(k);
+                               recompiled |= 
rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
+                       } else if (pb instanceof ForProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       } else if (pb instanceof WhileProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       } else if (pb instanceof FunctionProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       } else if (pb instanceof IfProgramBlock) {
+                               IfProgramBlock ipb = (IfProgramBlock) pb;
+                               recompiled |= 
rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
+                               if (ipb.getChildBlocksElseBody() != null)
+                                       recompiled |= 
rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
+                       } else {
+                               StatementBlock sb = pb.getStatementBlock();
+                               for (Hop hop : sb.getHops())
+                                       recompiled |= rAssignParallelism(hop, 
k, recompiled);
+                       }
+                       // Recompile the program block
+                       if (recompiled) {
+                               
Recompiler.recompileProgramBlockInstructions(pb);
+                       }
+               }
+               return recompiled;
+       }
+
+       private static boolean rAssignParallelism(Hop hop, int k, boolean 
recompiled) {
+               if (hop.isVisited()) {
+                       return recompiled;
+               }
+               if (hop instanceof MultiThreadedHop) {
+                       // Reassign the level of parallelism
+                       MultiThreadedHop mhop = (MultiThreadedHop) hop;
+                       mhop.setMaxNumThreads(k);
+                       recompiled = true;
+               }
+               ArrayList<Hop> inputs = hop.getInput();
+               for (Hop h : inputs) {
+                       recompiled |= rAssignParallelism(h, k, recompiled);
+               }
+               hop.setVisited();
+               return recompiled;
+       }
+
+
+       private static FunctionProgramBlock getFunctionBlock(ExecutionContext 
ec, String funcName) {
+               String[] keys = DMLProgram.splitFunctionKey(funcName);
+               String namespace = null;
+               String func = keys[0];
+               if (keys.length == 2) {
+                       namespace = keys[0];
+                       func = keys[1];
+               }
+               return ec.getProgram().getFunctionProgramBlock(namespace, func);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 65a1930..d2c0edd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -39,9 +39,6 @@ import static 
org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
 import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
 import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.concurrent.ExecutionException;
@@ -56,21 +53,10 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.MultiThreadedHop;
-import org.apache.sysml.hops.recompile.Recompiler;
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.DMLTranslator;
-import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
-import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
-import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
-import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.Program;
-import org.apache.sysml.runtime.controlprogram.ProgramBlock;
-import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
@@ -82,9 +68,11 @@ import 
org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.Statistics;
 
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
 
@@ -111,21 +99,26 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+
                PSModeType mode = getPSMode();
                int workerNum = getWorkerNum(mode);
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
-                       .namingPattern("workers-pool-thread-%d")
-                       .build();
+                       .namingPattern("workers-pool-thread-%d").build();
                ExecutorService es = Executors.newFixedThreadPool(workerNum, 
factory);
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
 
-               // Create the workers' execution context
                int k = getParLevel(workerNum);
-               List<ExecutionContext> workerECs = createExecutionContext(ec, 
updFunc, workerNum, k);
+
+               // Get the compiled execution context
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, updFunc, aggFunc, workerNum, k);
+
+               // Create workers' execution context
+               List<ExecutionContext> workerECs = 
createExecutionContext(workerNum, ec, newEC.getProgram());
 
                // Create the agg service's execution context
-               ExecutionContext aggServiceEC = createExecutionContext(ec, 
aggFunc, 1, 1).get(0);
+               ExecutionContext aggServiceEC = createExecutionContext(1, ec, 
newEC.getProgram()).get(0);
 
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
@@ -145,6 +138,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // Do data partition
                PSScheme scheme = getScheme();
                doDataPartitioning(scheme, ec, workers);
+               
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
 
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("\nConfiguration of paramserv 
func: "
@@ -169,6 +165,18 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
        }
 
+       private List<ExecutionContext> createExecutionContext(int size, 
ExecutionContext ec, Program program) {
+               return IntStream.range(0, size).mapToObj(i -> {
+                       // Put the hyperparam into the variables table
+                       LocalVariableMap varsMap = new LocalVariableMap();
+                       ListObject hyperParams = getHyperParams(ec);
+                       if (hyperParams != null) {
+                               varsMap.put(PS_HYPER_PARAMS, hyperParams);
+                       }
+                       return ExecutionContextFactory.createContext(varsMap, 
program);
+               }).collect(Collectors.toList());
+       }
+
        private PSModeType getPSMode() {
                PSModeType mode;
                try {
@@ -194,106 +202,6 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                return 
Math.max((int)Math.ceil((double)getRemainingCores()/workerNum), 1);
        }
 
-       private List<ExecutionContext> createExecutionContext(ExecutionContext 
ec, String funcName, int workerNum, int k) {
-               // Fetch the target function
-               String[] keys = DMLProgram.splitFunctionKey(funcName);
-               String namespace = null;
-               String func = keys[0];
-               if (keys.length == 2) {
-                       namespace = keys[0];
-                       func = keys[1];
-               }
-               return createExecutionContext(ec, namespace, func, workerNum, 
k);
-       }
-
-       private List<ExecutionContext> createExecutionContext(ExecutionContext 
ec, String namespace, String func,
-                       int workerNum, int k) {
-               FunctionProgramBlock targetFunc = 
ec.getProgram().getFunctionProgramBlock(namespace, func);
-               return IntStream.range(0, workerNum).mapToObj(i -> {
-                       // Put the hyperparam into the variables table
-                       LocalVariableMap varsMap = new LocalVariableMap();
-                       ListObject hyperParams = getHyperParams(ec);
-                       if (hyperParams != null) {
-                               varsMap.put(PS_HYPER_PARAMS, hyperParams);
-                       }
-
-                       // Deep copy the target func
-                       FunctionProgramBlock copiedFunc = ProgramConverter
-                               .createDeepCopyFunctionProgramBlock(targetFunc, 
new HashSet<>(), new HashSet<>());
-
-                       // Reset the visit status from root
-                       for( ProgramBlock pb : copiedFunc.getChildBlocks() )
-                               
DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
-
-                       // Should recursively assign the level of parallelism
-                       // and recompile the program block
-                       try {
-                               rAssignParallelism(copiedFunc.getChildBlocks(), 
k, false);
-                       } catch (IOException e) {
-                               throw new DMLRuntimeException(e);
-                       }
-
-                       Program prog = new Program();
-                       prog.addProgramBlock(copiedFunc);
-                       prog.addFunctionProgramBlock(namespace, func, 
copiedFunc);
-                       return ExecutionContextFactory.createContext(varsMap, 
prog);
-
-               }).collect(Collectors.toList());
-       }
-
-       private boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, 
boolean recompiled) throws IOException {
-               for (ProgramBlock pb : pbs) {
-                       if (pb instanceof ParForProgramBlock) {
-                               ParForProgramBlock pfpb = (ParForProgramBlock) 
pb;
-                               pfpb.setDegreeOfParallelism(k);
-                               recompiled |= 
rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
-                       }
-                       else if (pb instanceof ForProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
-                       }
-                       else if (pb instanceof WhileProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
-                       }
-                       else if (pb instanceof FunctionProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
-                       }
-                       else if (pb instanceof IfProgramBlock) {
-                               IfProgramBlock ipb = (IfProgramBlock) pb;
-                               recompiled |= 
rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
-                               if (ipb.getChildBlocksElseBody() != null)
-                                       recompiled |= 
rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
-                       }
-                       else {
-                               StatementBlock sb = pb.getStatementBlock();
-                               for (Hop hop : sb.getHops())
-                                       recompiled |= rAssignParallelism(hop, 
k, recompiled);
-                       }
-                       // Recompile the program block
-                       if (recompiled) {
-                               
Recompiler.recompileProgramBlockInstructions(pb);
-                       }
-               }
-               return recompiled;
-       }
-
-       private boolean rAssignParallelism(Hop hop, int k, boolean recompiled) {
-               if (hop.isVisited()) {
-                       return recompiled;
-               }
-               if (hop instanceof MultiThreadedHop) {
-                       // Reassign the level of parallelism
-                       MultiThreadedHop mhop = (MultiThreadedHop) hop;
-                       mhop.setMaxNumThreads(k);
-                       recompiled = true;
-               }
-               ArrayList<Hop> inputs = hop.getInput();
-               for (Hop h : inputs) {
-                       recompiled |= rAssignParallelism(h, k, recompiled);
-               }
-               hop.setVisited();
-               return recompiled;
-       }
-
        private PSUpdateType getUpdateType() {
                PSUpdateType updType;
                try {
@@ -310,13 +218,12 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                if (!getParameterMap().containsKey(PS_FREQUENCY)) {
                        return DEFAULT_UPDATE_FREQUENCY;
                }
-               PSFrequency freq;
                try {
-                       freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
+                       return PSFrequency.valueOf(getParam(PS_FREQUENCY));
                } catch (IllegalArgumentException e) {
-                       throw new DMLRuntimeException(String.format("Paramserv 
function: not support '%s' update frequency.", getParam(PS_FREQUENCY)));
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: "
+                               + "not support '%s' update frequency.", 
getParam(PS_FREQUENCY)));
                }
-               return freq;
        }
 
        private int getRemainingCores() {
@@ -330,19 +237,16 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return worker numbers
         */
        private int getWorkerNum(PSModeType mode) {
-               int workerNum = -1;
                switch (mode) {
                        case LOCAL:
                                // default worker number: available cores - 1 
(assign one process for agg service)
-                               workerNum = getRemainingCores();
-                               if 
(getParameterMap().containsKey(PS_PARALLELISM)) {
-                                       workerNum = Math.min(workerNum, 
Integer.valueOf(getParam(PS_PARALLELISM)));
-                               }
-                               break;
-                       case REMOTE_SPARK:
-                               throw new DMLRuntimeException("Do not support 
remote spark.");
+                               int workerNum = getRemainingCores();
+                               if 
(getParameterMap().containsKey(PS_PARALLELISM))
+                                       workerNum = 
Integer.valueOf(getParam(PS_PARALLELISM));
+                               return workerNum;
+                       default:
+                               throw new DMLRuntimeException("Unsupported 
parameter server: "+mode.name());
                }
-               return workerNum;
        }
 
        /**
@@ -351,15 +255,12 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return parameter server
         */
        private ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
-               ParamServer ps = null;
                switch (mode) {
                        case LOCAL:
-                               ps = new LocalParamServer(model, aggFunc, 
updateType, ec, workerNum);
-                               break;
-                       case REMOTE_SPARK:
-                               throw new DMLRuntimeException("Do not support 
remote spark.");
+                               return new LocalParamServer(model, aggFunc, 
updateType, ec, workerNum);
+                       default:
+                               throw new DMLRuntimeException("Unsupported 
parameter server: "+mode.name());
                }
-               return ps;
        }
 
        private long getBatchSize() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 8cf3f02..0618b38 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -107,6 +107,15 @@ public class Statistics
        private static final LongAdder sparkBroadcast = new LongAdder();
        private static final LongAdder sparkBroadcastCount = new LongAdder();
 
+       // Paramserv function stats (time is in milli sec)
+       private static final LongAdder psNumWorkers = new LongAdder();
+       private static final LongAdder psSetupTime = new LongAdder();
+       private static final LongAdder psGradientComputeTime = new LongAdder();
+       private static final LongAdder psAggregationTime = new LongAdder();
+       private static final LongAdder psLocalModelUpdateTime = new LongAdder();
+       private static final LongAdder psModelBroadcastTime = new LongAdder();
+       private static final LongAdder psBatchIndexTime = new LongAdder();
+
        //PARFOR optimization stats (low frequency updates)
        private static long parforOptTime = 0; //in milli sec
        private static long parforOptCount = 0; //count
@@ -516,7 +525,34 @@ public class Statistics
        public static void incSparkBroadcastCount(long c) {
                sparkBroadcastCount.add(c);
        }
-       
+
+       public static void incWorkerNumber() {
+               psNumWorkers.increment();
+       }
+
+       public static void accPSSetupTime(long t) {
+               psSetupTime.add(t);
+       }
+
+       public static void accPSGradientComputeTime(long t) {
+               psGradientComputeTime.add(t);
+       }
+
+       public static void accPSAggregationTime(long t) {
+               psAggregationTime.add(t);
+       }
+
+       public static void accPSLocalModelUpdateTime(long t) {
+               psLocalModelUpdateTime.add(t);
+       }
+
+       public static void accPSModelBroadcastTime(long t) {
+               psModelBroadcastTime.add(t);
+       }
+
+       public static void accPSBatchIndexingTime(long t) {
+               psBatchIndexTime.add(t);
+       }
        
        public static String getCPHeavyHitterCode( Instruction inst )
        {
@@ -850,6 +886,15 @@ public class Statistics
                                                                 
((double)sparkBroadcast.longValue())*1e-9,
                                                                 
((double)sparkCollect.longValue())*1e-9));
                        }
+                       if (psNumWorkers.longValue() > 0) {
+                               sb.append(String.format("Paramserv total num 
workers:\t%d.\n", psNumWorkers.longValue()));
+                               sb.append(String.format("Paramserv setup 
time:\t\t%.3f secs.\n", psSetupTime.doubleValue() / 1000));
+                               sb.append(String.format("Paramserv grad compute 
time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
+                               sb.append(String.format("Paramserv model update 
time:\t%.3f/%.3f secs.\n",
+                                       psLocalModelUpdateTime.doubleValue() / 
1000, psAggregationTime.doubleValue() / 1000));
+                               sb.append(String.format("Paramserv model 
broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
+                               sb.append(String.format("Paramserv batch slice 
time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+                       }
                        if( parforOptCount>0 ){
                                sb.append("ParFor loops optimized:\t\t" + 
getParforOptCount() + ".\n");
                                sb.append("ParFor optimize time:\t\t" + 
String.format("%.3f", ((double)getParforOptTime())/1000) + " sec.\n");  

Reply via email to