[SYSTEMML-1292] Support for codegen spark ops w/ multiple RDD inputs 

So far all codegen spark operations (all templates) can only consume one
RDD input, while all side inputs are transferred via broadcasts. This
captures the common case but easily fails with OOMs for multiple large
inputs. This patch adds support for arbitrary combinations of RDD and
broadcast inputs. Instead of generating these RDD operations and
functions (as was the original plan), we now generalized the respective
spark instruction to handle these decisions at runtime via additional
pre-processing that joins (and replicates) all RDD inputs and a generic
function signature that is independent of the number of input RDDs.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/95de2358
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/95de2358
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/95de2358

Branch: refs/heads/master
Commit: 95de23586988db56de57b731c03a751d5052a18e
Parents: 57dff5d
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Aug 11 23:04:45 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Aug 11 23:04:45 2017 -0700

----------------------------------------------------------------------
 .../spark/QuaternarySPInstruction.java          |  58 +--
 .../instructions/spark/SpoofSPInstruction.java  | 387 ++++++++++++-------
 .../spark/functions/ReplicateBlockFunction.java |  86 +++++
 3 files changed, 336 insertions(+), 195 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java
index 7dfe4be..711e3d0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuaternarySPInstruction.java
@@ -23,7 +23,6 @@ package org.apache.sysml.runtime.instructions.spark;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Iterator;
-import java.util.LinkedList;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
@@ -53,6 +52,7 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
 import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
+import 
org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -251,12 +251,12 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction
 
                        //preparation of transposed and replicated U
                        if( inU != null )
-                               inU = inU.flatMapToPair(new 
ReplicateBlocksFunction(clen, bclen, true));
+                               inU = inU.flatMapToPair(new 
ReplicateBlockFunction(clen, bclen, true));
 
                        //preparation of transposed and replicated V
                        if( inV != null )
                                inV = inV.mapToPair(new 
TransposeFactorIndexesFunction())
-                                        .flatMapToPair(new 
ReplicateBlocksFunction(rlen, brlen, false));
+                                        .flatMapToPair(new 
ReplicateBlockFunction(rlen, brlen, false));
                        
                        //functions calls w/ two rdd inputs             
                        if( inU != null && inV == null && inW == null )
@@ -529,57 +529,5 @@ public class QuaternarySPInstruction extends 
ComputationSPInstruction
                        //output new tuple
                        return new Tuple2<MatrixIndexes, 
MatrixBlock>(ixOut,blkOut);
                }
