Repository: systemml Updated Branches: refs/heads/master 1cc72d797 -> 686e3831d
[SYSTEMML-2279] Performance spark unary aggregates (empty block filter) This patch improves the performance of sparse-safe spark unary aggregate operations such as sum(X) by filtering empty blocks before the actual unary aggregate operations. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/18cc576d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/18cc576d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/18cc576d Branch: refs/heads/master Commit: 18cc576dcad813a059322d9f0bb83208ed0bb646 Parents: 1cc72d7 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 26 20:39:24 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 27 00:03:12 2018 -0700 ---------------------------------------------------------------------- .../instructions/spark/AggregateUnarySPInstruction.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/18cc576d/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java index 860384f..266db7b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction; import org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction; +import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -91,6 +92,9 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction { //perform aggregation if necessary and put output into symbol table if( _aggtype == SparkAggType.SINGLE_BLOCK ) { + if( auop.sparseSafe ) + out = out.filter(new FilterNonEmptyBlocksFunction()); + JavaRDD<MatrixBlock> out2 = out.map( new RDDUAggFunction2(auop, mc.getRowsPerBlock(), mc.getColsPerBlock())); MatrixBlock out3 = RDDAggregateUtils.aggStable(out2, aggop); @@ -111,7 +115,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction { } else if( _aggtype == SparkAggType.MULTI_BLOCK ) { //in case of multi-block aggregation, we always keep the correction - out = out.mapToPair(new RDDUAggFunction(auop, mc.getRowsPerBlock(), mc.getColsPerBlock())); + out = out.mapToPair(new RDDUAggFunction(auop, mc.getRowsPerBlock(), mc.getColsPerBlock())); out = RDDAggregateUtils.aggByKeyStable(out, aggop, false); //drop correction after aggregation if required (aggbykey creates @@ -124,7 +128,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction { updateUnaryAggOutputMatrixCharacteristics(sec, auop.indexFn); sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); - } + } } private static class RDDUAggFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> @@ -164,7 +168,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction { /** * Similar to RDDUAggFunction but single output block. */ - private static class RDDUAggFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> + public static class RDDUAggFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> { private static final long serialVersionUID = 2672082409287856038L;
