[SYSTEMML-2281] Performance spark sumByKey incr block aggregation This patch improves the performance of the very common spark sumByKey primitives as used for many matrix multiplications and other operations with global aggregation. We now avoid the unnecessary creation of dense correction blocks, which greatly reduces GC overhead for ultra-sparse scenarios.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3b359c39 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3b359c39 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3b359c39 Branch: refs/heads/master Commit: 3b359c39029e26cd188fd64f370d00eb102adcf8 Parents: 0bd08f2 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 26 20:59:28 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 27 00:03:15 2018 -0700 ---------------------------------------------------------------------- .../spark/utils/RDDAggregateUtils.java | 54 +++++++++++++++----- .../sysml/runtime/io/ReaderBinaryBlock.java | 2 +- .../sysml/runtime/matrix/data/CM_N_COVCell.java | 2 +- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 25 +++++---- .../sysml/runtime/matrix/data/MatrixCell.java | 2 +- .../sysml/runtime/matrix/data/MatrixValue.java | 9 ++-- .../matrix/data/OperationsOnMatrixValues.java | 10 +++- 7 files changed, 69 insertions(+), 35 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java index 0101e26..23b6ad9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java @@ -26,13 +26,16 @@ import org.apache.spark.api.java.function.Function2; import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.KahanPlus; +import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.KahanObject; +import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction.RDDUAggFunction2; import org.apache.sysml.runtime.instructions.spark.data.CorrMatrixBlock; import org.apache.sysml.runtime.instructions.spark.data.RowMatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; +import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; /** * Collection of utility methods for aggregating binary block rdds. As a general @@ -82,25 +85,30 @@ public class RDDAggregateUtils { //stable sum of blocks per key, by passing correction blocks along with aggregates JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = - in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner), - new MergeSumBlockValueFunction(), - new MergeSumBlockCombinerFunction(), numPartitions ); + in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner), + new MergeSumBlockValueFunction(deepCopyCombiner), + new MergeSumBlockCombinerFunction(deepCopyCombiner), numPartitions ); //strip-off correction blocks from JavaPairRDD<MatrixIndexes, MatrixBlock> out = - tmp.mapValues( new ExtractMatrixBlock() ); + tmp.mapValues( new ExtractMatrixBlock() ); //return the aggregate rdd return out; } - public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in ) + + public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in ) { + return sumCellsByKeyStable(in, in.getNumPartitions()); + } + + public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in, int numParts ) { //stable sum of blocks per key, by passing correction blocks along with aggregates - JavaPairRDD<MatrixIndexes, KahanObject> tmp = - in.combineByKey( new CreateCellCombinerFunction(), - new MergeSumCellValueFunction(), - new MergeSumCellCombinerFunction() ); + JavaPairRDD<MatrixIndexes, KahanObject> tmp = + in.combineByKey( new CreateCellCombinerFunction(), + new MergeSumCellValueFunction(), + new MergeSumCellCombinerFunction(), numParts); //strip-off correction blocks from JavaPairRDD<MatrixIndexes, Double> out = @@ -166,6 +174,12 @@ public class RDDAggregateUtils return out; } + public static double max(JavaPairRDD<MatrixIndexes, MatrixBlock> in) { + AggregateUnaryOperator auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax"); + MatrixBlock tmp = aggStable(in.map(new RDDUAggFunction2(auop, -1, -1)), auop.aggOp); + return tmp.quickGetValue(0, 0); + } + /** * Merges disjoint data of all blocks per key. * @@ -258,6 +272,12 @@ public class RDDAggregateUtils private AggregateOperator _op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE); + private final boolean _deep; + + public MergeSumBlockValueFunction(boolean deep) { + _deep = deep; + } + @Override public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception @@ -270,12 +290,12 @@ public class RDDAggregateUtils MatrixBlock corr = arg0.getCorrection(); //correction block allocation on demand - if( corr == null ) + if( corr == null && !arg1.isEmptyBlock(false) ) corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false); //aggregate other input and maintain corrections //(existing value and corr are used in place) - OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, _op, false); + OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, _op, false, _deep); return arg0.set(value, corr); } } @@ -285,6 +305,11 @@ public class RDDAggregateUtils private static final long serialVersionUID = 7664941774566119853L; private AggregateOperator _op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE); + private final boolean _deep; + + public MergeSumBlockCombinerFunction(boolean deep) { + _deep = deep; + } @Override public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) @@ -297,15 +322,16 @@ public class RDDAggregateUtils //correction block allocation on demand (but use second if exists) if( corr == null ) { - corr = (arg1.getCorrection()!=null)?arg1.getCorrection(): + corr = (arg1.getCorrection()!=null) ? arg1.getCorrection() : + value2.isEmptyBlock(false) || (!_deep && value1.isEmptyBlock(false)) ? null : new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false); } //aggregate other input and maintain corrections //(existing value and corr are used in place) - OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, _op, false); + OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, _op, false, _deep); return arg0.set(value1, corr); - } + } } private static class CreateBlockCombinerFunction implements Function<MatrixBlock, MatrixBlock> http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/io/ReaderBinaryBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/io/ReaderBinaryBlock.java b/src/main/java/org/apache/sysml/runtime/io/ReaderBinaryBlock.java index 9461ca1..f7a5147 100644 --- a/src/main/java/org/apache/sysml/runtime/io/ReaderBinaryBlock.java +++ b/src/main/java/org/apache/sysml/runtime/io/ReaderBinaryBlock.java @@ -113,7 +113,7 @@ public class ReaderBinaryBlock extends MatrixReader //where ultra-sparse deserialization only reuses CSR blocks MatrixBlock value = new MatrixBlock(brlen, bclen, sparse); if( sparse ) { - value.allocateAndResetSparseRowsBlock(true, SparseBlock.Type.CSR); + value.allocateAndResetSparseBlock(true, SparseBlock.Type.CSR); value.getSparseBlock().allocate(0, brlen*bclen); } return value; http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java index a2b17ec..10956ad 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java @@ -98,7 +98,7 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable @Override public void incrementalAggregate(AggregateOperator aggOp, - MatrixValue correction, MatrixValue newWithCorrection) { + MatrixValue correction, MatrixValue newWithCorrection, boolean deep) { throw new RuntimeException("operation not supported for CM_N_COVCell"); } http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 8de89c3..5dfddbd 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -122,12 +122,23 @@ public class LibMatrixAgg * @param in input matrix * @param aggVal current aggregate values (in/out) * @param aggCorr current aggregate correction (in/out) + * @param deep deep copy flag */ - public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) { + public static void aggregateBinaryMatrix(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr, boolean deep) { //Timing time = new Timing(true); //boolean saggVal = aggVal.sparse, saggCorr = aggCorr.sparse; //long naggVal = aggVal.nonZeros, naggCorr = aggCorr.nonZeros; + //common empty block handling + if( in.isEmptyBlock(false) ) { + return; + } + if( !deep && aggVal.isEmptyBlock(false) ) { + //shallow copy without correction allocation + aggVal.copyShallow(in); + return; + } + //ensure MCSR instead of CSR for update in-place if( aggVal.sparse && aggVal.isAllocated() && aggVal.getSparseBlock() instanceof SparseBlockCSR ) aggVal.sparseBlock = SparseBlockFactory.copySparseBlock(SparseBlock.Type.MCSR, aggVal.getSparseBlock(), true); @@ -977,9 +988,6 @@ public class LibMatrixAgg } private static void aggregateBinaryMatrixAllDense(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) { - if( in.denseBlock==null || in.isEmptyBlock(false) ) - return; - //allocate output arrays (if required) aggVal.allocateDenseBlock(); //should always stay in dense aggCorr.allocateDenseBlock(); //should always stay in dense @@ -1011,9 +1019,6 @@ public class LibMatrixAgg } private static void aggregateBinaryMatrixSparseDense(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) { - if( in.isEmptyBlock(false) ) - return; - //allocate output arrays (if required) aggVal.allocateDenseBlock(); //should always stay in dense aggCorr.allocateDenseBlock(); //should always stay in dense @@ -1055,9 +1060,6 @@ public class LibMatrixAgg } private static void aggregateBinaryMatrixSparseGeneric(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) { - if( in.isEmptyBlock(false) ) - return; - SparseBlock a = in.getSparseBlock(); KahanObject buffer1 = new KahanObject(0, 0); @@ -1095,9 +1097,6 @@ public class LibMatrixAgg } private static void aggregateBinaryMatrixDenseGeneric(MatrixBlock in, MatrixBlock aggVal, MatrixBlock aggCorr) { - if( in.denseBlock==null || in.isEmptyBlock(false) ) - return; - final int m = in.rlen; final int n = in.clen; http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java index aa1a01d..d51fd74 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java @@ -266,7 +266,7 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria @Override public void incrementalAggregate(AggregateOperator aggOp, - MatrixValue correction, MatrixValue newWithCorrection) { + MatrixValue correction, MatrixValue newWithCorrection, boolean deep) { throw new DMLRuntimeException("MatrixCell.incrementalAggregate should never be called"); } http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java index 05f0634..82b09e0 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java @@ -126,9 +126,12 @@ public abstract class MatrixValue implements WritableComparable int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP); public abstract MatrixValue unaryOperations(UnaryOperator op, MatrixValue result); - - public abstract void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, - MatrixValue newWithCorrection); + + public void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection) { + incrementalAggregate(aggOp, correction, newWithCorrection, true); + } + + public abstract void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection, boolean deep); public abstract void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection); http://git-wip-us.apache.org/repos/asf/systemml/blob/3b359c39/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java index 0e77b8e..bc4e969 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java @@ -196,12 +196,18 @@ public class OperationsOnMatrixValues } public static void incrementalAggregation(MatrixValue valueAgg, MatrixValue correction, MatrixValue valueAdd, - AggregateOperator op, boolean imbededCorrection) + AggregateOperator op, boolean imbededCorrection) { + incrementalAggregation(valueAgg, correction, valueAdd, op, imbededCorrection, true); + } + + + public static void incrementalAggregation(MatrixValue valueAgg, MatrixValue correction, MatrixValue valueAdd, + AggregateOperator op, boolean imbededCorrection, boolean deep) { if(op.correctionExists) { if(!imbededCorrection || op.correctionLocation==CorrectionLocationType.NONE) - valueAgg.incrementalAggregate(op, correction, valueAdd); + valueAgg.incrementalAggregate(op, correction, valueAdd, deep); else valueAgg.incrementalAggregate(op, valueAdd); }