-               
-       }
-
-       private static class ReplicateBlocksFunction implements 
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, 
MatrixBlock> 
-       {
-               private static final long serialVersionUID = 
-1184696764516975609L;
-               
-               private long _len = -1;
-               private long _blen = -1;
-               private boolean _left = false;
-               
-               public ReplicateBlocksFunction(long len, long blen, boolean 
left)
-               {
-                       _len = len;
-                       _blen = blen;
-                       _left = left;
-               }
-               
-               @Override
-               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
-                       throws Exception 
-               {
-                       LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = 
new LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>();
-                       MatrixIndexes ixIn = arg0._1();
-                       MatrixBlock blkIn = arg0._2();
-                       
-                       long numBlocks = (long) Math.ceil((double)_len/_blen); 
-                       
-                       if( _left ) //LHS MATRIX
-                       {
-                               //replicate wrt # column blocks in RHS
-                               long i = ixIn.getRowIndex();
-                               for( long j=1; j<=numBlocks; j++ ) {
-                                       MatrixIndexes tmpix = new 
MatrixIndexes(i, j);
-                                       MatrixBlock tmpblk = new 
MatrixBlock(blkIn);
-                                       ret.add( new Tuple2<MatrixIndexes, 
MatrixBlock>(tmpix, tmpblk) );
-                               }
-                       } 
-                       else // RHS MATRIX
-                       {
-                               //replicate wrt # row blocks in LHS
-                               long j = ixIn.getColumnIndex();
-                               for( long i=1; i<=numBlocks; i++ ) {
-                                       MatrixIndexes tmpix = new 
MatrixIndexes(i, j);
-                                       MatrixBlock tmpblk = new 
MatrixBlock(blkIn);
-                                       ret.add( new Tuple2<MatrixIndexes, 
MatrixBlock>(tmpix, tmpblk) );
-                               }
-                       }
-                       
-                       //output list of new tuples
-                       return ret.iterator();
-               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index 90e2184..0af46df 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -19,14 +19,20 @@
 
 package org.apache.sysml.runtime.instructions.spark;
 
+import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -51,6 +57,7 @@ import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import 
org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -99,21 +106,23 @@ public class SpoofSPInstruction extends SPInstruction
                throws DMLRuntimeException 
        {       
                SparkExecutionContext sec = (SparkExecutionContext)ec;
-
-               //get input rdd and variable name
-               ArrayList<String> bcVars = new ArrayList<String>();
+               
+               //decide upon broadcast side inputs
+               boolean[] bcVect = determineBroadcastInputs(sec, _in);
+               boolean[] bcVect2 = getMatrixBroadcastVector(sec, _in, bcVect);
+               
+               //create joined input rdd w/ replication if needed
                MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(_in[0].getName());
-               JavaPairRDD<MatrixIndexes, MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() );
+               JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = 
createJoinedInputRDD(
+                       sec, _in, bcVect, (_class.getSuperclass() == 
SpoofOuterProduct.class));
                JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
-                               
-               //simple case: map-side only operation (one rdd input, 
broadcast all)
-               //keep track of broadcast variables
-               ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new 
ArrayList<PartitionedBroadcast<MatrixBlock>>();
-               ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
+               
+               //create lists of input broadcasts and scalars
+               ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new 
ArrayList<>();
+               ArrayList<ScalarObject> scalars = new ArrayList<>();
                for( int i=1; i<_in.length; i++ ) {
-                       if( _in[i].getDataType()==DataType.MATRIX) {
+                       if( _in[i].getDataType()==DataType.MATRIX && bcVect[i] 
) {
                                
bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
-                               bcVars.add(_in[i].getName());
                        }
                        else if(_in[i].getDataType()==DataType.SCALAR) {
                                //note: even if literal, it might be compiled 
as scalar placeholder
@@ -121,8 +130,8 @@ public class SpoofSPInstruction extends SPInstruction
                        }
                }
                
-               //initialize Spark Operator
-               if(_class.getSuperclass() == SpoofCellwise.class) // cellwise 
operator
+               //execute generated operator
+               if(_class.getSuperclass() == SpoofCellwise.class) //CELL
                {
                        SpoofCellwise op = (SpoofCellwise) 
CodegenUtils.createInstance(_class);         
                        AggregateOperator aggop = 
getAggregateOperator(op.getAggOp());
@@ -130,7 +139,7 @@ public class SpoofSPInstruction extends SPInstruction
                        if( _out.getDataType()==DataType.MATRIX ) {
                                //execute codegen block operation
                                out = in.mapPartitionsToPair(new 
CellwiseFunction(
-                                       _class.getName(), _classBytes, 
bcMatrices, scalars), true);
+                                       _class.getName(), _classBytes, bcVect2, 
bcMatrices, scalars), true);
                                
                                if( (op.getCellType()==CellType.ROW_AGG && 
mcIn.getCols() > mcIn.getColsPerBlock())
                                        || (op.getCellType()==CellType.COL_AGG 
&& mcIn.getRows() > mcIn.getRowsPerBlock())) {
@@ -144,33 +153,29 @@ public class SpoofSPInstruction extends SPInstruction
                                }
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
-                               //maintain lineage information for output rdd
-                               sec.addLineageRDD(_out.getName(), 
_in[0].getName());
-                               for( String bcVar : bcVars )
-                                       sec.addLineageBroadcast(_out.getName(), 
bcVar);
-                               
-                               //update matrix characteristics
+                               //maintain lineage info and output 
characteristics
+                               maintainLineageInfo(sec, _in, bcVect, _out);
                                updateOutputMatrixCharacteristics(sec, op);     
                        }
                        else { //SCALAR
-                               out = in.mapPartitionsToPair(new 
CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               out = in.mapPartitionsToPair(new 
CellwiseFunction(
+                                       _class.getName(), _classBytes, bcVect2, 
bcMatrices, scalars), true);
                                MatrixBlock tmpMB = 
RDDAggregateUtils.aggStable(out, aggop);
                                sec.setVariable(_out.getName(), new 
DoubleObject(tmpMB.getValue(0, 0)));
                        }
                }
-               else if(_class.getSuperclass() == SpoofMultiAggregate.class)
+               else if(_class.getSuperclass() == SpoofMultiAggregate.class) 
//MAGG
                {
                        SpoofMultiAggregate op = (SpoofMultiAggregate) 
CodegenUtils.createInstance(_class);     
                        AggOp[] aggOps = op.getAggOps();
                        
-                       MatrixBlock tmpMB = in
-                                       .mapToPair(new 
MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars))
-                                       .values().fold(new MatrixBlock(), new 
MultiAggAggregateFunction(aggOps) );
+                       MatrixBlock tmpMB = in.mapToPair(new 
MultiAggregateFunction(
+                               _class.getName(), _classBytes,  bcVect2, 
bcMatrices, scalars))
+                               .values().fold(new MatrixBlock(), new 
MultiAggAggregateFunction(aggOps) );
                        
                        sec.setMatrixOutput(_out.getName(), tmpMB, 
getExtendedOpcode());
-                       return;
                }
