[SYSTEMML-2502] Performance spark cumagg offset aggregation (zero-copy)

This patch avoid unnecessary copy operations of input data blocks, which
were used to avoid data corruption on offset aggregation into the first
row. Instead we now directly pass the offset vector into the dedicated
cumulative aggregate operations. On our running example of 100
distributed sum(cumsum(X)) operations, this patch reduced the total
runtime from 887s to 732s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7a3447a5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7a3447a5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7a3447a5

Branch: refs/heads/master
Commit: 7a3447a50b6d2abdbaf6dce9d021a3ce7c2717d7
Parents: fee20fb
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sat Dec 1 21:06:04 2018 +0100
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Dec 1 21:06:04 2018 +0100

----------------------------------------------------------------------
 .../spark/CumulativeOffsetSPInstruction.java    | 62 +++++++-------------
 .../sysml/runtime/matrix/data/LibMatrixAgg.java | 10 +++-
 2 files changed, 27 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7a3447a5/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
index 952a6d0..53e6e91 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
@@ -32,50 +32,40 @@ import scala.Tuple2;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.functionobjects.Builtin;
-import org.apache.sysml.runtime.functionobjects.Multiply;
-import org.apache.sysml.runtime.functionobjects.Plus;
-import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.LibMatrixAgg;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.utils.IntUtils;
 
 public class CumulativeOffsetSPInstruction extends BinarySPInstruction {
-       private BinaryOperator _bop = null;
        private UnaryOperator _uop = null;
+       private boolean _cumsumprod = false;
        private final double _initValue ;
        private final boolean _broadcast;
 
        private CumulativeOffsetSPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, double init, boolean broadcast, String opcode, 
String istr) {
                super(SPType.CumsumOffset, op, in1, in2, out, opcode, istr);
 
-               if ("bcumoffk+".equals(opcode)) {
-                       _bop = new BinaryOperator(Plus.getPlusFnObject());
+               if ("bcumoffk+".equals(opcode))
                        _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
-               }
-               else if ("bcumoff*".equals(opcode)) {
-                       _bop = new 
BinaryOperator(Multiply.getMultiplyFnObject());
+               else if ("bcumoff*".equals(opcode))
                        _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
-               }
                else if ("bcumoff+*".equals(opcode)) {
-                       _bop = new BinaryOperator(PlusMultiply.getFnObject());
                        _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
+                       _cumsumprod = true;
                }
-               else if ("bcumoffmin".equals(opcode)) {
-                       _bop = new 
BinaryOperator(Builtin.getBuiltinFnObject("min"));
+               else if ("bcumoffmin".equals(opcode))
                        _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucummin"));
-               }
-               else if ("bcumoffmax".equals(opcode)) {
-                       _bop = new 
BinaryOperator(Builtin.getBuiltinFnObject("max"));
+               else if ("bcumoffmax".equals(opcode))
                        _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucummax"));
-               }
 
                _initValue = init;
                _broadcast = broadcast;
@@ -119,10 +109,10 @@ public class CumulativeOffsetSPInstruction extends 
BinarySPInstruction {
                
                //execute cumulative offset (apply cumulative op w/ offsets)
                JavaPairRDD<MatrixIndexes,MatrixBlock> out = joined
-                       .mapValues(new RDDCumOffsetFunction(_uop, _bop));
+                       .mapValues(new RDDCumOffsetFunction(_uop, _cumsumprod));
                
                //put output handle in symbol table
-               if( _bop.fn instanceof PlusMultiply )
+               if( _cumsumprod )
                        sec.getMatrixCharacteristics(output.getName())
                                .set(mc1.getRows(), 1, mc1.getRowsPerBlock(), 
mc1.getColsPerBlock());
                else //general case
@@ -219,12 +209,12 @@ public class CumulativeOffsetSPInstruction extends 
BinarySPInstruction {
        {
                private static final long serialVersionUID = 
-5804080263258064743L;
 
-               private UnaryOperator _uop = null;
-               private BinaryOperator _bop = null;
+               private final UnaryOperator _uop;
+               private final boolean _cumsumprod;
                
-               public RDDCumOffsetFunction(UnaryOperator uop, BinaryOperator 
bop) {
+               public RDDCumOffsetFunction(UnaryOperator uop, boolean 
cumsumprod) {
                        _uop = uop;
-                       _bop = bop;
+                       _cumsumprod = cumsumprod;
                }
 
                @Override
@@ -232,26 +222,14 @@ public class CumulativeOffsetSPInstruction extends 
BinarySPInstruction {
                        //prepare inputs and outputs
                        MatrixBlock dblkIn = arg0._1(); //original data 
                        MatrixBlock oblkIn = arg0._2(); //offset row vector
-                       MatrixBlock data2 = new MatrixBlock(dblkIn); //cp data
-                       boolean cumsumprod = _bop.fn instanceof PlusMultiply;
                        
-                       //blockwise offset aggregation and prefix sum 
computation
-                       if( cumsumprod ) {
-                               data2.quickSetValue(0, 0, 
data2.quickGetValue(0, 0)
-                                       + data2.quickGetValue(0, 1) * 
oblkIn.quickGetValue(0, 0));
-                       }
-                       else {
-                               MatrixBlock fdata2 = data2.slice(0, 0);
-                               fdata2.binaryOperationsInPlace(_bop, oblkIn); 
//sum offset to first row
-                               data2.copy(0, 0, 0, data2.getNumColumns()-1, 
fdata2, true); //0-based
-                       }
-                       
-                       //compute columnwise prefix sums/prod/min/max
+                       //allocate output block
                        MatrixBlock blkOut = new 
MatrixBlock(dblkIn.getNumRows(),
-                               cumsumprod ? 1 : dblkIn.getNumColumns(), 
dblkIn.isInSparseFormat());
-                       data2.unaryOperations(_uop, blkOut);
-
-                       return blkOut;
+                               _cumsumprod ? 1 : dblkIn.getNumColumns(), 
false);
+                       
+                       //blockwise cumagg computation, incl offset aggregation
+                       return LibMatrixAgg.cumaggregateUnaryMatrix(dblkIn, 
blkOut, _uop,
+                               DataConverter.convertToDoubleVector(oblkIn));
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/7a3447a5/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index e3e9dd7..c817a26 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -283,6 +283,10 @@ public class LibMatrixAgg
        }
 
        public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, 
MatrixBlock out, UnaryOperator uop) {
+               return cumaggregateUnaryMatrix(in, out, uop, null);
+       }
+       
+       public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, 
MatrixBlock out, UnaryOperator uop, double[] agg) {
                //prepare meta data 
                AggType aggtype = getAggType(uop);
                final int m = in.rlen;
@@ -290,7 +294,7 @@ public class LibMatrixAgg
                final int n2 = out.clen;
                
                //filter empty input blocks (incl special handling for 
sparse-unsafe operations)
-               if( in.isEmptyBlock(false) ){
+               if( in.isEmptyBlock(false) && (agg == null || aggtype == 
AggType.CUM_SUM_PROD ) ) {
                        return aggregateUnaryMatrixEmpty(in, out, aggtype, 
null);
                }
                
@@ -301,9 +305,9 @@ public class LibMatrixAgg
                //Timing time = new Timing(true);
                
                if( !in.sparse )
-                       cumaggregateUnaryMatrixDense(in, out, aggtype, uop.fn, 
null, 0, m);
+                       cumaggregateUnaryMatrixDense(in, out, aggtype, uop.fn, 
agg, 0, m);
                else
-                       cumaggregateUnaryMatrixSparse(in, out, aggtype, uop.fn, 
null, 0, m);
+                       cumaggregateUnaryMatrixSparse(in, out, aggtype, uop.fn, 
agg, 0, m);
                
                //cleanup output and change representation (if necessary)
                out.recomputeNonZeros();

Reply via email to