[MINOR] Reduced recompilation overhead and various cleanups This minor patch reduced the recompilation overhead by (1) better memoization across size expressions of operations, and (2) the removal of unnecessary list copies on specific rewrites and operator cloning.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/fe5ed594 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/fe5ed594 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/fe5ed594 Branch: refs/heads/master Commit: fe5ed59474e14214a14ce0b5dbd5f1d162821a62 Parents: 3737542 Author: Matthias Boehm <[email protected]> Authored: Thu Jan 11 20:46:06 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Jan 11 20:46:06 2018 -0800 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 89 ++++++++------------ .../apache/sysml/hops/recompile/Recompiler.java | 59 ++++++------- .../RewriteSplitDagDataDependentOperators.java | 2 +- .../runtime/compress/ColGroupUncompressed.java | 5 +- .../ColumnGroupPartitionerBinPacking.java | 6 +- .../compress/cocode/PlanningCoCoder.java | 3 +- .../runtime/compress/utils/ConverterUtils.java | 18 +--- .../controlprogram/ParForProgramBlock.java | 3 +- .../sysml/runtime/util/UtilFunctions.java | 8 ++ 9 files changed, 85 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 948394b..0f11175 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1648,88 +1648,73 @@ public abstract class Hop implements ParseInfo return ret; } - - public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars ) - { - long size = computeSizeInformation(input, vars); - - //always set the computed size not just if known (positive) in order to allow - //recompile with unknowns to reset sizes (otherwise potential for incorrect results) - setDim1( size ); + + //always set the computed size not just if known (positive) in order to allow + //recompile with unknowns to reset sizes (otherwise potential for incorrect results) + + public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars ) { + setDim1(computeSizeInformation(input, vars)); } - public void refreshColsParameterInformation( Hop input, LocalVariableMap vars ) - { - long size = computeSizeInformation(input, vars); - - //always set the computed size not just if known (positive) in order to allow - //recompile with unknowns to reset sizes (otherwise potential for incorrect results) - setDim2( size ); + public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo ) { + setDim1(computeSizeInformation(input, vars, memo)); + } + + public void refreshColsParameterInformation( Hop input, LocalVariableMap vars ) { + setDim2(computeSizeInformation(input, vars)); + } + + public void refreshColsParameterInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo ) { + setDim2(computeSizeInformation(input, vars, memo)); } - public long computeSizeInformation( Hop input, LocalVariableMap vars ) + public long computeSizeInformation( Hop input, LocalVariableMap vars ) { + return computeSizeInformation(input, vars, new HashMap<Long,Long>()); + } + + public long computeSizeInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo ) { long ret = -1; - - try - { - long tmp = OptimizerUtils.rEvalSimpleLongExpression(input, new HashMap<Long,Long>(), vars); + try { + long tmp = OptimizerUtils.rEvalSimpleLongExpression(input, memo, vars); if( tmp!=Long.MAX_VALUE ) ret = tmp; } - catch(Exception ex) - { + catch(Exception ex) { LOG.error("Failed to compute size information.", ex); ret = -1; } - return ret; } - public double computeBoundsInformation( Hop input ) - { + public double computeBoundsInformation( Hop input ) { double ret = Double.MAX_VALUE; - - try - { + try { ret = OptimizerUtils.rEvalSimpleDoubleExpression(input, new HashMap<Long, Double>()); } - catch(Exception ex) - { + catch(Exception ex) { LOG.error("Failed to compute bounds information.", ex); ret = Double.MAX_VALUE; } - return ret; } - /** - * Computes bound information for sequence if possible, otherwise returns - * Double.MAX_VALUE - * - * @param input high-level operator - * @param vars local variable map - * @return bounds information - */ - public double computeBoundsInformation( Hop input, LocalVariableMap vars ) - { + public double computeBoundsInformation( Hop input, LocalVariableMap vars ) { + return computeBoundsInformation(input, vars, new HashMap<Long, Double>()); + } + + public double computeBoundsInformation( Hop input, LocalVariableMap vars, HashMap<Long, Double> memo ) { double ret = Double.MAX_VALUE; - - try - { - ret = OptimizerUtils.rEvalSimpleDoubleExpression(input, new HashMap<Long, Double>(), vars); - + try { + ret = OptimizerUtils.rEvalSimpleDoubleExpression(input, memo, vars); } - catch(Exception ex) - { + catch(Exception ex) { LOG.error("Failed to compute bounds information.", ex); ret = Double.MAX_VALUE; } - return ret; } - /** * Compute worst case estimate for size expression based on worst-case * statistics of inputs. Limited set of supported operations in comparison @@ -1860,8 +1845,8 @@ public abstract class Hop implements ParseInfo _updateType = that._updateType; //no copy of lops (regenerated) - _parent = new ArrayList<>(); - _input = new ArrayList<>(); + _parent = new ArrayList<>(_parent.size()); + _input = new ArrayList<>(_input.size()); _lops = null; _etype = that._etype; http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java index 360e261..8751daa 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java @@ -740,31 +740,22 @@ public class Recompiler return ret; } - private static Hop rDeepCopyHopsDag( Hop hops, HashMap<Long,Hop> memo ) + private static Hop rDeepCopyHopsDag( Hop hop, HashMap<Long,Hop> memo ) throws CloneNotSupportedException { - Hop ret = memo.get(hops.getHopID()); + Hop ret = memo.get(hop.getHopID()); //create clone if required - if( ret == null ) - { - ret = (Hop) hops.clone(); - ArrayList<Hop> tmp = new ArrayList<>(); + if( ret == null ) { + ret = (Hop) hop.clone(); - //create new childs - for( Hop in : hops.getInput() ) - { - Hop newIn = rDeepCopyHopsDag(in, memo); - tmp.add(newIn); - } - //modify references of childs - for( Hop in : tmp ) - { - ret.getInput().add(in); - in.getParent().add(ret); + //create new childs and modify references + for( Hop in : hop.getInput() ) { + Hop tmp = rDeepCopyHopsDag(in, memo); + ret.getInput().add(tmp); + tmp.getParent().add(ret); } - - memo.put(hops.getHopID(), ret); + memo.put(hop.getHopID(), ret); } return ret; @@ -1548,8 +1539,9 @@ public class Recompiler int ix1 = params.get(DataExpression.RAND_ROWS); int ix2 = params.get(DataExpression.RAND_COLS); //update rows/cols by evaluating simple expression of literals, nrow, ncol, scalars, binaryops - d.refreshRowsParameterInformation(d.getInput().get(ix1), vars); - d.refreshColsParameterInformation(d.getInput().get(ix2), vars); + HashMap<Long, Long> memo = new HashMap<>(); + d.refreshRowsParameterInformation(d.getInput().get(ix1), vars, memo); + d.refreshColsParameterInformation(d.getInput().get(ix2), vars, memo); updatedSizeExpr = initUnknown & d.dimsKnown(); } else if ( d.getOp() == DataGenMethod.SEQ ) @@ -1558,9 +1550,10 @@ public class Recompiler int ix1 = params.get(Statement.SEQ_FROM); int ix2 = params.get(Statement.SEQ_TO); int ix3 = params.get(Statement.SEQ_INCR); - double from = d.computeBoundsInformation(d.getInput().get(ix1), vars); - double to = d.computeBoundsInformation(d.getInput().get(ix2), vars); - double incr = d.computeBoundsInformation(d.getInput().get(ix3), vars); + HashMap<Long, Double> memo = new HashMap<>(); + double from = d.computeBoundsInformation(d.getInput().get(ix1), vars, memo); + double to = d.computeBoundsInformation(d.getInput().get(ix2), vars, memo); + double incr = d.computeBoundsInformation(d.getInput().get(ix3), vars, memo); //special case increment if ( from!=Double.MAX_VALUE && to!=Double.MAX_VALUE ) { @@ -1584,8 +1577,9 @@ public class Recompiler { ReorgOp d = (ReorgOp) hop; boolean initUnknown = !d.dimsKnown(); - d.refreshRowsParameterInformation(d.getInput().get(1), vars); - d.refreshColsParameterInformation(d.getInput().get(2), vars); + HashMap<Long, Long> memo = new HashMap<>(); + d.refreshRowsParameterInformation(d.getInput().get(1), vars, memo); + d.refreshColsParameterInformation(d.getInput().get(2), vars, memo); updatedSizeExpr = initUnknown & d.dimsKnown(); } //update size expression for indexing according to symbol table entries @@ -1597,10 +1591,11 @@ public class Recompiler Hop input4 = iop.getInput().get(3); //inpColL Hop input5 = iop.getInput().get(4); //inpColU boolean initUnknown = !iop.dimsKnown(); - double rl = iop.computeBoundsInformation(input2, vars); - double ru = iop.computeBoundsInformation(input3, vars); - double cl = iop.computeBoundsInformation(input4, vars); - double cu = iop.computeBoundsInformation(input5, vars); + HashMap<Long, Double> memo = new HashMap<>(); + double rl = iop.computeBoundsInformation(input2, vars, memo); + double ru = iop.computeBoundsInformation(input3, vars, memo); + double cl = iop.computeBoundsInformation(input4, vars, memo); + double cu = iop.computeBoundsInformation(input5, vars, memo); if( rl!=Double.MAX_VALUE && ru!=Double.MAX_VALUE ) iop.setDim1( (long)(ru-rl+1) ); if( cl!=Double.MAX_VALUE && cu!=Double.MAX_VALUE ) @@ -1727,14 +1722,14 @@ public class Recompiler { ret = false; break; - } + } } //default case (known dimensions) else { long nnz = mo.getNnz(); double sp = OptimizerUtils.getSparsity(rows, cols, nnz); - double mem = MatrixBlock.estimateSizeInMemory(rows, cols, sp); + double mem = MatrixBlock.estimateSizeInMemory(rows, cols, sp); if( !OptimizerUtils.isValidCPDimensions(rows, cols) || !OptimizerUtils.isValidCPMatrixSize(rows, cols, sp) || mem >= OptimizerUtils.getLocalMemBudget() ) http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 6fbd953..a55ea41 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -77,7 +77,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //DAG splits not required for forced single node if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || !HopRewriteUtils.isLastLevelStatementBlock(sb) ) - return new ArrayList<>(Arrays.asList(sb)); + return Arrays.asList(sb); ArrayList<StatementBlock> ret = new ArrayList<>(); http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/compress/ColGroupUncompressed.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/ColGroupUncompressed.java b/src/main/java/org/apache/sysml/runtime/compress/ColGroupUncompressed.java index 1032d89..2cc5b4b 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupUncompressed.java @@ -23,7 +23,6 @@ package org.apache.sysml.runtime.compress; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -128,7 +127,7 @@ public class ColGroupUncompressed extends ColGroup * compressed columns to subsume. Must contain at least one * element. */ - public ColGroupUncompressed(ArrayList<ColGroup> groupsToDecompress) + public ColGroupUncompressed(List<ColGroup> groupsToDecompress) { super(mergeColIndices(groupsToDecompress), groupsToDecompress.get(0)._numRows); @@ -186,7 +185,7 @@ public class ColGroupUncompressed extends ColGroup * UncompressedColGroup * @return a merged set of column indices across all those groups */ - private static int[] mergeColIndices(ArrayList<ColGroup> groupsToDecompress) + private static int[] mergeColIndices(List<ColGroup> groupsToDecompress) { // Pass 1: Determine number of columns int sz = 0; http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/compress/cocode/ColumnGroupPartitionerBinPacking.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/cocode/ColumnGroupPartitionerBinPacking.java b/src/main/java/org/apache/sysml/runtime/compress/cocode/ColumnGroupPartitionerBinPacking.java index 908da20..d365e71 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/cocode/ColumnGroupPartitionerBinPacking.java +++ b/src/main/java/org/apache/sysml/runtime/compress/cocode/ColumnGroupPartitionerBinPacking.java @@ -87,12 +87,12 @@ public class ColumnGroupPartitionerBinPacking extends ColumnGroupPartitioner assigned = true; break; } } - + //create new bin at end of list if( !assigned ) { - bins.add(new ArrayList<>(Arrays.asList(items[i]))); + bins.add(Arrays.asList(items[i])); binWeights.add(BIN_CAPACITY-itemWeights[i]); - } + } } return bins; http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/compress/cocode/PlanningCoCoder.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/cocode/PlanningCoCoder.java b/src/main/java/org/apache/sysml/runtime/compress/cocode/PlanningCoCoder.java index 298571d..80a5504 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/cocode/PlanningCoCoder.java +++ b/src/main/java/org/apache/sysml/runtime/compress/cocode/PlanningCoCoder.java @@ -135,8 +135,7 @@ public class PlanningCoCoder if( LOG.isTraceEnabled() ) LOG.trace("Cocoding: process "+singletonGroups.length); - List<PlanningCoCodingGroup> workset = - new ArrayList<>(Arrays.asList(singletonGroups)); + List<PlanningCoCodingGroup> workset = Arrays.asList(singletonGroups); //establish memo table for extracted column groups PlanningMemoTable memo = new PlanningMemoTable(); http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/compress/utils/ConverterUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/utils/ConverterUtils.java b/src/main/java/org/apache/sysml/runtime/compress/utils/ConverterUtils.java index bf15d37..2873275 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/utils/ConverterUtils.java +++ b/src/main/java/org/apache/sysml/runtime/compress/utils/ConverterUtils.java @@ -19,7 +19,6 @@ package org.apache.sysml.runtime.compress.utils; -import java.util.ArrayList; import java.util.Arrays; import org.apache.sysml.runtime.compress.ColGroup; @@ -68,18 +67,9 @@ public class ConverterUtils return DataConverter.convertToDoubleVector(vector, false); } - public static MatrixBlock getUncompressedColBlock( ColGroup group ) - { - MatrixBlock ret = null; - if( group instanceof ColGroupUncompressed ) { - ret = ((ColGroupUncompressed) group).getData(); - } - else { - ArrayList<ColGroup> tmpGroup = new ArrayList<>(Arrays.asList(group)); - ColGroupUncompressed decompressedCols = new ColGroupUncompressed(tmpGroup); - ret = decompressedCols.getData(); - } - - return ret; + public static MatrixBlock getUncompressedColBlock( ColGroup group ) { + return (group instanceof ColGroupUncompressed) ? + ((ColGroupUncompressed) group).getData() : + new ColGroupUncompressed(Arrays.asList(group)).getData(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index a0c3ef1..291fb9e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -29,6 +29,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.IntStream; import org.apache.hadoop.fs.FileSystem; @@ -1301,7 +1302,7 @@ public class ParForProgramBlock extends ForProgramBlock throws CacheException { ParForStatementBlock sb = (ParForStatementBlock)getStatementBlock(); - HashSet<String> blacklist = new HashSet<>(Arrays.asList(blacklistNames)); + Set<String> blacklist = UtilFunctions.asSet(blacklistNames); if( LIVEVAR_AWARE_EXPORT && sb != null) { http://git-wip-us.apache.org/repos/asf/systemml/blob/fe5ed594/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java index b876300..200b2eb 100644 --- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java @@ -606,4 +606,12 @@ public class UtilFunctions ret.add(element); return ret; } + + @SafeVarargs + public static <T> Set<T> asSet(T... inputs) { + Set<T> ret = new HashSet<>(); + for( T element : inputs ) + ret.add(element); + return ret; + } }