-               else if(_class.getSuperclass() == SpoofOuterProduct.class) // 
outer product operator
+               else if(_class.getSuperclass() == SpoofOuterProduct.class) 
//OUTER
                {
                        if( _out.getDataType()==DataType.MATRIX ) {
                                SpoofOperator op = (SpoofOperator) 
CodegenUtils.createInstance(_class);         
@@ -180,7 +185,8 @@ public class SpoofSPInstruction extends SPInstruction
                                updateOutputMatrixCharacteristics(sec, op);     
                
                                MatrixCharacteristics mcOut = 
sec.getMatrixCharacteristics(_out.getName());
                                
-                               out = in.mapPartitionsToPair(new 
OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               out = in.mapPartitionsToPair(new 
OuterProductFunction(
+                                       _class.getName(), _classBytes, bcVect2, 
bcMatrices, scalars), true);
                                if(type == OutProdType.LEFT_OUTER_PRODUCT || 
type == OutProdType.RIGHT_OUTER_PRODUCT ) {
                                        //TODO investigate if some other side 
effect of correct blocks
                                        if( in.partitions().size() > 
mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() )
@@ -190,26 +196,24 @@ public class SpoofSPInstruction extends SPInstruction
                                }
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
-                               //maintain lineage information for output rdd
-                               sec.addLineageRDD(_out.getName(), 
_in[0].getName());
-                               for( String bcVar : bcVars )
-                                       sec.addLineageBroadcast(_out.getName(), 
bcVar);
-                               
+                               //maintain lineage info and output 
characteristics
+                               maintainLineageInfo(sec, _in, bcVect, _out);
                        }
                        else {
-                               out = in.mapPartitionsToPair(new 
OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               out = in.mapPartitionsToPair(new 
OuterProductFunction(
+                                       _class.getName(), _classBytes, bcVect2, 
bcMatrices, scalars), true);
                                MatrixBlock tmp = 
RDDAggregateUtils.sumStable(out);
                                sec.setVariable(_out.getName(), new 
DoubleObject(tmp.getValue(0, 0)));
                        }
                }
-               else if( _class.getSuperclass() == SpoofRowwise.class ) { //row 
aggregate operator
+               else if( _class.getSuperclass() == SpoofRowwise.class ) { //ROW
                        if( mcIn.getCols() > mcIn.getColsPerBlock() ) {
                                throw new DMLRuntimeException("Invalid spark 
rowwise operator w/ ncol=" + 
                                        mcIn.getCols()+", 
ncolpb="+mcIn.getColsPerBlock()+".");
                        }
                        SpoofRowwise op = (SpoofRowwise) 
CodegenUtils.createInstance(_class);   
                        RowwiseFunction fmmc = new 
RowwiseFunction(_class.getName(),
-                               _classBytes, bcMatrices, scalars, 
(int)mcIn.getCols());
+                               _classBytes, bcVect2, bcMatrices, scalars, 
(int)mcIn.getCols());
                        out = in.mapPartitionsToPair(fmmc, 
op.getRowType()==RowType.ROW_AGG
                                        || op.getRowType() == RowType.NO_AGG);
                        
@@ -233,21 +237,89 @@ public class SpoofSPInstruction extends SPInstruction
                                
                                sec.setRDDHandleForVariable(_out.getName(), 
out);
                                
-                               //maintain lineage information for output rdd
-                               sec.addLineageRDD(_out.getName(), 
_in[0].getName());
-                               for( String bcVar : bcVars )
-                                       sec.addLineageBroadcast(_out.getName(), 
bcVar);
-                               
-                               //update matrix characteristics
+                               //maintain lineage info and output 
characteristics
+                               maintainLineageInfo(sec, _in, bcVect, _out);
                                updateOutputMatrixCharacteristics(sec, op);
                        }
-                       return;
                }
                else {
                        throw new DMLRuntimeException("Operator " + 
_class.getSuperclass() + " is not supported on Spark");
                }
        }
        
+       private static boolean[] determineBroadcastInputs(SparkExecutionContext 
sec, CPOperand[] inputs) 
+               throws DMLRuntimeException 
+       {
+               boolean[] ret = new boolean[inputs.length];
+               double localBudget = OptimizerUtils.getLocalMemBudget();
+               double bcBudget = 
SparkExecutionContext.getBroadcastMemoryBudget();
+               
+               //decided for each matrix input if it fits into remaining memory
+               //budget; the major input, i.e., inputs[0] is always an RDD
+               for( int i=1; i<inputs.length; i++ ) 
+                       if( inputs[i].getDataType().isMatrix() ) {
+                               MatrixCharacteristics mc = 
sec.getMatrixCharacteristics(inputs[i].getName());
+                               double sizeL = 
OptimizerUtils.estimateSizeExactSparsity(mc);
+                               double sizeP = 
OptimizerUtils.estimatePartitionedSizeExactSparsity(mc);
+                               ret[i] = localBudget > sizeL && bcBudget > 
sizeP;
+                               localBudget -= ret[i] ? sizeL : 0;
+                               bcBudget -= ret[i] ? sizeP : 0;
+                       }
+               
+               return ret;
+       }
+       
+       private static boolean[] getMatrixBroadcastVector(SparkExecutionContext 
sec, CPOperand[] inputs, boolean[] bcVect) 
+                       throws DMLRuntimeException 
+       {
+               int numMtx = (int) Arrays.stream(inputs)
+                       .filter(in -> in.getDataType().isMatrix()).count();
+               boolean[] ret = new boolean[numMtx];
+               for(int i=0, pos=0; i<inputs.length; i++)
+                       if( inputs[i].getDataType().isMatrix() )
+                               ret[pos++] = bcVect[i];
+               return ret;
+       }
+       
+       private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> 
createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] 
bcVect, boolean outer) 
+               throws DMLRuntimeException
+       {
+               //get input rdd for main input
+               MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(inputs[0].getName());
+               JavaPairRDD<MatrixIndexes, MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable(inputs[0].getName());
+               JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = 
in.mapValues(new MapInputSignature());
+               
+               for( int i=1; i<inputs.length; i++ )
+                       if( inputs[i].getDataType().isMatrix() && !bcVect[i] ) {
+                               //create side input rdd 
+                               String varname = inputs[i].getName();
+                               JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = 
sec
+                                       
.getBinaryBlockRDDHandleForVariable(varname);
+                               MatrixCharacteristics mcTmp = 
sec.getMatrixCharacteristics(varname);
+                               //replicate blocks if mismatch with main input
+                               if( outer && i==2 )
+                                       tmp = tmp.flatMapToPair(new 
ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock()));
+                               else if( mcIn.getNumRowBlocks() > 
mcTmp.getNumRowBlocks() )
+                                       tmp = tmp.flatMapToPair(new 
ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false));
+                               else if( mcIn.getNumColBlocks() > 
mcTmp.getNumColBlocks() )
+                                       tmp = tmp.flatMapToPair(new 
ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true));
+                               //join main and side inputs and consolidate 
signature
+                               ret = ret.join(tmp)
+                                       .mapValues(new MapJoinSignature());
+                       }
+               
+               return ret;
+       }
+       
+       private static void maintainLineageInfo(SparkExecutionContext sec, 
CPOperand[] inputs, boolean[] bcVect, CPOperand output) 
+               throws DMLRuntimeException 
+       {
+               //add lineage info for all rdd/broadcast inputs 
+               for( int i=0; i<inputs.length; i++ )
+                       if( inputs[i].getDataType().isMatrix() )
+                               sec.addLineage(output.getName(), 
inputs[i].getName(), bcVect[i]);
+       }
+       
        private void updateOutputMatrixCharacteristics(SparkExecutionContext 
sec, SpoofOperator op) 
                throws DMLRuntimeException 
        {
@@ -290,30 +362,89 @@ public class SpoofSPInstruction extends SPInstruction
                                mcOut.set(mcIn.getCols(), 1, 
mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
                }
        }
