This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f8ff2a  [SYSTEMDS-2918] Improved IPA/recompile for parameter server 
functions
3f8ff2a is described below

commit 3f8ff2a87bdd061f9c1dcc4f137ac51fe595cb54
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 27 00:58:56 2021 +0100

    [SYSTEMDS-2918] Improved IPA/recompile for parameter server functions
    
    This patch addresses shortcomings of compiling parameter servers
    functions for gradients, model updates, and validation. For correctness,
    the paramserv builtin function was treated as a second-order function
    and thus called unoptimized functions because the used functions are
    invisible to inter-procedural analysis (IPA). Now, whenever the
    functions are given as literal strings, we expose these functions with
    pseudo function calls in the function call graph, and thus the full IPA
    and its IPA passes applies. In order to make use of that with functions
    that wrap all model matrices in lists, we further extended the core
    recompiler to better handle lists and infer sizes if possible.
    
    On a scenario of running a moderately complex CNN architecture with a
    parameter server of 1 worker, 2 epochs, and batch size 128 on mnist,
    this patch improved end-to-end performance from 174s to 143s, which is
    very close to the manually inlined nested for loop without the parameter
    server (1 worker, multi-threaded operations).
---
 .../java/org/apache/sysds/hops/FunctionOp.java     |  13 ++-
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |  39 +++++++
 .../apache/sysds/hops/ipa/FunctionCallGraph.java   | 120 +++++++++++++--------
 .../sysds/hops/ipa/IPAPassInlineFunctions.java     |  16 ++-
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |  20 +++-
 .../apache/sysds/hops/recompile/Recompiler.java    |  52 +++++++--
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  22 +++-
 .../paramserv/FederatedPSControlThread.java        |  20 ++--
 .../runtime/controlprogram/paramserv/PSWorker.java |   5 +-
 .../controlprogram/paramserv/ParamServer.java      |  10 +-
 .../controlprogram/paramserv/ParamservUtils.java   |   8 +-
 .../cp/ParamservBuiltinCPInstruction.java          |   2 +-
 src/main/java/org/apache/sysds/utils/Explain.java  |   2 +-
 13 files changed, 245 insertions(+), 84 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java 
b/src/main/java/org/apache/sysds/hops/FunctionOp.java
index 62fdeb5..1b6c2fc 100644
--- a/src/main/java/org/apache/sysds/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java
@@ -54,6 +54,7 @@ public class FunctionOp extends Hop
        private String _fnamespace = null;
        private String _fname = null;
        private boolean _opt = true; //call to optimized/unoptimized
+       private boolean _pseudo = false;
        
        private String[] _inputNames = null;  // A,B in C = foo(A=X, B=Y)
        private String[] _outputNames = null; // C in C = foo(A=X, B=Y)
@@ -67,8 +68,11 @@ public class FunctionOp extends Hop
                this(type, fnamespace, fname, inputNames, inputs, outputNames, 
false);
                _outputHops = outputHops;
        }
