Repository: systemml Updated Branches: refs/heads/master f3e3bdd78 -> 642a00638
[MINOR] Cleanup spark codegen instructions (number of partitions) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c4cc170 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c4cc170 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c4cc170 Branch: refs/heads/master Commit: 6c4cc17006864bfb7dc3d4d4df8325bbee314735 Parents: f3e3bdd Author: Matthias Boehm <[email protected]> Authored: Mon Jan 1 19:28:49 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Mon Jan 1 19:28:49 2018 -0800 ---------------------------------------------------------------------- .../instructions/spark/SpoofSPInstruction.java | 23 ++++++-------------- .../spark/utils/RDDAggregateUtils.java | 12 +++++----- 2 files changed, 13 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/6c4cc170/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index f5b1576..a1b368d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -144,13 +144,10 @@ public class SpoofSPInstruction extends SPInstruction { if( (op.getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) || (op.getCellType()==CellType.COL_AGG && mcIn.getRows() > mcIn.getRowsPerBlock())) { - //TODO investigate if some other side effect of correct blocks long numBlocks = (op.getCellType()==CellType.ROW_AGG ) ? mcIn.getNumRowBlocks() : mcIn.getNumColBlocks(); - if( out.partitions().size() > numBlocks ) - out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int)numBlocks, false); - else - out = RDDAggregateUtils.aggByKeyStable(out, aggop, false); + out = RDDAggregateUtils.aggByKeyStable(out, aggop, + (int)Math.min(out.getNumPartitions(), numBlocks), false); } sec.setRDDHandleForVariable(_out.getName(), out); @@ -189,11 +186,9 @@ public class SpoofSPInstruction extends SPInstruction { out = in.mapPartitionsToPair(new OuterProductFunction( _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) { - //TODO investigate if some other side effect of correct blocks - if( in.partitions().size() > mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() ) - out = RDDAggregateUtils.sumByKeyStable(out, (int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()), false); - else - out = RDDAggregateUtils.sumByKeyStable(out, false); + long numBlocks = mcOut.getNumRowBlocks() * mcOut.getNumColBlocks(); + out = RDDAggregateUtils.sumByKeyStable(out, + (int)Math.min(out.getNumPartitions(), numBlocks), false); } sec.setRDDHandleForVariable(_out.getName(), out); @@ -231,13 +226,9 @@ public class SpoofSPInstruction extends SPInstruction { else //row-agg or no-agg { if( op.getRowType()==RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) { - //TODO investigate if some other side effect of correct blocks - if( out.partitions().size() > mcIn.getNumRowBlocks() ) - out = RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks(), false); - else - out = RDDAggregateUtils.sumByKeyStable(out, false); + out = RDDAggregateUtils.sumByKeyStable(out, + (int)Math.min(out.getNumPartitions(), mcIn.getNumRowBlocks()), false); } - sec.setRDDHandleForVariable(_out.getName(), out); //maintain lineage info and output characteristics http://git-wip-us.apache.org/repos/asf/systemml/blob/6c4cc170/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java index 5ed5df8..70174a9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java @@ -96,14 +96,14 @@ public class RDDAggregateUtils public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in ) { - //stable sum of blocks per key, by passing correction blocks along with aggregates + //stable sum of blocks per key, by passing correction blocks along with aggregates JavaPairRDD<MatrixIndexes, KahanObject> tmp = in.combineByKey( new CreateCellCombinerFunction(), new MergeSumCellValueFunction(), new MergeSumCellCombinerFunction() ); - //strip-off correction blocks from - JavaPairRDD<MatrixIndexes, Double> out = + //strip-off correction blocks from + JavaPairRDD<MatrixIndexes, Double> out = tmp.mapValues( new ExtractDoubleCell() ); //return the aggregate rdd @@ -152,13 +152,13 @@ public class RDDAggregateUtils public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop, int numPartitions, boolean deepCopyCombiner ) { - //stable sum of blocks per key, by passing correction blocks along with aggregates + //stable sum of blocks per key, by passing correction blocks along with aggregates JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = - in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner), + in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner), new MergeAggBlockValueFunction(aop), new MergeAggBlockCombinerFunction(aop), numPartitions ); - //strip-off correction blocks from + //strip-off correction blocks from JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp.mapValues( new ExtractMatrixBlock() );
