Repository: systemml
Updated Branches:
  refs/heads/master 340c99d26 -> ee5f61307


[SYSTEMML-2174] Performance spark cpmm over ultra-sparse inputs

This patch improves the performance of spark cross-product matrix
multiply (cpmm) operators with single-block outputs to prune empty
blocks before and after the join. Furthermore, we now also avoid
unnecessary deep copies of input blocks.

On an end-to-end scenario with large ultra-sparse matrices and repeated
cpmm executions this reduced the time spent in cpmm from 290s to 135s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ee5f6130
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ee5f6130
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ee5f6130

Branch: refs/heads/master
Commit: ee5f6130759f307e931a9336b6d74c8d1e153aae
Parents: 340c99d
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Wed Mar 7 18:40:37 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Mar 7 18:40:37 2018 -0800

----------------------------------------------------------------------
 .../instructions/spark/CpmmSPInstruction.java   | 93 +++++++++-----------
 1 file changed, 44 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ee5f6130/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
index 1d74edc..c51a6ee 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
@@ -32,6 +32,7 @@ import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -46,7 +47,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
  * 
  * NOTE: There is additional optimization potential by preventing aggregation 
for a single
  * block on the common dimension. However, in such a case we would never pick 
cpmm because
- * this would result in a degree of parallelism of 1.  
+ * this would result in a degree of parallelism of 1.
  * 
  */
 public class CpmmSPInstruction extends BinarySPInstruction {
@@ -57,106 +58,100 @@ public class CpmmSPInstruction extends 
BinarySPInstruction {
                _aggtype = aggtype;
        }
 
-       public static CpmmSPInstruction parseInstruction( String str ) 
+       public static CpmmSPInstruction parseInstruction( String str )
                throws DMLRuntimeException 
        {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
-
-               if ( opcode.equalsIgnoreCase("cpmm")) {
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand out = new CPOperand(parts[3]);
-                       AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
-                       AggregateBinaryOperator aggbin = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
-                       SparkAggType aggtype = SparkAggType.valueOf(parts[4]);
-                       
-                       return new CpmmSPInstruction(aggbin, in1, in2, out, 
aggtype, opcode, str);
-               } 
-               else {
-                       throw new 
DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown 
opcode " + opcode);
-               }
                
+               if ( !opcode.equalsIgnoreCase("cpmm"))
+                       throw new 
DMLRuntimeException("CpmmSPInstruction.parseInstruction(): Unknown opcode " + 
opcode);
+               
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand out = new CPOperand(parts[3]);
+               AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateBinaryOperator aggbin = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+               SparkAggType aggtype = SparkAggType.valueOf(parts[4]);
+               return new CpmmSPInstruction(aggbin, in1, in2, out, aggtype, 
opcode, str);
        }
        
        @Override
-       public void processInstruction(ExecutionContext ec) 
+       public void processInstruction(ExecutionContext ec)
                throws DMLRuntimeException
-       {       
+       {
                SparkExecutionContext sec = (SparkExecutionContext)ec;
                
                //get rdd inputs
                JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = 
sec.getBinaryBlockRDDHandleForVariable( input1.getName() );
                JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = 
sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
+               if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
+                       //prune empty blocks of ultra-sparse matrices
+                       in1 = in1.filter(new FilterNonEmptyBlocksFunction());
+                       in2 = in2.filter(new FilterNonEmptyBlocksFunction());
+               }
                
                //process core cpmm matrix multiply 
                JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new 
CpmmIndexFunction(true));
                JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new 
CpmmIndexFunction(false));
                JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
-                                  .join(tmp2)                              // 
join over common dimension
-                                  .mapToPair(new CpmmMultiplyFunction());  // 
compute block multiplications
-                                  
-               //process cpmm aggregation and handle outputs                   
        
