Repository: incubator-systemml
Updated Branches:
  refs/heads/master d127dfa2d -> 30f72e83f


[SYSTEMML-1315] Robustness spark mapmm (handle large output partitions)

Our spark mapmm (broadcast-based matrix multiply) instruction already
repartitions the input RDD for shapes that potentially create large
outputs in order to control the size of output partitions and with that
avoid failures due to Spark's 2GB limitation per partition. However,
with out perftest-stratstats script we encountered special cases of
dense-sparse matrix multiplications, where this was not enough and we
ran unnecessarily into OOMs. In detail, the sparse matrix was the
broadcast (because it was smaller) but the dense input had a small
number of blocks which limited the maximum number of partitions. 

This patch extends the mapmm instruction to consider these cases and
potentially flip broadcast and rdd at runtime if this is beneficial for
controlling the output sizes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/073b7024
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/073b7024
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/073b7024

Branch: refs/heads/master
Commit: 073b70241547a840518a785054cd319711d2d2a3
Parents: d127dfa
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Mar 17 17:38:50 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Mar 18 14:34:39 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/lops/MapMult.java     |  12 +-
 .../mr/AggregateBinaryInstruction.java          |   8 +-
 .../instructions/spark/MapmmSPInstruction.java  | 113 +++++++++++--------
 3 files changed, 80 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/lops/MapMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/MapMult.java 
b/src/main/java/org/apache/sysml/lops/MapMult.java
index 1cb8e46..6597136 100644
--- a/src/main/java/org/apache/sysml/lops/MapMult.java
+++ b/src/main/java/org/apache/sysml/lops/MapMult.java
@@ -38,9 +38,19 @@ public class MapMult extends Lop
                LEFT,
                LEFT_PART;
                
-               public boolean isRightCache(){
+               public boolean isRight() {
                        return (this == RIGHT || this == RIGHT_PART);
                }
+               
+               public CacheType getFlipped() {
+                       switch( this ) {
+                               case RIGHT: return LEFT;
+                               case RIGHT_PART: return LEFT_PART;
+                               case LEFT: return RIGHT;
+                               case LEFT_PART: return RIGHT_PART;
+                               default: return null;
+                       }
+               }
        }
        
        private CacheType _cacheType = null;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java
index 63f72bd..e9a023a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java
@@ -119,7 +119,7 @@ public class AggregateBinaryInstruction extends 
BinaryMRInstructionBase implemen
        @Override //IDistributedCacheConsumer
        public boolean isDistCacheOnlyIndex( String inst, byte index )
        {
-               return _cacheType.isRightCache() ? 
+               return _cacheType.isRight() ? 
                                (index==input2 && index!=input1) : 
                                (index==input1 && index!=input2);
        }
@@ -127,7 +127,7 @@ public class AggregateBinaryInstruction extends 
BinaryMRInstructionBase implemen
        @Override //IDistributedCacheConsumer
        public void addDistCacheIndex( String inst, ArrayList<Byte> indexes )
        {
-               indexes.add( _cacheType.isRightCache() ? input2 : input1 );
+               indexes.add( _cacheType.isRight() ? input2 : input1 );
        }
        
        @Override