+       
+       private static class MapInputSignature implements Function<MatrixBlock, 
MatrixBlock[]> 
+       {
+               private static final long serialVersionUID = 
-816443970067626102L;
                
-       private static class RowwiseFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, 
MatrixIndexes, MatrixBlock> 
+               @Override
+               public MatrixBlock[] call(MatrixBlock v1) throws Exception {
+                       return new MatrixBlock[]{ v1 };
+               }
+       }
+       
+       private static class MapJoinSignature implements 
Function<Tuple2<MatrixBlock[],MatrixBlock>, MatrixBlock[]> 
+       {
+               private static final long serialVersionUID = 
-704403012606821854L;
+
+               @Override
+               public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> 
v1) throws Exception {
+                       return ArrayUtils.add(v1._1(), v1._2());
+               }
+       }
+       
+       private static class SpoofFunction implements Serializable 
+       {       
+               private static final long serialVersionUID = 
2953479427746463003L;
+               
+               protected final boolean[] _bcInd;
+               protected final ArrayList<PartitionedBroadcast<MatrixBlock>> 
_inputs;
+               protected final ArrayList<ScalarObject> _scalars;
+               protected final byte[] _classBytes;
+               protected final String _className;
+               
+               protected SpoofFunction(String className, byte[] classBytes, 
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) {
+                       _bcInd = bcInd;
+                       _inputs = bcMatrices;
+                       _scalars = scalars;
+                       _classBytes = classBytes;
+                       _className = className;
+               }
+               
+               protected ArrayList<MatrixBlock> 
getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) 
+                       throws DMLRuntimeException 
+               {
+                       return getAllMatrixInputs(ixIn, blkIn, false);
+               }
+               
+               protected ArrayList<MatrixBlock> 
getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) 
+                       throws DMLRuntimeException 
+               {
+                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
+                       //add all rdd/broadcast inputs (main and side inputs)
+                       for( int i=0, posRdd=0, posBc=0; i<_bcInd.length; i++ ) 
{
+                               if( _bcInd[i] ) {
+                                       PartitionedBroadcast<MatrixBlock> pb = 
_inputs.get(posBc++);
+                                       int rowIndex = (int) ((outer && i==2) ? 
ixIn.getColumnIndex() : 
+                                               
(pb.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1);
+                                       int colIndex = (int) ((outer && i==2) ? 
1 : 
+                                               
(pb.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1);
+                                       ret.add(pb.getBlock(rowIndex, 
colIndex));
+                               }
+                               else
+                                       ret.add(blkIn[posRdd++]);
+                       }
+                       return ret;
+               }
+       }
+       
+       private static class RowwiseFunction extends SpoofFunction
+               implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, 
MatrixBlock[]>>, MatrixIndexes, MatrixBlock> 
        {
                private static final long serialVersionUID = 
-7926980450209760212L;
 
-               private final ArrayList<PartitionedBroadcast<MatrixBlock>> 
_vectors;
-               private final ArrayList<ScalarObject> _scalars;
-               private final byte[] _classBytes;
-               private final String _className;
                private final int _clen;
                private SpoofRowwise _op = null;
                
-               public RowwiseFunction(String className, byte[] classBytes, 
ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars, int clen) 
+               public RowwiseFunction(String className, byte[] classBytes, 
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars, int clen) 
                        throws DMLRuntimeException
                {                       
-                       _className = className;
-                       _classBytes = classBytes;
-                       _vectors = bcMatrices;
-                       _scalars = scalars;
+                       super(className, classBytes, bcInd, bcMatrices, 
scalars);
                        _clen = clen;
                }
                
                @Override
-               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( 
Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg ) 
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( 
Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg ) 
                        throws Exception 
                {
                        //lazy load of shipped class
@@ -323,7 +454,7 @@ public class SpoofSPInstruction extends SPInstruction
                        }
                        
                        //setup local memory for reuse
-                       int clen2 = (int) (_op.getRowType().isRowTypeB1() ? 
_vectors.get(0).getNumCols() : -1);
+                       int clen2 = (int) (_op.getRowType().isRowTypeB1() ? 
_inputs.get(0).getNumCols() : -1);
                        
LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, 
clen2);
                        
                        ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
@@ -333,13 +464,12 @@ public class SpoofSPInstruction extends SPInstruction
                        
                        while( arg.hasNext() ) {
                                //get main input block and indexes
-                               Tuple2<MatrixIndexes,MatrixBlock> e = 
arg.next();
+                               Tuple2<MatrixIndexes,MatrixBlock[]> e = 
arg.next();
                                MatrixIndexes ixIn = e._1();
-                               MatrixBlock blkIn = e._2();
-                               int rowIx = (int)ixIn.getRowIndex();
+                               MatrixBlock[] blkIn = e._2();
                                
                                //prepare output and execute single-threaded 
operator
-                               ArrayList<MatrixBlock> inputs = 
getVectorInputsFromBroadcast(blkIn, rowIx);
+                               ArrayList<MatrixBlock> inputs = 
getAllMatrixInputs(ixIn, blkIn);
                                blkOut = aggIncr ? blkOut : new MatrixBlock();
                                _op.execute(inputs, _scalars, blkOut, false, 
aggIncr);
                                if( !aggIncr ) {
@@ -356,39 +486,23 @@ public class SpoofSPInstruction extends SPInstruction
                        
                        return ret.iterator();
                }
-               
-               private ArrayList<MatrixBlock> 
getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) 
-                       throws DMLRuntimeException 
-               {
-                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
-                       ret.add(blkIn);
-                       for( PartitionedBroadcast<MatrixBlock> vector : 
_vectors )
-                               
ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1));
-                       return ret;
-               }
        }
        
