[SYSTEMML-2470] Fix distributed spark cumsumprod (aggregate 1st pass)

This patch fixes result correctness issues of distributed spark
operations of the new cumulative aggregate cumsumprod. In detail, we now
use cumsumprod(AB)[n] instead of sum(AB) as aggregation function during
the forward pass of the generic two-pass algorithm.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/50ddddb9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/50ddddb9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/50ddddb9

Branch: refs/heads/master
Commit: 50ddddb90b28c6e28e97195dded9696edcdc3b45
Parents: 252e498
Author: Matthias Boehm <[email protected]>
Authored: Mon Jul 30 14:34:49 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jul 30 14:36:19 2018 -0700

----------------------------------------------------------------------
 .../spark/CumulativeAggregateSPInstruction.java    | 17 +++++++++++------
 .../functions/unary/matrix/FullCumsumprodTest.java |  3 +--
 2 files changed, 12 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
index 74390e1..8514acc 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
@@ -27,6 +27,7 @@ import scala.Tuple2;
 
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
@@ -36,6 +37,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
 
 public class CumulativeAggregateSPInstruction extends 
AggregateUnarySPInstruction {
 
@@ -79,10 +81,11 @@ public class CumulativeAggregateSPInstruction extends 
AggregateUnarySPInstructio
        {
                private static final long serialVersionUID = 11324676268945117L;
                
-               private AggregateUnaryOperator _op = null;
-               private long _rlen = -1;
-               private int _brlen = -1;
-               private int _bclen = -1;
+               private final AggregateUnaryOperator _op;
+               private UnaryOperator _uop = null;
+               private final long _rlen;
+               private final int _brlen;
+               private final int _bclen;
                
                public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, 
int brlen, int bclen ) {
                        _op = op;
@@ -105,10 +108,12 @@ public class CumulativeAggregateSPInstruction extends 
AggregateUnarySPInstructio
                        AggregateUnaryOperator aop = 
(AggregateUnaryOperator)_op;
                        if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { 
//cumsumprod
                                aop.indexFn.execute(ixIn, ixOut);
-                               MatrixBlock t1 = blkIn.slice(0, 
blkIn.getNumRows()-1, 0, 0, new MatrixBlock());
+                               if( _uop == null )
+                                       _uop = new 
UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
+                               MatrixBlock t1 = (MatrixBlock) 
blkIn.unaryOperations(_uop, new MatrixBlock());
                                MatrixBlock t2 = blkIn.slice(0, 
blkIn.getNumRows()-1, 1, 1, new MatrixBlock());
                                blkOut.reset(1, 2);
-                               blkOut.quickSetValue(0, 0, t1.sum());
+                               blkOut.quickSetValue(0, 0, 
t1.quickGetValue(t1.getNumRows()-1, 0));
                                blkOut.quickSetValue(0, 1, t2.prod());
                        }
                        else { //general case

http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
index 7f02055..f13e765 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
@@ -111,8 +111,7 @@ public class FullCumsumprodTest extends AutomatedTestBase
                                String.valueOf(reverse).toUpperCase(), 
output("C") };
                        
                        double[][] A = getRandomMatrix(rows, 1, -10, 10, 
sparsity, 3);
-                       double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.1, 7);
-                       //FIXME double[][] B = getRandomMatrix(rows, 1, -1, 1, 
0.9, 7);
+                       double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7);
                        writeInputMatrixWithMTD("A", A, false);
                        writeInputMatrixWithMTD("B", B, false);
                        

Reply via email to