This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e9a3cb36b0c637cabe739a9c9340b38789db686f
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 4 21:17:08 2023 +0200

    [SYSTEMDS-3514] Fix parfor checkpoint injection and recompilation
    
    This patch fixes an issue where every parfor optimization added one
    more statement block with additional checkpoint operators. Additionally
    we make selected modifications to the compilation of ifelse and ctable.
    Together these changes completely eliminated the decisionTree
    recompilation overhead:
    
    OLD:
    Total execution time:            17.148 sec.
    HOP DAGs recompiled (PRED, SB):  136/71630.
    HOP DAGs recompile time:         26.572 sec.
    Functions recompiled:            58.
    Functions recompile time:        0.231 sec.
    
    NEW:
    Total execution time:            6.314 sec.
    HOP DAGs recompiled (PRED, SB):  136/321.
    HOP DAGs recompile time:         0.159 sec.
    Functions recompiled:            58.
    Functions recompile time:        0.623 sec.
---
 src/main/java/org/apache/sysds/hops/DataGenOp.java      | 12 ++++++++++--
 src/main/java/org/apache/sysds/hops/TernaryOp.java      | 13 ++++++-------
 .../org/apache/sysds/hops/rewrite/HopRewriteUtils.java  |  4 ++++
 .../org/apache/sysds/lops/ParameterizedBuiltin.java     | 17 +++++++----------
 .../runtime/controlprogram/ParForProgramBlock.java      |  7 +++++--
 .../controlprogram/parfor/opt/OptimizationWrapper.java  | 12 ++++++------
 .../runtime/controlprogram/parfor/opt/Optimizer.java    |  3 ++-
 .../controlprogram/parfor/opt/OptimizerConstrained.java |  5 +++--
 .../controlprogram/parfor/opt/OptimizerRuleBased.java   |  5 +++--
 9 files changed, 46 insertions(+), 32 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java 
b/src/main/java/org/apache/sysds/hops/DataGenOp.java
index 77a05cbe75..8bab19be1b 100644
--- a/src/main/java/org/apache/sysds/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java
@@ -29,6 +29,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.DataGen;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
@@ -348,16 +349,15 @@ public class DataGenOp extends MultiThreadedHop
                }
                else if (_op == OpOpDG.SEQ ) 
                {
+                       //bounds computation
                        input1 = 
getInput().get(_paramIndexMap.get(Statement.SEQ_FROM));
                        input2 = 
getInput().get(_paramIndexMap.get(Statement.SEQ_TO)); 
                        input3 = 
getInput().get(_paramIndexMap.get(Statement.SEQ_INCR)); 
 
                        double from = computeBoundsInformation(input1);
                        boolean fromKnown = (from != Double.MAX_VALUE);
-                       
                        double to = computeBoundsInformation(input2);
                        boolean toKnown = (to != Double.MAX_VALUE);
-                       
                        double incr = computeBoundsInformation(input3);
                        boolean incrKnown = (incr != Double.MAX_VALUE);
                        if(  fromKnown && toKnown && incr == 1) {
@@ -370,6 +370,14 @@ public class DataGenOp extends MultiThreadedHop
                                setDim2(1);
                                _incr = incr;
                        }
+                       
+                       //leverage high-probability information of output
+                       if( getDim1() == -1 && getParent().size() == 1 ) {
+                               Hop p = getParent().get(0);
+                               p.refreshSizeInformation();
+                               setDim1((HopRewriteUtils.isTernary(p, 
OpOp3.CTABLE)
+                                       && p.getDim1() >= 0 ) ? p.getDim1() : 
-1);
+                       }
                }
                else if (_op == OpOpDG.TIME ) {
                        setDim1(0);
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 58c1f564ee..e6387b429c 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.common.Types.ValueType;
@@ -346,8 +347,8 @@ public class TernaryOp extends MultiThreadedHop
                setLineNumbers(plusmult);
                setLops(plusmult);
                
-               if( _op==OpOp3.IFELSE && getInput(0).getDataType().isScalar() )
-                       setRequiresRecompile(); //good chance of removing ops
+               if( _op==OpOp3.IFELSE && HopRewriteUtils.isData(getInput(0), 
OpOpData.TRANSIENTREAD, DataType.SCALAR))
+                       setRequiresRecompile(); //good chance of removing ops 
via literal replacements + rewrites
        }
        
        @Override