-       private static class CellwiseFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, 
MatrixIndexes, MatrixBlock> 
+       private static class CellwiseFunction extends SpoofFunction
+               implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, 
MatrixBlock[]>>, MatrixIndexes, MatrixBlock> 
        {
                private static final long serialVersionUID = 
-8209188316939435099L;
                
-               private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = 
null;
-               private ArrayList<ScalarObject> _scalars = null;
-               private byte[] _classBytes = null;
-               private String _className = null;
                private SpoofOperator _op = null;
                
-               public CellwiseFunction(String className, byte[] classBytes, 
ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+               public CellwiseFunction(String className, byte[] classBytes, 
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
                        throws DMLRuntimeException
                {
-                       _className = className;
-                       _classBytes = classBytes;
-                       _vectors = bcMatrices;
-                       _scalars = scalars;
+                       super(className, classBytes, bcInd, bcMatrices, 
scalars);
                }
                
                @Override
-               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg)
                        throws Exception 
                {
                        //lazy load of shipped class
@@ -400,13 +514,13 @@ public class SpoofSPInstruction extends SPInstruction
                        List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
                        while(arg.hasNext()) 
                        {
-                               Tuple2<MatrixIndexes,MatrixBlock> tmp = 
arg.next();
+                               Tuple2<MatrixIndexes,MatrixBlock[]> tmp = 
arg.next();
                                MatrixIndexes ixIn = tmp._1();
-                               MatrixBlock blkIn = tmp._2();
+                               MatrixBlock[] blkIn = tmp._2();
                                MatrixIndexes ixOut = ixIn; 
                                MatrixBlock blkOut = new MatrixBlock();
-                               ArrayList<MatrixBlock> inputs = 
getVectorInputsFromBroadcast(blkIn, ixIn);
-                                       
+                               ArrayList<MatrixBlock> inputs = 
getAllMatrixInputs(ixIn, blkIn);
+                               
                                //execute core operation
                                
if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) {
                                        ScalarObject obj = _op.execute(inputs, 
_scalars, 1);
@@ -424,42 +538,23 @@ public class SpoofSPInstruction extends SPInstruction
                        }
                        return ret.iterator();
                }
-               
-               private ArrayList<MatrixBlock> 
getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) 
-                       throws DMLRuntimeException 
-               {
-                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
-                       ret.add(blkIn);
-                       for( PartitionedBroadcast<MatrixBlock> in : _vectors ) {
-                               int rowIndex = 
(int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1);
-                               int colIndex = 
(int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1);
-                               ret.add(in.getBlock(rowIndex, colIndex));
-                       }
-                       return ret;
-               }
        }
        
