[SYSTEMML-2470] Fix distributed spark cumsumprod (aggregate 1st pass) This patch fixes result correctness issues of distributed spark operations of the new cumulative aggregate cumsumprod. In detail, we now use cumsumprod(AB)[n] instead of sum(AB) as aggregation function during the forward pass of the generic two-pass algorithm.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/50ddddb9 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/50ddddb9 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/50ddddb9 Branch: refs/heads/master Commit: 50ddddb90b28c6e28e97195dded9696edcdc3b45 Parents: 252e498 Author: Matthias Boehm <[email protected]> Authored: Mon Jul 30 14:34:49 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jul 30 14:36:19 2018 -0700 ---------------------------------------------------------------------- .../spark/CumulativeAggregateSPInstruction.java | 17 +++++++++++------ .../functions/unary/matrix/FullCumsumprodTest.java | 3 +-- 2 files changed, 12 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java index 74390e1..8514acc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java @@ -27,6 +27,7 @@ import scala.Tuple2; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; @@ -36,6 +37,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysml.runtime.matrix.operators.UnaryOperator; public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstruction { @@ -79,10 +81,11 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio { private static final long serialVersionUID = 11324676268945117L; - private AggregateUnaryOperator _op = null; - private long _rlen = -1; - private int _brlen = -1; - private int _bclen = -1; + private final AggregateUnaryOperator _op; + private UnaryOperator _uop = null; + private final long _rlen; + private final int _brlen; + private final int _bclen; public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int brlen, int bclen ) { _op = op; @@ -105,10 +108,12 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio AggregateUnaryOperator aop = (AggregateUnaryOperator)_op; if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { //cumsumprod aop.indexFn.execute(ixIn, ixOut); - MatrixBlock t1 = blkIn.slice(0, blkIn.getNumRows()-1, 0, 0, new MatrixBlock()); + if( _uop == null ) + _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*")); + MatrixBlock t1 = (MatrixBlock) blkIn.unaryOperations(_uop, new MatrixBlock()); MatrixBlock t2 = blkIn.slice(0, blkIn.getNumRows()-1, 1, 1, new MatrixBlock()); blkOut.reset(1, 2); - blkOut.quickSetValue(0, 0, t1.sum()); + blkOut.quickSetValue(0, 0, t1.quickGetValue(t1.getNumRows()-1, 0)); blkOut.quickSetValue(0, 1, t2.prod()); } else { //general case http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java index 7f02055..f13e765 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java @@ -111,8 +111,7 @@ public class FullCumsumprodTest extends AutomatedTestBase String.valueOf(reverse).toUpperCase(), output("C") }; double[][] A = getRandomMatrix(rows, 1, -10, 10, sparsity, 3); - double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.1, 7); - //FIXME double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7); + double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7); writeInputMatrixWithMTD("A", A, false); writeInputMatrixWithMTD("B", B, false);