-
-       public FunctionOp(FunctionType type, String fnamespace, String fname, 
String[] inputNames, List<Hop> inputs, String[] outputNames, boolean singleOut) 
+       public FunctionOp(FunctionType type, String fnamespace, String fname, 
String[] inputNames, List<Hop> inputs, String[] outputNames, boolean singleOut) 
{
+               this(type, fnamespace, fname, inputNames, inputs, outputNames, 
singleOut, false);
+       }
+       
+       public FunctionOp(FunctionType type, String fnamespace, String fname, 
String[] inputNames, List<Hop> inputs, String[] outputNames, boolean singleOut, 
boolean pseudo) 
        {
                super(fnamespace + Program.KEY_DELIM + fname, DataType.UNKNOWN, 
ValueType.UNKNOWN );
                
@@ -77,6 +81,7 @@ public class FunctionOp extends Hop
                _fname = fname;
                _inputNames = inputNames;
                _outputNames = outputNames;
+               _pseudo = pseudo;
                
                for( Hop in : inputs ) {
                        getInput().add(in);
@@ -137,6 +142,10 @@ public class FunctionOp extends Hop
        public void setCallOptimized(boolean opt) {
                _opt = opt;
        }
+       
+       public boolean isPseudoFunctionCall() {
+               return _pseudo;
+       }
 
        @Override
        public boolean allowsAllExecTypes() {
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 68128e0..cc51375 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -19,10 +19,14 @@
 
 package org.apache.sysds.hops;
 
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map.Entry;
 
+import org.apache.commons.lang3.ObjectUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.AggOp;
@@ -33,6 +37,7 @@ import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.FunctionOp.FunctionType;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Data;
 import org.apache.sysds.lops.GroupedAggregate;
@@ -40,7 +45,9 @@ import org.apache.sysds.lops.GroupedAggregateM;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.lops.ParameterizedBuiltin;
+import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -967,6 +974,38 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        && ((ReorgOp)targetHop).getOp()==ReOrgOp.DIAG 
                        && targetHop.getInput().get(0).getDim2() == 1 ); 
        }
+       
+       public List<FunctionOp> getParamservPseudoFunctionCalls() {
+               try {
+                       String supd[] = 
DMLProgram.splitFunctionKey(((LiteralOp)getParameterHop("upd")).getStringValue());
+                       String sagg[] = 
DMLProgram.splitFunctionKey(((LiteralOp)getParameterHop("agg")).getStringValue());
+                       String sval[] = getParameterHop("val") == null ? null :
+                               
DMLProgram.splitFunctionKey(((LiteralOp)getParameterHop("val")).getStringValue());
+                       Hop model = getParameterHop("model");
+                       Hop hyp = getParameterHop("hyperparams");
+                       Hop batch = 
ObjectUtils.defaultIfNull(getParameterHop("batchsize"),
+                               new 
LiteralOp(ParamservBuiltinCPInstruction.DEFAULT_BATCH_SIZE));
+                       Hop X = 
HopRewriteUtils.createIndexingOp(getParameterHop("features"), batch);
+                       Hop y = 
HopRewriteUtils.createIndexingOp(getParameterHop("labels"), batch);
+                       FunctionOp fupd = new FunctionOp(FunctionType.DML, 
supd[0], supd[1],
+                               new String[] 
{"model","hyperparams","features","labels"}, Arrays.asList(model, hyp, X, y),
+                               new String[] {"gradients"}, false, true); 
//pseudo fcall
+                       FunctionOp fagg = new FunctionOp(FunctionType.DML, 
sagg[0], sagg[1],
+                               new String[] 
{"model","hyperparams","gradients"}, Arrays.asList(model, hyp, fupd),
+                               new String[] {"model"}, false, true); //pseudo 
fcall
+                       FunctionOp fval = sval == null ? null : new 
FunctionOp(FunctionType.DML, sval[0], sval[1],
+                               new String[] 
{"model","hyperparams","valfeatures","vallabels"}, Arrays.asList(model,
+                                       hyp, getParameterHop("val_features"), 
getParameterHop("val_labels")),
+                               new String[] {"loss","accuracy"}, false, true); 
//pseudo fcall
+                       return (sval == null) ? 
+                               Arrays.asList(fupd, fagg) : Arrays.asList(fupd, 
fagg, fval);
+               }
+               catch(Exception ex) {
+                       // silent error handling for robustness (e.g., wrong 
parameters)
+                       // later handled consistenty by the runtime instruction
+                       return Collections.emptyList();
+               }
+       }
 
        /**
         * This will check if there is sufficient memory locally (twice the 
size of second matrix, for original and sort data), and remotely (size of 
second matrix (sorted data)).  
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index 394cd70..23bd233 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -32,9 +32,11 @@ import java.util.stream.Collectors;
 import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.ForStatement;
@@ -409,58 +411,90 @@ public class FunctionCallGraph
                        if( hopsDAG == null || hopsDAG.isEmpty() ) 
                                return false; //nothing to do
 
-                       //function ops can only occur as root nodes of the dag
                        ret = 
HopRewriteUtils.containsSecondOrderBuiltin(hopsDAG);
+                       Hop.resetVisitStatus(hopsDAG);
                        for( Hop h : hopsDAG ) {
+                               //function ops can only occur as root nodes of 
the dag
                                if( h instanceof FunctionOp ) {
-                                       FunctionOp fop = (FunctionOp) h;
-                                       String lfkey = fop.getFunctionKey();
-                                       //keep all function operators
-                                       if( !_fCalls.containsKey(lfkey) ) {
-                                               _fCalls.put(lfkey, new 
ArrayList<>());
-                                               _fCallsSB.put(lfkey, new 
ArrayList<>());
-                                       }
-                                       _fCalls.get(lfkey).add(fop);
-                                       _fCallsSB.get(lfkey).add(sb);
-                                       
-                                       //prevent redundant call edges
-                                       if( lfset.contains(lfkey) || 
fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
-                                               continue;
-                                       
-                                       if( !_fGraph.containsKey(lfkey) )
-                                               _fGraph.put(lfkey, new 
HashSet<String>());
-                                       
-                                       //recursively construct function call 
dag
-                                       if( !fstack.contains(lfkey) ) {
-                                               fstack.push(lfkey);
-                                               _fGraph.get(fkey).add(lfkey);
-                                               
-                                               FunctionStatementBlock fsb = 
sb.getDMLProg()
-                                                       
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
-                                               FunctionStatement fs = 
(FunctionStatement) fsb.getStatement(0);
-                                               for( StatementBlock csb : 
fs.getBody() )
-                                                       ret |= 
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
-                                               fstack.pop();
-                                       }
-                                       //recursive function call
-                                       else {
-                                               _fGraph.get(fkey).add(lfkey);
-                                               _fRecursive.add(lfkey);
-                                       
-                                               //mark indirectly recursive 
functions as recursive
-                                               int ix = fstack.indexOf(lfkey);
-                                               for( int i=ix+1; 
i<fstack.size(); i++ )
-                                                       
_fRecursive.add(fstack.get(i));
-                                       }
-                                       
-                                       //mark as visited for current function 
call context
-                                       lfset.add( lfkey );
+                                       ret |= 
addFunctionOpToGraph((FunctionOp) h, fkey, sb, fstack, lfset);
                                }
+                               
+                               //recursive processing for paramserv functions
+                               rConstructFunctionCallGraph(h, fkey, sb, 
fstack, lfset);
                        }
                }
                
                return ret;
        }
+       
+       private boolean rConstructFunctionCallGraph(Hop hop, String fkey, 
StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
+               boolean ret = false;
+               if( hop.isVisited() )
+                       return ret;
+               
+               //recursively process all child nodes
+               for( Hop h : hop.getInput() )
+                       rConstructFunctionCallGraph(h, fkey, sb, fstack, lfset);
+               
+               if( HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV)
+                       && HopRewriteUtils.knownParamservFunctions(hop))
+               {
+                       ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) 
hop;
+                       List<FunctionOp> fps = 
pop.getParamservPseudoFunctionCalls();
+                       //include artificial function ops into functional call 
graph
+                       for( FunctionOp fop : fps )
+                               ret |= addFunctionOpToGraph(fop, fkey, sb, 
fstack, lfset);
+               }
+               
+               hop.setVisited();
+               return ret;
+       }
+       
+       private boolean addFunctionOpToGraph(FunctionOp fop, String fkey, 
StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
+               boolean ret = false;
+               String lfkey = fop.getFunctionKey();
+               //keep all function operators
+               if( !_fCalls.containsKey(lfkey) ) {
+                       _fCalls.put(lfkey, new ArrayList<>());
+                       _fCallsSB.put(lfkey, new ArrayList<>());
+               }
+               _fCalls.get(lfkey).add(fop);
+               _fCallsSB.get(lfkey).add(sb);
+               
+               //prevent redundant call edges
+               if( lfset.contains(lfkey) || 
fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
+                       return ret;
+               
+               if( !_fGraph.containsKey(lfkey) )
+                       _fGraph.put(lfkey, new HashSet<String>());
+               
+               //recursively construct function call dag
+               if( !fstack.contains(lfkey) ) {
+                       fstack.push(lfkey);
+                       _fGraph.get(fkey).add(lfkey);
+                       System.out.println(fkey+" -> "+lfkey);
+                       FunctionStatementBlock fsb = sb.getDMLProg()
+                               
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
+                       FunctionStatement fs = (FunctionStatement) 
fsb.getStatement(0);
+                       for( StatementBlock csb : fs.getBody() )
+                               ret |= rConstructFunctionCallGraph(lfkey, csb, 
fstack, new HashSet<String>());
+                       fstack.pop();
+               }
+               //recursive function call
+               else {
+                       _fGraph.get(fkey).add(lfkey);
+                       _fRecursive.add(lfkey);
+               
+                       //mark indirectly recursive functions as recursive
+                       int ix = fstack.indexOf(lfkey);
+                       for( int i=ix+1; i<fstack.size(); i++ )
+                               _fRecursive.add(fstack.get(i));
+               }
+               
+               //mark as visited for current function call context
+               lfset.add( lfkey );
+               return ret;
+       }
 
        private boolean rAnalyzeSecondOrderCall(StatementBlock sb) {
                boolean ret = false;
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
index 3a465db..77ca7ba 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
@@ -58,11 +58,14 @@ public class IPAPassInlineFunctions extends IPAPass
                //NOTE: we inline single-statement-block (i.e., last-level 
block) functions
                //that do not contain other functions, and either are small or 
called once
                
+               boolean ret = false; //rebuild fgraph
                for( String fkey : fgraph.getReachableFunctions() ) {
                        FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(fkey);
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
-                       if( fstmt.getBody().size() == 1 
-                               && 
HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) 
+                       if( fgraph.getFunctionCalls(fkey)==null )
+                               ret = true; //inlining might have remove 
paramserv-fcalls
+                       else if( fstmt.getBody().size() == 1 
+                               && 
HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0))
                                && 
!containsFunctionOp(fstmt.getBody().get(0).getHops())
                                && (fgraph.getFunctionCalls(fkey).size() == 1
                                        || 
countOperators(fstmt.getBody().get(0).getHops()) 
@@ -81,9 +84,10 @@ public class IPAPassInlineFunctions extends IPAPass
                                        if( LOG.isDebugEnabled() )
                                                LOG.debug("-- inline '"+fkey+"' 
at line "+op.getBeginLine());
                                        
-                                       //step 0: robustness for special cases
+                                       //step 0: robustness for special cases 
(named args, paramserv)
                                        if( op.getInput().size() != 
fstmt.getInputParams().size()
-                                               || 
op.getOutputVariableNames().length != fstmt.getOutputParams().size() ) {
+                                               || 
op.getOutputVariableNames().length != fstmt.getOutputParams().size()
+                                               || op.isPseudoFunctionCall() ) {
                                                removedAll = false;
                                                continue;
                                        }
@@ -130,10 +134,12 @@ public class IPAPassInlineFunctions extends IPAPass
                                        for( String fkeyTrans : fkeysTrans )
                                                if( 
!fgraph.isReachableFunction(fkeyTrans, true) )
                                                        
fgraph.removeFunctionCalls(fkeyTrans);
+                                       ret = true; //rebuild fgraph in next 
iteration
                                }
                        }
                }
-               return false;
+               
+               return ret;
        }
        
        private static boolean containsFunctionOp(ArrayList<Hop> hops) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index e1c9a44..35e04f4 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -21,6 +21,8 @@ package org.apache.sysds.hops.ipa;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
@@ -52,6 +54,7 @@ import 
org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.utils.Explain;
 
 import java.util.ArrayList;
 import java.util.HashSet;
@@ -74,6 +77,7 @@ import java.util.Set;
  */
 public class InterProceduralAnalysis 
 {
+       private static final boolean LDEBUG = false; //internal local debug 
level
        private static final Log LOG = 
LogFactory.getLog(InterProceduralAnalysis.class.getName());
 
        //internal configuration parameters
@@ -101,6 +105,15 @@ public class InterProceduralAnalysis
        
        //set IPA passes to apply in order 
        private final ArrayList<IPAPass> _passes;
+
+       static {
+               // for internal debugging only
+               if( LDEBUG ) {
+                       Logger.getLogger("org.apache.sysds.hops.ipa")
+                               .setLevel(Level.TRACE);
+               }
+       }
+
        
        /**
         * Creates a handle for performing inter-procedural analysis
@@ -112,10 +125,14 @@ public class InterProceduralAnalysis
         * @param dmlp The DML program to analyze
         */
        public InterProceduralAnalysis(DMLProgram dmlp) {
-               //analyzes the function call graph 
+               //analyzes the function call graph
                _prog = dmlp;
                _sb = null;
                _fgraph = new FunctionCallGraph(dmlp);
+               if( LOG.isDebugEnabled() ) {
+                       LOG.debug("IPA: Initial FunctionCallGraph: \n--MAIN 
PROGRAM\n" + 
+                               Explain.explainFunctionCallGraph(_fgraph, new 
HashSet<String>(), null, 1));
+               }
                
                //create ordered list of IPA passes
                _passes = new ArrayList<>();
@@ -157,7 +174,6 @@ public class InterProceduralAnalysis
         * 
         * @param repetitions number of IPA rounds 
         */
-       @SuppressWarnings("null")
        public void analyzeProgram(int repetitions) {
                //sanity check for valid number of repetitions
                if( repetitions <= 0 )
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index a714266..583788d 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -79,6 +79,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IntObject;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -1303,7 +1304,13 @@ public class Recompiler
                                        FrameObject fo = (FrameObject) dat;
                                        d.setDim1(fo.getNumRows());
                                        d.setDim2(fo.getNumColumns());
-                               } else if( dat instanceof TensorObject) {
+                               }
+                               else if( dat instanceof ListObject ) {
+                                       ListObject lo = (ListObject) dat;
+                                       d.setDim1(lo.getLength());
+                                       d.setDim2(1);
+                               }
+                               else if( dat instanceof TensorObject) {
                                        TensorObject to = (TensorObject) dat;
                                        // TODO: correct dimensions
                                        d.setDim1(to.getNumRows());
@@ -1387,18 +1394,41 @@ public class Recompiler
                        }
                }
                //update size expression for indexing according to symbol table 
entries
-               else if( hop instanceof IndexingOp && 
hop.getDataType()!=DataType.LIST ) {
+               else if( hop instanceof IndexingOp ) {
                        hop.refreshSizeInformation(); //update, incl reset
                        if( !hop.dimsKnown() ) {
-                               HashMap<Long, Double> memo = new HashMap<>();
-                               double rl = 
Hop.computeBoundsInformation(hop.getInput().get(1), vars, memo);
-                               double ru = 
Hop.computeBoundsInformation(hop.getInput().get(2), vars, memo);
-                               double cl = 
Hop.computeBoundsInformation(hop.getInput().get(3), vars, memo);
-                               double cu = 
Hop.computeBoundsInformation(hop.getInput().get(4), vars, memo);
-                               if( rl!=Double.MAX_VALUE && 
ru!=Double.MAX_VALUE )
-                                       hop.setDim1( (long)(ru-rl+1) );
-                               if( cl!=Double.MAX_VALUE && 
cu!=Double.MAX_VALUE )
-                                       hop.setDim2( (long)(cu-cl+1) );
+                               if( hop.getDataType().isList() 
+                                       && hop.getInput().get(1).getValueType() 
== ValueType.STRING ) {
+                                       hop.setDim1(1);
+                                       hop.setDim2(1);
+                               }
+                               else {
+                                       HashMap<Long, Double> memo = new 
HashMap<>();
+                                       double rl = 
Hop.computeBoundsInformation(hop.getInput().get(1), vars, memo);
+                                       double ru = 
Hop.computeBoundsInformation(hop.getInput().get(2), vars, memo);
+                                       double cl = 
Hop.computeBoundsInformation(hop.getInput().get(3), vars, memo);
+                                       double cu = 
Hop.computeBoundsInformation(hop.getInput().get(4), vars, memo);
+                                       if( rl!=Double.MAX_VALUE && 
ru!=Double.MAX_VALUE )
+                                               hop.setDim1( (long)(ru-rl+1) );
+                                       if( cl!=Double.MAX_VALUE && 
cu!=Double.MAX_VALUE )
+                                               hop.setDim2( (long)(cu-cl+1) );
+                               }
+                       }
+               }
+               else if(HopRewriteUtils.isUnary(hop, OpOp1.CAST_AS_MATRIX)
+                       && hop.getInput(0) instanceof IndexingOp && 
hop.getInput(0).getDataType().isList()
+                       && HopRewriteUtils.isData(hop.getInput(0).getInput(0), 
OpOpData.TRANSIENTREAD) ) {
+                       ListObject list = (ListObject) 
vars.get(hop.getInput(0).getInput(0).getName());
+                       Hop rix = hop.getInput(0);
+                       if( list != null
+                               && rix.getInput(1) instanceof LiteralOp
+                               && rix.getInput(2) instanceof LiteralOp
+                               && 
HopRewriteUtils.isEqualValue(rix.getInput(1), rix.getInput(2))) {
+                               MatrixObject mo = (MatrixObject) 
((rix.getInput(1).getValueType() == ValueType.STRING) ? 
+                                       
list.getData(((LiteralOp)rix.getInput(1)).getStringValue()) :
+                                       
list.getData((int)HopRewriteUtils.getIntValueSafe(rix.getInput(1))-1));
+                               hop.setDim1(mo.getNumRows());
+                               hop.setDim2(mo.getNumColumns());
                        }
                }
                else {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 391cc0d..0b00ffd 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -707,6 +707,13 @@ public class HopRewriteUtils
                return createUnary(ix, OpOp1.CAST_AS_SCALAR);
        }
        
+       public static IndexingOp createIndexingOp(Hop input, Hop batchsize) {
+               LiteralOp rl = new LiteralOp(1);
+               LiteralOp cl = new LiteralOp(1);
+               Hop cu = createUnary(input, OpOp1.NCOL);
+               return createIndexingOp(input, rl, batchsize, cl, cu);
+       }
+       
        public static IndexingOp createIndexingOp(Hop input, long rix, long 
cix) {
                LiteralOp row = new LiteralOp(rix);
                LiteralOp col = new LiteralOp(cix);
@@ -956,6 +963,10 @@ public class HopRewriteUtils
        public static boolean isSparse( Hop hop, double threshold ) {
                return hop.getSparsity() < threshold;
        }
+
+       public static boolean isEqualValue( Hop hop1, Hop hop2 ) {
+               return isEqualValue((LiteralOp)hop1, (LiteralOp)hop2);
+       }
        
        public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 ) {
                //check for string (no defined double value)
@@ -1588,9 +1599,18 @@ public class HopRewriteUtils
                if( hop.isVisited() ) return false;
                hop.setVisited();
                return HopRewriteUtils.isNary(hop, OpOpN.EVAL)
-                       || HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV)
+                       || (HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV) 
+                               && !knownParamservFunctions(hop))
                        || hop.getInput().stream().anyMatch(c -> 
containsSecondOrderBuiltin(c));
        }
