Repository: systemml
Updated Branches:
  refs/heads/master f3e3bdd78 -> 642a00638


[MINOR] Cleanup spark codegen instructions (number of partitions)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c4cc170
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c4cc170
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c4cc170

Branch: refs/heads/master
Commit: 6c4cc17006864bfb7dc3d4d4df8325bbee314735
Parents: f3e3bdd
Author: Matthias Boehm <[email protected]>
Authored: Mon Jan 1 19:28:49 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jan 1 19:28:49 2018 -0800

----------------------------------------------------------------------
 .../instructions/spark/SpoofSPInstruction.java  | 23 ++++++--------------
 .../spark/utils/RDDAggregateUtils.java          | 12 +++++-----
 2 files changed, 13 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6c4cc170/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index f5b1576..a1b368d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -144,13 +144,10 @@ public class SpoofSPInstruction extends SPInstruction {
                                
                                if( (op.getCellType()==CellType.ROW_AGG && 
mcIn.getCols() > mcIn.getColsPerBlock())
                                        || (op.getCellType()==CellType.COL_AGG 
&& mcIn.getRows() > mcIn.getRowsPerBlock())) {
-                                       //TODO investigate if some other side 
effect of correct blocks
                                        long numBlocks = 
(op.getCellType()==CellType.ROW_AGG ) ? 
                                                mcIn.getNumRowBlocks() : 
mcIn.getNumColBlocks();
-                                       if( out.partitions().size() > numBlocks 
)
-                                               out = 
RDDAggregateUtils.aggByKeyStable(out, aggop, (int)numBlocks, false);
-                                       else
-                                               out = 
RDDAggregateUtils.aggByKeyStable(out, aggop, false);
+                                       out = 
RDDAggregateUtils.aggByKeyStable(out, aggop,
+                                               
(int)Math.min(out.getNumPartitions(), numBlocks), false);
                                }
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
@@ -189,11 +186,9 @@ public class SpoofSPInstruction extends SPInstruction {
                                out = in.mapPartitionsToPair(new 
OuterProductFunction(
                                        _class.getName(), _classBytes, bcVect2, 
bcMatrices, scalars), true);
                                if(type == OutProdType.LEFT_OUTER_PRODUCT || 
type == OutProdType.RIGHT_OUTER_PRODUCT ) {
-                                       //TODO investigate if some other side 
effect of correct blocks
-                                       if( in.partitions().size() > 
mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() )
-                                               out = 
RDDAggregateUtils.sumByKeyStable(out, 
(int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()), false);
-                                       else
-                                               out = 
RDDAggregateUtils.sumByKeyStable(out, false);     
+                                       long numBlocks = 
mcOut.getNumRowBlocks() * mcOut.getNumColBlocks();
+                                       out = 
RDDAggregateUtils.sumByKeyStable(out,
+                                               
(int)Math.min(out.getNumPartitions(), numBlocks), false);
                                }
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
@@ -231,13 +226,9 @@ public class SpoofSPInstruction extends SPInstruction {
                        else //row-agg or no-agg 
                        {
                                if( op.getRowType()==RowType.ROW_AGG && 
mcIn.getCols() > mcIn.getColsPerBlock() ) {
-                                       //TODO investigate if some other side 
effect of correct blocks
-                                       if( out.partitions().size() > 
mcIn.getNumRowBlocks() )
-                                               out = 
RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks(), false);
-                                       else
-                                               out = 
RDDAggregateUtils.sumByKeyStable(out, false);
+                                       out = 
RDDAggregateUtils.sumByKeyStable(out,
+                                               
(int)Math.min(out.getNumPartitions(), mcIn.getNumRowBlocks()), false);
                                }
-                               
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
                                //maintain lineage info and output 
characteristics

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c4cc170/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
index 5ed5df8..70174a9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
@@ -96,14 +96,14 @@ public class RDDAggregateUtils
 
        public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( 
JavaPairRDD<MatrixIndexes, Double> in )
        {
-               //stable sum of blocks per key, by passing correction blocks 
along with aggregates              
+               //stable sum of blocks per key, by passing correction blocks 
along with aggregates
                JavaPairRDD<MatrixIndexes, KahanObject> tmp = 
                                in.combineByKey( new 
CreateCellCombinerFunction(), 
                                                             new 
MergeSumCellValueFunction(), 
                                                             new 
MergeSumCellCombinerFunction() );
                
-               //strip-off correction blocks from                              
             
-               JavaPairRDD<MatrixIndexes, Double> out =  
+               //strip-off correction blocks from
+               JavaPairRDD<MatrixIndexes, Double> out =
                                tmp.mapValues( new ExtractDoubleCell() );
                
                //return the aggregate rdd
@@ -152,13 +152,13 @@ public class RDDAggregateUtils
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable( 
JavaPairRDD<MatrixIndexes, MatrixBlock> in, 
                        AggregateOperator aop, int numPartitions, boolean 
deepCopyCombiner )
        {
-               //stable sum of blocks per key, by passing correction blocks 
along with aggregates              
+               //stable sum of blocks per key, by passing correction blocks 
along with aggregates
                JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = 
-                               in.combineByKey( new 
CreateCorrBlockCombinerFunction(deepCopyCombiner), 
+                               in.combineByKey( new 
CreateCorrBlockCombinerFunction(deepCopyCombiner),
                                                             new 
MergeAggBlockValueFunction(aop), 
                                                             new 
MergeAggBlockCombinerFunction(aop), numPartitions );
                
-               //strip-off correction blocks from                              
             
+               //strip-off correction blocks from
                JavaPairRDD<MatrixIndexes, MatrixBlock> out =  
                                tmp.mapValues( new ExtractMatrixBlock() );
                

Reply via email to