[SYSTEMML-562] Generalized spark checkpoint instruction (frame support) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/507172a1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/507172a1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/507172a1
Branch: refs/heads/master Commit: 507172a1085f47a6a4338ae7a76f72d055201fca Parents: 4c058fa Author: Matthias Boehm <[email protected]> Authored: Wed Jun 1 00:45:00 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 1 09:48:45 2016 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 2 +- .../controlprogram/caching/FrameObject.java | 3 +- .../context/SparkExecutionContext.java | 94 +++++++++++++++----- .../spark/CheckpointSPInstruction.java | 42 +++++---- .../spark/functions/CopyFrameBlockFunction.java | 52 +++++++++++ .../functions/CopyFrameBlockPairFunction.java | 60 +++++++++++++ .../sysml/runtime/matrix/data/InputInfo.java | 3 +- 7 files changed, 217 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 088ddb4..8519055 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -373,7 +373,7 @@ public abstract class Hop //investigate need for serialized storage of large sparse matrices //(compile- instead of runtime-level for better debugging) boolean serializedStorage = false; - if( dimsKnown(true) && !Checkpoint.CHECKPOINT_SPARSE_CSR ) { + if( getDataType()==DataType.MATRIX && dimsKnown(true) && !Checkpoint.CHECKPOINT_SPARSE_CSR ) { double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(_dim1, _dim2, _rows_in_block, _cols_in_block, _nnz); double dataCache = SparkExecutionContext.getDataMemoryBudget(true, true); serializedStorage = (MatrixBlock.evalSparseFormatInMemory(_dim1, _dim2, _nnz) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java index 78f58be..2df32c6 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/FrameObject.java @@ -148,8 +148,7 @@ public class FrameObject extends CacheableData<FrameBlock> FrameBlock data = null; try { - FrameReader reader = FrameReaderFactory.createFrameReader( - iimd.getInputInfo()); + FrameReader reader = FrameReaderFactory.createFrameReader(iimd.getInputInfo()); data = reader.readFrameFromHDFS(fname, _schema, mc.getRows(), mc.getCols()); } catch( DMLRuntimeException ex ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index d6da82f..f282579 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -41,6 +41,7 @@ import scala.Tuple2; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.MLContext; import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.Checkpoint; import org.apache.sysml.runtime.DMLRuntimeException; @@ -48,6 +49,7 @@ import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.spark.CheckpointSPInstruction; import org.apache.sysml.runtime.instructions.spark.SPInstruction; import org.apache.sysml.runtime.instructions.spark.data.BlockPartitioner; @@ -58,12 +60,14 @@ import org.apache.sysml.runtime.instructions.spark.data.PartitionedMatrixBlock; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction; +import org.apache.sysml.runtime.instructions.spark.functions.CopyFrameBlockPairFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction; import org.apache.sysml.runtime.instructions.spark.functions.CreateSparseBlockFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixCell; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -252,8 +256,18 @@ public class SparkExecutionContext extends ExecutionContext public JavaPairRDD<?,?> getRDDHandleForVariable( String varname, InputInfo inputInfo ) throws DMLRuntimeException { - MatrixObject mo = getMatrixObject(varname); - return getRDDHandleForMatrixObject(mo, inputInfo); + Data dat = getVariable(varname); + if( dat instanceof MatrixObject ) { + MatrixObject mo = getMatrixObject(varname); + return getRDDHandleForMatrixObject(mo, inputInfo); + } + else if( dat instanceof FrameObject ) { + FrameObject fo = getFrameObject(varname); + return getRDDHandleForFrameObject(fo, inputInfo); + } + else { + throw new DMLRuntimeException("Failed to obtain RDD for data type other than matrix or frame."); + } } /** @@ -300,7 +314,7 @@ public class SparkExecutionContext extends ExecutionContext } else { //default case MatrixBlock mb = mo.acquireRead(); //pin matrix in memory - rdd = toJavaPairRDD(getSparkContext(), mb, (int)mo.getNumRowsPerBlock(), (int)mo.getNumColumnsPerBlock()); + rdd = toMatrixJavaPairRDD(getSparkContext(), mb, (int)mo.getNumRowsPerBlock(), (int)mo.getNumColumnsPerBlock()); mo.release(); //unpin matrix } @@ -353,12 +367,15 @@ public class SparkExecutionContext extends ExecutionContext @SuppressWarnings("unchecked") public JavaPairRDD<?,?> getRDDHandleForFrameObject( FrameObject fo, InputInfo inputInfo ) throws DMLRuntimeException - { - //NOTE: MB this logic should be integrated into MatrixObject + { + //NOTE: MB this logic should be integrated into FrameObject //However, for now we cannot assume that spark libraries are //always available and hence only store generic references in //matrix object while all the logic is in the SparkExecContext + InputInfo inputInfo2 = (inputInfo==InputInfo.BinaryBlockInputInfo) ? + InputInfo.BinaryBlockFrameInputInfo : inputInfo; + JavaPairRDD<?,?> rdd = null; //CASE 1: rdd already existing (reuse if checkpoint or trigger //pending rdd operations if not yet cached but prevent to re-evaluate @@ -379,15 +396,14 @@ public class SparkExecutionContext extends ExecutionContext if( fo.isDirty() ) { //write only if necessary fo.exportData(); } - rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass); - rdd = ((JavaPairRDD<MatrixIndexes, MatrixBlock>)rdd).mapToPair( new CopyBlockPairFunction() ); //cp is workaround for read bug + rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass); + rdd = ((JavaPairRDD<LongWritable, FrameBlock>)rdd).mapToPair( new CopyFrameBlockPairFunction() ); //cp is workaround for read bug fromFile = true; } else { //default case - //MatrixBlock mb = mo.acquireRead(); //pin matrix in memory - //rdd = toJavaPairRDD(getSparkContext(), mb, (int)mo.getNumRowsPerBlock(), (int)mo.getNumColumnsPerBlock()); - //fo.release(); //unpin matrix - throw new RuntimeException("Not implemented yet."); + FrameBlock fb = fo.acquireRead(); //pin frame in memory + rdd = toFrameJavaPairRDD(getSparkContext(), fb); + fo.release(); //unpin frame } //keep rdd handle for future operations on it @@ -400,19 +416,18 @@ public class SparkExecutionContext extends ExecutionContext { // parallelize hdfs-resident file // For binary block, these are: SequenceFileInputFormat.class, MatrixIndexes.class, MatrixBlock.class - if(inputInfo == InputInfo.BinaryBlockInputInfo) { - rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass); + if(inputInfo2 == InputInfo.BinaryBlockFrameInputInfo) { + rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass); //note: this copy is still required in Spark 1.4 because spark hands out whatever the inputformat //recordreader returns; the javadoc explicitly recommend to copy all key/value pairs - rdd = ((JavaPairRDD<MatrixIndexes, MatrixBlock>)rdd).mapToPair( new CopyBlockPairFunction() ); //cp is workaround for read bug + rdd = ((JavaPairRDD<LongWritable, FrameBlock>)rdd).mapToPair( new CopyFrameBlockPairFunction() ); //cp is workaround for read bug } - else if(inputInfo == InputInfo.TextCellInputInfo || inputInfo == InputInfo.CSVInputInfo || inputInfo == InputInfo.MatrixMarketInputInfo) { - rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass); + else if(inputInfo2 == InputInfo.TextCellInputInfo || inputInfo2 == InputInfo.CSVInputInfo || inputInfo2 == InputInfo.MatrixMarketInputInfo) { + rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass); rdd = ((JavaPairRDD<LongWritable, Text>)rdd).mapToPair( new CopyTextInputFunction() ); //cp is workaround for read bug } - else if(inputInfo == InputInfo.BinaryCellInputInfo) { - rdd = getSparkContext().hadoopFile( fo.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass); - rdd = ((JavaPairRDD<MatrixIndexes, MatrixCell>)rdd).mapToPair( new CopyBinaryCellFunction() ); //cp is workaround for read bug + else if(inputInfo2 == InputInfo.BinaryCellInputInfo) { + throw new DMLRuntimeException("Binarycell not supported for frames."); } else { throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable"); @@ -541,7 +556,7 @@ public class SparkExecutionContext extends ExecutionContext * @return * @throws DMLRuntimeException */ - public static JavaPairRDD<MatrixIndexes,MatrixBlock> toJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen) + public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen) throws DMLRuntimeException { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; @@ -588,6 +603,45 @@ public class SparkExecutionContext extends ExecutionContext } /** + * + * @param sc + * @param src + * @return + * @throws DMLRuntimeException + */ + public static JavaPairRDD<Long,FrameBlock> toFrameJavaPairRDD(JavaSparkContext sc, FrameBlock src) + throws DMLRuntimeException + { + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; + LinkedList<Tuple2<Long,FrameBlock>> list = new LinkedList<Tuple2<Long,FrameBlock>>(); + + //create and write subblocks of matrix + int blksize = ConfigurationManager.getBlocksize(); + for(int blockRow = 0; blockRow < (int)Math.ceil(src.getNumRows()/(double)blksize); blockRow++) + { + int maxRow = (blockRow*blksize + blksize < src.getNumRows()) ? blksize : src.getNumRows() - blockRow*blksize; + int roffset = blockRow*blksize; + + FrameBlock block = new FrameBlock(src.getSchema()); + + //copy submatrix to block + src.sliceOperations( roffset, roffset+maxRow-1, + 0, src.getNumColumns()-1, block ); + + //append block to sequence file + list.addLast(new Tuple2<Long,FrameBlock>(new Long(roffset+1), block)); + } + + JavaPairRDD<Long,FrameBlock> result = sc.parallelizePairs(list); + if (DMLScript.STATISTICS) { + Statistics.accSparkParallelizeTime(System.nanoTime() - t0); + Statistics.incSparkParallelizeCount(1); + } + + return result; + } + + /** * This method is a generic abstraction for calls from the buffer pool. * See toMatrixBlock(JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, int numRows, int numCols); * http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java index 31d7665..24c614a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java @@ -22,8 +22,9 @@ package org.apache.sysml.runtime.instructions.spark; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.storage.StorageLevel; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; @@ -32,8 +33,11 @@ import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockFunction; +import org.apache.sysml.runtime.instructions.spark.functions.CopyFrameBlockFunction; import org.apache.sysml.runtime.instructions.spark.functions.CreateSparseBlockFunction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.SparseBlock; @@ -45,7 +49,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction //default storage level private StorageLevel _level = null; - public CheckpointSPInstruction(Operator op, CPOperand in, CPOperand out, StorageLevel level, String opcode, String istr){ + public CheckpointSPInstruction(Operator op, CPOperand in, CPOperand out, StorageLevel level, String opcode, String istr) { super(op, in, out, opcode, istr); _sptype = SPINSTRUCTION_TYPE.Reorg; @@ -61,13 +65,13 @@ public class CheckpointSPInstruction extends UnarySPInstruction String opcode = parts[0]; CPOperand in = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); - StorageLevel level = StorageLevel.fromString(parts[3]); return new CheckpointSPInstruction(null, in, out, level, opcode, str); } @Override + @SuppressWarnings("unchecked") public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { @@ -85,8 +89,8 @@ public class CheckpointSPInstruction extends UnarySPInstruction return; } - //get input rdd handle - JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); + //get input rdd handle (for matrix or frame) + JavaPairRDD<?,?> in = sec.getRDDHandleForVariable( input1.getName(), InputInfo.BinaryBlockInputInfo ); MatrixCharacteristics mcIn = sec.getMatrixCharacteristics( input1.getName() ); // Step 2: Checkpoint given rdd (only if currently in different storage level to prevent redundancy) @@ -94,7 +98,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction // Note that persist is an transformation which will be triggered on-demand with the next rdd operations // This prevents unnecessary overhead if the dataset is only consumed by cp operations. - JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; + JavaPairRDD<?,?> out = null; if( !in.getStorageLevel().equals( _level ) ) { //investigate issue of unnecessarily large number of partitions @@ -109,12 +113,20 @@ public class CheckpointSPInstruction extends UnarySPInstruction else { //since persist is an in-place marker for a storage level, we //apply a narrow shallow copy to allow for short-circuit collects - out = in.mapValues(new CopyBlockFunction(false)); + if( input1.getDataType() == DataType.MATRIX ) + out = ((JavaPairRDD<MatrixIndexes,MatrixBlock>)in) + .mapValues(new CopyBlockFunction(false)); + else if( input1.getDataType() == DataType.FRAME) + out = ((JavaPairRDD<Long,FrameBlock>)in) + .mapValues(new CopyFrameBlockFunction(false)); } //convert mcsr into memory-efficient csr if potentially sparse - if( OptimizerUtils.checkSparseBlockCSRConversion(mcIn) ) { - out = out.mapValues(new CreateSparseBlockFunction(SparseBlock.Type.CSR)); + if( input1.getDataType()==DataType.MATRIX + && OptimizerUtils.checkSparseBlockCSRConversion(mcIn) ) + { + out = ((JavaPairRDD<MatrixIndexes,MatrixBlock>)out) + .mapValues(new CreateSparseBlockFunction(SparseBlock.Type.CSR)); } //actual checkpoint into given storage level @@ -124,7 +136,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction out = in; //pass-through } - // Step 3: In-place update of input matrix rdd handle and set as output + // Step 3: In-place update of input matrix/frame rdd handle and set as output // ------- // We use this in-place approach for two reasons. First, it is correct because our checkpoint // injection rewrites guarantee that after checkpoint instructions there are no consumers on the @@ -133,15 +145,15 @@ public class CheckpointSPInstruction extends UnarySPInstruction // caching and subsequent collects. Note that in-place update requires us to explicitly handle // lineage information in order to prevent cycles on cleanup. - MatrixObject mo = sec.getMatrixObject( input1.getName() ); + CacheableData<?> cd = sec.getCacheableData( input1.getName() ); if( out != in ) { //prevent unnecessary lineage info - RDDObject inro = mo.getRDDHandle(); //guaranteed to exist (see above) + RDDObject inro = cd.getRDDHandle(); //guaranteed to exist (see above) RDDObject outro = new RDDObject(out, output.getName()); //create new rdd object outro.setCheckpointRDD(true); //mark as checkpointed outro.addLineageChild(inro); //keep lineage to prevent cycles on cleanup - mo.setRDDHandle(outro); + cd.setRDDHandle(outro); } - sec.setVariable( output.getName(), mo); + sec.setVariable( output.getName(), cd); } /** @@ -150,7 +162,7 @@ public class CheckpointSPInstruction extends UnarySPInstruction * @param in * @return */ - public static int getNumCoalescePartitions(MatrixCharacteristics mc, JavaPairRDD<MatrixIndexes,MatrixBlock> in) + public static int getNumCoalescePartitions(MatrixCharacteristics mc, JavaPairRDD<?,?> in) { if( mc.dimsKnown(true) ) { double hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockFunction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockFunction.java new file mode 100644 index 0000000..10beceb --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockFunction.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.spark.api.java.function.Function; +import org.apache.sysml.runtime.matrix.data.FrameBlock; + +/** + * General purpose copy function for binary block rdds. This function can be used in + * mapValues (copy frame blocks). It supports both deep and shallow copies of values. + * + */ +public class CopyFrameBlockFunction implements Function<FrameBlock,FrameBlock> +{ + private static final long serialVersionUID = 612972882700587381L; + + private boolean _deepCopy = true; + + public CopyFrameBlockFunction() { + this(true); + } + + public CopyFrameBlockFunction(boolean deepCopy) { + _deepCopy = deepCopy; + } + + @Override + public FrameBlock call(FrameBlock arg0) + throws Exception + { + if( _deepCopy ) + return new FrameBlock(arg0); + else + return arg0; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockPairFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockPairFunction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockPairFunction.java new file mode 100644 index 0000000..9e31878 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyFrameBlockPairFunction.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.hadoop.io.LongWritable; +import org.apache.spark.api.java.function.PairFunction; + +import scala.Tuple2; + +import org.apache.sysml.runtime.matrix.data.FrameBlock; + +/** + * General purpose copy function for binary block rdds. This function can be used in + * mapToPair (copy frame indexes and blocks). It supports both deep and shallow copies + * of key/value pairs. + * + */ +public class CopyFrameBlockPairFunction implements PairFunction<Tuple2<LongWritable, FrameBlock>,Long, FrameBlock> +{ + private static final long serialVersionUID = -4686652006382558021L; + + private boolean _deepCopy = true; + + public CopyFrameBlockPairFunction() { + this(true); + } + + public CopyFrameBlockPairFunction(boolean deepCopy) { + _deepCopy = deepCopy; + } + + @Override + public Tuple2<Long, FrameBlock> call(Tuple2<LongWritable, FrameBlock> arg0) + throws Exception + { + if( _deepCopy ) { + FrameBlock block = new FrameBlock(arg0._2()); + return new Tuple2<Long,FrameBlock>(arg0._1().get(), block); + } + else { + return new Tuple2<Long,FrameBlock>(arg0._1().get(), arg0._2()); + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/507172a1/src/main/java/org/apache/sysml/runtime/matrix/data/InputInfo.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/InputInfo.java b/src/main/java/org/apache/sysml/runtime/matrix/data/InputInfo.java index ee64cdc..a80ca34 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/InputInfo.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/InputInfo.java @@ -72,8 +72,9 @@ public class InputInfo implements Serializable public static final InputInfo BinaryCellInputInfo=new InputInfo(SequenceFileInputFormat.class, MatrixIndexes.class, MatrixCell.class); public static final InputInfo BinaryBlockInputInfo=new InputInfo( - //for jobs like GMR, we use CombineSequenceFileInputFormat (which requires to specify the maxsplitsize, hence not included here) SequenceFileInputFormat.class, MatrixIndexes.class, MatrixBlock.class); + public static final InputInfo BinaryBlockFrameInputInfo=new InputInfo( + SequenceFileInputFormat.class, LongWritable.class, FrameBlock.class); // Format that denotes the input of a SORT job public static final InputInfo InputInfoForSort=new InputInfo(SequenceFileInputFormat.class,
