[SYSTEMML-1274] Improved nnz maintenance on spark rdd write

We now consistently piggyback any nnz maintenance on write operations in
order to avoid unnecessary RDD computation. Furthermore, this change
also removes the utils primitive to compute the nnz in isolation in
order to prevent reintroducing such inefficiencies.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/73afc2c1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/73afc2c1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/73afc2c1

Branch: refs/heads/master
Commit: 73afc2c19fe34caf08ec2c63bdbfb0b42aab881f
Parents: ee7591c
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Feb 16 12:12:57 2017 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Feb 16 12:13:48 2017 -0800

----------------------------------------------------------------------
 .../context/SparkExecutionContext.java          |  9 ++++---
 .../instructions/spark/WriteSPInstruction.java  | 13 ++++++----
 .../instructions/spark/utils/SparkUtils.java    | 25 --------------------
 3 files changed, 15 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 66fab1e..77bcc8d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.storage.RDDInfo;
 import org.apache.spark.storage.StorageLevel;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.MLContextProxy;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -55,6 +56,7 @@ import 
org.apache.sysml.runtime.instructions.spark.data.LineageObject;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBlock;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
 import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
+import 
org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction;
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyFrameBlockPairFunction;
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
@@ -966,8 +968,9 @@ public class SparkExecutionContext extends ExecutionContext
        {
                JavaPairRDD<MatrixIndexes,MatrixBlock> lrdd = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>) rdd.getRDD();
                
-               //recompute nnz 
-               long nnz = SparkUtils.computeNNZFromBlocks(lrdd);
+               //piggyback nnz maintenance on write
+               LongAccumulator aNnz = 
getSparkContextStatic().sc().longAccumulator("nnz");
+               lrdd = lrdd.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
                
                //save file is an action which also triggers nnz maintenance
                lrdd.saveAsHadoopFile(path, 
@@ -976,7 +979,7 @@ public class SparkExecutionContext extends ExecutionContext
                                oinfo.outputFormatClass);
                
                //return nnz aggregate of all blocks
-               return nnz;
+               return aNnz.value();
        }
 
        @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
index 3387770..c30c85f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
@@ -39,7 +39,6 @@ import 
org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockN
 import 
org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import 
org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
-import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
 import org.apache.sysml.runtime.matrix.data.FileFormatProperties;
@@ -179,9 +178,12 @@ public class WriteSPInstruction extends SPInstruction
                if(    oi == OutputInfo.MatrixMarketOutputInfo
                        || oi == OutputInfo.TextCellOutputInfo     ) 
                {
-                       //recompute nnz if necessary (required for header if 
matrix market)
-                       if ( isInputMatrixBlock && !mc.nnzKnown() )
-                               mc.setNonZeros( 
SparkUtils.computeNNZFromBlocks(in1) );
+                       //piggyback nnz maintenance on write
+                       LongAccumulator aNnz = null;
+                       if ( isInputMatrixBlock && !mc.nnzKnown() ) {
+                               aNnz = 
sec.getSparkContext().sc().longAccumulator("nnz");
+                               in1 = in1.mapValues(new 
ComputeBinaryBlockNnzFunction(aNnz));
+                       }
                        
                        JavaRDD<String> header = null;                          
                        if( oi == OutputInfo.MatrixMarketOutputInfo  ) {
@@ -199,6 +201,9 @@ public class WriteSPInstruction extends SPInstruction
                                customSaveTextFile(header.union(ijv), fname, 
true);
                        else
                                customSaveTextFile(ijv, fname, false);
+                       
+                       if ( isInputMatrixBlock && !mc.nnzKnown() )
+                               mc.setNonZeros( aNnz.value() );
                }
                else if( oi == OutputInfo.CSVOutputInfo ) 
                {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
index d53f3cf..d27e37a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
@@ -249,29 +249,4 @@ public class SparkUtils
                                        arg0.getNonZeros() + arg1.getNonZeros() 
); //sum
                }       
        }
-
-       /**
-        * Utility to compute number of non-zeros from the given RDD of 
MatrixBlocks
-        * 
-        * @param rdd matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
-        * @return number of non-zeros
-        */
-       public static long computeNNZFromBlocks(JavaPairRDD<MatrixIndexes, 
MatrixBlock> rdd) {
-               long nnz = rdd.values().aggregate(      0L, 
-                                               new 
Function2<Long,MatrixBlock,Long>() {
-                                                       private static final 
long serialVersionUID = 4907645080949985267L;
-                                                       @Override
-                                                       public Long call(Long 
v1, MatrixBlock v2) throws Exception {
-                                                               return (v1 + 
v2.getNonZeros());
-                                                       } 
-                                               }, 
-                                               new Function2<Long,Long,Long>() 
{
-                                                       private static final 
long serialVersionUID = 333028431986883739L;
-                                                       @Override
-                                                       public Long call(Long 
v1, Long v2) throws Exception {
-                                                               return v1+v2;
-                                                       }
-                                               } );
-               return nnz;
-       }
 }

Reply via email to