-       private static class MultiAggregateFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> 
+       private static class MultiAggregateFunction extends SpoofFunction
+               implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, 
MatrixIndexes, MatrixBlock> 
        {
                private static final long serialVersionUID = 
-5224519291577332734L;
                
-               private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = 
null;
-               private ArrayList<ScalarObject> _scalars = null;
-               private byte[] _classBytes = null;
-               private String _className = null;
                private SpoofOperator _op = null;
                
-               public MultiAggregateFunction(String className, byte[] 
classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+               public MultiAggregateFunction(String className, byte[] 
classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> 
bcMatrices, ArrayList<ScalarObject> scalars) 
                        throws DMLRuntimeException
                {
-                       _className = className;
-                       _classBytes = classBytes;
-                       _vectors = bcMatrices;
-                       _scalars = scalars;
+                       super(className, classBytes, bcInd, bcMatrices, 
scalars);
                }
                
                @Override
-               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, MatrixBlock> arg)
+               public Tuple2<MatrixIndexes, MatrixBlock> 
call(Tuple2<MatrixIndexes, MatrixBlock[]> arg)
                        throws Exception 
                {
                        //lazy load of shipped class
@@ -469,25 +564,12 @@ public class SpoofSPInstruction extends SPInstruction
                        }
                                
                        //execute core operation
-                       ArrayList<MatrixBlock> inputs = 
getVectorInputsFromBroadcast(arg._2(), arg._1());
+                       ArrayList<MatrixBlock> inputs = 
getAllMatrixInputs(arg._1(), arg._2());
                        MatrixBlock blkOut = new MatrixBlock();
                        _op.execute(inputs, _scalars, blkOut);
                        
                        return new Tuple2<MatrixIndexes,MatrixBlock>(arg._1(), 
blkOut);
                }