+       
+       public static boolean knownParamservFunctions(Hop hop) {
+               ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) hop;
+               return pop.getParameterHop("upd") instanceof LiteralOp
+                       && pop.getParameterHop("agg") instanceof LiteralOp
+                       && (pop.getParameterHop("val") == null 
+                        || pop.getParameterHop("val") instanceof LiteralOp);
+       }
 
        public static void setUnoptimizedFunctionCalls(StatementBlock sb) {
                if( sb instanceof FunctionStatementBlock ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 10fee56..79ce52c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -151,9 +151,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        pbs.add(aggProgramBlock);
                }
 
+               boolean opt = 
_ec.getProgram().getFunctionProgramBlocks(false).isEmpty();
                programSerialized = InstructionUtils.concatStrings(
                        PROG_BEGIN, NEWLINE,
-                       ProgramConverter.serializeProgram(_ec.getProgram(), 
pbs, new HashMap<>(), false),
+                       ProgramConverter.serializeProgram(_ec.getProgram(), 
pbs, new HashMap<>(), opt),
                        PROG_END);
 
                // write program and meta data to worker
@@ -467,11 +468,12 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        long dataSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
                        int possibleBatchesPerLocalEpoch = (int) ((IntObject) 
ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue();
                        String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
-                       String gradientsFunctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
-                       String aggregationFuctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+                       String gradientsFunc = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+                       String aggFunc = ((StringObject) 
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
 
                        // recreate gradient instruction and output
-                       FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, 
false);
+                       boolean opt = 
!ec.getProgram().containsFunctionProgramBlock(namespace, gradientsFunc, false);
+                       FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunc, opt);
                        ArrayList<DataIdentifier> inputs = 
func.getInputParams();
                        ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
                        CPOperand[] boundInputs = inputs.stream()
@@ -479,15 +481,15 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                .toArray(CPOperand[]::new);
                        ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                
.collect(Collectors.toCollection(ArrayList::new));
-                       Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
-                               func.getInputParamNames(), outputNames, 
"gradient function");
+                       Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunc,
+                               opt, boundInputs,func.getInputParamNames(), 
outputNames, "gradient function");
                        DataIdentifier gradientsOutput = outputs.get(0);
 
                        // recreate aggregation instruction and output if needed
                        Instruction aggregationInstruction = null;
                        DataIdentifier aggregationOutput = null;
                        if(_localUpdate && _numBatchesToCompute > 1) {
-                               func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, 
false);
+                               func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
                                inputs = func.getInputParams();
                                outputs = func.getOutputParams();
                                boundInputs = inputs.stream()
