[SYSTEMML-1716] Fix memory-efficiency spark rmm matrix mult operations This patch improves the memory efficiency and performance of our spark rmm matrix multiplication operator:
(1) Shallow block replication: Since we only need modified keys, we now use shallow block replicas, which avoids (potentially a large number of) unnecessary block copies before the join. (2) Adjusted number of join partitions: The replication leads to increased data sizes of intermediates. Using the number of input partitions causes potentially very large partitions after the join which creates memory pressure and limits the degree of parallelism. We now also parallelize the matrix multiply over IJK blocks instead of IK output blocks. (3) Hash function: In order to enable IJK parallelization, we now use a hash function over IJK instead of IK. Along with that, this patch also replaces the historically used hash function for triple indexes with our default hash function because the old function was known to have issues with certain matrix shapes. Finally, there is additional potential to remove the shuffle for final aggregation (which is not addressed here). Down the road, we should use to alternative rmm implementations - if the number of output blocks already saturates the degree of parallelism and is known to have limited memory requirements we should parallelize over IK and provide a custom partitioner that can be carried through from the join to the final aggregation (note the existing hadoop partitioner of triple indexes is not used automatically). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f7a18fa7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f7a18fa7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f7a18fa7 Branch: refs/heads/master Commit: f7a18fa7b2606bda76aa5466b0b0d9a4d19f2565 Parents: a625c64 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 17 18:07:48 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 17 18:07:48 2017 -0700 ---------------------------------------------------------------------- .../hops/codegen/template/PlanSelection.java | 2 +- .../instructions/spark/BinarySPInstruction.java | 5 ++- .../instructions/spark/RmmSPInstruction.java | 41 ++++++++++++++------ .../spark/utils/RDDAggregateUtils.java | 6 +-- .../instructions/spark/utils/SparkUtils.java | 6 +++ .../runtime/matrix/data/MatrixIndexes.java | 2 +- .../runtime/matrix/data/TripleIndexes.java | 2 +- .../runtime/util/LongLongDoubleHashMap.java | 2 +- .../sysml/runtime/util/UtilFunctions.java | 25 ++++++++++-- 9 files changed, 66 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java index 85126da..80ff725 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java @@ -109,7 +109,7 @@ public abstract class PlanSelection } @Override public int hashCode() { - return UtilFunctions.longlongHashCode( + return UtilFunctions.longHashCode( _hopID, (_type!=null)?_type.hashCode():0); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.java index ac7ef4b..ef9da23 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.java @@ -187,7 +187,7 @@ public abstract class BinarySPInstruction extends ComputationSPInstruction sec.addLineageRDD(output.getName(), rddVar); } - protected void updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sec, boolean checkCommonDim) + protected MatrixCharacteristics updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sec, boolean checkCommonDim) throws DMLRuntimeException { MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName()); @@ -203,7 +203,8 @@ public abstract class BinarySPInstruction extends ComputationSPInstruction else { mcOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock()); } - } + } + return mcOut; } protected void updateBinaryAppendOutputMatrixCharacteristics(SparkExecutionContext sec, boolean cbind) http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java index 1fe025a..e1eb724 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java @@ -29,14 +29,17 @@ import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -83,26 +86,42 @@ public class RmmSPInstruction extends BinarySPInstruction MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( input2.getName() ); JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); + MatrixCharacteristics mcOut = updateBinaryMMOutputMatrixCharacteristics(sec, true); //execute Spark RMM instruction - //step 1: prepare join keys (w/ replication), i/j/k + //step 1: prepare join keys (w/ shallow replication), i/j/k JavaPairRDD<TripleIndexes,MatrixBlock> tmp1 = in1.flatMapToPair( - new RmmReplicateFunction(mc2.getCols(), mc2.getColsPerBlock(), true)); + new RmmReplicateFunction(mc2.getCols(), mc2.getColsPerBlock(), true)); JavaPairRDD<TripleIndexes,MatrixBlock> tmp2 = in2.flatMapToPair( - new RmmReplicateFunction(mc1.getRows(), mc1.getRowsPerBlock(), false)); + new RmmReplicateFunction(mc1.getRows(), mc1.getRowsPerBlock(), false)); //step 2: join prepared datasets, multiply, and aggregate - JavaPairRDD<MatrixIndexes,MatrixBlock> out = - tmp1.join( tmp2 ) //join by result block - .mapToPair( new RmmMultiplyFunction() ); //do matrix multiplication - out = RDDAggregateUtils.sumByKeyStable(out, false); //aggregation per result block + int numPartJoin = Math.max(getNumJoinPartitions(mc1, mc2), + SparkExecutionContext.getDefaultParallelism(true)); + int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut); + JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1 + .join( tmp2, numPartJoin ) //join by result block + .mapToPair( new RmmMultiplyFunction() ); //do matrix multiplication + out = RDDAggregateUtils.sumByKeyStable(out, //aggregation per result block + numPartOut, false); //put output block into symbol table (no lineage because single block) - updateBinaryMMOutputMatrixCharacteristics(sec, true); sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); sec.addLineageRDD(output.getName(), input2.getName()); } + + private static int getNumJoinPartitions(MatrixCharacteristics mc1, MatrixCharacteristics mc2) { + if( !mc1.dimsKnown() || !mc2.dimsKnown() ) + SparkExecutionContext.getDefaultParallelism(true); + //compute data size of replicated inputs + double hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize(); + double matrix1PSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc1) + * ((long) Math.ceil((double)mc2.getCols()/mc2.getColsPerBlock())); + double matrix2PSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc2) + * ((long) Math.ceil((double)mc1.getRows()/mc1.getRowsPerBlock())); + return (int) Math.max(Math.ceil((matrix1PSize+matrix2PSize)/hdfsBlockSize), 1); + } private static class RmmReplicateFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, TripleIndexes, MatrixBlock> { @@ -136,8 +155,7 @@ public class RmmSPInstruction extends BinarySPInstruction long k = ixIn.getColumnIndex(); for( long j=1; j<=numBlocks; j++ ) { TripleIndexes tmptix = new TripleIndexes(i, j, k); - MatrixBlock tmpblk = new MatrixBlock(blkIn); - ret.add( new Tuple2<TripleIndexes, MatrixBlock>(tmptix, tmpblk) ); + ret.add( new Tuple2<TripleIndexes, MatrixBlock>(tmptix, blkIn) ); } } else // RHS MATRIX @@ -147,8 +165,7 @@ public class RmmSPInstruction extends BinarySPInstruction long j = ixIn.getColumnIndex(); for( long i=1; i<=numBlocks; i++ ) { TripleIndexes tmptix = new TripleIndexes(i, j, k); - MatrixBlock tmpblk = new MatrixBlock(blkIn); - ret.add( new Tuple2<TripleIndexes, MatrixBlock>(tmptix, tmpblk) ); + ret.add( new Tuple2<TripleIndexes, MatrixBlock>(tmptix, blkIn) ); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java index 2759f7f..5ed5df8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java @@ -80,14 +80,14 @@ public class RDDAggregateUtils public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions, boolean deepCopyCombiner) { - //stable sum of blocks per key, by passing correction blocks along with aggregates + //stable sum of blocks per key, by passing correction blocks along with aggregates JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner), new MergeSumBlockValueFunction(), new MergeSumBlockCombinerFunction(), numPartitions ); - //strip-off correction blocks from - JavaPairRDD<MatrixIndexes, MatrixBlock> out = + //strip-off correction blocks from + JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp.mapValues( new ExtractMatrixBlock() ); //return the aggregate rdd http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java index 947c817..4e7866d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java @@ -124,6 +124,12 @@ public class SparkUtils public static int getNumPreferredPartitions(MatrixCharacteristics mc, JavaPairRDD<?,?> in) { if( !mc.dimsKnown(true) && in != null ) return in.getNumPartitions(); + return getNumPreferredPartitions(mc); + } + + public static int getNumPreferredPartitions(MatrixCharacteristics mc) { + if( !mc.dimsKnown() ) + return SparkExecutionContext.getDefaultParallelism(true); double hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize(); double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc); return (int) Math.max(Math.ceil(matrixPSize/hdfsBlockSize), 1); http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixIndexes.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixIndexes.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixIndexes.java index 7e68b61..3f52a5e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixIndexes.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixIndexes.java @@ -100,7 +100,7 @@ public class MatrixIndexes implements WritableComparable<MatrixIndexes>, RawComp @Override public int hashCode() { - return UtilFunctions.longlongHashCode(_row, _col); + return UtilFunctions.longHashCode(_row, _col); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/matrix/data/TripleIndexes.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/TripleIndexes.java b/src/main/java/org/apache/sysml/runtime/matrix/data/TripleIndexes.java index ac517be..cbcc9f8 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/TripleIndexes.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/TripleIndexes.java @@ -123,7 +123,7 @@ public class TripleIndexes implements WritableComparable<TripleIndexes>, Seriali @Override public int hashCode() { - return UtilFunctions.longHashCode((first<<32)+(second<<16)+third+UtilFunctions.ADD_PRIME1)%UtilFunctions.DIVIDE_PRIME; + return UtilFunctions.longHashCode(first, second, third); } public static class Comparator implements RawComparator<TripleIndexes> http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/util/LongLongDoubleHashMap.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/LongLongDoubleHashMap.java b/src/main/java/org/apache/sysml/runtime/util/LongLongDoubleHashMap.java index f3b331a..d8c8011 100644 --- a/src/main/java/org/apache/sysml/runtime/util/LongLongDoubleHashMap.java +++ b/src/main/java/org/apache/sysml/runtime/util/LongLongDoubleHashMap.java @@ -112,7 +112,7 @@ public class LongLongDoubleHashMap } private static int hash(long key1, long key2) { - int h = UtilFunctions.longlongHashCode(key1, key2); + int h = UtilFunctions.longHashCode(key1, key2); // This function ensures that hashCodes that differ only by // constant multiples at each bit position have a bounded http://git-wip-us.apache.org/repos/asf/systemml/blob/f7a18fa7/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java index c15ace1..8a62476 100644 --- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java @@ -43,8 +43,8 @@ public class UtilFunctions public static final long ADD_PRIME1 = 99991; public static final int DIVIDE_PRIME = 1405695061; - public static int longHashCode(long v) { - return (int)(v^(v>>>32)); + public static int longHashCode(long key1) { + return (int)(key1^(key1>>>32)); } /** @@ -55,13 +55,30 @@ public class UtilFunctions * @param key2 second long key * @return hash code */ - public static int longlongHashCode(long key1, long key2) { + public static int longHashCode(long key1, long key2) { //basic hash mixing of two longs hashes (similar to //Arrays.hashCode(long[]) but w/o array creation/copy) - int h = (int)(key1 ^ (key1 >>> 32)); + int h = 31 + (int)(key1 ^ (key1 >>> 32)); return h*31 + (int)(key2 ^ (key2 >>> 32)); } + /** + * Returns the hash code for a long-long-long triple. This is the default + * hash function for the keys of a distributed matrix in MR/Spark. + * + * @param key1 first long key + * @param key2 second long key + * @param key3 third long key + * @return hash code + */ + public static int longHashCode(long key1, long key2, long key3) { + //basic hash mixing of three longs hashes (similar to + //Arrays.hashCode(long[]) but w/o array creation/copy) + int h1 = 31 + (int)(key1 ^ (key1 >>> 32)); + int h2 = h1*31 + (int)(key2 ^ (key2 >>> 32)); + return h2*31 + (int)(key3 ^ (key3 >>> 32)); + } + public static int nextIntPow2( int in ) { int expon = (in==0) ? 0 : 32-Integer.numberOfLeadingZeros(in-1); long pow2 = (long) Math.pow(2, expon);
