[SYSTEMML-1287] Code generator runtime integration 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/982ecb1a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/982ecb1a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/982ecb1a

Branch: refs/heads/master
Commit: 982ecb1a4be69685a8e124eccfa3a12331f998b0
Parents: d7fd587
Author: Matthias Boehm <[email protected]>
Authored: Sun Feb 26 19:01:36 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sun Feb 26 19:01:36 2017 -0800

----------------------------------------------------------------------
 .../instructions/CPInstructionParser.java       |  19 +-
 .../instructions/SPInstructionParser.java       |  14 +-
 .../runtime/instructions/cp/CPInstruction.java  |   8 +-
 .../instructions/cp/SpoofCPInstruction.java     |  98 +++++
 .../instructions/spark/SPInstruction.java       |   2 +-
 .../instructions/spark/SpoofSPInstruction.java  | 407 +++++++++++++++++++
 .../spark/utils/RDDAggregateUtils.java          |   8 +-
 7 files changed, 541 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index f3c1605..f0603b4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -61,6 +61,7 @@ import 
org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.RelationalBinaryCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.SpoofCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.UaggOuterChainCPInstruction;
@@ -271,8 +272,9 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "lu",    
CPINSTRUCTION_TYPE.MultiReturnBuiltin);
                String2CPInstructionType.put( "eigen", 
CPINSTRUCTION_TYPE.MultiReturnBuiltin);
                
-               String2CPInstructionType.put( "partition", 
CPINSTRUCTION_TYPE.Partition);
-               String2CPInstructionType.put( "compress", 
CPINSTRUCTION_TYPE.Compression);
+               String2CPInstructionType.put( "partition",      
CPINSTRUCTION_TYPE.Partition);
+               String2CPInstructionType.put( "compress",       
CPINSTRUCTION_TYPE.Compression);
+               String2CPInstructionType.put( "spoof",          
CPINSTRUCTION_TYPE.SpoofFused);
                
                //CP FILE instruction
                String2CPFileInstructionType = new HashMap<String, 
CPINSTRUCTION_TYPE>();
@@ -424,16 +426,19 @@ public class CPInstructionParser extends InstructionParser
                        
                        case Partition:
                                return 
DataPartitionCPInstruction.parseInstruction(str);        
-       
-                       case Compression:
-                               return (CPInstruction) 
CompressionCPInstruction.parseInstruction(str);  
-                               
+               
                        case CentralMoment:
                                return 
CentralMomentCPInstruction.parseInstruction(str);
        
                        case Covariance:
                                return 
CovarianceCPInstruction.parseInstruction(str);
-                               
+       
+                       case Compression:
+                               return (CPInstruction) 
CompressionCPInstruction.parseInstruction(str);  
+                       
+                       case SpoofFused:
+                               return SpoofCPInstruction.parseInstruction(str);
+                       
                        case INVALID:
                        
                        default: 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 6658a88..5ca3847 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -73,6 +73,7 @@ import 
org.apache.sysml.runtime.instructions.spark.ReorgSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.RmmSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.SPInstruction;
 import 
org.apache.sysml.runtime.instructions.spark.SPInstruction.SPINSTRUCTION_TYPE;
+import org.apache.sysml.runtime.instructions.spark.SpoofSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.TernarySPInstruction;
 import org.apache.sysml.runtime.instructions.spark.Tsmm2SPInstruction;
 import org.apache.sysml.runtime.instructions.spark.TsmmSPInstruction;
@@ -277,10 +278,12 @@ public class SPInstructionParser extends InstructionParser
                
                String2SPInstructionType.put( "binuaggchain", 
SPINSTRUCTION_TYPE.BinUaggChain);
                
-               String2SPInstructionType.put( "write"   , 
SPINSTRUCTION_TYPE.Write);
+               String2SPInstructionType.put( "write"   , 
SPINSTRUCTION_TYPE.Write);
        
-               String2SPInstructionType.put( "castdtm"   , 
SPINSTRUCTION_TYPE.Cast);
-               String2SPInstructionType.put( "castdtf"   , 
SPINSTRUCTION_TYPE.Cast);
+               String2SPInstructionType.put( "castdtm" , 
SPINSTRUCTION_TYPE.Cast);
+               String2SPInstructionType.put( "castdtf" , 
SPINSTRUCTION_TYPE.Cast);
+               
+               String2SPInstructionType.put( "spoof"   , 
SPINSTRUCTION_TYPE.SpoofFused);
        }
 
        public static SPInstruction parseSingleInstruction (String str ) 