@@ -495,8 +497,8 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                        .toArray(CPOperand[]::new);
                                outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                        
.collect(Collectors.toCollection(ArrayList::new));
-                               aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
-                                       func.getInputParamNames(), outputNames, 
"aggregation function");
+                               aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggFunc,
+                                       opt, boundInputs, 
func.getInputParamNames(), outputNames, "aggregation function");
                                aggregationOutput = outputs.get(0);
                        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index cc75e52..99ec9e2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -76,7 +76,8 @@ public abstract class PSWorker implements Serializable
                String[] cfn = DMLProgram.splitFunctionKey(updFunc);
                String ns = cfn[0];
                String fname = cfn[1];
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, false);
+               boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, 
fname, false);
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
                CPOperand[] boundInputs = inputs.stream()
@@ -84,7 +85,7 @@ public abstract class PSWorker implements Serializable
                        .toArray(CPOperand[]::new);
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
-               _inst = new FunctionCallCPInstruction(ns, fname, false, 
boundInputs,
+               _inst = new FunctionCallCPInstruction(ns, fname, opt, 
boundInputs,
                        func.getInputParamNames(), outputNames, "update 
function");
 
                // Check the inputs of the update function
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 96a08e3..9f5b126 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -112,7 +112,8 @@ public abstract class ParamServer
                String[] cfn = DMLProgram.splitFunctionKey(aggFunc);
                String ns = cfn[0];
                String fname = cfn[1];
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, false);
+               boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, 
fname, false);
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
 
