[SYSTEMML-1274] Improved nnz maintenance on spark rdd write We now consistently piggyback any nnz maintenance on write operations in order to avoid unnecessary RDD computation. Furthermore, this change also removes the utils primitive to compute the nnz in isolation in order to prevent reintroducing such inefficiencies.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/73afc2c1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/73afc2c1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/73afc2c1 Branch: refs/heads/master Commit: 73afc2c19fe34caf08ec2c63bdbfb0b42aab881f Parents: ee7591c Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Feb 16 12:12:57 2017 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Feb 16 12:13:48 2017 -0800 ---------------------------------------------------------------------- .../context/SparkExecutionContext.java | 9 ++++--- .../instructions/spark/WriteSPInstruction.java | 13 ++++++---- .../instructions/spark/utils/SparkUtils.java | 25 -------------------- 3 files changed, 15 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 66fab1e..77bcc8d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.storage.RDDInfo; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; @@ -55,6 +56,7 @@ import org.apache.sysml.runtime.instructions.spark.data.LineageObject; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBlock; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; +import org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyFrameBlockPairFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction; @@ -966,8 +968,9 @@ public class SparkExecutionContext extends ExecutionContext { JavaPairRDD<MatrixIndexes,MatrixBlock> lrdd = (JavaPairRDD<MatrixIndexes, MatrixBlock>) rdd.getRDD(); - //recompute nnz - long nnz = SparkUtils.computeNNZFromBlocks(lrdd); + //piggyback nnz maintenance on write + LongAccumulator aNnz = getSparkContextStatic().sc().longAccumulator("nnz"); + lrdd = lrdd.mapValues(new ComputeBinaryBlockNnzFunction(aNnz)); //save file is an action which also triggers nnz maintenance lrdd.saveAsHadoopFile(path, @@ -976,7 +979,7 @@ public class SparkExecutionContext extends ExecutionContext oinfo.outputFormatClass); //return nnz aggregate of all blocks - return nnz; + return aNnz.value(); } @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java index 3387770..c30c85f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java @@ -39,7 +39,6 @@ import org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockN import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; -import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties; import org.apache.sysml.runtime.matrix.data.FileFormatProperties; @@ -179,9 +178,12 @@ public class WriteSPInstruction extends SPInstruction if( oi == OutputInfo.MatrixMarketOutputInfo || oi == OutputInfo.TextCellOutputInfo ) { - //recompute nnz if necessary (required for header if matrix market) - if ( isInputMatrixBlock && !mc.nnzKnown() ) - mc.setNonZeros( SparkUtils.computeNNZFromBlocks(in1) ); + //piggyback nnz maintenance on write + LongAccumulator aNnz = null; + if ( isInputMatrixBlock && !mc.nnzKnown() ) { + aNnz = sec.getSparkContext().sc().longAccumulator("nnz"); + in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz)); + } JavaRDD<String> header = null; if( oi == OutputInfo.MatrixMarketOutputInfo ) { @@ -199,6 +201,9 @@ public class WriteSPInstruction extends SPInstruction customSaveTextFile(header.union(ijv), fname, true); else customSaveTextFile(ijv, fname, false); + + if ( isInputMatrixBlock && !mc.nnzKnown() ) + mc.setNonZeros( aNnz.value() ); } else if( oi == OutputInfo.CSVOutputInfo ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/73afc2c1/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java index d53f3cf..d27e37a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java @@ -249,29 +249,4 @@ public class SparkUtils arg0.getNonZeros() + arg1.getNonZeros() ); //sum } } - - /** - * Utility to compute number of non-zeros from the given RDD of MatrixBlocks - * - * @param rdd matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} - * @return number of non-zeros - */ - public static long computeNNZFromBlocks(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) { - long nnz = rdd.values().aggregate( 0L, - new Function2<Long,MatrixBlock,Long>() { - private static final long serialVersionUID = 4907645080949985267L; - @Override - public Long call(Long v1, MatrixBlock v2) throws Exception { - return (v1 + v2.getNonZeros()); - } - }, - new Function2<Long,Long,Long>() { - private static final long serialVersionUID = 333028431986883739L; - @Override - public Long call(Long v1, Long v2) throws Exception { - return v1+v2; - } - } ); - return nnz; - } }