-               if( _aggtype == SparkAggType.SINGLE_BLOCK )
-               {
+                       .join(tmp2)                              // join over 
common dimension
+                       .mapToPair(new CpmmMultiplyFunction());  // compute 
block multiplications
+               
+               //process cpmm aggregation and handle outputs
+               if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
+                       //prune empty blocks and aggregate all results
+                       out = out.filter(new FilterNonEmptyBlocksFunction());
                        MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
                        
                        //put output block into symbol table (no lineage 
because single block)
                        //this also includes implicit maintenance of matrix 
characteristics
-                       sec.setMatrixOutput(output.getName(), out2, 
getExtendedOpcode());       
+                       sec.setMatrixOutput(output.getName(), out2, 
getExtendedOpcode());
                }
-               else //DEFAULT: MULTI_BLOCK
-               {
-                       out = RDDAggregateUtils.sumByKeyStable(out, false); 
+               else { //DEFAULT: MULTI_BLOCK
+                       out = RDDAggregateUtils.sumByKeyStable(out, false);
                        
                        //put output RDD handle into symbol table
                        sec.setRDDHandleForVariable(output.getName(), out);
                        sec.addLineageRDD(output.getName(), input1.getName());
-                       sec.addLineageRDD(output.getName(), input2.getName());  
        
+                       sec.addLineageRDD(output.getName(), input2.getName());
                        
                        //update output statistics if not inferred
                        updateBinaryMMOutputMatrixCharacteristics(sec, true);
                }
        }
 
-       private static class CpmmIndexFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue>  
+       private static class CpmmIndexFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue>
        {
                private static final long serialVersionUID = 
-1187183128301671162L;
-
-               private boolean _left = false;
+               private final boolean _left;
                
                public CpmmIndexFunction( boolean left ) {
                        _left = left;
                }
                
                @Override
-               public Tuple2<Long, IndexedMatrixValue> 
call(Tuple2<MatrixIndexes, MatrixBlock> arg0) 
-                       throws Exception 
-               {
-                       IndexedMatrixValue value = new IndexedMatrixValue();
-                       value.set(arg0._1(), new MatrixBlock(arg0._2()));
-                       
+               public Tuple2<Long, IndexedMatrixValue> 
call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
+                       IndexedMatrixValue value = new 
IndexedMatrixValue(arg0._1(), arg0._2());
                        Long key = _left ? arg0._1.getColumnIndex() : 
arg0._1.getRowIndex();
                        return new Tuple2<>(key, value);
-               }       
+               }
        }
 
-       private static class CpmmMultiplyFunction implements 
PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue,IndexedMatrixValue>>, 
MatrixIndexes, MatrixBlock> 
+       private static class CpmmMultiplyFunction implements 
PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue,IndexedMatrixValue>>, 
MatrixIndexes, MatrixBlock>
        {
                private static final long serialVersionUID = 
-2009255629093036642L;
-               
                private AggregateBinaryOperator _op = null;
-               
-               public CpmmMultiplyFunction()
-               {
-                       AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
-                       _op = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
-               }
 
                @Override
                public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<Long, 
Tuple2<IndexedMatrixValue, IndexedMatrixValue>> arg0)
-                       throws Exception 
+                       throws Exception
                {
+                       if( _op == null ) { //lazy operator construction
+                               AggregateOperator agg = new 
AggregateOperator(0, Plus.getPlusFnObject());
+                               _op = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+                       }
+                       
                        MatrixBlock blkIn1 = 
(MatrixBlock)arg0._2()._1().getValue();
                        MatrixBlock blkIn2 = 
(MatrixBlock)arg0._2()._2().getValue();
                        MatrixIndexes ixOut = new MatrixIndexes();
@@ -166,7 +161,7 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                        blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, 
blkOut, _op);
                        
                        //return target block
-                       
ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(), 
+                       
ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(),
                                arg0._2()._2().getIndexes().getColumnIndex());
                        return new Tuple2<>( ixOut, blkOut );
                }

Reply via email to