@@ -443,10 +446,13 @@ public class SPInstructionParser extends InstructionParser
                                
                        case Checkpoint:
                                return 
CheckpointSPInstruction.parseInstruction(str);
-                       
+
                        case Compression:
                                return 
CompressionSPInstruction.parseInstruction(str);
                        
+                       case SpoofFused:
+                               return SpoofSPInstruction.parseInstruction(str);
+                               
                        case Cast:
                                return CastSPInstruction.parseInstruction(str);
                                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
index 1d192d5..dcd8d89 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
@@ -29,7 +29,13 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
 
 public abstract class CPInstruction extends Instruction 
 {
-       public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, 
AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, 
BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, BuiltinMultiple, 
MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, 
Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, 
QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, 
Compression, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution 
};
+       public enum CPINSTRUCTION_TYPE { INVALID, 
+               AggregateUnary, AggregateBinary, AggregateTernary, 
ArithmeticBinary, 
+               Ternary, Quaternary, BooleanBinary, BooleanUnary, 
BuiltinBinary, BuiltinUnary, 
+               BuiltinMultiple, MultiReturnParameterizedBuiltin, 
ParameterizedBuiltin, MultiReturnBuiltin, 
+               Builtin, Reorg, RelationalBinary, File, Variable, External, 
Append, Rand, QSort, QPick, 
+               MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, 
Compression, SpoofFused,
+               StringInit, CentralMoment, Covariance, UaggOuterChain, 
Convolution };
        
        protected CPINSTRUCTION_TYPE _cptype;
        protected Operator _optr;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
