Repository: systemml Updated Branches: refs/heads/master 340c99d26 -> ee5f61307
[SYSTEMML-2174] Performance spark cpmm over ultra-sparse inputs This patch improves the performance of spark cross-product matrix multiply (cpmm) operators with single-block outputs to prune empty blocks before and after the join. Furthermore, we now also avoid unnecessary deep copies of input blocks. On an end-to-end scenario with large ultra-sparse matrices and repeated cpmm executions this reduced the time spent in cpmm from 290s to 135s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ee5f6130 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ee5f6130 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ee5f6130 Branch: refs/heads/master Commit: ee5f6130759f307e931a9336b6d74c8d1e153aae Parents: 340c99d Author: Matthias Boehm <mboe...@gmail.com> Authored: Wed Mar 7 18:40:37 2018 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Mar 7 18:40:37 2018 -0800 ---------------------------------------------------------------------- .../instructions/spark/CpmmSPInstruction.java | 93 +++++++++----------- 1 file changed, 44 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ee5f6130/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java index 1d74edc..c51a6ee 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java @@ -32,6 +32,7 @@ import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -46,7 +47,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; * * NOTE: There is additional optimization potential by preventing aggregation for a single * block on the common dimension. However, in such a case we would never pick cpmm because - * this would result in a degree of parallelism of 1. + * this would result in a degree of parallelism of 1. * */ public class CpmmSPInstruction extends BinarySPInstruction { @@ -57,106 +58,100 @@ public class CpmmSPInstruction extends BinarySPInstruction { _aggtype = aggtype; } - public static CpmmSPInstruction parseInstruction( String str ) + public static CpmmSPInstruction parseInstruction( String str ) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; - - if ( opcode.equalsIgnoreCase("cpmm")) { - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand out = new CPOperand(parts[3]); - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - SparkAggType aggtype = SparkAggType.valueOf(parts[4]); - - return new CpmmSPInstruction(aggbin, in1, in2, out, aggtype, opcode, str); - } - else { - throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode); - } + if ( !opcode.equalsIgnoreCase("cpmm")) + throw new DMLRuntimeException("CpmmSPInstruction.parseInstruction(): Unknown opcode " + opcode); + + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + SparkAggType aggtype = SparkAggType.valueOf(parts[4]); + return new CpmmSPInstruction(aggbin, in1, in2, out, aggtype, opcode, str); } @Override - public void processInstruction(ExecutionContext ec) + public void processInstruction(ExecutionContext ec) throws DMLRuntimeException - { + { SparkExecutionContext sec = (SparkExecutionContext)ec; //get rdd inputs JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); + if( _aggtype == SparkAggType.SINGLE_BLOCK ) { + //prune empty blocks of ultra-sparse matrices + in1 = in1.filter(new FilterNonEmptyBlocksFunction()); + in2 = in2.filter(new FilterNonEmptyBlocksFunction()); + } //process core cpmm matrix multiply JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new CpmmIndexFunction(true)); JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new CpmmIndexFunction(false)); JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1 - .join(tmp2) // join over common dimension - .mapToPair(new CpmmMultiplyFunction()); // compute block multiplications - - //process cpmm aggregation and handle outputs - if( _aggtype == SparkAggType.SINGLE_BLOCK ) - { + .join(tmp2) // join over common dimension + .mapToPair(new CpmmMultiplyFunction()); // compute block multiplications + + //process cpmm aggregation and handle outputs + if( _aggtype == SparkAggType.SINGLE_BLOCK ) { + //prune empty blocks and aggregate all results + out = out.filter(new FilterNonEmptyBlocksFunction()); MatrixBlock out2 = RDDAggregateUtils.sumStable(out); //put output block into symbol table (no lineage because single block) //this also includes implicit maintenance of matrix characteristics - sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode()); + sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode()); } - else //DEFAULT: MULTI_BLOCK - { - out = RDDAggregateUtils.sumByKeyStable(out, false); + else { //DEFAULT: MULTI_BLOCK + out = RDDAggregateUtils.sumByKeyStable(out, false); //put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); - sec.addLineageRDD(output.getName(), input2.getName()); + sec.addLineageRDD(output.getName(), input2.getName()); //update output statistics if not inferred updateBinaryMMOutputMatrixCharacteristics(sec, true); } } - private static class CpmmIndexFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue> + private static class CpmmIndexFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue> { private static final long serialVersionUID = -1187183128301671162L; - - private boolean _left = false; + private final boolean _left; public CpmmIndexFunction( boolean left ) { _left = left; } @Override - public Tuple2<Long, IndexedMatrixValue> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) - throws Exception - { - IndexedMatrixValue value = new IndexedMatrixValue(); - value.set(arg0._1(), new MatrixBlock(arg0._2())); - + public Tuple2<Long, IndexedMatrixValue> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception { + IndexedMatrixValue value = new IndexedMatrixValue(arg0._1(), arg0._2()); Long key = _left ? arg0._1.getColumnIndex() : arg0._1.getRowIndex(); return new Tuple2<>(key, value); - } + } } - private static class CpmmMultiplyFunction implements PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue,IndexedMatrixValue>>, MatrixIndexes, MatrixBlock> + private static class CpmmMultiplyFunction implements PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue,IndexedMatrixValue>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -2009255629093036642L; - private AggregateBinaryOperator _op = null; - - public CpmmMultiplyFunction() - { - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - } @Override public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>> arg0) - throws Exception + throws Exception { + if( _op == null ) { //lazy operator construction + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + } + MatrixBlock blkIn1 = (MatrixBlock)arg0._2()._1().getValue(); MatrixBlock blkIn2 = (MatrixBlock)arg0._2()._2().getValue(); MatrixIndexes ixOut = new MatrixIndexes(); @@ -166,7 +161,7 @@ public class CpmmSPInstruction extends BinarySPInstruction { blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, blkOut, _op); //return target block - ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(), + ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(), arg0._2()._2().getIndexes().getColumnIndex()); return new Tuple2<>( ixOut, blkOut ); }