This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e9a3cb36b0c637cabe739a9c9340b38789db686f Author: Matthias Boehm <[email protected]> AuthorDate: Tue Apr 4 21:17:08 2023 +0200 [SYSTEMDS-3514] Fix parfor checkpoint injection and recompilation This patch fixes an issue where every parfor optimization added one more statement block with additional checkpoint operators. Additionally we make selected modifications to the compilation of ifelse and ctable. Together these changes completely eliminated the decisionTree recompilation overhead: OLD: Total execution time: 17.148 sec. HOP DAGs recompiled (PRED, SB): 136/71630. HOP DAGs recompile time: 26.572 sec. Functions recompiled: 58. Functions recompile time: 0.231 sec. NEW: Total execution time: 6.314 sec. HOP DAGs recompiled (PRED, SB): 136/321. HOP DAGs recompile time: 0.159 sec. Functions recompiled: 58. Functions recompile time: 0.623 sec. --- src/main/java/org/apache/sysds/hops/DataGenOp.java | 12 ++++++++++-- src/main/java/org/apache/sysds/hops/TernaryOp.java | 13 ++++++------- .../org/apache/sysds/hops/rewrite/HopRewriteUtils.java | 4 ++++ .../org/apache/sysds/lops/ParameterizedBuiltin.java | 17 +++++++---------- .../runtime/controlprogram/ParForProgramBlock.java | 7 +++++-- .../controlprogram/parfor/opt/OptimizationWrapper.java | 12 ++++++------ .../runtime/controlprogram/parfor/opt/Optimizer.java | 3 ++- .../controlprogram/parfor/opt/OptimizerConstrained.java | 5 +++-- .../controlprogram/parfor/opt/OptimizerRuleBased.java | 5 +++-- 9 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java b/src/main/java/org/apache/sysds/hops/DataGenOp.java index 77a05cbe75..8bab19be1b 100644 --- a/src/main/java/org/apache/sysds/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java @@ -29,6 +29,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.lops.DataGen; import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.common.Types.OpOp3; import org.apache.sysds.parser.DataExpression; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.parser.Statement; @@ -348,16 +349,15 @@ public class DataGenOp extends MultiThreadedHop } else if (_op == OpOpDG.SEQ ) { + //bounds computation input1 = getInput().get(_paramIndexMap.get(Statement.SEQ_FROM)); input2 = getInput().get(_paramIndexMap.get(Statement.SEQ_TO)); input3 = getInput().get(_paramIndexMap.get(Statement.SEQ_INCR)); double from = computeBoundsInformation(input1); boolean fromKnown = (from != Double.MAX_VALUE); - double to = computeBoundsInformation(input2); boolean toKnown = (to != Double.MAX_VALUE); - double incr = computeBoundsInformation(input3); boolean incrKnown = (incr != Double.MAX_VALUE); if( fromKnown && toKnown && incr == 1) { @@ -370,6 +370,14 @@ public class DataGenOp extends MultiThreadedHop setDim2(1); _incr = incr; } + + //leverage high-probability information of output + if( getDim1() == -1 && getParent().size() == 1 ) { + Hop p = getParent().get(0); + p.refreshSizeInformation(); + setDim1((HopRewriteUtils.isTernary(p, OpOp3.CTABLE) + && p.getDim1() >= 0 ) ? p.getDim1() : -1); + } } else if (_op == OpOpDG.TIME ) { setDim1(0); diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index 58c1f564ee..e6387b429c 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.OpOp3; import org.apache.sysds.common.Types.OpOpDG; +import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.ParamBuiltinOp; import org.apache.sysds.common.Types.ReOrgOp; import org.apache.sysds.common.Types.ValueType; @@ -346,8 +347,8 @@ public class TernaryOp extends MultiThreadedHop setLineNumbers(plusmult); setLops(plusmult); - if( _op==OpOp3.IFELSE && getInput(0).getDataType().isScalar() ) - setRequiresRecompile(); //good chance of removing ops + if( _op==OpOp3.IFELSE && HopRewriteUtils.isData(getInput(0), OpOpData.TRANSIENTREAD, DataType.SCALAR)) + setRequiresRecompile(); //good chance of removing ops via literal replacements + rewrites } @Override @@ -518,7 +519,7 @@ public class TernaryOp extends MultiThreadedHop // additional condition: when execType=CP and additional dimension inputs // are provided (and those values are unknown at initial compile time). setRequiresRecompileIfNecessary(); - if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) + if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown() && _etype == ExecType.CP && _dimInputsPresent) { setRequiresRecompile(); } @@ -572,10 +573,8 @@ public class TernaryOp extends MultiThreadedHop // if output dimensions are provided, update _dim1 and _dim2 if( getInput().size() >= 5 ) { - if( getInput().get(3) instanceof LiteralOp ) - setDim1( HopRewriteUtils.getIntValueSafe((LiteralOp)getInput().get(3)) ); - if( getInput().get(4) instanceof LiteralOp ) - setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp)getInput().get(4)) ); + refreshRowsParameterInformation(getInput(3)); + refreshColsParameterInformation(getInput(4)); } } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index d10a43e810..210a9f152b 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -1133,6 +1133,10 @@ public class HopRewriteUtils { return hop instanceof DataOp && ((DataOp)hop).getOp()==type; } + public static boolean isData(Hop hop, OpOpData type, DataType dt) { + return isData(hop, type) && hop.getDataType()==dt; + } + public static boolean isBinaryMatrixColVectorOperation(Hop hop) { return hop instanceof BinaryOp && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index fdf23a7b62..fc9e60419c 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -135,14 +135,14 @@ public class ParameterizedBuiltin extends Lop case LOWER_TRI: { sb.append( "lowertri" ); sb.append( OPERAND_DELIMITOR ); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } case UPPER_TRI: { sb.append( "uppertri" ); sb.append( OPERAND_DELIMITOR ); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } @@ -178,25 +178,25 @@ public class ParameterizedBuiltin extends Lop case PARAMSERV: { sb.append(_operation.name().toLowerCase()); //opcode sb.append(OPERAND_DELIMITOR); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } case AUTODIFF: { sb.append("autoDiff"); //opcode sb.append(OPERAND_DELIMITOR); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } case LIST: { sb.append("nvlist"); //opcode sb.append(OPERAND_DELIMITOR); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } case TOSTRING: { sb.append("toString"); //opcode sb.append(OPERAND_DELIMITOR); - sb.append(compileGenericParamMap(_inputParams)); + compileGenericParamMap(sb, _inputParams); break; } @@ -243,8 +243,7 @@ public class ParameterizedBuiltin extends Lop return sb.toString(); } - private static String compileGenericParamMap(HashMap<String, Lop> params) { - StringBuilder sb = InstructionUtils.getStringBuilder(); + private static void compileGenericParamMap(StringBuilder sb, HashMap<String, Lop> params) { for ( Entry<String, Lop> e : params.entrySet() ) { sb.append(e.getKey()); sb.append(NAME_VALUE_SEPARATOR); @@ -254,7 +253,5 @@ public class ParameterizedBuiltin extends Lop sb.append( e.getValue().prepScalarLabel() ); sb.append(OPERAND_DELIMITOR); } - - return sb.toString(); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index b287b4957c..4df1a7052e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -359,6 +359,7 @@ public class ParForProgramBlock extends ForProgramBlock protected long _ID = -1; protected int _IDPrefix = -1; protected boolean _monitorReport = false; + protected int _numRuns = -1; // local parworker data protected HashMap<Long,ArrayList<ProgramBlock>> _pbcache = null; @@ -422,6 +423,7 @@ public class ParForProgramBlock extends ForProgramBlock //created profiling report after parfor exec _monitorReport = _monitor; + _numRuns = 0; //materialized meta data (reused for all invocations) _hasFunctions = ProgramRecompiler.containsAtLeastOneFunction(this); @@ -574,7 +576,7 @@ public class ParForProgramBlock extends ForProgramBlock public void execute(ExecutionContext ec) { ParForStatementBlock sb = (ParForStatementBlock)getStatementBlock(); - + // evaluate from, to, incr only once (assumption: known at for entry) ScalarObject from0 = executePredicateInstructions(1, _fromInstructions, ec, false); ScalarObject to0 = executePredicateInstructions(2, _toInstructions, ec, false); @@ -608,7 +610,7 @@ public class ParForProgramBlock extends ForProgramBlock /////// if( _optMode != POptMode.NONE ) { OptimizationWrapper.setLogLevel(_optLogLevel); //set optimizer log level - OptimizationWrapper.optimize(_optMode, sb, this, ec, _monitor); //core optimize + OptimizationWrapper.optimize(_optMode, sb, this, ec, _monitor, _numRuns); //core optimize } /////// @@ -704,6 +706,7 @@ public class ParForProgramBlock extends ForProgramBlock //print profiling report (only if top-level parfor because otherwise in parallel context) if( _monitorReport ) LOG.info("\n"+StatisticMonitor.createReport()); + _numRuns ++; //reset flags/modifications made by optimizer //TODO reset of hop parallelism constraint (e.g., ba+*) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java index 05681982da..6dbc952172 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java @@ -89,8 +89,9 @@ public class OptimizationWrapper * @param pb parfor program block * @param ec execution context * @param monitor ? + * @param numRuns number of optimizations performed so far */ - public static void optimize( POptMode type, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor ) + public static void optimize( POptMode type, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor, int numRuns ) { Timing time = new Timing(true); @@ -103,7 +104,7 @@ public class OptimizationWrapper double cm = InfrastructureAnalyzer.getCmMax() * OptimizerUtils.MEM_UTIL_FACTOR; //execute optimizer - optimize( type, ck, cm, sb, pb, ec, monitor ); + optimize( type, ck, cm, sb, pb, ec, monitor, numRuns ); double timeVal = time.stop(); LOG.debug("ParFOR Opt: Finished optimization for PARFOR("+pb.getID()+") in "+timeVal+"ms."); @@ -118,7 +119,7 @@ public class OptimizationWrapper } @SuppressWarnings("unused") - private static void optimize( POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor ) + private static void optimize( POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor, int numRuns ) { Timing time = new Timing(true); @@ -161,7 +162,6 @@ public class OptimizationWrapper throw new DMLRuntimeException(ex); } - //program rewrites (e.g., constant folding, branch removal) according to replaced literals try { ProgramRewriter rewriter = createProgramRewriterWithRuleSets(); @@ -225,7 +225,7 @@ public class OptimizationWrapper LOG.trace("ParFOR Opt: Created cost estimator ("+cmtype+")"); //core optimize - opt.optimize(sb, pb, tree, est, ec); + opt.optimize(sb, pb, tree, est, numRuns, ec); LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false)); //assert plan correctness @@ -238,7 +238,7 @@ public class OptimizationWrapper throw new DMLRuntimeException("Failed to check program correctness.", ex); } } - + long ltime = (long) time.stop(); LOG.trace("ParFOR Opt: Optimized plan in "+ltime+"ms."); if( DMLScript.STATISTICS ) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java index d04f324abd..fdf7f23556 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java @@ -62,10 +62,11 @@ public abstract class Optimizer * @param pb parfor program block * @param plan complete plan of a top-level parfor * @param est cost estimator + * @param numRuns * @param ec execution context * @return true if plan changed, false otherwise */ - public abstract boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec); + public abstract boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, int numRuns, ExecutionContext ec); public abstract PlanInputType getPlanInputType(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java index c039f932cb..0665357ad1 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java @@ -75,7 +75,7 @@ public class OptimizerConstrained extends OptimizerRuleBased { * (no use of sb, direct change of pb). */ @Override - public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec) + public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, int numRuns, ExecutionContext ec) { LOG.debug("--- "+getOptMode()+" OPTIMIZER -------"); _cost = est; @@ -179,7 +179,8 @@ public class OptimizerConstrained extends OptimizerRuleBased { super.rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec); //rewrite 17: checkpoint injection for parfor loop body - super.rewriteInjectSparkLoopCheckpointing( pn ); + if( numRuns <= 0 ) //only on first + super.rewriteInjectSparkLoopCheckpointing( pn ); //rewrite 18: repartition read-only inputs for zipmm super.rewriteInjectSparkRepartition( pn, ec.getVariables() ); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 13b846ea02..fb97c6a15c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -193,7 +193,7 @@ public class OptimizerRuleBased extends Optimizer { * (no use of sb, direct change of pb). */ @Override - public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec) + public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, int numRuns, ExecutionContext ec) { LOG.debug("--- "+getOptMode()+" OPTIMIZER -------"); @@ -298,7 +298,8 @@ public class OptimizerRuleBased extends Optimizer { rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec); //rewrite 17: checkpoint injection for parfor loop body - rewriteInjectSparkLoopCheckpointing( pn ); + if( numRuns <= 0 ) //only on first + rewriteInjectSparkLoopCheckpointing( pn ); //rewrite 18: repartition read-only inputs for zipmm rewriteInjectSparkRepartition( pn, ec.getVariables() );