new file mode 100644
index 0000000..61313d7
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.cp;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.codegen.SpoofOperator;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public class SpoofCPInstruction extends ComputationCPInstruction
+{
+       private Class<?> _class = null;
+       private int _numThreads = 1;
+       private CPOperand[] _in = null;
+       
+       public SpoofCPInstruction(Class<?> cla, int k, CPOperand[] in, 
CPOperand out, String opcode, String str) {
+               super(null, null, null, out, opcode, str);
+               _class = cla;
+               _numThreads = k;
+               _in = in;
+       }
+
+       public static SpoofCPInstruction parseInstruction(String str) 
+               throws DMLRuntimeException 
+       {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               
+               //String opcode = parts[0];
+               ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
+               Class<?> cla = CodegenUtils.loadClass(parts[1], null);
+               String opcode =  parts[0] + CodegenUtils.getSpoofType(cla);
+               
+               for( int i=2; i<parts.length-2; i++ )
+                       inlist.add(new CPOperand(parts[i]));
+               CPOperand out = new CPOperand(parts[parts.length-2]);
+               int k = Integer.parseInt(parts[parts.length-1]);
+               
+               return new SpoofCPInstruction(cla, k, inlist.toArray(new 
CPOperand[0]), out, opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec)
+               throws DMLRuntimeException 
+       {               
+               SpoofOperator op = (SpoofOperator) 
CodegenUtils.createInstance(_class);
+               
+               //get input matrices and scalars, incl pinning of matrices
+               ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
+               ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
+               for (CPOperand input : _in) {
+                       if(input.getDataType()==DataType.MATRIX)
+                               inputs.add(ec.getMatrixInput(input.getName()));
+                       else if(input.getDataType()==DataType.SCALAR)
+                               scalars.add(ec.getScalarInput(input.getName(), 
input.getValueType(), input.isLiteral()));
+               }
+               
+               // set the output dimensions to the hop node matrix dimensions
+               if( output.getDataType() == DataType.MATRIX) {
+                       MatrixBlock out = new MatrixBlock();
+                       op.execute(inputs, scalars, out, _numThreads);
+                       ec.setMatrixOutput(output.getName(), out);
+               }
+               else if (output.getDataType() == DataType.SCALAR) {
+                       ScalarObject out = op.execute(inputs, scalars, 
_numThreads);
+                       ec.setScalarOutput(output.getName(), out);
+               }
+               
+               // release input matrices
+               for (CPOperand input : _in)
+                       if(input.getDataType()==DataType.MATRIX)
+                               ec.releaseMatrixInput(input.getName());
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index b28e408..17d1561 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -37,7 +37,7 @@ public abstract class SPInstruction extends Instruction
                CentralMoment, Covariance, QSort, QPick, 
                ParameterizedBuiltin, MAppend, RAppend, GAppend, 
GAlignedAppend, Rand, 
                MatrixReshape, Ternary, Quaternary, CumsumAggregate, 
CumsumOffset, BinUaggChain, UaggOuterChain, 
-               Write, INVALID, 
+               Write, SpoofFused, INVALID, 
                Convolution
        };
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
new file mode 100644
index 0000000..15b0751
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.codegen.SpoofCellwise;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysml.runtime.codegen.SpoofOperator;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
+import org.apache.sysml.runtime.codegen.SpoofRowAggregate;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.instructions.spark.SPInstruction;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+import scala.Tuple2;
+
+public class SpoofSPInstruction extends SPInstruction
+{
+       private final Class<?> _class;
+       private final byte[] _classBytes;
+       private final CPOperand[] _in;
+       private final CPOperand _out;
+       
+       public SpoofSPInstruction(Class<?> cls , byte[] classBytes, CPOperand[] 
in, CPOperand out, String opcode, String str) {
+               super(opcode, str);
+               _class = cls;
+               _classBytes = classBytes;
+               _sptype = SPINSTRUCTION_TYPE.SpoofFused;
+               _in = in;
+               _out = out;
+       }
+       
+       public static SpoofSPInstruction parseInstruction(String str) 
+               throws DMLRuntimeException
+       {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               
+               //String opcode = parts[0];
+               ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
+               Class<?> cls = CodegenUtils.loadClass(parts[1], null);
+               byte[] classBytes = CodegenUtils.getClassAsByteArray(parts[1]);
+               String opcode =  parts[0] + CodegenUtils.getSpoofType(cls);
+               
+               for( int i=2; i<parts.length-2; i++ )
+                       inlist.add(new CPOperand(parts[i]));
+               CPOperand out = new CPOperand(parts[parts.length-2]);
+               //note: number of threads parts[parts.length-1] always ignored
+               
+               return new SpoofSPInstruction(cls, classBytes, 
inlist.toArray(new CPOperand[0]), out, opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec)
+               throws DMLRuntimeException 
+       {       
+               SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+               //get input rdd and variable name
+               ArrayList<String> bcVars = new ArrayList<String>();
+               MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(_in[0].getName());
+               JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() );
+               JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
+                               
+               //simple case: map-side only operation (one rdd input, 
broadcast all)
+               //keep track of broadcast variables
+               ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new 
ArrayList<PartitionedBroadcast<MatrixBlock>>();
+               ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
+               for( int i=1; i<_in.length; i++ ) {
+                       if( _in[i].getDataType()==DataType.MATRIX) {
+                               
bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
+                               bcVars.add(_in[i].getName());
+                       }
+                       else if(_in[i].getDataType()==DataType.SCALAR) {
+                               
scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), 
_in[i].isLiteral()));
+                       }
+               }
+               
+               //initialize Spark Operator
+               if(_class.getSuperclass() == SpoofCellwise.class) // cellwise 
operator
+               {
+                       if( _out.getDataType()==DataType.MATRIX ) {
+                               SpoofOperator op = (SpoofOperator) 
CodegenUtils.createInstance(_class);         
+                               
+                               out = in.mapPartitionsToPair(new 
CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               if( 
((SpoofCellwise)op).getCellType()==CellType.ROW_AGG && mcIn.getCols() > 
mcIn.getColsPerBlock() ) {
+                                       //NOTE: workaround with partition size 
needed due to potential bug in SPARK
+                                       //TODO investigate if some other side 
effect of correct blocks
+                                       if( out.partitions().size() > 
mcIn.getNumRowBlocks() )
+                                               out = 
RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks());
+                                       else
+                                               out = 
RDDAggregateUtils.sumByKeyStable(out);
+                               }
+                               sec.setRDDHandleForVariable(_out.getName(), 
out);
+                               
+                               //maintain lineage information for output rdd
+                               sec.addLineageRDD(_out.getName(), 
_in[0].getName());
+                               for( String bcVar : bcVars )
+                                       sec.addLineageBroadcast(_out.getName(), 
bcVar);
+                               
+                               //update matrix characteristics
+                               updateOutputMatrixCharacteristics(sec, op);     
+                       }
+                       else { //SCALAR
+                               out = in.mapPartitionsToPair(new 
CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               MatrixBlock tmpMB = 
RDDAggregateUtils.sumStable(out);
+                               sec.setVariable(_out.getName(), new 
DoubleObject(tmpMB.getValue(0, 0)));
+                       }
+               }
+               else if(_class.getSuperclass() == SpoofOuterProduct.class) // 
outer product operator
+               {
+                       if( _out.getDataType()==DataType.MATRIX ) {
+                               SpoofOperator op = (SpoofOperator) 
CodegenUtils.createInstance(_class);         
+                               OutProdType type = 
((SpoofOuterProduct)op).getOuterProdType();
+
+                               //update matrix characteristics
+                               updateOutputMatrixCharacteristics(sec, op);     
                
+                               MatrixCharacteristics mcOut = 
sec.getMatrixCharacteristics(_out.getName());
+                               
+                               out = in.mapPartitionsToPair(new 
OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               if(type == OutProdType.LEFT_OUTER_PRODUCT || 
type == OutProdType.RIGHT_OUTER_PRODUCT ) {
+                                       //NOTE: workaround with partition size 
needed due to potential bug in SPARK
+                                       //TODO investigate if some other side 
effect of correct blocks
+                                       if( in.partitions().size() > 
mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() )
+                                               out = 
RDDAggregateUtils.sumByKeyStable( out, 
(int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()) );
+                                       else
+                                               out = 
RDDAggregateUtils.sumByKeyStable( out );  
+                               }
+                               sec.setRDDHandleForVariable(_out.getName(), 
out);
+                               
+                               //maintain lineage information for output rdd
+                               sec.addLineageRDD(_out.getName(), 
_in[0].getName());
+                               for( String bcVar : bcVars )
+                                       sec.addLineageBroadcast(_out.getName(), 
bcVar);
+                               
+                       }
+                       else {
+                               out = in.mapPartitionsToPair(new 
OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+                               MatrixBlock tmp = 
RDDAggregateUtils.sumStable(out);
+                               sec.setVariable(_out.getName(), new 
DoubleObject(tmp.getValue(0, 0)));
+                       }
+               }
+               else if( _class.getSuperclass() == SpoofRowAggregate.class ) { 
//row aggregate operator
+                       RowAggregateFunction fmmc = new 
RowAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars);
+                       JavaPairRDD<MatrixIndexes,MatrixBlock> tmpRDD = 
in.mapToPair(fmmc);
+                       MatrixBlock tmpMB = 
RDDAggregateUtils.sumStable(tmpRDD);                
+                       sec.setMatrixOutput(_out.getName(), tmpMB);
+                       return;
+               }
+               else {
+                       throw new DMLRuntimeException("Operator " + 
_class.getSuperclass() + " is not supported on Spark");
+               }
+       }
+       
+       private void updateOutputMatrixCharacteristics(SparkExecutionContext 
sec, SpoofOperator op) 
+               throws DMLRuntimeException 
+       {
+               if(op instanceof SpoofCellwise)
+               {
+                       MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(_in[0].getName());
+                       MatrixCharacteristics mcOut = 
sec.getMatrixCharacteristics(_out.getName());
+                       if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG 
)
+                               mcOut.set(mcIn.getRows(), 1, 
mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
+                       else if( 
((SpoofCellwise)op).getCellType()==CellType.NO_AGG )
+                               mcOut.set(mcIn);
+               }
+               else if(op instanceof SpoofOuterProduct)
+               {
+                       MatrixCharacteristics mcIn1 = 
sec.getMatrixCharacteristics(_in[0].getName()); //X
+                       MatrixCharacteristics mcIn2 = 
sec.getMatrixCharacteristics(_in[1].getName()); //U
+                       MatrixCharacteristics mcIn3 = 
sec.getMatrixCharacteristics(_in[2].getName()); //V
+                       MatrixCharacteristics mcOut = 
sec.getMatrixCharacteristics(_out.getName());
+                       OutProdType type = 
((SpoofOuterProduct)op).getOuterProdType();
+                       
+                       if( type == OutProdType.CELLWISE_OUTER_PRODUCT)
+                               mcOut.set(mcIn1.getRows(), mcIn1.getCols(), 
mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
+                       else if( type == OutProdType.LEFT_OUTER_PRODUCT)        
        
+                               mcOut.set(mcIn3.getRows(), mcIn3.getCols(), 
mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock());          
+                       else if( type == OutProdType.RIGHT_OUTER_PRODUCT )
+                               mcOut.set(mcIn2.getRows(), mcIn2.getCols(), 
mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock());
+               }
+       }
+               
+       private static class RowAggregateFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> 
+       {
+               private static final long serialVersionUID = 
-7926980450209760212L;
+
+               private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = 
null;
+               private ArrayList<ScalarObject> _scalars = null;
+               private byte[] _classBytes = null;
+               private String _className = null;
+               private SpoofOperator _op = null;
+               
+               public RowAggregateFunction(String className, byte[] 
classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+                       throws DMLRuntimeException
+               {                       
+                       _className = className;
+                       _classBytes = classBytes;
+                       _vectors = bcMatrices;
+                       _scalars = scalars;
+               }
+               
+               @Override
+               public Tuple2<MatrixIndexes, MatrixBlock> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
+                       throws Exception 
+               {
+                       //lazy load of shipped class
+                       if( _op == null ) {
+                               Class<?> loadedClass = 
CodegenUtils.loadClass(_className, _classBytes);
+                               _op = (SpoofOperator) 
CodegenUtils.createInstance(loadedClass); 
+                       }
+                       
+                       //get main input block and indexes
+                       MatrixIndexes ixIn = arg0._1();
+                       MatrixBlock blkIn = arg0._2();
+                       int rowIx = (int)ixIn.getRowIndex();
+                       
+                       //prepare output and execute single-threaded operator
+                       ArrayList<MatrixBlock> inputs = 
getVectorInputsFromBroadcast(blkIn, rowIx);
+                       MatrixIndexes ixOut = new MatrixIndexes(1,1);
+                       MatrixBlock blkOut = new MatrixBlock();
+                       _op.execute(inputs, _scalars, blkOut);
+                       
+                       //output new tuple
+                       return new Tuple2<MatrixIndexes, MatrixBlock>(ixOut, 
blkOut);
+               }
+               
+               private ArrayList<MatrixBlock> 
getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) 
+                       throws DMLRuntimeException 
+               {
+                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
+                       ret.add(blkIn);
+                       for( PartitionedBroadcast<MatrixBlock> vector : 
_vectors )
+                               
ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1));
+                       return ret;
+               }
+       }
+       
+       private static class CellwiseFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, 
MatrixIndexes, MatrixBlock> 
+       {
+               private static final long serialVersionUID = 
-8209188316939435099L;
+               
+               private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = 
null;
+               private ArrayList<ScalarObject> _scalars = null;
+               private byte[] _classBytes = null;
+               private String _className = null;
+               private SpoofOperator _op = null;
+               
+               public CellwiseFunction(String className, byte[] classBytes, 
ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+                       throws DMLRuntimeException
+               {
+                       _className = className;
+                       _classBytes = classBytes;
+                       _vectors = bcMatrices;
+                       _scalars = scalars;
+               }
+               
+               @Override
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+                       throws Exception 
+               {
+                       //lazy load of shipped class
+                       if( _op == null ) {
+                               Class<?> loadedClass = 
CodegenUtils.loadClass(_className, _classBytes);
+                               _op = (SpoofOperator) 
CodegenUtils.createInstance(loadedClass); 
+                       }
+                       
+                       List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
+                       while(arg.hasNext()) 
+                       {
+                               Tuple2<MatrixIndexes,MatrixBlock> tmp = 
arg.next();
+                               MatrixIndexes ixIn = tmp._1();
+                               MatrixBlock blkIn = tmp._2();
+                               MatrixIndexes ixOut = ixIn; 
+                               MatrixBlock blkOut = new MatrixBlock();
+                               ArrayList<MatrixBlock> inputs = 
getVectorInputsFromBroadcast(blkIn, (int)ixIn.getRowIndex());
+                                       
+                               //execute core operation
+                               
if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) {
+                                       ScalarObject obj = _op.execute(inputs, 
_scalars, 1);
+                                       blkOut.reset(1, 1);
+                                       blkOut.quickSetValue(0, 0, 
obj.getDoubleValue());       
+                               }
+                               else {
+                                       
if(((SpoofCellwise)_op).getCellType()==CellType.ROW_AGG)
+                                               ixOut = new 
MatrixIndexes(ixOut.getRowIndex(), 1);
+                                       _op.execute(inputs, _scalars, blkOut);
+                               }
+                               ret.add(new 
Tuple2<MatrixIndexes,MatrixBlock>(ixOut, blkOut));
+                       }
+                       return ret.iterator();
+               }
+               
+               private ArrayList<MatrixBlock> 
getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) 
+                       throws DMLRuntimeException 
+               {
+                       ArrayList<MatrixBlock> ret = new 
ArrayList<MatrixBlock>();
+                       ret.add(blkIn);
+                       for( PartitionedBroadcast<MatrixBlock> vector : 
_vectors )
+                               
ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1));
+                       return ret;
+               }
+       }       
+       
+       private static class OuterProductFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, 
MatrixIndexes, MatrixBlock> 
+       {
+               private static final long serialVersionUID = 
-8209188316939435099L;
+               
+               private ArrayList<PartitionedBroadcast<MatrixBlock>> 
_bcMatrices = null;
+               private ArrayList<ScalarObject> _scalars = null;
+               private byte[] _classBytes = null;
+               private String _className = null;
+               private SpoofOperator _op = null;
+               
+               public OuterProductFunction(String className, byte[] 
classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, 
ArrayList<ScalarObject> scalars) 
+                               throws DMLRuntimeException
+               {
+                       _className = className;
+                       _classBytes = classBytes;
+                       _bcMatrices = bcMatrices;
+                       _scalars = scalars;
+               }
+               
+               @Override
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+                       throws Exception 
+               {
+                       //lazy load of shipped class
+                       if( _op == null ) {
+                               Class<?> loadedClass = 
CodegenUtils.loadClass(_className, _classBytes);
+                               _op = (SpoofOperator) 
CodegenUtils.createInstance(loadedClass); 
+                       }
+                       
+                       List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
+                       while(arg.hasNext())
+                       {
+                               Tuple2<MatrixIndexes,MatrixBlock> tmp = 
arg.next();
+                               MatrixIndexes ixIn = tmp._1();
+                               MatrixBlock blkIn = tmp._2();
+                               MatrixBlock blkOut = new MatrixBlock();
+
+                               ArrayList<MatrixBlock> inputs = new 
ArrayList<MatrixBlock>();
+                               inputs.add(blkIn);
+                               
inputs.add(_bcMatrices.get(0).getBlock((int)ixIn.getRowIndex(), 1)); // U
+                               
inputs.add(_bcMatrices.get(1).getBlock((int)ixIn.getColumnIndex(), 1)); // V
+                                               
+                               //execute core operation
+                               
if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) {
+                                       ScalarObject obj = _op.execute(inputs, 
_scalars,1);
+                                       blkOut.reset(1, 1);
+                                       blkOut.quickSetValue(0, 0, 
obj.getDoubleValue());
+                               }
+                               else {
+                                       _op.execute(inputs, _scalars, blkOut);
+                               }
+                               
+                               ret.add(new 
Tuple2<MatrixIndexes,MatrixBlock>(createOutputIndexes(ixIn,_op), blkOut));      
                    
+                       }
+                       
+                       return ret.iterator();
+               }
+               
+               private MatrixIndexes createOutputIndexes(MatrixIndexes in, 
SpoofOperator spoofOp) {
+                       if( ((SpoofOuterProduct)spoofOp).getOuterProdType() == 
OutProdType.LEFT_OUTER_PRODUCT ) 
+                               return new MatrixIndexes(in.getColumnIndex(), 
1);
+                       else if ( 
((SpoofOuterProduct)spoofOp).getOuterProdType() == 
OutProdType.RIGHT_OUTER_PRODUCT)
+                               return new MatrixIndexes(in.getRowIndex(), 1);
+                       else 
+                               return in;
+               }               
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
index 61c950a..2dfff74 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
@@ -69,13 +69,17 @@ public class RDDAggregateUtils
                }
        }
 
-       public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( 
JavaPairRDD<MatrixIndexes, MatrixBlock> in )
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( 
JavaPairRDD<MatrixIndexes, MatrixBlock> in ) {
+               return sumByKeyStable(in, in.getNumPartitions());
+       }
+       
+       public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( 
JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions  )
        {
                //stable sum of blocks per key, by passing correction blocks 
along with aggregates              
                JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = 
                                in.combineByKey( new 
CreateCorrBlockCombinerFunction(), 
                                                             new 
MergeSumBlockValueFunction(), 
-                                                            new 
MergeSumBlockCombinerFunction() );
+                                                            new 
MergeSumBlockCombinerFunction(), numPartitions );
                
                //strip-off correction blocks from                              
             
                JavaPairRDD<MatrixIndexes, MatrixBlock> out =  

Reply via email to