-               
-               private ArrayList<MatrixBlock> 
getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) 
-                       throws DMLRuntimeException 
-               {
-                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
-                       ret.add(blkIn);
-                       for( PartitionedBroadcast<MatrixBlock> in : _vectors ) {
-                               int rowIndex = 
(int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1);
-                               int colIndex = 
(int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1);
-                               ret.add(in.getBlock(rowIndex, colIndex));
-                       }
-                       return ret;
-               }
        }
        
        private static class MultiAggAggregateFunction implements 
Function2<MatrixBlock, MatrixBlock, MatrixBlock> 
@@ -520,27 +602,21 @@ public class SpoofSPInstruction extends SPInstruction
                }
        }
        
-       private static class OuterProductFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, 
MatrixIndexes, MatrixBlock> 
+       private static class OuterProductFunction extends SpoofFunction
+               implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, 
MatrixBlock[]>>, MatrixIndexes, MatrixBlock> 
        {
                private static final long serialVersionUID = 
-8209188316939435099L;
                
-               private ArrayList<PartitionedBroadcast<MatrixBlock>> 
_bcMatrices = null;
-               private ArrayList<ScalarObject> _scalars = null;
-               private byte[] _classBytes = null;
-               private String _className = null;
                private SpoofOperator _op = null;
                
-               public OuterProductFunction(String className, byte[] 
classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+               public OuterProductFunction(String className, byte[] 
classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> 
bcMatrices, ArrayList<ScalarObject> scalars) 
                                throws DMLRuntimeException
                {
-                       _className = className;
-                       _classBytes = classBytes;
-                       _bcMatrices = bcMatrices;
-                       _scalars = scalars;
+                       super(className, classBytes, bcInd, bcMatrices, 
scalars);
                }
                
                @Override
-               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg)
                        throws Exception 
                {
                        //lazy load of shipped class
@@ -552,16 +628,12 @@ public class SpoofSPInstruction extends SPInstruction
                        List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
                        while(arg.hasNext())
                        {
-                               Tuple2<MatrixIndexes,MatrixBlock> tmp = 
arg.next();
+                               Tuple2<MatrixIndexes,MatrixBlock[]> tmp = 
arg.next();
                                MatrixIndexes ixIn = tmp._1();
-                               MatrixBlock blkIn = tmp._2();
+                               MatrixBlock[] blkIn = tmp._2();
                                MatrixBlock blkOut = new MatrixBlock();
 
-                               ArrayList<MatrixBlock> inputs = new 
ArrayList<MatrixBlock>();
-                               inputs.add(blkIn);
-                               
inputs.add(_bcMatrices.get(0).getBlock((int)ixIn.getRowIndex(), 1)); // U
-                               
inputs.add(_bcMatrices.get(1).getBlock((int)ixIn.getColumnIndex(), 1)); // V
-                                               
+                               ArrayList<MatrixBlock> inputs = 
getAllMatrixInputs(ixIn, blkIn, true);
                                //execute core operation
                                
if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) {
                                        ScalarObject obj = _op.execute(inputs, 
_scalars,1);
@@ -588,6 +660,41 @@ public class SpoofSPInstruction extends SPInstruction
                }               
        }
        
+       public static class ReplicateRightFactorFunction implements 
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, 
MatrixBlock> 
+       {
+               private static final long serialVersionUID = 
-7295989688796126442L;
+               
+               private final long _len;
+               private final long _blen;
+               
+               public ReplicateRightFactorFunction(long len, long blen) {
+                       _len = len;
+                       _blen = blen;
+               }
+               
+               @Override
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
+                       throws Exception 
+               {
+                       LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = 
new LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>();
+                       MatrixIndexes ixIn = arg0._1();
+                       MatrixBlock blkIn = arg0._2();
+                       
+                       long numBlocks = (long) Math.ceil((double)_len/_blen); 
+                       
+                       //replicate wrt # row blocks in LHS
+                       long j = ixIn.getRowIndex();
+                       for( long i=1; i<=numBlocks; i++ ) {
+                               MatrixIndexes tmpix = new MatrixIndexes(i, j);
+                               MatrixBlock tmpblk = blkIn;
+                               ret.add( new Tuple2<MatrixIndexes, 
MatrixBlock>(tmpix, tmpblk) );
+                       }
+                       
+                       //output list of new tuples
+                       return ret.iterator();
+               }
+       }
+       
        public static AggregateOperator getAggregateOperator(AggOp aggop) {
                if( aggop == AggOp.SUM || aggop == AggOp.SUM_SQ )
                        return new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE);

http://git-wip-us.apache.org/repos/asf/systemml/blob/95de2358/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java
new file mode 100644
index 0000000..22ea33f
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ReplicateBlockFunction.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark.functions;
+
+import java.util.Iterator;
+import java.util.LinkedList;
+
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+import scala.Tuple2;
+
+public class ReplicateBlockFunction implements 
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, 
MatrixBlock> 
+{
+       private static final long serialVersionUID = -1184696764516975609L;
+       
+       private final long _len;
+       private final long _blen;
+       private final boolean _left;
+       private final boolean _deep;
+       
+       public ReplicateBlockFunction(long len, long blen, boolean left) {
+               //by default: shallow copy of blocks
+               this(len, blen, left, false);
+       }
+       
+       public ReplicateBlockFunction(long len, long blen, boolean left, 
boolean deep) {
+               _len = len;
+               _blen = blen;
+               _left = left;
+               _deep = deep;
+       }
+       
+       @Override
+       public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
+               throws Exception 
+       {
+               LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new 
LinkedList<Tuple2<MatrixIndexes, MatrixBlock>>();
+               MatrixIndexes ixIn = arg0._1();
+               MatrixBlock blkIn = arg0._2();
+               
+               long numBlocks = (long) Math.ceil((double)_len/_blen); 
+               
+               if( _left ) //LHS MATRIX
+               {
+                       //replicate wrt # column blocks in RHS
+                       long i = ixIn.getRowIndex();
+                       for( long j=1; j<=numBlocks; j++ ) {
+                               MatrixIndexes tmpix = new MatrixIndexes(i, j);
+                               MatrixBlock tmpblk = _deep ? new 
MatrixBlock(blkIn) : blkIn;
+                               ret.add( new Tuple2<MatrixIndexes, 
MatrixBlock>(tmpix, tmpblk) );
+                       }
+               } 
+               else // RHS MATRIX
+               {
+                       //replicate wrt # row blocks in LHS
+                       long j = ixIn.getColumnIndex();
+                       for( long i=1; i<=numBlocks; i++ ) {
+                               MatrixIndexes tmpix = new MatrixIndexes(i, j);
+                               MatrixBlock tmpblk = _deep ? new 
MatrixBlock(blkIn) : blkIn;
+                               ret.add( new Tuple2<MatrixIndexes, 
MatrixBlock>(tmpix, tmpblk) );
+                       }
+               }
+               
+               //output list of new tuples
+               return ret.iterator();
+       }
+}

Reply via email to