@@ -518,7 +519,7 @@ public class TernaryOp extends MultiThreadedHop
                // additional condition: when execType=CP and additional 
dimension inputs 
                // are provided (and those values are unknown at initial 
compile time).
                setRequiresRecompileIfNecessary();
-               if( ConfigurationManager.isDynamicRecompilation() && 
!dimsKnown(true) 
+               if( ConfigurationManager.isDynamicRecompilation() && 
!dimsKnown() 
                        && _etype == ExecType.CP && _dimInputsPresent) {
                        setRequiresRecompile();
                }
@@ -572,10 +573,8 @@ public class TernaryOp extends MultiThreadedHop
                                                
                                                // if output dimensions are 
provided, update _dim1 and _dim2
                                                if( getInput().size() >= 5 ) {
-                                                       if( getInput().get(3) 
instanceof LiteralOp )
-                                                               setDim1( 
HopRewriteUtils.getIntValueSafe((LiteralOp)getInput().get(3)) );
-                                                       if( getInput().get(4) 
instanceof LiteralOp )
-                                                               setDim2( 
HopRewriteUtils.getIntValueSafe((LiteralOp)getInput().get(4)) );
+                                                       
refreshRowsParameterInformation(getInput(3));
+                                                       
refreshColsParameterInformation(getInput(4));
                                                }
                                        }
 
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index d10a43e810..210a9f152b 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1133,6 +1133,10 @@ public class HopRewriteUtils {
                return hop instanceof DataOp && ((DataOp)hop).getOp()==type;
        }
        
