This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ac6e2a  [SYSTEMDS-2855] Rework function recompilation on entry (w/ 
rewrites)
7ac6e2a is described below

commit 7ac6e2a87ffb447dcf0f8064e979bcc2800cec38
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Feb 10 12:50:39 2021 +0100

    [SYSTEMDS-2855] Rework function recompilation on entry (w/ rewrites)
    
    This patch makes a major change to the recompilation of functions (i.e.,
    functions that have been marked during inter-procedural analysis for
    recompile-once). So far, we applied inplace recompilation for updating
    size information but without rewrites to allow a reset for future
    function invocations. This caused inconsistencies where function
    recompilation does not apply the same rewrites as normal block
    recompilation (e.g., fails to rewrite  -t(X) %*% y to -(t(X) %*% y) and
    then -t(t(y)%*%X)). Now we first apply the logic for size propagation
    (with potential reset), and if applicable then apply rewrites in a
    second pass. Furthermore, this patch also cleans up the somewhat messy
    passing of recompilation configurations via a reworked RecompileStatus
    and fixes the local size propagation of matrix multiplications to allow
    for a reset with unknown sizes (to ensure correct results).
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |  25 +--
 .../sysds/hops/recompile/RecompileStatus.java      |  53 ++++-
 .../apache/sysds/hops/recompile/Recompiler.java    | 214 +++++++++------------
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |  11 ++
 .../RewriteAlgebraicSimplificationDynamic.java     |   2 +-
 .../RewriteAlgebraicSimplificationStatic.java      |  11 +-
 .../org/apache/sysds/parser/ForStatementBlock.java |   4 +
 .../apache/sysds/parser/ParForStatementBlock.java  |   1 -
 .../controlprogram/FunctionProgramBlock.java       |   4 +-
 .../parfor/opt/OptimizationWrapper.java            |   4 +-
 .../fed/AggregateUnaryFEDInstruction.java          |   1 -
 .../federated/algorithms/FederatedLmPipeline.java  |  10 +-
 12 files changed, 188 insertions(+), 152 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 5dcc5ee..c279071 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -977,19 +977,19 @@ public class AggBinaryOp extends MultiThreadedHop
                
                //right side cached (no agg if left has just one column block)
                if(  method == MMultMethod.MAPMM_R && 
getInput().get(0).getDim2() >= 0 //known num columns
-                && getInput().get(0).getDim2() <= 
getInput().get(0).getBlocksize() ) 
-        {
-            ret = false;
-        }
-        
+                       && getInput().get(0).getDim2() <= 
getInput().get(0).getBlocksize() ) 
+               {
+                       ret = false;
+               }
+
                //left side cached (no agg if right has just one row block)
