Repository: incubator-systemml Updated Branches: refs/heads/master d127dfa2d -> 30f72e83f
[SYSTEMML-1315] Robustness spark mapmm (handle large output partitions) Our spark mapmm (broadcast-based matrix multiply) instruction already repartitions the input RDD for shapes that potentially create large outputs in order to control the size of output partitions and with that avoid failures due to Spark's 2GB limitation per partition. However, with out perftest-stratstats script we encountered special cases of dense-sparse matrix multiplications, where this was not enough and we ran unnecessarily into OOMs. In detail, the sparse matrix was the broadcast (because it was smaller) but the dense input had a small number of blocks which limited the maximum number of partitions. This patch extends the mapmm instruction to consider these cases and potentially flip broadcast and rdd at runtime if this is beneficial for controlling the output sizes. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/073b7024 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/073b7024 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/073b7024 Branch: refs/heads/master Commit: 073b70241547a840518a785054cd319711d2d2a3 Parents: d127dfa Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Mar 17 17:38:50 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sat Mar 18 14:34:39 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/lops/MapMult.java | 12 +- .../mr/AggregateBinaryInstruction.java | 8 +- .../instructions/spark/MapmmSPInstruction.java | 113 +++++++++++-------- 3 files changed, 80 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/lops/MapMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/MapMult.java b/src/main/java/org/apache/sysml/lops/MapMult.java index 1cb8e46..6597136 100644 --- a/src/main/java/org/apache/sysml/lops/MapMult.java +++ b/src/main/java/org/apache/sysml/lops/MapMult.java @@ -38,9 +38,19 @@ public class MapMult extends Lop LEFT, LEFT_PART; - public boolean isRightCache(){ + public boolean isRight() { return (this == RIGHT || this == RIGHT_PART); } + + public CacheType getFlipped() { + switch( this ) { + case RIGHT: return LEFT; + case RIGHT_PART: return LEFT_PART; + case LEFT: return RIGHT; + case LEFT_PART: return RIGHT_PART; + default: return null; + } + } } private CacheType _cacheType = null; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java index 63f72bd..e9a023a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java @@ -119,7 +119,7 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen @Override //IDistributedCacheConsumer public boolean isDistCacheOnlyIndex( String inst, byte index ) { - return _cacheType.isRightCache() ? + return _cacheType.isRight() ? (index==input2 && index!=input1) : (index==input1 && index!=input2); } @@ -127,7 +127,7 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen @Override //IDistributedCacheConsumer public void addDistCacheIndex( String inst, ArrayList<Byte> indexes ) { - indexes.add( _cacheType.isRightCache() ? input2 : input1 ); + indexes.add( _cacheType.isRight() ? input2 : input1 ); } @Override @@ -142,7 +142,7 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen if ( _opcode.equals(MapMult.OPCODE) ) { //check empty inputs (data for different instructions) - if( _cacheType.isRightCache() ? in1==null : in2==null ) + if( _cacheType.isRight() ? in1==null : in2==null ) return; // one of the input is from distributed cache. @@ -190,7 +190,7 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen { boolean removeOutput = true; - if( _cacheType.isRightCache() ) + if( _cacheType.isRight() ) { DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(input2); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java index 5baffb0..c1fdea6 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java @@ -55,19 +55,14 @@ import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.Operator; -/** - * TODO: we need to reason about multiple broadcast variables for chains of mapmults (sum of operations until cleanup) - * - */ public class MapmmSPInstruction extends BinarySPInstruction -{ - +{ private CacheType _type = null; private boolean _outputEmpty = true; private SparkAggType _aggtype; public MapmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, CacheType type, - boolean outputEmpty, SparkAggType aggtype, String opcode, String istr ) + boolean outputEmpty, SparkAggType aggtype, String opcode, String istr ) { super(op, in1, in2, out, opcode, istr); _sptype = SPINSTRUCTION_TYPE.MAPMM; @@ -83,23 +78,19 @@ public class MapmmSPInstruction extends BinarySPInstruction String parts[] = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; - if ( opcode.equalsIgnoreCase(MapMult.OPCODE)) - { - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand out = new CPOperand(parts[3]); - CacheType type = CacheType.valueOf(parts[4]); - boolean outputEmpty = Boolean.parseBoolean(parts[5]); - SparkAggType aggtype = SparkAggType.valueOf(parts[6]); - - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str); - } - else { + if(!opcode.equalsIgnoreCase(MapMult.OPCODE)) throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode); - } + + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + CacheType type = CacheType.valueOf(parts[4]); + boolean outputEmpty = Boolean.parseBoolean(parts[5]); + SparkAggType aggtype = SparkAggType.valueOf(parts[6]); + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str); } @Override @@ -108,14 +99,37 @@ public class MapmmSPInstruction extends BinarySPInstruction { SparkExecutionContext sec = (SparkExecutionContext)ec; - String rddVar = (_type==CacheType.LEFT) ? input2.getName() : input1.getName(); - String bcastVar = (_type==CacheType.LEFT) ? input1.getName() : input2.getName(); + CacheType type = _type; + String rddVar = type.isRight() ? input1.getName() : input2.getName(); + String bcastVar = type.isRight() ? input2.getName() : input1.getName(); MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar); MatrixCharacteristics mcBc = sec.getMatrixCharacteristics(bcastVar); + //get input rdd + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar); + + //investigate if a repartitioning - including a potential flip of broadcast and rdd + //inputs - is required to ensure moderately sized output partitions (2GB limitation) + if( requiresFlatMapFunction(type, mcBc) && + requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions()) ) + { + int numParts = getNumRepartitioning(type, mcRdd, mcBc); + int numParts2 = getNumRepartitioning(type.getFlipped(), mcBc, mcRdd); + if( numParts2 > numParts ) { //flip required + type = type.getFlipped(); + rddVar = type.isRight() ? input1.getName() : input2.getName(); + bcastVar = type.isRight() ? input2.getName() : input1.getName(); + mcRdd = sec.getMatrixCharacteristics(rddVar); + mcBc = sec.getMatrixCharacteristics(bcastVar); + in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar); + LOG.warn("Mapmm: Switching rdd ('"+bcastVar+"') and broadcast ('"+rddVar+"') inputs " + + "for repartitioning because this allows better control of output partition " + + "sizes ("+numParts+" < "+numParts2+")."); + } + } + //get inputs - JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar ); - PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable( bcastVar ); + PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar); //empty input block filter if( !_outputEmpty ) @@ -124,7 +138,7 @@ public class MapmmSPInstruction extends BinarySPInstruction //execute mapmm and aggregation if necessary and put output into symbol table if( _aggtype == SparkAggType.SINGLE_BLOCK ) { - JavaRDD<MatrixBlock> out = in1.map(new RDDMapMMFunction2(_type, in2)); + JavaRDD<MatrixBlock> out = in1.map(new RDDMapMMFunction2(type, in2)); MatrixBlock out2 = RDDAggregateUtils.sumStable(out); //put output block into symbol table (no lineage because single block) @@ -134,15 +148,19 @@ public class MapmmSPInstruction extends BinarySPInstruction else //MULTI_BLOCK or NONE { JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; - if( requiresFlatMapFunction(_type, mcBc) ) { - if( requiresRepartitioning(_type, mcRdd, mcBc, in1.getNumPartitions()) ) - in1 = in1.repartition(getNumRepartitioning(_type, mcRdd, mcBc, in1.getNumPartitions())); - out = in1.flatMapToPair( new RDDFlatMapMMFunction(_type, in2) ); + if( requiresFlatMapFunction(type, mcBc) ) { + if( requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions()) ) { + int numParts = getNumRepartitioning(type, mcRdd, mcBc); + LOG.warn("Mapmm: Repartition input rdd '"+rddVar+"' from "+in1.getNumPartitions()+" to " + +numParts+" partitions to satisfy size restrictions of output partitions."); + in1 = in1.repartition(numParts); + } + out = in1.flatMapToPair( new RDDFlatMapMMFunction(type, in2) ); } - else if( preservesPartitioning(mcRdd, _type) ) - out = in1.mapPartitionsToPair(new RDDMapMMPartitionFunction(_type, in2), true); + else if( preservesPartitioning(mcRdd, type) ) + out = in1.mapPartitionsToPair(new RDDMapMMPartitionFunction(type, in2), true); else - out = in1.mapToPair( new RDDMapMMFunction(_type, in2) ); + out = in1.mapToPair( new RDDMapMMFunction(type, in2) ); //empty output block filter if( !_outputEmpty ) @@ -216,10 +234,9 @@ public class MapmmSPInstruction extends BinarySPInstruction * @param type cache type * @param mcRdd rdd matrix characteristics * @param mcBc ? - * @param numPartitions number of partitions * @return number of target partitions for repartitioning */ - private static int getNumRepartitioning( CacheType type, MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc, int numPartitions ) { + private static int getNumRepartitioning( CacheType type, MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc ) { boolean isLeft = (type == CacheType.LEFT); long sizeOutput = (OptimizerUtils.estimatePartitionedSizeExactSparsity(isLeft?mcBc.getRows():mcRdd.getRows(), isLeft?mcRdd.getCols():mcBc.getCols(), isLeft?mcBc.getRowsPerBlock():mcRdd.getRowsPerBlock(), @@ -232,9 +249,9 @@ public class MapmmSPInstruction extends BinarySPInstruction { private static final long serialVersionUID = 8197406787010296291L; - private CacheType _type = null; - private AggregateBinaryOperator _op = null; - private PartitionedBroadcast<MatrixBlock> _pbc = null; + private final CacheType _type; + private final AggregateBinaryOperator _op; + private final PartitionedBroadcast<MatrixBlock> _pbc; public RDDMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) { @@ -287,9 +304,9 @@ public class MapmmSPInstruction extends BinarySPInstruction { private static final long serialVersionUID = -2753453898072910182L; - private CacheType _type = null; - private AggregateBinaryOperator _op = null; - private PartitionedBroadcast<MatrixBlock> _pbc = null; + private final CacheType _type; + private final AggregateBinaryOperator _op; + private final PartitionedBroadcast<MatrixBlock> _pbc; public RDDMapMMFunction2( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) { @@ -333,9 +350,9 @@ public class MapmmSPInstruction extends BinarySPInstruction { private static final long serialVersionUID = 1886318890063064287L; - private CacheType _type = null; - private AggregateBinaryOperator _op = null; - private PartitionedBroadcast<MatrixBlock> _pbc = null; + private final CacheType _type; + private final AggregateBinaryOperator _op; + private final PartitionedBroadcast<MatrixBlock> _pbc; public RDDMapMMPartitionFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) { @@ -399,9 +416,9 @@ public class MapmmSPInstruction extends BinarySPInstruction { private static final long serialVersionUID = -6076256569118957281L; - private CacheType _type = null; - private AggregateBinaryOperator _op = null; - private PartitionedBroadcast<MatrixBlock> _pbc = null; + private final CacheType _type; + private final AggregateBinaryOperator _op; + private final PartitionedBroadcast<MatrixBlock> _pbc; public RDDFlatMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) {