@@ -130,7 +131,7 @@ public abstract class ParamServer
                        .toArray(CPOperand[]::new);
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
-               _inst = new FunctionCallCPInstruction(ns, fname, false, 
boundInputs,
+               _inst = new FunctionCallCPInstruction(ns, fname, opt, 
boundInputs,
                        func.getInputParamNames(), outputNames, "aggregate 
function");
        }
 
@@ -138,7 +139,8 @@ public abstract class ParamServer
                String[] cfn = DMLProgram.splitFunctionKey(valFunc);
                String ns = cfn[0];
                String fname = cfn[1];
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, false);
+               boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, 
fname, false);
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
 
@@ -157,7 +159,7 @@ public abstract class ParamServer
                        .toArray(CPOperand[]::new);
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
-               _valInst = new FunctionCallCPInstruction(ns, fname, false, 
boundInputs,
+               _valInst = new FunctionCallCPInstruction(ns, fname, opt, 
boundInputs,
                        func.getInputParamNames(), outputNames, "validate 
function");
 
                // write validation data to execution context. hyper params are 
already in ec
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 51600d2..30cfb64 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -255,7 +255,8 @@ public class ParamservUtils {
                // 1. Recompile the internal program blocks 
                recompileProgramBlocks(k, prog.getProgramBlocks(), 
forceExecTypeCP);
                // 2. Recompile the imported function blocks
-               prog.getFunctionProgramBlocks(false)
+               boolean opt = prog.getFunctionProgramBlocks(false).isEmpty();
+               prog.getFunctionProgramBlocks(opt)
                        .forEach((fname, fvalue) -> recompileProgramBlocks(k, 
fvalue.getChildBlocks(), forceExecTypeCP));
 
                // 3. Copy all functions 
@@ -273,11 +274,12 @@ public class ParamservUtils {
        
        private static Program copyProgramFunctions(Program prog) {
                Program newProg = new Program(prog.getDMLProg());
-               for( Entry<String, FunctionProgramBlock> e : 
prog.getFunctionProgramBlocks(false).entrySet() ) {
+               boolean opt = prog.getFunctionProgramBlocks(false).isEmpty();
+               for( Entry<String, FunctionProgramBlock> e : 
prog.getFunctionProgramBlocks(opt).entrySet() ) {
                        String[] parts = 
DMLProgram.splitFunctionKey(e.getKey());
                        FunctionProgramBlock fpb = ProgramConverter
                                
.createDeepCopyFunctionProgramBlock(e.getValue(), new HashSet<>(), new 
HashSet<>());
-                       newProg.addFunctionProgramBlock(parts[0], parts[1], 
fpb, false);
+                       newProg.addFunctionProgramBlock(parts[0], parts[1], 
fpb, opt);
                }
                return newProg;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 4057f73..086eb54 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -90,7 +90,7 @@ import org.apache.sysds.utils.Statistics;
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
        
-       private static final int DEFAULT_BATCH_SIZE = 64;
+       public static final int DEFAULT_BATCH_SIZE = 64;
        private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.EPOCH;
        private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
        private static final PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING = 
PSRuntimeBalancing.NONE;
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java 
b/src/main/java/org/apache/sysds/utils/Explain.java
index e1ea9d8..ae6a523 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -926,7 +926,7 @@ public class Explain
                }
        }
 
-       private static String explainFunctionCallGraph(FunctionCallGraph 
fgraph, HashSet<String> fstack, String fkey, int level)
+       public static String explainFunctionCallGraph(FunctionCallGraph fgraph, 
HashSet<String> fstack, String fkey, int level)
        {
                StringBuilder builder = new StringBuilder();
                String offset = createOffset(level);

Reply via email to