+       public static boolean isData(Hop hop, OpOpData type, DataType dt) {
+               return isData(hop, type) && hop.getDataType()==dt;
+       }
+       
        public static boolean isBinaryMatrixColVectorOperation(Hop hop) {
                return hop instanceof BinaryOp 
                        && hop.getInput().get(0).getDataType().isMatrix() && 
hop.getInput().get(1).getDataType().isMatrix()
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java 
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index fdf23a7b62..fc9e60419c 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -135,14 +135,14 @@ public class ParameterizedBuiltin extends Lop
                        case LOWER_TRI: {
                                sb.append( "lowertri" );
                                sb.append( OPERAND_DELIMITOR );
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        
                        case UPPER_TRI: {
                                sb.append( "uppertri" );
                                sb.append( OPERAND_DELIMITOR );
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        
@@ -178,25 +178,25 @@ public class ParameterizedBuiltin extends Lop
                        case PARAMSERV: { 
                                sb.append(_operation.name().toLowerCase()); 
//opcode
                                sb.append(OPERAND_DELIMITOR);
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        case AUTODIFF: {
                                sb.append("autoDiff"); //opcode
                                sb.append(OPERAND_DELIMITOR);
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        case LIST: {
                                sb.append("nvlist"); //opcode
                                sb.append(OPERAND_DELIMITOR);
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        case TOSTRING: {
                                sb.append("toString"); //opcode
                                sb.append(OPERAND_DELIMITOR);
-                               sb.append(compileGenericParamMap(_inputParams));
+                               compileGenericParamMap(sb, _inputParams);
                                break;
                        }
                        
@@ -243,8 +243,7 @@ public class ParameterizedBuiltin extends Lop
                return sb.toString();
        }
        
-       private static String compileGenericParamMap(HashMap<String, Lop> 
params) {
-               StringBuilder sb = InstructionUtils.getStringBuilder();
+       private static void compileGenericParamMap(StringBuilder sb, 
HashMap<String, Lop> params) {
                for ( Entry<String, Lop> e : params.entrySet() ) {
                        sb.append(e.getKey());
                        sb.append(NAME_VALUE_SEPARATOR);
@@ -254,7 +253,5 @@ public class ParameterizedBuiltin extends Lop
                                sb.append( e.getValue().prepScalarLabel() );
                        sb.append(OPERAND_DELIMITOR);
                }
-               
-               return sb.toString();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index b287b4957c..4df1a7052e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -359,6 +359,7 @@ public class ParForProgramBlock extends ForProgramBlock
        protected long _ID = -1;
        protected int _IDPrefix = -1;
        protected boolean _monitorReport = false;
+       protected int _numRuns = -1;
        
        // local parworker data
        protected HashMap<Long,ArrayList<ProgramBlock>> _pbcache = null;
@@ -422,6 +423,7 @@ public class ParForProgramBlock extends ForProgramBlock
                
                //created profiling report after parfor exec
                _monitorReport = _monitor;
+               _numRuns = 0;
                
                //materialized meta data (reused for all invocations)
                _hasFunctions = 
ProgramRecompiler.containsAtLeastOneFunction(this);
@@ -574,7 +576,7 @@ public class ParForProgramBlock extends ForProgramBlock
        public void execute(ExecutionContext ec)
        {
                ParForStatementBlock sb = 
(ParForStatementBlock)getStatementBlock();
-
+               
                // evaluate from, to, incr only once (assumption: known at for 
entry)
                ScalarObject from0 = executePredicateInstructions(1, 
_fromInstructions, ec, false);
                ScalarObject to0   = executePredicateInstructions(2, 
_toInstructions, ec, false);
@@ -608,7 +610,7 @@ public class ParForProgramBlock extends ForProgramBlock
                ///////
                if( _optMode != POptMode.NONE ) {
                        OptimizationWrapper.setLogLevel(_optLogLevel); //set 
optimizer log level
-                       OptimizationWrapper.optimize(_optMode, sb, this, ec, 
_monitor); //core optimize
+                       OptimizationWrapper.optimize(_optMode, sb, this, ec, 
_monitor, _numRuns); //core optimize
                }
 
                ///////
@@ -704,6 +706,7 @@ public class ParForProgramBlock extends ForProgramBlock
                //print profiling report (only if top-level parfor because 
otherwise in parallel context)
                if( _monitorReport )
                        LOG.info("\n"+StatisticMonitor.createReport());
+               _numRuns ++;
                
                //reset flags/modifications made by optimizer
                //TODO reset of hop parallelism constraint (e.g., ba+*)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index 05681982da..6dbc952172 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -89,8 +89,9 @@ public class OptimizationWrapper
         * @param pb parfor program block
         * @param ec execution context
         * @param monitor ?
+        * @param numRuns number of optimizations performed so far
         */
-       public static void optimize( POptMode type, ParForStatementBlock sb, 
ParForProgramBlock pb, ExecutionContext ec, boolean monitor ) 
+       public static void optimize( POptMode type, ParForStatementBlock sb, 
ParForProgramBlock pb, ExecutionContext ec, boolean monitor, int numRuns ) 
        {
                Timing time = new Timing(true);
                
@@ -103,7 +104,7 @@ public class OptimizationWrapper
                double cm = InfrastructureAnalyzer.getCmMax() * 
OptimizerUtils.MEM_UTIL_FACTOR; 
                
                //execute optimizer
-               optimize( type, ck, cm, sb, pb, ec, monitor );
+               optimize( type, ck, cm, sb, pb, ec, monitor, numRuns );
                
                double timeVal = time.stop();
                LOG.debug("ParFOR Opt: Finished optimization for 
PARFOR("+pb.getID()+") in "+timeVal+"ms.");
@@ -118,7 +119,7 @@ public class OptimizationWrapper
        }
 
        @SuppressWarnings("unused")
-       private static void optimize( POptMode otype, int ck, double cm, 
ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean 
monitor ) 
+       private static void optimize( POptMode otype, int ck, double cm, 
ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean 
monitor, int numRuns ) 
        {
                Timing time = new Timing(true);
                
@@ -161,7 +162,6 @@ public class OptimizationWrapper
                                throw new DMLRuntimeException(ex);
                        }
                        
-                       
                        //program rewrites (e.g., constant folding, branch 
removal) according to replaced literals
                        try {
                                ProgramRewriter rewriter = 
createProgramRewriterWithRuleSets();
@@ -225,7 +225,7 @@ public class OptimizationWrapper
                LOG.trace("ParFOR Opt: Created cost estimator ("+cmtype+")");
                
                //core optimize
-               opt.optimize(sb, pb, tree, est, ec);
+               opt.optimize(sb, pb, tree, est, numRuns, ec);
                LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" 
+ tree.explain(false));
                
                //assert plan correctness
@@ -238,7 +238,7 @@ public class OptimizationWrapper
                                throw new DMLRuntimeException("Failed to check 
program correctness.", ex);
                        }
                }
-               
+
                long ltime = (long) time.stop();
                LOG.trace("ParFOR Opt: Optimized plan in "+ltime+"ms.");
                if( DMLScript.STATISTICS )
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
index d04f324abd..fdf7f23556 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
@@ -62,10 +62,11 @@ public abstract class Optimizer
         * @param pb parfor program block
         * @param plan  complete plan of a top-level parfor
         * @param est cost estimator
+        * @param numRuns
         * @param ec execution context
         * @return true if plan changed, false otherwise
         */
-       public abstract boolean optimize(ParForStatementBlock sb, 
ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec);
+       public abstract boolean optimize(ParForStatementBlock sb, 
ParForProgramBlock pb, OptTree plan, CostEstimator est, int numRuns, 
ExecutionContext ec);
        
        public abstract PlanInputType getPlanInputType();
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
index c039f932cb..0665357ad1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
@@ -75,7 +75,7 @@ public class OptimizerConstrained extends OptimizerRuleBased {
         * (no use of sb, direct change of pb).
         */
        @Override
-       public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, 
OptTree plan, CostEstimator est, ExecutionContext ec)
+       public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, 
OptTree plan, CostEstimator est, int numRuns, ExecutionContext ec)
        {
                LOG.debug("--- "+getOptMode()+" OPTIMIZER -------");
                _cost = est;
@@ -179,7 +179,8 @@ public class OptimizerConstrained extends 
OptimizerRuleBased {
                        super.rewriteSetInPlaceResultIndexing(pn, _cost, 
ec.getVariables(), inplaceResultVars, ec);
 
                        //rewrite 17: checkpoint injection for parfor loop body
-                       super.rewriteInjectSparkLoopCheckpointing( pn );
+                       if( numRuns <= 0 ) //only on first
+                               super.rewriteInjectSparkLoopCheckpointing( pn );
 
                        //rewrite 18: repartition read-only inputs for zipmm 
                        super.rewriteInjectSparkRepartition( pn, 
ec.getVariables() );
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 13b846ea02..fb97c6a15c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -193,7 +193,7 @@ public class OptimizerRuleBased extends Optimizer {
         * (no use of sb, direct change of pb).
         */
        @Override
-       public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, 
OptTree plan, CostEstimator est, ExecutionContext ec) 
+       public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, 
OptTree plan, CostEstimator est, int numRuns, ExecutionContext ec) 
        {
                LOG.debug("--- "+getOptMode()+" OPTIMIZER -------");
 
@@ -298,7 +298,8 @@ public class OptimizerRuleBased extends Optimizer {
                        rewriteSetInPlaceResultIndexing(pn, _cost, 
ec.getVariables(), inplaceResultVars, ec);
                        
                        //rewrite 17: checkpoint injection for parfor loop body
-                       rewriteInjectSparkLoopCheckpointing( pn );
+                       if( numRuns <= 0 ) //only on first
+                               rewriteInjectSparkLoopCheckpointing( pn );
                        
                        //rewrite 18: repartition read-only inputs for zipmm 
                        rewriteInjectSparkRepartition( pn, ec.getVariables() );

Reply via email to