@@ -142,7 +142,7 @@ public class AggregateBinaryInstruction extends 
BinaryMRInstructionBase implemen
                if ( _opcode.equals(MapMult.OPCODE) ) 
                {
                        //check empty inputs (data for different instructions)
-                       if( _cacheType.isRightCache() ? in1==null : in2==null )
+                       if( _cacheType.isRight() ? in1==null : in2==null )
                                return;
                        
                        // one of the input is from distributed cache.
@@ -190,7 +190,7 @@ public class AggregateBinaryInstruction extends 
BinaryMRInstructionBase implemen
        {
                boolean removeOutput = true;
                
-               if( _cacheType.isRightCache() )
+               if( _cacheType.isRight() )
                {
                        DistributedCacheInput dcInput = 
MRBaseForCommonInstructions.dcValues.get(input2);
                        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/073b7024/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
index 5baffb0..c1fdea6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
@@ -55,19 +55,14 @@ import 
org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 
-/**
- * TODO: we need to reason about multiple broadcast variables for chains of 
mapmults (sum of operations until cleanup) 
- * 
- */
 public class MapmmSPInstruction extends BinarySPInstruction 
-{
-       
+{      
        private CacheType _type = null;
        private boolean _outputEmpty = true;
        private SparkAggType _aggtype;
        
        public MapmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, CacheType type, 
-                                           boolean outputEmpty, SparkAggType 
aggtype, String opcode, String istr )
+                       boolean outputEmpty, SparkAggType aggtype, String 
opcode, String istr )
        {
                super(op, in1, in2, out, opcode, istr);
                _sptype = SPINSTRUCTION_TYPE.MAPMM;
@@ -83,23 +78,19 @@ public class MapmmSPInstruction extends BinarySPInstruction
                String parts[] = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
 
-               if ( opcode.equalsIgnoreCase(MapMult.OPCODE)) 
-               {
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand out = new CPOperand(parts[3]);
-                       CacheType type = CacheType.valueOf(parts[4]);
-                       boolean outputEmpty = Boolean.parseBoolean(parts[5]);
-                       SparkAggType aggtype = SparkAggType.valueOf(parts[6]);
-                       
-                       AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
-                       AggregateBinaryOperator aggbin = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
-                       return new MapmmSPInstruction(aggbin, in1, in2, out, 
type, outputEmpty, aggtype, opcode, str);
-               } 
-               else {
+               if(!opcode.equalsIgnoreCase(MapMult.OPCODE)) 
                        throw new 
DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + 
opcode);
-               }
+                       
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand out = new CPOperand(parts[3]);
+               CacheType type = CacheType.valueOf(parts[4]);
+               boolean outputEmpty = Boolean.parseBoolean(parts[5]);
+               SparkAggType aggtype = SparkAggType.valueOf(parts[6]);
                
+               AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateBinaryOperator aggbin = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+               return new MapmmSPInstruction(aggbin, in1, in2, out, type, 
outputEmpty, aggtype, opcode, str);          
        }
        
        @Override
@@ -108,14 +99,37 @@ public class MapmmSPInstruction extends BinarySPInstruction
        {       
                SparkExecutionContext sec = (SparkExecutionContext)ec;
                
-               String rddVar = (_type==CacheType.LEFT) ? input2.getName() : 
input1.getName();
-               String bcastVar = (_type==CacheType.LEFT) ? input1.getName() : 
input2.getName();
+               CacheType type = _type;
+               String rddVar = type.isRight() ? input1.getName() : 
input2.getName();
+               String bcastVar = type.isRight() ? input2.getName() : 
input1.getName();
                MatrixCharacteristics mcRdd = 
sec.getMatrixCharacteristics(rddVar);
                MatrixCharacteristics mcBc = 
sec.getMatrixCharacteristics(bcastVar);
                
+               //get input rdd
+               JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = 
sec.getBinaryBlockRDDHandleForVariable(rddVar);
+               
+               //investigate if a repartitioning - including a potential flip 
of broadcast and rdd 
+               //inputs - is required to ensure moderately sized output 
partitions (2GB limitation)
+               if( requiresFlatMapFunction(type, mcBc) &&
+                       requiresRepartitioning(type, mcRdd, mcBc, 
in1.getNumPartitions()) ) 
+               {
+                       int numParts = getNumRepartitioning(type, mcRdd, mcBc);
+                       int numParts2 = getNumRepartitioning(type.getFlipped(), 
mcBc, mcRdd);
+                       if( numParts2 > numParts ) { //flip required
+                               type = type.getFlipped();
+                               rddVar = type.isRight() ? input1.getName() : 
input2.getName();
+                               bcastVar = type.isRight() ? input2.getName() : 
input1.getName();
+                               mcRdd = sec.getMatrixCharacteristics(rddVar);
+                               mcBc = sec.getMatrixCharacteristics(bcastVar);
+                               in1 = 
sec.getBinaryBlockRDDHandleForVariable(rddVar);
+                               LOG.warn("Mapmm: Switching rdd ('"+bcastVar+"') 
and broadcast ('"+rddVar+"') inputs "
+                                               + "for repartitioning because 
this allows better control of output partition "
+                                               + "sizes ("+numParts+" < 
"+numParts2+").");     
+                       }
+               }
+               
                //get inputs
-               JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = 
sec.getBinaryBlockRDDHandleForVariable( rddVar );
-               PartitionedBroadcast<MatrixBlock> in2 = 
sec.getBroadcastForVariable( bcastVar ); 
+               PartitionedBroadcast<MatrixBlock> in2 = 
sec.getBroadcastForVariable(bcastVar); 
                
                //empty input block filter
                if( !_outputEmpty )
@@ -124,7 +138,7 @@ public class MapmmSPInstruction extends BinarySPInstruction
                //execute mapmm and aggregation if necessary and put output 
into symbol table
                if( _aggtype == SparkAggType.SINGLE_BLOCK )
                {
-                       JavaRDD<MatrixBlock> out = in1.map(new 
RDDMapMMFunction2(_type, in2));
+                       JavaRDD<MatrixBlock> out = in1.map(new 
RDDMapMMFunction2(type, in2));
                        MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
                        
                        //put output block into symbol table (no lineage 
because single block)
@@ -134,15 +148,19 @@ public class MapmmSPInstruction extends 
BinarySPInstruction
                else //MULTI_BLOCK or NONE
                {
                        JavaPairRDD<MatrixIndexes,MatrixBlock> out = null;
-                       if( requiresFlatMapFunction(_type, mcBc) ) {
-                               if( requiresRepartitioning(_type, mcRdd, mcBc, 
in1.getNumPartitions()) )
-                                       in1 = 
in1.repartition(getNumRepartitioning(_type, mcRdd, mcBc, 
in1.getNumPartitions()));
-                               out = in1.flatMapToPair( new 
RDDFlatMapMMFunction(_type, in2) );
+                       if( requiresFlatMapFunction(type, mcBc) ) {
+                               if( requiresRepartitioning(type, mcRdd, mcBc, 
in1.getNumPartitions()) ) {
+                                       int numParts = 
getNumRepartitioning(type, mcRdd, mcBc);
+                                       LOG.warn("Mapmm: Repartition input rdd 
'"+rddVar+"' from "+in1.getNumPartitions()+" to "
+                                                       +numParts+" partitions 
to satisfy size restrictions of output partitions.");
+                                       in1 = in1.repartition(numParts);
+                               }
+                               out = in1.flatMapToPair( new 
RDDFlatMapMMFunction(type, in2) );
                        }
-                       else if( preservesPartitioning(mcRdd, _type) )
-                               out = in1.mapPartitionsToPair(new 
RDDMapMMPartitionFunction(_type, in2), true);
+                       else if( preservesPartitioning(mcRdd, type) )
+                               out = in1.mapPartitionsToPair(new 
RDDMapMMPartitionFunction(type, in2), true);
                        else
-                               out = in1.mapToPair( new 
RDDMapMMFunction(_type, in2) );
+                               out = in1.mapToPair( new RDDMapMMFunction(type, 
in2) );
                        
                        //empty output block filter
                        if( !_outputEmpty )
@@ -216,10 +234,9 @@ public class MapmmSPInstruction extends BinarySPInstruction
         * @param type cache type
         * @param mcRdd rdd matrix characteristics
         * @param mcBc ?
-        * @param numPartitions number of partitions
         * @return number of target partitions for repartitioning
         */
-       private static int getNumRepartitioning( CacheType type, 
MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc, int numPartitions ) {
+       private static int getNumRepartitioning( CacheType type, 
MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc ) {
                boolean isLeft = (type == CacheType.LEFT);
                long sizeOutput = 
(OptimizerUtils.estimatePartitionedSizeExactSparsity(isLeft?mcBc.getRows():mcRdd.getRows(),
                                isLeft?mcRdd.getCols():mcBc.getCols(), 
isLeft?mcBc.getRowsPerBlock():mcRdd.getRowsPerBlock(),
@@ -232,9 +249,9 @@ public class MapmmSPInstruction extends BinarySPInstruction
        {
                private static final long serialVersionUID = 
8197406787010296291L;
 
-               private CacheType _type = null;
-               private AggregateBinaryOperator _op = null;
-               private PartitionedBroadcast<MatrixBlock> _pbc = null;
+               private final CacheType _type;
+               private final AggregateBinaryOperator _op;
+               private final PartitionedBroadcast<MatrixBlock> _pbc;
                
                public RDDMapMMFunction( CacheType type, 
PartitionedBroadcast<MatrixBlock> binput )
                {
@@ -287,9 +304,9 @@ public class MapmmSPInstruction extends BinarySPInstruction
        {
                private static final long serialVersionUID = 
-2753453898072910182L;
                
-               private CacheType _type = null;
-               private AggregateBinaryOperator _op = null;
-               private PartitionedBroadcast<MatrixBlock> _pbc = null;
+               private final CacheType _type;
+               private final AggregateBinaryOperator _op;
+               private final PartitionedBroadcast<MatrixBlock> _pbc;
                
                public RDDMapMMFunction2( CacheType type, 
PartitionedBroadcast<MatrixBlock> binput )
                {
@@ -333,9 +350,9 @@ public class MapmmSPInstruction extends BinarySPInstruction
        {
                private static final long serialVersionUID = 
1886318890063064287L;
        
-               private CacheType _type = null;
-               private AggregateBinaryOperator _op = null;
-               private PartitionedBroadcast<MatrixBlock> _pbc = null;
+               private final CacheType _type;
+               private final AggregateBinaryOperator _op;
+               private final PartitionedBroadcast<MatrixBlock> _pbc;
                
                public RDDMapMMPartitionFunction( CacheType type, 
PartitionedBroadcast<MatrixBlock> binput )
                {
@@ -399,9 +416,9 @@ public class MapmmSPInstruction extends BinarySPInstruction
        {
                private static final long serialVersionUID = 
-6076256569118957281L;
                
-               private CacheType _type = null;
-               private AggregateBinaryOperator _op = null;
-               private PartitionedBroadcast<MatrixBlock> _pbc = null;
+               private final CacheType _type;
+               private final AggregateBinaryOperator _op;
+               private final PartitionedBroadcast<MatrixBlock> _pbc;
                
                public RDDFlatMapMMFunction( CacheType type, 
PartitionedBroadcast<MatrixBlock> binput )
                {

Reply via email to