[SYSTEMML-1287] Code generator runtime integration Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/982ecb1a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/982ecb1a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/982ecb1a
Branch: refs/heads/master Commit: 982ecb1a4be69685a8e124eccfa3a12331f998b0 Parents: d7fd587 Author: Matthias Boehm <[email protected]> Authored: Sun Feb 26 19:01:36 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sun Feb 26 19:01:36 2017 -0800 ---------------------------------------------------------------------- .../instructions/CPInstructionParser.java | 19 +- .../instructions/SPInstructionParser.java | 14 +- .../runtime/instructions/cp/CPInstruction.java | 8 +- .../instructions/cp/SpoofCPInstruction.java | 98 +++++ .../instructions/spark/SPInstruction.java | 2 +- .../instructions/spark/SpoofSPInstruction.java | 407 +++++++++++++++++++ .../spark/utils/RDDAggregateUtils.java | 8 +- 7 files changed, 541 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index f3c1605..f0603b4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -61,6 +61,7 @@ import org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.RelationalBinaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.ReorgCPInstruction; +import org.apache.sysml.runtime.instructions.cp.SpoofCPInstruction; import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction; import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.UaggOuterChainCPInstruction; @@ -271,8 +272,9 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "lu", CPINSTRUCTION_TYPE.MultiReturnBuiltin); String2CPInstructionType.put( "eigen", CPINSTRUCTION_TYPE.MultiReturnBuiltin); - String2CPInstructionType.put( "partition", CPINSTRUCTION_TYPE.Partition); - String2CPInstructionType.put( "compress", CPINSTRUCTION_TYPE.Compression); + String2CPInstructionType.put( "partition", CPINSTRUCTION_TYPE.Partition); + String2CPInstructionType.put( "compress", CPINSTRUCTION_TYPE.Compression); + String2CPInstructionType.put( "spoof", CPINSTRUCTION_TYPE.SpoofFused); //CP FILE instruction String2CPFileInstructionType = new HashMap<String, CPINSTRUCTION_TYPE>(); @@ -424,16 +426,19 @@ public class CPInstructionParser extends InstructionParser case Partition: return DataPartitionCPInstruction.parseInstruction(str); - - case Compression: - return (CPInstruction) CompressionCPInstruction.parseInstruction(str); - + case CentralMoment: return CentralMomentCPInstruction.parseInstruction(str); case Covariance: return CovarianceCPInstruction.parseInstruction(str); - + + case Compression: + return (CPInstruction) CompressionCPInstruction.parseInstruction(str); + + case SpoofFused: + return SpoofCPInstruction.parseInstruction(str); + case INVALID: default: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 6658a88..5ca3847 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -73,6 +73,7 @@ import org.apache.sysml.runtime.instructions.spark.ReorgSPInstruction; import org.apache.sysml.runtime.instructions.spark.RmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.SPInstruction; import org.apache.sysml.runtime.instructions.spark.SPInstruction.SPINSTRUCTION_TYPE; +import org.apache.sysml.runtime.instructions.spark.SpoofSPInstruction; import org.apache.sysml.runtime.instructions.spark.TernarySPInstruction; import org.apache.sysml.runtime.instructions.spark.Tsmm2SPInstruction; import org.apache.sysml.runtime.instructions.spark.TsmmSPInstruction; @@ -277,10 +278,12 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "binuaggchain", SPINSTRUCTION_TYPE.BinUaggChain); - String2SPInstructionType.put( "write" , SPINSTRUCTION_TYPE.Write); + String2SPInstructionType.put( "write" , SPINSTRUCTION_TYPE.Write); - String2SPInstructionType.put( "castdtm" , SPINSTRUCTION_TYPE.Cast); - String2SPInstructionType.put( "castdtf" , SPINSTRUCTION_TYPE.Cast); + String2SPInstructionType.put( "castdtm" , SPINSTRUCTION_TYPE.Cast); + String2SPInstructionType.put( "castdtf" , SPINSTRUCTION_TYPE.Cast); + + String2SPInstructionType.put( "spoof" , SPINSTRUCTION_TYPE.SpoofFused); } public static SPInstruction parseSingleInstruction (String str ) @@ -443,10 +446,13 @@ public class SPInstructionParser extends InstructionParser case Checkpoint: return CheckpointSPInstruction.parseInstruction(str); - + case Compression: return CompressionSPInstruction.parseInstruction(str); + case SpoofFused: + return SpoofSPInstruction.parseInstruction(str); + case Cast: return CastSPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index 1d192d5..dcd8d89 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -29,7 +29,13 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class CPInstruction extends Instruction { - public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, BuiltinMultiple, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution }; + public enum CPINSTRUCTION_TYPE { INVALID, + AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, + Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, + BuiltinMultiple, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, + Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, + MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, SpoofFused, + StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution }; protected CPINSTRUCTION_TYPE _cptype; protected Operator _optr; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java new file mode 100644 index 0000000..61313d7 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.ArrayList; + +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.codegen.SpoofOperator; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public class SpoofCPInstruction extends ComputationCPInstruction +{ + private Class<?> _class = null; + private int _numThreads = 1; + private CPOperand[] _in = null; + + public SpoofCPInstruction(Class<?> cla, int k, CPOperand[] in, CPOperand out, String opcode, String str) { + super(null, null, null, out, opcode, str); + _class = cla; + _numThreads = k; + _in = in; + } + + public static SpoofCPInstruction parseInstruction(String str) + throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + + //String opcode = parts[0]; + ArrayList<CPOperand> inlist = new ArrayList<CPOperand>(); + Class<?> cla = CodegenUtils.loadClass(parts[1], null); + String opcode = parts[0] + CodegenUtils.getSpoofType(cla); + + for( int i=2; i<parts.length-2; i++ ) + inlist.add(new CPOperand(parts[i])); + CPOperand out = new CPOperand(parts[parts.length-2]); + int k = Integer.parseInt(parts[parts.length-1]); + + return new SpoofCPInstruction(cla, k, inlist.toArray(new CPOperand[0]), out, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); + + //get input matrices and scalars, incl pinning of matrices + ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>(); + ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>(); + for (CPOperand input : _in) { + if(input.getDataType()==DataType.MATRIX) + inputs.add(ec.getMatrixInput(input.getName())); + else if(input.getDataType()==DataType.SCALAR) + scalars.add(ec.getScalarInput(input.getName(), input.getValueType(), input.isLiteral())); + } + + // set the output dimensions to the hop node matrix dimensions + if( output.getDataType() == DataType.MATRIX) { + MatrixBlock out = new MatrixBlock(); + op.execute(inputs, scalars, out, _numThreads); + ec.setMatrixOutput(output.getName(), out); + } + else if (output.getDataType() == DataType.SCALAR) { + ScalarObject out = op.execute(inputs, scalars, _numThreads); + ec.setScalarOutput(output.getName(), out); + } + + // release input matrices + for (CPOperand input : _in) + if(input.getDataType()==DataType.MATRIX) + ec.releaseMatrixInput(input.getName()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java index b28e408..17d1561 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java @@ -37,7 +37,7 @@ public abstract class SPInstruction extends Instruction CentralMoment, Covariance, QSort, QPick, ParameterizedBuiltin, MAppend, RAppend, GAppend, GAlignedAppend, Rand, MatrixReshape, Ternary, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain, - Write, INVALID, + Write, SpoofFused, INVALID, Convolution }; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java new file mode 100644 index 0000000..15b0751 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.codegen.SpoofCellwise; +import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; +import org.apache.sysml.runtime.codegen.SpoofOperator; +import org.apache.sysml.runtime.codegen.SpoofOuterProduct; +import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; +import org.apache.sysml.runtime.codegen.SpoofRowAggregate; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.DoubleObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.instructions.spark.SPInstruction; +import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; +import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; + +import scala.Tuple2; + +public class SpoofSPInstruction extends SPInstruction +{ + private final Class<?> _class; + private final byte[] _classBytes; + private final CPOperand[] _in; + private final CPOperand _out; + + public SpoofSPInstruction(Class<?> cls , byte[] classBytes, CPOperand[] in, CPOperand out, String opcode, String str) { + super(opcode, str); + _class = cls; + _classBytes = classBytes; + _sptype = SPINSTRUCTION_TYPE.SpoofFused; + _in = in; + _out = out; + } + + public static SpoofSPInstruction parseInstruction(String str) + throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + + //String opcode = parts[0]; + ArrayList<CPOperand> inlist = new ArrayList<CPOperand>(); + Class<?> cls = CodegenUtils.loadClass(parts[1], null); + byte[] classBytes = CodegenUtils.getClassAsByteArray(parts[1]); + String opcode = parts[0] + CodegenUtils.getSpoofType(cls); + + for( int i=2; i<parts.length-2; i++ ) + inlist.add(new CPOperand(parts[i])); + CPOperand out = new CPOperand(parts[parts.length-2]); + //note: number of threads parts[parts.length-1] always ignored + + return new SpoofSPInstruction(cls, classBytes, inlist.toArray(new CPOperand[0]), out, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + SparkExecutionContext sec = (SparkExecutionContext)ec; + + //get input rdd and variable name + ArrayList<String> bcVars = new ArrayList<String>(); + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); + JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() ); + JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; + + //simple case: map-side only operation (one rdd input, broadcast all) + //keep track of broadcast variables + ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>(); + ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>(); + for( int i=1; i<_in.length; i++ ) { + if( _in[i].getDataType()==DataType.MATRIX) { + bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName())); + bcVars.add(_in[i].getName()); + } + else if(_in[i].getDataType()==DataType.SCALAR) { + scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral())); + } + } + + //initialize Spark Operator + if(_class.getSuperclass() == SpoofCellwise.class) // cellwise operator + { + if( _out.getDataType()==DataType.MATRIX ) { + SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); + + out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) { + //NOTE: workaround with partition size needed due to potential bug in SPARK + //TODO investigate if some other side effect of correct blocks + if( out.partitions().size() > mcIn.getNumRowBlocks() ) + out = RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks()); + else + out = RDDAggregateUtils.sumByKeyStable(out); + } + sec.setRDDHandleForVariable(_out.getName(), out); + + //maintain lineage information for output rdd + sec.addLineageRDD(_out.getName(), _in[0].getName()); + for( String bcVar : bcVars ) + sec.addLineageBroadcast(_out.getName(), bcVar); + + //update matrix characteristics + updateOutputMatrixCharacteristics(sec, op); + } + else { //SCALAR + out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out); + sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); + } + } + else if(_class.getSuperclass() == SpoofOuterProduct.class) // outer product operator + { + if( _out.getDataType()==DataType.MATRIX ) { + SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); + OutProdType type = ((SpoofOuterProduct)op).getOuterProdType(); + + //update matrix characteristics + updateOutputMatrixCharacteristics(sec, op); + MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); + + out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) { + //NOTE: workaround with partition size needed due to potential bug in SPARK + //TODO investigate if some other side effect of correct blocks + if( in.partitions().size() > mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() ) + out = RDDAggregateUtils.sumByKeyStable( out, (int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()) ); + else + out = RDDAggregateUtils.sumByKeyStable( out ); + } + sec.setRDDHandleForVariable(_out.getName(), out); + + //maintain lineage information for output rdd + sec.addLineageRDD(_out.getName(), _in[0].getName()); + for( String bcVar : bcVars ) + sec.addLineageBroadcast(_out.getName(), bcVar); + + } + else { + out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); + MatrixBlock tmp = RDDAggregateUtils.sumStable(out); + sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0))); + } + } + else if( _class.getSuperclass() == SpoofRowAggregate.class ) { //row aggregate operator + RowAggregateFunction fmmc = new RowAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars); + JavaPairRDD<MatrixIndexes,MatrixBlock> tmpRDD = in.mapToPair(fmmc); + MatrixBlock tmpMB = RDDAggregateUtils.sumStable(tmpRDD); + sec.setMatrixOutput(_out.getName(), tmpMB); + return; + } + else { + throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark"); + } + } + + private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op) + throws DMLRuntimeException + { + if(op instanceof SpoofCellwise) + { + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); + MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); + if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG ) + mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); + else if( ((SpoofCellwise)op).getCellType()==CellType.NO_AGG ) + mcOut.set(mcIn); + } + else if(op instanceof SpoofOuterProduct) + { + MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(_in[0].getName()); //X + MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(_in[1].getName()); //U + MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(_in[2].getName()); //V + MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); + OutProdType type = ((SpoofOuterProduct)op).getOuterProdType(); + + if( type == OutProdType.CELLWISE_OUTER_PRODUCT) + mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock()); + else if( type == OutProdType.LEFT_OUTER_PRODUCT) + mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock()); + else if( type == OutProdType.RIGHT_OUTER_PRODUCT ) + mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock()); + } + } + + private static class RowAggregateFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = -7926980450209760212L; + + private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null; + private ArrayList<ScalarObject> _scalars = null; + private byte[] _classBytes = null; + private String _className = null; + private SpoofOperator _op = null; + + public RowAggregateFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + throws DMLRuntimeException + { + _className = className; + _classBytes = classBytes; + _vectors = bcMatrices; + _scalars = scalars; + } + + @Override + public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) + throws Exception + { + //lazy load of shipped class + if( _op == null ) { + Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes); + _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); + } + + //get main input block and indexes + MatrixIndexes ixIn = arg0._1(); + MatrixBlock blkIn = arg0._2(); + int rowIx = (int)ixIn.getRowIndex(); + + //prepare output and execute single-threaded operator + ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, rowIx); + MatrixIndexes ixOut = new MatrixIndexes(1,1); + MatrixBlock blkOut = new MatrixBlock(); + _op.execute(inputs, _scalars, blkOut); + + //output new tuple + return new Tuple2<MatrixIndexes, MatrixBlock>(ixOut, blkOut); + } + + private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) + throws DMLRuntimeException + { + ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); + ret.add(blkIn); + for( PartitionedBroadcast<MatrixBlock> vector : _vectors ) + ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1)); + return ret; + } + } + + private static class CellwiseFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = -8209188316939435099L; + + private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null; + private ArrayList<ScalarObject> _scalars = null; + private byte[] _classBytes = null; + private String _className = null; + private SpoofOperator _op = null; + + public CellwiseFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + throws DMLRuntimeException + { + _className = className; + _classBytes = classBytes; + _vectors = bcMatrices; + _scalars = scalars; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) + throws Exception + { + //lazy load of shipped class + if( _op == null ) { + Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes); + _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); + } + + List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); + while(arg.hasNext()) + { + Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next(); + MatrixIndexes ixIn = tmp._1(); + MatrixBlock blkIn = tmp._2(); + MatrixIndexes ixOut = ixIn; + MatrixBlock blkOut = new MatrixBlock(); + ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, (int)ixIn.getRowIndex()); + + //execute core operation + if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) { + ScalarObject obj = _op.execute(inputs, _scalars, 1); + blkOut.reset(1, 1); + blkOut.quickSetValue(0, 0, obj.getDoubleValue()); + } + else { + if(((SpoofCellwise)_op).getCellType()==CellType.ROW_AGG) + ixOut = new MatrixIndexes(ixOut.getRowIndex(), 1); + _op.execute(inputs, _scalars, blkOut); + } + ret.add(new Tuple2<MatrixIndexes,MatrixBlock>(ixOut, blkOut)); + } + return ret.iterator(); + } + + private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex) + throws DMLRuntimeException + { + ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); + ret.add(blkIn); + for( PartitionedBroadcast<MatrixBlock> vector : _vectors ) + ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1)); + return ret; + } + } + + private static class OuterProductFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = -8209188316939435099L; + + private ArrayList<PartitionedBroadcast<MatrixBlock>> _bcMatrices = null; + private ArrayList<ScalarObject> _scalars = null; + private byte[] _classBytes = null; + private String _className = null; + private SpoofOperator _op = null; + + public OuterProductFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + throws DMLRuntimeException + { + _className = className; + _classBytes = classBytes; + _bcMatrices = bcMatrices; + _scalars = scalars; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) + throws Exception + { + //lazy load of shipped class + if( _op == null ) { + Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes); + _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); + } + + List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); + while(arg.hasNext()) + { + Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next(); + MatrixIndexes ixIn = tmp._1(); + MatrixBlock blkIn = tmp._2(); + MatrixBlock blkOut = new MatrixBlock(); + + ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>(); + inputs.add(blkIn); + inputs.add(_bcMatrices.get(0).getBlock((int)ixIn.getRowIndex(), 1)); // U + inputs.add(_bcMatrices.get(1).getBlock((int)ixIn.getColumnIndex(), 1)); // V + + //execute core operation + if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) { + ScalarObject obj = _op.execute(inputs, _scalars,1); + blkOut.reset(1, 1); + blkOut.quickSetValue(0, 0, obj.getDoubleValue()); + } + else { + _op.execute(inputs, _scalars, blkOut); + } + + ret.add(new Tuple2<MatrixIndexes,MatrixBlock>(createOutputIndexes(ixIn,_op), blkOut)); + } + + return ret.iterator(); + } + + private MatrixIndexes createOutputIndexes(MatrixIndexes in, SpoofOperator spoofOp) { + if( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.LEFT_OUTER_PRODUCT ) + return new MatrixIndexes(in.getColumnIndex(), 1); + else if ( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.RIGHT_OUTER_PRODUCT) + return new MatrixIndexes(in.getRowIndex(), 1); + else + return in; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java index 61c950a..2dfff74 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java @@ -69,13 +69,17 @@ public class RDDAggregateUtils } } - public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in ) + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in ) { + return sumByKeyStable(in, in.getNumPartitions()); + } + + public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions ) { //stable sum of blocks per key, by passing correction blocks along with aggregates JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = in.combineByKey( new CreateCorrBlockCombinerFunction(), new MergeSumBlockValueFunction(), - new MergeSumBlockCombinerFunction() ); + new MergeSumBlockCombinerFunction(), numPartitions ); //strip-off correction blocks from JavaPairRDD<MatrixIndexes, MatrixBlock> out =
