[SYSTEMML-1292] Support for codegen spark ops w/ multiple RDD inputs So far all codegen spark operations (all templates) can only consume one RDD input, while all side inputs are transferred via broadcasts. This captures the common case but easily fails with OOMs for multiple large inputs. This patch adds support for arbitrary combinations of RDD and broadcast inputs. Instead of generating these RDD operations and functions (as was the original plan), we now generalized the respective spark instruction to handle these decisions at runtime via additional pre-processing that joins (and replicates) all RDD inputs and a generic function signature that is independent of the number of input RDDs.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/95de2358 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/95de2358 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/95de2358 Branch: refs/heads/master Commit: 95de23586988db56de57b731c03a751d5052a18e Parents: 57dff5d Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Aug 11 23:04:45 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Aug 11 23:04:45 2017 -0700 ---------------------------------------------------------------------- .../spark/QuaternarySPInstruction.java | 58 +-- .../instructions/spark/SpoofSPInstruction.java | 387 ++++++++++++------- .../spark/functions/ReplicateBlockFunction.java | 86 +++++ 3 files changed, 336 insertions(+), 195 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java index 7dfe4be..711e3d0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java @@ -23,7 +23,6 @@ package org.apache.sysml.runtime.instructions.spark; import java.io.Serializable; import java.util.ArrayList; import java.util.Iterator; -import java.util.LinkedList; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; @@ -53,6 +52,7 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; +import org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -251,12 +251,12 @@ public class QuaternarySPInstruction extends ComputationSPInstruction //preparation of transposed and replicated U if( inU != null ) - inU = inU.flatMapToPair(new ReplicateBlocksFunction(clen, bclen, true)); + inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, bclen, true)); //preparation of transposed and replicated V if( inV != null ) inV = inV.mapToPair(new TransposeFactorIndexesFunction()) - .flatMapToPair(new ReplicateBlocksFunction(rlen, brlen, false)); + .flatMapToPair(new ReplicateBlockFunction(rlen, brlen, false)); //functions calls w/ two rdd inputs if( inU != null && inV == null && inW == null ) @@ -529,57 +529,5 @@ public class QuaternarySPInstruction extends ComputationSPInstruction //output new tuple return new Tuple2<MatrixIndexes, MatrixBlock>(ixOut,blkOut); } - - } - - private static class ReplicateBlocksFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> - { - private static final long serialVersionUID = -1184696764516975609L; - - private long _len = -1; - private long _blen = -1; - private boolean _left = false; - - public ReplicateBlocksFunction(long len, long blen, boolean left) - { - _len = len; - _blen = blen; - _left = left; - } - - @Override - public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) - throws Exception - { - LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>(); - MatrixIndexes ixIn = arg0._1(); - MatrixBlock blkIn = arg0._2(); - - long numBlocks = (long) Math.ceil((double)_len/_blen); - - if( _left ) //LHS MATRIX - { - //replicate wrt # column blocks in RHS - long i = ixIn.getRowIndex(); - for( long j=1; j<=numBlocks; j++ ) { - MatrixIndexes tmpix = new MatrixIndexes(i, j); - MatrixBlock tmpblk = new MatrixBlock(blkIn); - ret.add( new Tuple2<MatrixIndexes, MatrixBlock>(tmpix, tmpblk) ); - } - } - else // RHS MATRIX - { - //replicate wrt # row blocks in LHS - long j = ixIn.getColumnIndex(); - for( long i=1; i<=numBlocks; i++ ) { - MatrixIndexes tmpix = new MatrixIndexes(i, j); - MatrixBlock tmpblk = new MatrixBlock(blkIn); - ret.add( new Tuple2<MatrixIndexes, MatrixBlock>(tmpix, tmpblk) ); - } - } - - //output list of new tuples - return ret.iterator(); - } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index 90e2184..0af46df 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -19,14 +19,20 @@ package org.apache.sysml.runtime.instructions.spark; +import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; +import java.util.LinkedList; import java.util.List; +import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; @@ -51,6 +57,7 @@ import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; +import org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -99,21 +106,23 @@ public class SpoofSPInstruction extends SPInstruction throws DMLRuntimeException { SparkExecutionContext sec = (SparkExecutionContext)ec; - - //get input rdd and variable name - ArrayList<String> bcVars = new ArrayList<String>(); + + //decide upon broadcast side inputs + boolean[] bcVect = determineBroadcastInputs(sec, _in); + boolean[] bcVect2 = getMatrixBroadcastVector(sec, _in, bcVect); + + //create joined input rdd w/ replication if needed MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); - JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() ); + JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = createJoinedInputRDD( + sec, _in, bcVect, (_class.getSuperclass() == SpoofOuterProduct.class)); JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; - - //simple case: map-side only operation (one rdd input, broadcast all) - //keep track of broadcast variables - ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>(); - ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>(); + + //create lists of input broadcasts and scalars + ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<>(); + ArrayList<ScalarObject> scalars = new ArrayList<>(); for( int i=1; i<_in.length; i++ ) { - if( _in[i].getDataType()==DataType.MATRIX) { + if( _in[i].getDataType()==DataType.MATRIX && bcVect[i] ) { bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName())); - bcVars.add(_in[i].getName()); } else if(_in[i].getDataType()==DataType.SCALAR) { //note: even if literal, it might be compiled as scalar placeholder @@ -121,8 +130,8 @@ public class SpoofSPInstruction extends SPInstruction } } - //initialize Spark Operator - if(_class.getSuperclass() == SpoofCellwise.class) // cellwise operator + //execute generated operator + if(_class.getSuperclass() == SpoofCellwise.class) //CELL { SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class); AggregateOperator aggop = getAggregateOperator(op.getAggOp()); @@ -130,7 +139,7 @@ public class SpoofSPInstruction extends SPInstruction if( _out.getDataType()==DataType.MATRIX ) { //execute codegen block operation out = in.mapPartitionsToPair(new CellwiseFunction( - _class.getName(), _classBytes, bcMatrices, scalars), true); + _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); if( (op.getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) || (op.getCellType()==CellType.COL_AGG && mcIn.getRows() > mcIn.getRowsPerBlock())) { @@ -144,33 +153,29 @@ public class SpoofSPInstruction extends SPInstruction } sec.setRDDHandleForVariable(_out.getName(), out); - //maintain lineage information for output rdd - sec.addLineageRDD(_out.getName(), _in[0].getName()); - for( String bcVar : bcVars ) - sec.addLineageBroadcast(_out.getName(), bcVar); - - //update matrix characteristics + //maintain lineage info and output characteristics + maintainLineageInfo(sec, _in, bcVect, _out); updateOutputMatrixCharacteristics(sec, op); } else { //SCALAR - out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + out = in.mapPartitionsToPair(new CellwiseFunction( + _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop); sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); } } - else if(_class.getSuperclass() == SpoofMultiAggregate.class) + else if(_class.getSuperclass() == SpoofMultiAggregate.class) //MAGG { SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class); AggOp[] aggOps = op.getAggOps(); - MatrixBlock tmpMB = in - .mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars)) - .values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps) ); + MatrixBlock tmpMB = in.mapToPair(new MultiAggregateFunction( + _class.getName(), _classBytes, bcVect2, bcMatrices, scalars)) + .values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps) ); sec.setMatrixOutput(_out.getName(), tmpMB, getExtendedOpcode()); - return; } - else if(_class.getSuperclass() == SpoofOuterProduct.class) // outer product operator + else if(_class.getSuperclass() == SpoofOuterProduct.class) //OUTER { if( _out.getDataType()==DataType.MATRIX ) { SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); @@ -180,7 +185,8 @@ public class SpoofSPInstruction extends SPInstruction updateOutputMatrixCharacteristics(sec, op); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); - out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + out = in.mapPartitionsToPair(new OuterProductFunction( + _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) { //TODO investigate if some other side effect of correct blocks if( in.partitions().size() > mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() ) @@ -190,26 +196,24 @@ public class SpoofSPInstruction extends SPInstruction } sec.setRDDHandleForVariable(_out.getName(), out); - //maintain lineage information for output rdd - sec.addLineageRDD(_out.getName(), _in[0].getName()); - for( String bcVar : bcVars ) - sec.addLineageBroadcast(_out.getName(), bcVar); - + //maintain lineage info and output characteristics + maintainLineageInfo(sec, _in, bcVect, _out); } else { - out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + out = in.mapPartitionsToPair(new OuterProductFunction( + _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); MatrixBlock tmp = RDDAggregateUtils.sumStable(out); sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0))); } } - else if( _class.getSuperclass() == SpoofRowwise.class ) { //row aggregate operator + else if( _class.getSuperclass() == SpoofRowwise.class ) { //ROW if( mcIn.getCols() > mcIn.getColsPerBlock() ) { throw new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + mcIn.getCols()+", ncolpb="+mcIn.getColsPerBlock()+"."); } SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class); RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), - _classBytes, bcMatrices, scalars, (int)mcIn.getCols()); + _classBytes, bcVect2, bcMatrices, scalars, (int)mcIn.getCols()); out = in.mapPartitionsToPair(fmmc, op.getRowType()==RowType.ROW_AGG || op.getRowType() == RowType.NO_AGG); @@ -233,21 +237,89 @@ public class SpoofSPInstruction extends SPInstruction sec.setRDDHandleForVariable(_out.getName(), out); - //maintain lineage information for output rdd - sec.addLineageRDD(_out.getName(), _in[0].getName()); - for( String bcVar : bcVars ) - sec.addLineageBroadcast(_out.getName(), bcVar); - - //update matrix characteristics + //maintain lineage info and output characteristics + maintainLineageInfo(sec, _in, bcVect, _out); updateOutputMatrixCharacteristics(sec, op); } - return; } else { throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark"); } } + private static boolean[] determineBroadcastInputs(SparkExecutionContext sec, CPOperand[] inputs) + throws DMLRuntimeException + { + boolean[] ret = new boolean[inputs.length]; + double localBudget = OptimizerUtils.getLocalMemBudget(); + double bcBudget = SparkExecutionContext.getBroadcastMemoryBudget(); + + //decided for each matrix input if it fits into remaining memory + //budget; the major input, i.e., inputs[0] is always an RDD + for( int i=1; i<inputs.length; i++ ) + if( inputs[i].getDataType().isMatrix() ) { + MatrixCharacteristics mc = sec.getMatrixCharacteristics(inputs[i].getName()); + double sizeL = OptimizerUtils.estimateSizeExactSparsity(mc); + double sizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc); + ret[i] = localBudget > sizeL && bcBudget > sizeP; + localBudget -= ret[i] ? sizeL : 0; + bcBudget -= ret[i] ? sizeP : 0; + } + + return ret; + } + + private static boolean[] getMatrixBroadcastVector(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect) + throws DMLRuntimeException + { + int numMtx = (int) Arrays.stream(inputs) + .filter(in -> in.getDataType().isMatrix()).count(); + boolean[] ret = new boolean[numMtx]; + for(int i=0, pos=0; i<inputs.length; i++) + if( inputs[i].getDataType().isMatrix() ) + ret[pos++] = bcVect[i]; + return ret; + } + + private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) + throws DMLRuntimeException + { + //get input rdd for main input + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(inputs[0].getName()); + JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(inputs[0].getName()); + JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = in.mapValues(new MapInputSignature()); + + for( int i=1; i<inputs.length; i++ ) + if( inputs[i].getDataType().isMatrix() && !bcVect[i] ) { + //create side input rdd + String varname = inputs[i].getName(); + JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec + .getBinaryBlockRDDHandleForVariable(varname); + MatrixCharacteristics mcTmp = sec.getMatrixCharacteristics(varname); + //replicate blocks if mismatch with main input + if( outer && i==2 ) + tmp = tmp.flatMapToPair(new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock())); + else if( mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks() ) + tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false)); + else if( mcIn.getNumColBlocks() > mcTmp.getNumColBlocks() ) + tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true)); + //join main and side inputs and consolidate signature + ret = ret.join(tmp) + .mapValues(new MapJoinSignature()); + } + + return ret; + } + + private static void maintainLineageInfo(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, CPOperand output) + throws DMLRuntimeException + { + //add lineage info for all rdd/broadcast inputs + for( int i=0; i<inputs.length; i++ ) + if( inputs[i].getDataType().isMatrix() ) + sec.addLineage(output.getName(), inputs[i].getName(), bcVect[i]); + } + private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op) throws DMLRuntimeException { @@ -290,30 +362,89 @@ public class SpoofSPInstruction extends SPInstruction mcOut.set(mcIn.getCols(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); } } + + private static class MapInputSignature implements Function<MatrixBlock, MatrixBlock[]> + { + private static final long serialVersionUID = -816443970067626102L; - private static class RowwiseFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + @Override + public MatrixBlock[] call(MatrixBlock v1) throws Exception { + return new MatrixBlock[]{ v1 }; + } + } + + private static class MapJoinSignature implements Function<Tuple2<MatrixBlock[],MatrixBlock>, MatrixBlock[]> + { + private static final long serialVersionUID = -704403012606821854L; + + @Override + public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> v1) throws Exception { + return ArrayUtils.add(v1._1(), v1._2()); + } + } + + private static class SpoofFunction implements Serializable + { + private static final long serialVersionUID = 2953479427746463003L; + + protected final boolean[] _bcInd; + protected final ArrayList<PartitionedBroadcast<MatrixBlock>> _inputs; + protected final ArrayList<ScalarObject> _scalars; + protected final byte[] _classBytes; + protected final String _className; + + protected SpoofFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) { + _bcInd = bcInd; + _inputs = bcMatrices; + _scalars = scalars; + _classBytes = classBytes; + _className = className; + } + + protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) + throws DMLRuntimeException + { + return getAllMatrixInputs(ixIn, blkIn, false); + } + + protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) + throws DMLRuntimeException + { + ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); + //add all rdd/broadcast inputs (main and side inputs) + for( int i=0, posRdd=0, posBc=0; i<_bcInd.length; i++ ) { + if( _bcInd[i] ) { + PartitionedBroadcast<MatrixBlock> pb = _inputs.get(posBc++); + int rowIndex = (int) ((outer && i==2) ? ixIn.getColumnIndex() : + (pb.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); + int colIndex = (int) ((outer && i==2) ? 1 : + (pb.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); + ret.add(pb.getBlock(rowIndex, colIndex)); + } + else + ret.add(blkIn[posRdd++]); + } + return ret; + } + } + + private static class RowwiseFunction extends SpoofFunction + implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -7926980450209760212L; - private final ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors; - private final ArrayList<ScalarObject> _scalars; - private final byte[] _classBytes; - private final String _className; private final int _clen; private SpoofRowwise _op = null; - public RowwiseFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int clen) + public RowwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int clen) throws DMLRuntimeException { - _className = className; - _classBytes = classBytes; - _vectors = bcMatrices; - _scalars = scalars; + super(className, classBytes, bcInd, bcMatrices, scalars); _clen = clen; } @Override - public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg ) + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg ) throws Exception { //lazy load of shipped class @@ -323,7 +454,7 @@ public class SpoofSPInstruction extends SPInstruction } //setup local memory for reuse - int clen2 = (int) (_op.getRowType().isRowTypeB1() ? _vectors.get(0).getNumCols() : -1); + int clen2 = (int) (_op.getRowType().isRowTypeB1() ? _inputs.get(0).getNumCols() : -1); LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, clen2); ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); @@ -333,13 +464,12 @@ public class SpoofSPInstruction extends SPInstruction while( arg.hasNext() ) { //get main input block and indexes - Tuple2<MatrixIndexes,MatrixBlock> e = arg.next(); + Tuple2<MatrixIndexes,MatrixBlock[]> e = arg.next(); MatrixIndexes ixIn = e._1(); - MatrixBlock blkIn = e._2(); - int rowIx = (int)ixIn.getRowIndex(); + MatrixBlock[] blkIn = e._2(); //prepare output and execute single-threaded operator - ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, rowIx); + ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); blkOut = aggIncr ? blkOut : new MatrixBlock(); _op.execute(inputs, _scalars, blkOut, false, aggIncr); if( !aggIncr ) { @@ -356,39 +486,23 @@ public class SpoofSPInstruction extends SPInstruction return ret.iterator(); } - - private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) - throws DMLRuntimeException - { - ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); - ret.add(blkIn); - for( PartitionedBroadcast<MatrixBlock> vector : _vectors ) - ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1)); - return ret; - } } - private static class CellwiseFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + private static class CellwiseFunction extends SpoofFunction + implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8209188316939435099L; - private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null; - private ArrayList<ScalarObject> _scalars = null; - private byte[] _classBytes = null; - private String _className = null; private SpoofOperator _op = null; - public CellwiseFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + public CellwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException { - _className = className; - _classBytes = classBytes; - _vectors = bcMatrices; - _scalars = scalars; + super(className, classBytes, bcInd, bcMatrices, scalars); } @Override - public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception { //lazy load of shipped class @@ -400,13 +514,13 @@ public class SpoofSPInstruction extends SPInstruction List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); while(arg.hasNext()) { - Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next(); + Tuple2<MatrixIndexes,MatrixBlock[]> tmp = arg.next(); MatrixIndexes ixIn = tmp._1(); - MatrixBlock blkIn = tmp._2(); + MatrixBlock[] blkIn = tmp._2(); MatrixIndexes ixOut = ixIn; MatrixBlock blkOut = new MatrixBlock(); - ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, ixIn); - + ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); + //execute core operation if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) { ScalarObject obj = _op.execute(inputs, _scalars, 1); @@ -424,42 +538,23 @@ public class SpoofSPInstruction extends SPInstruction } return ret.iterator(); } - - private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) - throws DMLRuntimeException - { - ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); - ret.add(blkIn); - for( PartitionedBroadcast<MatrixBlock> in : _vectors ) { - int rowIndex = (int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); - int colIndex = (int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); - ret.add(in.getBlock(rowIndex, colIndex)); - } - return ret; - } } - private static class MultiAggregateFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> + private static class MultiAggregateFunction extends SpoofFunction + implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -5224519291577332734L; - private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null; - private ArrayList<ScalarObject> _scalars = null; - private byte[] _classBytes = null; - private String _className = null; private SpoofOperator _op = null; - public MultiAggregateFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + public MultiAggregateFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException { - _className = className; - _classBytes = classBytes; - _vectors = bcMatrices; - _scalars = scalars; + super(className, classBytes, bcInd, bcMatrices, scalars); } @Override - public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg) + public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) throws Exception { //lazy load of shipped class @@ -469,25 +564,12 @@ public class SpoofSPInstruction extends SPInstruction } //execute core operation - ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(arg._2(), arg._1()); + ArrayList<MatrixBlock> inputs = getAllMatrixInputs(arg._1(), arg._2()); MatrixBlock blkOut = new MatrixBlock(); _op.execute(inputs, _scalars, blkOut); return new Tuple2<MatrixIndexes,MatrixBlock>(arg._1(), blkOut); } - - private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) - throws DMLRuntimeException - { - ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); - ret.add(blkIn); - for( PartitionedBroadcast<MatrixBlock> in : _vectors ) { - int rowIndex = (int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); - int colIndex = (int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); - ret.add(in.getBlock(rowIndex, colIndex)); - } - return ret; - } } private static class MultiAggAggregateFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> @@ -520,27 +602,21 @@ public class SpoofSPInstruction extends SPInstruction } } - private static class OuterProductFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + private static class OuterProductFunction extends SpoofFunction + implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8209188316939435099L; - private ArrayList<PartitionedBroadcast<MatrixBlock>> _bcMatrices = null; - private ArrayList<ScalarObject> _scalars = null; - private byte[] _classBytes = null; - private String _className = null; private SpoofOperator _op = null; - public OuterProductFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + public OuterProductFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException { - _className = className; - _classBytes = classBytes; - _bcMatrices = bcMatrices; - _scalars = scalars; + super(className, classBytes, bcInd, bcMatrices, scalars); } @Override - public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception { //lazy load of shipped class @@ -552,16 +628,12 @@ public class SpoofSPInstruction extends SPInstruction List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); while(arg.hasNext()) { - Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next(); + Tuple2<MatrixIndexes,MatrixBlock[]> tmp = arg.next(); MatrixIndexes ixIn = tmp._1(); - MatrixBlock blkIn = tmp._2(); + MatrixBlock[] blkIn = tmp._2(); MatrixBlock blkOut = new MatrixBlock(); - ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>(); - inputs.add(blkIn); - inputs.add(_bcMatrices.get(0).getBlock((int)ixIn.getRowIndex(), 1)); // U - inputs.add(_bcMatrices.get(1).getBlock((int)ixIn.getColumnIndex(), 1)); // V - + ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn, true); //execute core operation if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) { ScalarObject obj = _op.execute(inputs, _scalars,1); @@ -588,6 +660,41 @@ public class SpoofSPInstruction extends SPInstruction } } + public static class ReplicateRightFactorFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = -7295989688796126442L; + + private final long _len; + private final long _blen; + + public ReplicateRightFactorFunction(long len, long blen) { + _len = len; + _blen = blen; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) + throws Exception + { + LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>(); + MatrixIndexes ixIn = arg0._1(); + MatrixBlock blkIn = arg0._2(); + + long numBlocks = (long) Math.ceil((double)_len/_blen); + + //replicate wrt # row blocks in LHS + long j = ixIn.getRowIndex(); + for( long i=1; i<=numBlocks; i++ ) { + MatrixIndexes tmpix = new MatrixIndexes(i, j); + MatrixBlock tmpblk = blkIn; + ret.add( new Tuple2<MatrixIndexes, MatrixBlock>(tmpix, tmpblk) ); + } + + //output list of new tuples + return ret.iterator(); + } + } + public static AggregateOperator getAggregateOperator(AggOp aggop) { if( aggop == AggOp.SUM || aggop == AggOp.SUM_SQ ) return new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE); http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java new file mode 100644 index 0000000..22ea33f --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.functions; + +import java.util.Iterator; +import java.util.LinkedList; + +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; + +import scala.Tuple2; + +public class ReplicateBlockFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> +{ + private static final long serialVersionUID = -1184696764516975609L; + + private final long _len; + private final long _blen; + private final boolean _left; + private final boolean _deep; + + public ReplicateBlockFunction(long len, long blen, boolean left) { + //by default: shallow copy of blocks + this(len, blen, left, false); + } + + public ReplicateBlockFunction(long len, long blen, boolean left, boolean deep) { + _len = len; + _blen = blen; + _left = left; + _deep = deep; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) + throws Exception + { + LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>(); + MatrixIndexes ixIn = arg0._1(); + MatrixBlock blkIn = arg0._2(); + + long numBlocks = (long) Math.ceil((double)_len/_blen); + + if( _left ) //LHS MATRIX + { + //replicate wrt # column blocks in RHS + long i = ixIn.getRowIndex(); + for( long j=1; j<=numBlocks; j++ ) { + MatrixIndexes tmpix = new MatrixIndexes(i, j); + MatrixBlock tmpblk = _deep ? new MatrixBlock(blkIn) : blkIn; + ret.add( new Tuple2<MatrixIndexes, MatrixBlock>(tmpix, tmpblk) ); + } + } + else // RHS MATRIX + { + //replicate wrt # row blocks in LHS + long j = ixIn.getColumnIndex(); + for( long i=1; i<=numBlocks; i++ ) { + MatrixIndexes tmpix = new MatrixIndexes(i, j); + MatrixBlock tmpblk = _deep ? new MatrixBlock(blkIn) : blkIn; + ret.add( new Tuple2<MatrixIndexes, MatrixBlock>(tmpix, tmpblk) ); + } + } + + //output list of new tuples + return ret.iterator(); + } +}