-        if(  method == MMultMethod.MAPMM_L && getInput().get(1).getDim1() >= 0 
//known num rows
-             && getInput().get(1).getDim1() <= 
getInput().get(1).getBlocksize() ) 
-        {
-                   ret = false;
-        }
-        
-        return ret;
+               if(  method == MMultMethod.MAPMM_L && 
getInput().get(1).getDim1() >= 0 //known num rows
+                       && getInput().get(1).getDim1() <= 
getInput().get(1).getBlocksize() ) 
+               {
+                       ret = false;
+               }
+
+               return ret;
        }
        
        /**
@@ -1274,6 +1274,7 @@ public class AggBinaryOp extends MultiThreadedHop
                if( isMatrixMultiply() ) {
                        setDim1(input1.getDim1());
                        setDim2(input2.getDim2());
+                       setNnz(-1); // for reset on recompile w/ unknowns 
                        if( input1.getNnz() == 0 || input2.getNnz() == 0 )
                                setNnz(0);
                }
diff --git a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java 
b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
index bdb91fd..edb03d2 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
@@ -19,21 +19,39 @@
 
 package org.apache.sysds.hops.recompile;
 
+import org.apache.sysds.hops.recompile.Recompiler.ResetType;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
 import java.util.HashMap;
 
 public class RecompileStatus 
 {
+       //immutable flags for recompilation configurations
+       private final long _tid;               // thread-id, 0 if main thread
+       private final boolean _inplace;        // in-place recompilation, false 
for rewrites
+       private final ResetType _reset;        // reset type for program 
compilation
+       private final boolean _initialCodegen; // initial codegen compilation 
(no recompilation)
+       
+       //track if parts of recompiled program still require recompilation
+       private boolean _requiresRecompile = false;
+       
+       //collection of extracted statistics for control flow reconciliation
        private final HashMap<String, DataCharacteristics> _lastTWrites;
-       private final boolean _initialCodegen;
        
        public RecompileStatus() {
-               this(false);
+               this(0, true, ResetType.NO_RESET, false);
        }
        
        public RecompileStatus(boolean initialCodegen) {
+               this(0, true, ResetType.NO_RESET, initialCodegen);
+       }
+       
+       public RecompileStatus(long tid, boolean inplace, ResetType reset, 
boolean initialCodegen) {
                _lastTWrites = new HashMap<>();
+               _tid = tid;
+               _inplace = inplace;
+               _reset = reset;
                _initialCodegen = initialCodegen;
        }
        
@@ -41,13 +59,42 @@ public class RecompileStatus
                return _lastTWrites;
        }
        
+       public long getTID() {
+               return _tid;
+       }
+       
+       public boolean hasThreadID() {
+               return ProgramBlock.isThreadID(_tid);
+       }
+       
+       public boolean isInPlace() {
+               return _inplace;
+       }
+       
+       public boolean isReset() {
+               return _reset.isReset();
+       }
+       
+       public ResetType getReset() {
+               return _reset;
+       }
+       
        public boolean isInitialCodegen() {
                return _initialCodegen;
        }
+       
+       public void trackRecompile(boolean flag) {
+               _requiresRecompile |= flag;
+       }
+       
+       public boolean requiresRecompile() {
+               return _requiresRecompile;
+       }
 
        @Override
        public Object clone() {
-               RecompileStatus ret = new RecompileStatus();
+               RecompileStatus ret = new RecompileStatus(
+                       _tid, _inplace, _reset, _initialCodegen);
                ret._lastTWrites.putAll(_lastTWrites);
                return ret;
        }
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 3b15b44..a714266 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -173,28 +173,6 @@ public class Recompiler
        {
                return recompileHopsDag(sb, hops, new ExecutionContext(vars), 
status, inplace, replaceLit, tid);
        }
-       
-       public static ArrayList<Instruction> recompileHopsDag( Hop hop, 
ExecutionContext ec, 
-                       RecompileStatus status, boolean inplace, boolean 
replaceLit, long tid ) 
-       {
-               ArrayList<Instruction> newInst = null;
-
-               //need for synchronization as we do temp changes in shared 
hops/lops
-               synchronized( hop ) {
-                       newInst = recompile(null, new 
ArrayList<>(Arrays.asList(hop)),
-                               ec, status, inplace, replaceLit, true, false, 
true, null, tid);
-               }
-               
-               // replace thread ids in new instructions
-               if( ProgramBlock.isThreadID(tid) ) //only in parfor context
-                       newInst = 
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, 
null, false, false);
-               
-               // explain recompiled instructions
-               if( DMLScript.EXPLAIN == ExplainType.RECOMPILE_RUNTIME )
-                       logExplainPred(hop, newInst);
-               
-               return newInst;
-       }
 
        public static ArrayList<Instruction> recompileHopsDag( Hop hop, 
LocalVariableMap vars, 
                        RecompileStatus status, boolean inplace, boolean 
replaceLit, long tid ) 
@@ -447,11 +425,25 @@ public class Recompiler
                        System.out.println("EXPLAIN RECOMPILE \nPRED (line 
"+hops.getBeginLine()+"):\n" + Explain.explain(inst,1));
        }
 
-       public static void recompileProgramBlockHierarchy( 
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, ResetType 
resetRecompile ) {
-               RecompileStatus status = new RecompileStatus();
+       public static void recompileProgramBlockHierarchy( 
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean inplace, 
ResetType resetRecompile ) {
+               //function recompilation via two-phase approach due to 
challenges 
+               //of unclear reconciliation of arbitrary complex control flow
+               
+               // phase 1: normal inplace=true w/o rewrite as usual, but track 
requiresRecompile
+               // (preserve variables for potential second pass, otherwise 
corrupted stats)
+               RecompileStatus status1 = new RecompileStatus(tid, true, 
resetRecompile, false);
                synchronized( pbs ) {
                        for( ProgramBlock pb : pbs )
-                               rRecompileProgramBlock(pb, vars, status, tid, 
resetRecompile);
+                               rRecompileProgramBlock(pb, vars, status1);
+               
+                       // phase 2: if called with inplace-false, run a second 
in-place=false pass in
+                       // order to apply rewrites (at this point sizes are 
already propagated, but for
+                       // correctness we call it with an empty symbol table to 
avoid invalid size updates)
+                       if( !status1.requiresRecompile() && !inplace ) {
+                               RecompileStatus status2 = new 
RecompileStatus(tid, false, resetRecompile, false);
+                               for( ProgramBlock pb : pbs )
+                                       rRecompileProgramBlock(pb, new 
LocalVariableMap(), status2);
+                       }
                }
        }
        
@@ -635,27 +627,26 @@ public class Recompiler
        // private helper functions //
        //////////////////////////////
        
-       private static void rRecompileProgramBlock( ProgramBlock pb, 
LocalVariableMap vars, 
-               RecompileStatus status, long tid, ResetType resetRecompile ) 
+       private static void rRecompileProgramBlock( ProgramBlock pb, 
LocalVariableMap vars, RecompileStatus status )
        {
                if (pb instanceof WhileProgramBlock) {
                        WhileProgramBlock wpb = (WhileProgramBlock)pb;
                        WhileStatementBlock wsb = (WhileStatementBlock) 
wpb.getStatementBlock();
                        //recompile predicate
-                       recompileWhilePredicate(wpb, wsb, vars, status, tid, 
resetRecompile);
+                       recompileWhilePredicate(wpb, wsb, vars, status);
                        //remove updated scalars because in loop
                        removeUpdatedScalars(vars, wsb); 
                        //copy vars for later compare
                        LocalVariableMap oldVars = (LocalVariableMap) 
vars.clone();
                        RecompileStatus oldStatus = (RecompileStatus) 
status.clone();
                        for (ProgramBlock pb2 : wpb.getChildBlocks())
-                               rRecompileProgramBlock(pb2, vars, status, tid, 
resetRecompile);
+                               rRecompileProgramBlock(pb2, vars, status);
                        if( reconcileUpdatedCallVarsLoops(oldVars, vars, wsb) 
                                | reconcileUpdatedCallVarsLoops(oldStatus, 
status, wsb) ) {
                                //second pass with unknowns if required
-                               recompileWhilePredicate(wpb, wsb, vars, status, 
tid, resetRecompile);
+                               recompileWhilePredicate(wpb, wsb, vars, status);
                                for (ProgramBlock pb2 : wpb.getChildBlocks())
-                                       rRecompileProgramBlock(pb2, vars, 
status, tid, resetRecompile);
+                                       rRecompileProgramBlock(pb2, vars, 
status);
                        }
                        removeUpdatedScalars(vars, wsb);
                }
@@ -663,16 +654,16 @@ public class Recompiler
                        IfProgramBlock ipb = (IfProgramBlock)pb;
                        IfStatementBlock isb = 
(IfStatementBlock)ipb.getStatementBlock();
                        //recompile predicate
-                       recompileIfPredicate(ipb, isb, vars, status, tid, 
resetRecompile);
+                       recompileIfPredicate(ipb, isb, vars, status);
                        //copy vars for later compare
                        LocalVariableMap oldVars = (LocalVariableMap) 
vars.clone();
                        LocalVariableMap varsElse = (LocalVariableMap) 
vars.clone();
                        RecompileStatus oldStatus = 
(RecompileStatus)status.clone();
                        RecompileStatus statusElse = 
(RecompileStatus)status.clone();
                        for( ProgramBlock pb2 : ipb.getChildBlocksIfBody() )
-                               rRecompileProgramBlock(pb2, vars, status, tid, 
resetRecompile);
+                               rRecompileProgramBlock(pb2, vars, status);
                        for( ProgramBlock pb2 : ipb.getChildBlocksElseBody() )
-                               rRecompileProgramBlock(pb2, varsElse, 
statusElse, tid, resetRecompile);
+                               rRecompileProgramBlock(pb2, varsElse, 
statusElse);
                        reconcileUpdatedCallVarsIf(oldVars, vars, varsElse, 
isb);
                        reconcileUpdatedCallVarsIf(oldStatus, status, 
statusElse, isb);
                        removeUpdatedScalars(vars, ipb.getStatementBlock());
@@ -681,20 +672,20 @@ public class Recompiler
                        ForProgramBlock fpb = (ForProgramBlock)pb;
                        ForStatementBlock fsb = (ForStatementBlock) 
fpb.getStatementBlock();
                        //recompile predicates
-                       recompileForPredicates(fpb, fsb, vars, status, tid, 
resetRecompile);
+                       recompileForPredicates(fpb, fsb, vars, status);
                        //remove updated scalars because in loop
                        removeUpdatedScalars(vars, fpb.getStatementBlock());
                        //copy vars for later compare
                        LocalVariableMap oldVars = (LocalVariableMap) 
vars.clone();
                        RecompileStatus oldStatus = (RecompileStatus) 
status.clone();
                        for( ProgramBlock pb2 : fpb.getChildBlocks() )
-                               rRecompileProgramBlock(pb2, vars, status, tid, 
resetRecompile);
+                               rRecompileProgramBlock(pb2, vars, status);
                        if( reconcileUpdatedCallVarsLoops(oldVars, vars, fsb) 
                                | reconcileUpdatedCallVarsLoops(oldStatus, 
status, fsb)) {
                                //second pass with unknowns if required
-                               recompileForPredicates(fpb, fsb, vars, status, 
tid, resetRecompile);
+                               recompileForPredicates(fpb, fsb, vars, status);
                                for( ProgramBlock pb2 : fpb.getChildBlocks() )
-                                       rRecompileProgramBlock(pb2, vars, 
status, tid, resetRecompile);
+                                       rRecompileProgramBlock(pb2, vars, 
status);
                        }
                        removeUpdatedScalars(vars, fpb.getStatementBlock());
                }
@@ -711,20 +702,22 @@ public class Recompiler
                        
                        //recompile all for stats propagation and recompile 
flags
                        tmp = Recompiler.recompileHopsDag(
-                               sb, sb.getHops(), vars, status, true, false, 
tid);
+                               sb, sb.getHops(), vars, status, 
status.isInPlace(), false, status.getTID());
                        bpb.setInstructions( tmp );
                        
                        //propagate stats across hops (should be executed on 
clone of vars)
-                       Recompiler.extractDAGOutputStatistics(sb.getHops(), 
vars);
+                       if( status.isInPlace() )
+                               
Recompiler.extractDAGOutputStatistics(sb.getHops(), vars);
                        
                        //reset recompilation flags (w/ special handling 
functions)
                        if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
                                && !containsRootFunctionOp(sb.getHops())
-                               && resetRecompile.isReset() )
+                               && status.isReset() )
                        {
-                               Hop.resetRecompilationFlag(sb.getHops(), 
ExecType.CP, resetRecompile);
+                               Hop.resetRecompilationFlag(sb.getHops(), 
ExecType.CP, status.getReset());
                                sb.updateRecompilationFlag();
                        }
+                       status.trackRecompile(sb.requiresRecompilation());
                }
        }
        
@@ -952,91 +945,70 @@ public class Recompiler
        
        //helper functions for predicate recompile
        
-       private static void recompileIfPredicate( IfProgramBlock ipb, 
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid, 
ResetType resetRecompile ) 
-       {
-               if( isb == null )
+       private static void recompileIfPredicate( IfProgramBlock ipb, 
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status ) {
+               if( isb == null || isb.getPredicateHops() == null )
                        return;
-               
                Hop hops = isb.getPredicateHops();
-               if( hops != null ) {
-                       ArrayList<Instruction> tmp = recompileHopsDag(
-                               hops, vars, status, true, false, tid);
-                       ipb.setPredicate( tmp );
-                       if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
-                               && resetRecompile.isReset() ) {
-                               Hop.resetRecompilationFlag(hops, ExecType.CP, 
resetRecompile);
-                               isb.updatePredicateRecompilationFlag();
-                       }
-               }
+               ArrayList<Instruction> tmp = recompileHopsDag(
+                       hops, vars, status, status.isInPlace(), false, 
status.getTID());
+               ipb.setPredicate( tmp );
+               if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs && 
status.isReset() ) {
+                       Hop.resetRecompilationFlag(hops, ExecType.CP, 
status.getReset());
+                       isb.updatePredicateRecompilationFlag();
+               }
+               status.trackRecompile(isb.requiresPredicateRecompilation());
        }
        
-       private static void recompileWhilePredicate( WhileProgramBlock wpb, 
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long 
tid, ResetType resetRecompile ) {
-               if( wsb == null )
+       private static void recompileWhilePredicate( WhileProgramBlock wpb, 
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status ) {
+               if( wsb == null || wsb.getPredicateHops() == null )
                        return;
-               
                Hop hops = wsb.getPredicateHops();
-               if( hops != null ) {
-                       ArrayList<Instruction> tmp = recompileHopsDag(
-                               hops, vars, status, true, false, tid);
-                       wpb.setPredicate( tmp );
-                       if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
-                               && resetRecompile.isReset() ) {
-                               Hop.resetRecompilationFlag(hops, ExecType.CP, 
resetRecompile);
-                               wsb.updatePredicateRecompilationFlag();
-                       }
-               }
+               ArrayList<Instruction> tmp = recompileHopsDag(
+                       hops, vars, status, status.isInPlace(), false, 
status.getTID());
+               wpb.setPredicate( tmp );
+               if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs && 
status.isReset() ) {
+                       Hop.resetRecompilationFlag(hops, ExecType.CP, 
status.getReset());
+                       wsb.updatePredicateRecompilationFlag();
+               }
+               status.trackRecompile(wsb.requiresPredicateRecompilation());
        }
        
-       private static void recompileForPredicates( ForProgramBlock fpb, 
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid, 
ResetType resetRecompile ) {
-               if( fsb != null )
-               {
-                       Hop fromHops = fsb.getFromHops();
-                       Hop toHops = fsb.getToHops();
-                       Hop incrHops = fsb.getIncrementHops();
-                       
-                       //handle recompilation flags
-                       if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
-                               && resetRecompile.isReset() ) 
-                       {
-                               if( fromHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               fromHops, vars, status, true, 
false, tid);
-                                       fpb.setFromInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(fromHops,ExecType.CP, resetRecompile);
-                               }
-                               if( toHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               toHops, vars, status, true, 
false, tid);
-                                       fpb.setToInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(toHops,ExecType.CP, resetRecompile);
-                               }
-                               if( incrHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               incrHops, vars, status, true, 
false, tid);
-                                       fpb.setIncrementInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(incrHops,ExecType.CP, resetRecompile);
-                               }
-                               fsb.updatePredicateRecompilationFlags();
-                       }
-                       else //no reset of recompilation flags
-                       {
-                               if( fromHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               fromHops, vars, status, true, 
false, tid);
-                                       fpb.setFromInstructions(tmp);
-                               }
-                               if( toHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               toHops, vars, status, true, 
false, tid);
-                                       fpb.setToInstructions(tmp);
-                               }
-                               if( incrHops != null ) {
-                                       ArrayList<Instruction> tmp = 
recompileHopsDag(
-                                               incrHops, vars, status, true, 
false, tid);
-                                       fpb.setIncrementInstructions(tmp);
-                               }
-                       }
+       private static void recompileForPredicates( ForProgramBlock fpb, 
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status ) {
+               if( fsb == null )
+                       return;
+               
+               Hop fromHops = fsb.getFromHops();
+               Hop toHops = fsb.getToHops();
+               Hop incrHops = fsb.getIncrementHops();
+               
+               // recompile predicates
+               if( fromHops != null ) {
+                       ArrayList<Instruction> tmp = recompileHopsDag(
+                               fromHops, vars, status, status.isInPlace(), 
false, status.getTID());
+                       fpb.setFromInstructions(tmp);
                }
+               if( toHops != null ) {
+                       ArrayList<Instruction> tmp = recompileHopsDag(
+                               toHops, vars, status, status.isInPlace(), 
false, status.getTID());
+                       fpb.setToInstructions(tmp);
+               }
+               if( incrHops != null ) {
+                       ArrayList<Instruction> tmp = recompileHopsDag(
+                               incrHops, vars, status, status.isInPlace(), 
false, status.getTID());
+                       fpb.setIncrementInstructions(tmp);
+               }
+               
+               //handle recompilation flags
+               if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs && 
status.isReset() ) {
+                       if( fromHops != null )
+                               Hop.resetRecompilationFlag(fromHops, 
ExecType.CP, status.getReset());
+                       if( toHops != null )
+                               Hop.resetRecompilationFlag(toHops, ExecType.CP, 
status.getReset());
+                       if( incrHops != null )
+                               Hop.resetRecompilationFlag(incrHops, 
ExecType.CP, status.getReset());
+                       fsb.updatePredicateRecompilationFlags();
+               }
+               status.trackRecompile(fsb.requiresPredicateRecompilation());
        }
        
        public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long 
tid, HashSet<String> fnStack, ExecType et ) {
@@ -1142,13 +1114,11 @@ public class Recompiler
                }
        }
        
-       public static void extractDAGOutputStatistics(ArrayList<Hop> hops, 
LocalVariableMap vars)
-       {
+       public static void extractDAGOutputStatistics(ArrayList<Hop> hops, 
LocalVariableMap vars) {
                extractDAGOutputStatistics(hops, vars, true);
        }
        
-       public static void extractDAGOutputStatistics(ArrayList<Hop> hops, 
LocalVariableMap vars, boolean overwrite)
-       {
+       public static void extractDAGOutputStatistics(ArrayList<Hop> hops, 
LocalVariableMap vars, boolean overwrite) {
                for( Hop hop : hops ) //for all hop roots
                        extractDAGOutputStatistics(hop, vars, overwrite);
        }
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 5127384..62ba4ae 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -22,6 +22,8 @@ package org.apache.sysds.hops.rewrite;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
@@ -48,11 +50,20 @@ import org.apache.sysds.runtime.lineage.LineageCacheConfig;
  */
 public class ProgramRewriter
 {
+       private static final boolean LDEBUG = false; //internal local debug 
level
        private static final boolean CHECK = false;
        
        private ArrayList<HopRewriteRule> _dagRuleSet = null;
        private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;
 
+       static {
+               // for internal debugging only
+               if( LDEBUG ) {
+                       Logger.getLogger("org.apache.sysds.hops.rewrite")
+                               .setLevel(Level.DEBUG);
+               }
+       }
+       
        public ProgramRewriter() {
                // by default which is used during initial compile 
                // apply all (static and dynamic) rewrites
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 519c400..71a7240 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2495,7 +2495,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                hi = minus;
                                
                                LOG.debug("Applied reorderMinusMatrixMult (line 
"+hi.getBeginLine()+").");
-                       }       
+                       }
                }
                
                return hi;
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5ab97bf..f0d9dea 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -254,7 +254,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
         * handle removal of unnecessary binary operations
         * 
         * X/1 or X*1 or 1*X or X-0 -> X
-        * -1*X or X*-1-> -X            
+        * -1*X or X*-1-> -X
         * 
         * @param parent parent high-level operator
         * @param hi high-level operator
@@ -777,7 +777,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
         */
        private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop 
hi, int pos )
        {
-               
                if( hi instanceof BinaryOp )
                {
                        BinaryOp bop = (BinaryOp)hi;
@@ -810,9 +809,9 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                hi = mult;
                                                applied = true;
                                                
-                                               LOG.debug("Applied 
simplifyDistributiveBinaryOperation1");
-                                       }                                       
-                               }       
+                                               LOG.debug("Applied 
simplifyDistributiveBinaryOperation1 (line "+hi.getBeginLine()+").");
+                                       }
+                               }
                                
                                if( !applied && HopRewriteUtils.isBinary(right, 
OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X
                                {
@@ -831,7 +830,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                
HopRewriteUtils.cleanupUnreferenced(hi, right);
                                                hi = mult;
 
-                                               LOG.debug("Applied 
simplifyDistributiveBinaryOperation2");
+                                               LOG.debug("Applied 
simplifyDistributiveBinaryOperation2 (line "+hi.getBeginLine()+").");
                                        }
                                }       
                        }
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 092fbb7..c2686da 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -410,6 +410,10 @@ public class ForStatementBlock extends StatementBlock
                        _requiresToRecompile = 
Recompiler.requiresRecompilation(getToHops());
                        _requiresIncrementRecompile = 
Recompiler.requiresRecompilation(getIncrementHops());
                }
+               return requiresPredicateRecompilation();
+       }
+       
+       public boolean requiresPredicateRecompilation() {
                return (_requiresFromRecompile || _requiresToRecompile || 
_requiresIncrementRecompile);
        }
        
diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index af88d60..9219ba1 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -152,7 +152,6 @@ public class ParForStatementBlock extends ForStatementBlock
                if( LDEBUG ) {
                        
Logger.getLogger("org.apache.sysds.parser.ParForStatementBlock")
                                .setLevel(Level.TRACE);
-                       System.out.println();
                }
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index cd7f0cc..32ba97f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -116,8 +116,8 @@ public class FunctionProgramBlock extends ProgramBlock 
implements FunctionBlock
                                boolean codegen = 
ConfigurationManager.isCodegenEnabled();
                                boolean singlenode = 
DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
                                ResetType reset = (codegen || singlenode) ? 
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
-                               
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, reset);
-                               
+                               
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false, 
reset);
+
                                if( DMLScript.STATISTICS ){
                                        long t1 = System.nanoTime();
                                        
Statistics.incrementFunRecompileTime(t1-t0);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index a329f9c..62f7e41 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -185,7 +185,7 @@ public class OptimizationWrapper
                                LocalVariableMap tmp = (LocalVariableMap) 
ec.getVariables().clone();
                                ResetType reset = 
ConfigurationManager.isCodegenEnabled() ? 
                                        ResetType.RESET_KNOWN_DIMS : 
ResetType.RESET;
-                               
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
+                               
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true, 
reset);
                                
                                //inter-procedural optimization (based on 
previous recompilation)
                                if( pb.hasFunctions() ) {
@@ -201,7 +201,7 @@ public class OptimizationWrapper
                                                        //reset recompilation 
flags according to recompileOnce because it is only safe if function is 
recompileOnce 
                                                        //because then 
recompiled for every execution (otherwise potential issues if func also called 
outside parfor)
                                                        ResetType reset2 = 
fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
-                                                       
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new 
LocalVariableMap(), 0, reset2);
+                                                       
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new 
LocalVariableMap(), 0, true, reset2);
                                                }
                                        }
                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 4fbe4e6..097b678 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -21,7 +21,6 @@ package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.concurrent.Future;
 
-import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
index 8d4ac8d..14e4de1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
@@ -28,6 +28,8 @@ import 
org.apache.sysds.runtime.transform.encode.EncoderRecode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Test;
 
 @net.jcip.annotations.NotThreadSafe
@@ -110,7 +112,7 @@ public class FederatedLmPipeline extends AutomatedTestBase {
 
                        TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                        loadTestConfiguration(config);
-
+                       
                        // Run reference dml script with normal matrix
                        fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                        programArgs = new String[] {"-args", input("X1"), 
input("X2"), input("X3"), input("X4"), input("Y"),
@@ -119,7 +121,7 @@ public class FederatedLmPipeline extends AutomatedTestBase {
 
                        // Run actual dml script with federated matrix
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] {"-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                       programArgs = new String[] {"-stats", "-nvargs", 
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
                                "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
                                "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
                                "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + (cols + 1),
@@ -129,6 +131,10 @@ public class FederatedLmPipeline extends AutomatedTestBase 
{
                        // compare via files
                        compareResults(1e-2);
                        TestUtils.shutdownThreads(t1, t2, t3, t4);
+                       
+                       // check correct federated operations
+                       
Assert.assertTrue(Statistics.getCPHeavyHitterCount("fed_mmchain")>10);
+                       
Assert.assertTrue(Statistics.getCPHeavyHitterCount("fed_ba+*")==3);
                }
                finally {
                        resetExecMode(oldExec);

Reply via email to