Repository: systemml Updated Branches: refs/heads/master b586d1691 -> 50ddddb90
[SYSTEMML-2470] New cumsumprod cumulative aggregate (compiler/runtime) This patch introduces a new cumulative aggregate builtin function for Ci = Ai + Bi * Ci-1, where we pass cbind(A,B) as input. In detail, this includes the compiler and runtime integration for CP and Spark. However, although the respective tests are passing, the distributed Spark operations have still correctness issues in the general case (tested with special input data). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e90af572 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e90af572 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e90af572 Branch: refs/heads/master Commit: e90af572c41edf5f215f83d08bd5bcb4b342f55e Parents: b586d16 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sat Jul 28 00:58:30 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Jul 30 14:36:10 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 2 +- src/main/java/org/apache/sysml/hops/Hop.java | 3 +- .../java/org/apache/sysml/hops/UnaryOp.java | 33 +++-- .../codegen/opt/PlanSelectionFuseCostBased.java | 1 + .../opt/PlanSelectionFuseCostBasedV2.java | 1 + .../java/org/apache/sysml/lops/Aggregate.java | 5 +- .../sysml/lops/CumulativeOffsetBinary.java | 16 ++- .../sysml/lops/CumulativePartialAggregate.java | 26 ++-- src/main/java/org/apache/sysml/lops/Unary.java | 8 +- .../sysml/parser/BuiltinFunctionExpression.java | 5 + .../org/apache/sysml/parser/DMLTranslator.java | 1 + .../org/apache/sysml/parser/Expression.java | 1 + .../sysml/runtime/functionobjects/Builtin.java | 11 +- .../runtime/functionobjects/KahanPlus.java | 2 +- .../instructions/CPInstructionParser.java | 1 + .../runtime/instructions/InstructionUtils.java | 41 +++--- .../instructions/SPInstructionParser.java | 4 +- .../instructions/cp/UnaryCPInstruction.java | 2 +- .../spark/CumulativeAggregateSPInstruction.java | 29 +++-- .../spark/CumulativeOffsetSPInstruction.java | 63 +++++---- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 54 ++++++-- .../sysml/runtime/matrix/data/MatrixBlock.java | 19 ++- .../unary/matrix/FullCumsumprodTest.java | 129 +++++++++++++++++++ .../functions/unary/matrix/Cumsumprod.dml | 37 ++++++ 24 files changed, 371 insertions(+), 123 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 47943d0..af9d936 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -264,7 +264,7 @@ public class AggUnaryOp extends MultiThreadedHop setLops(unary1); } - } + } else //default { boolean needAgg = requiresAggregation(input, _direction); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 41a32c3..6466575 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1046,7 +1046,7 @@ public abstract class Hop implements ParseInfo PRINT, ASSERT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY, SVD, EXISTS, //cumulative sums, products, extreme values - CUMSUM, CUMPROD, CUMMIN, CUMMAX, + CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, //fused ML-specific operators for performance SPROP, //sample proportion: P * (1 - P) SIGMOID, //sigmoid function: 1 / (1 + exp(-X)) @@ -1303,6 +1303,7 @@ public abstract class Hop implements ParseInfo HopsOpOp1LopsU.put(OpOp1.CUMPROD, org.apache.sysml.lops.Unary.OperationTypes.CUMPROD); HopsOpOp1LopsU.put(OpOp1.CUMMIN, org.apache.sysml.lops.Unary.OperationTypes.CUMMIN); HopsOpOp1LopsU.put(OpOp1.CUMMAX, org.apache.sysml.lops.Unary.OperationTypes.CUMMAX); + HopsOpOp1LopsU.put(OpOp1.CUMSUMPROD, org.apache.sysml.lops.Unary.OperationTypes.CUMSUMPROD); HopsOpOp1LopsU.put(OpOp1.INVERSE, org.apache.sysml.lops.Unary.OperationTypes.INVERSE); HopsOpOp1LopsU.put(OpOp1.CHOLESKY, org.apache.sysml.lops.Unary.OperationTypes.CHOLESKY); HopsOpOp1LopsU.put(OpOp1.CAST_AS_SCALAR, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index 19b6ed3..f93d40d 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -456,8 +456,7 @@ public class UnaryOp extends MultiThreadedHop //recursive preaggregation until aggregates fit into CP memory budget while( ((2*OptimizerUtils.estimateSize(TEMP.getOutputParameters().getNumRows(), clen) + OptimizerUtils.estimateSize(1, clen)) - > OptimizerUtils.getLocalMemBudget() - && TEMP.getOutputParameters().getNumRows()>1) || force ) + > OptimizerUtils.getLocalMemBudget() && TEMP.getOutputParameters().getNumRows()>1) || force ) { DATA.add(TEMP); @@ -468,7 +467,7 @@ public class UnaryOp extends MultiThreadedHop preagg.getOutputParameters().setDimensions(rlenAgg, clen, brlen, bclen, -1); setLineNumbers(preagg); - TEMP = preagg; + TEMP = preagg; level++; force = false; //in case of unknowns, generate one level } @@ -497,19 +496,20 @@ public class UnaryOp extends MultiThreadedHop return TEMP; } - private OperationTypes getCumulativeAggType() - { + private OperationTypes getCumulativeAggType() { switch( _op ) { - case CUMSUM: return OperationTypes.KahanSum; - case CUMPROD: return OperationTypes.Product; - case CUMMIN: return OperationTypes.Min; - case CUMMAX: return OperationTypes.Max; - default: return null; + case CUMSUM: return OperationTypes.KahanSum; + case CUMPROD: return OperationTypes.Product; + case CUMSUMPROD: return OperationTypes.SumProduct; + case CUMMIN: return OperationTypes.Min; + case CUMMAX: return OperationTypes.Max; + default: return null; } } private double getCumulativeInitValue() { switch( _op ) { + case CUMSUMPROD: case CUMSUM: return 0; case CUMPROD: return 1; case CUMMIN: return Double.POSITIVE_INFINITY; @@ -580,6 +580,8 @@ public class UnaryOp extends MultiThreadedHop { ret = new long[]{mc.getRows(), mc.getCols(), mc.getNonZeros()}; } + else if( _op==OpOp1.CUMSUMPROD ) + ret = new long[]{mc.getRows(), 1, -1}; else ret = new long[]{mc.getRows(), mc.getCols(), -1}; } @@ -603,7 +605,8 @@ public class UnaryOp extends MultiThreadedHop return (_op == OpOp1.CUMSUM || _op == OpOp1.CUMPROD || _op == OpOp1.CUMMIN - || _op == OpOp1.CUMMAX); + || _op == OpOp1.CUMMAX + || _op == OpOp1.CUMSUMPROD); } public boolean isCastUnaryOperation() { @@ -701,6 +704,10 @@ public class UnaryOp extends MultiThreadedHop setDim1( 1 ); setDim2( 1 ); } + else if ( _op==OpOp1.CUMSUMPROD ) { + setDim1(input.getDim1()); + setDim2(1); + } else //general case { // If output is a Matrix then this operation is of type (B = op(A)) @@ -721,7 +728,7 @@ public class UnaryOp extends MultiThreadedHop @Override public Object clone() throws CloneNotSupportedException { - UnaryOp ret = new UnaryOp(); + UnaryOp ret = new UnaryOp(); //copy generic attributes ret.clone(this, false); @@ -749,7 +756,7 @@ public class UnaryOp extends MultiThreadedHop if( _op == OpOp1.PRINT ) return false; - UnaryOp that2 = (UnaryOp)that; + UnaryOp that2 = (UnaryOp)that; return ( _op == that2._op && getInput().get(0) == that2.getInput().get(0)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java index a4afa6d..ed37084 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java @@ -665,6 +665,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection case CUMMIN: case CUMMAX: case CUMPROD: costs = 1; break; + case CUMSUMPROD: costs = 2; break; default: LOG.warn("Cost model not " + "implemented yet for: "+((UnaryOp)current).getOp()); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 0ef255a..1a18d4d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -1048,6 +1048,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection case CUMMIN: case CUMMAX: case CUMPROD: costs = 1; break; + case CUMSUMPROD: costs = 2; break; default: LOG.warn("Cost model not " + "implemented yet for: "+((UnaryOp)current).getOp()); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/lops/Aggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Aggregate.java b/src/main/java/org/apache/sysml/lops/Aggregate.java index 2631415..57139cc 100644 --- a/src/main/java/org/apache/sysml/lops/Aggregate.java +++ b/src/main/java/org/apache/sysml/lops/Aggregate.java @@ -33,12 +33,11 @@ import org.apache.sysml.parser.Expression.*; public class Aggregate extends Lop { - - /** Aggregate operation types **/ public enum OperationTypes { - Sum, Product, Min, Max, Trace, KahanSum, KahanSumSq, KahanTrace, Mean, Var, MaxIndex, MinIndex + Sum, Product, SumProduct, Min, Max, Trace, + KahanSum, KahanSumSq, KahanTrace, Mean, Var, MaxIndex, MinIndex } OperationTypes operation; http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/lops/CumulativeOffsetBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/CumulativeOffsetBinary.java b/src/main/java/org/apache/sysml/lops/CumulativeOffsetBinary.java index 5d07c44..78443e0 100644 --- a/src/main/java/org/apache/sysml/lops/CumulativeOffsetBinary.java +++ b/src/main/java/org/apache/sysml/lops/CumulativeOffsetBinary.java @@ -88,8 +88,9 @@ public class CumulativeOffsetBinary extends Lop private static void checkSupportedOperations(OperationTypes op) { //sanity check for supported aggregates - if( !(op == OperationTypes.KahanSum || op == OperationTypes.Product || - op == OperationTypes.Min || op == OperationTypes.Max) ) + if( !( op == OperationTypes.KahanSum || op == OperationTypes.Product + || op == OperationTypes.SumProduct + || op == OperationTypes.Min || op == OperationTypes.Max) ) { throw new LopsException("Unsupported aggregate operation type: "+op); } @@ -97,11 +98,12 @@ public class CumulativeOffsetBinary extends Lop private String getOpcode() { switch( _op ) { - case KahanSum: return "bcumoffk+"; - case Product: return "bcumoff*"; - case Min: return "bcumoffmin"; - case Max: return "bcumoffmax"; - default: return null; + case KahanSum: return "bcumoffk+"; + case Product: return "bcumoff*"; + case SumProduct: return "bcumoff+*"; + case Min: return "bcumoffmin"; + case Max: return "bcumoffmax"; + default: return null; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/lops/CumulativePartialAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/CumulativePartialAggregate.java b/src/main/java/org/apache/sysml/lops/CumulativePartialAggregate.java index 015bcd0..f50bf0c 100644 --- a/src/main/java/org/apache/sysml/lops/CumulativePartialAggregate.java +++ b/src/main/java/org/apache/sysml/lops/CumulativePartialAggregate.java @@ -34,13 +34,13 @@ public class CumulativePartialAggregate extends Lop super(Lop.Type.CumulativePartialAggregate, dt, vt); //sanity check for supported aggregates - if( !(op == OperationTypes.KahanSum || op == OperationTypes.Product || - op == OperationTypes.Min || op == OperationTypes.Max) ) + if( !( op == OperationTypes.KahanSum || op == OperationTypes.Product + || op == OperationTypes.SumProduct + || op == OperationTypes.Min || op == OperationTypes.Max) ) { throw new LopsException("Unsupported aggregate operation type: "+op); } _op = op; - init(input, dt, vt, et); } @@ -62,12 +62,8 @@ public class CumulativePartialAggregate extends Lop } else //Spark/CP { - //setup Spark parameters - boolean breaksAlignment = false; - boolean aligner = false; - boolean definesMRJob = false; lps.addCompatibility(JobType.INVALID); - lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + lps.setProperties( inputs, et, ExecLocation.ControlProgram, false, false, false ); } } @@ -76,14 +72,14 @@ public class CumulativePartialAggregate extends Lop return "CumulativePartialAggregate"; } - private String getOpcode() - { + private String getOpcode() { switch( _op ) { - case KahanSum: return "ucumack+"; - case Product: return "ucumac*"; - case Min: return "ucumacmin"; - case Max: return "ucumacmax"; - default: return null; + case KahanSum: return "ucumack+"; + case Product: return "ucumac*"; + case SumProduct: return "ucumac+*"; + case Min: return "ucumacmin"; + case Max: return "ucumacmax"; + default: return null; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/lops/Unary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Unary.java b/src/main/java/org/apache/sysml/lops/Unary.java index a5403b2..c6f3151 100644 --- a/src/main/java/org/apache/sysml/lops/Unary.java +++ b/src/main/java/org/apache/sysml/lops/Unary.java @@ -42,7 +42,7 @@ public class Unary extends Lop LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS, AND, OR, XOR, BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY, - CUMSUM, CUMPROD, CUMMIN, CUMMAX, + CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, SPROP, SIGMOID, SUBTRACT_NZ, LOG_NZ, CAST_AS_MATRIX, CAST_AS_FRAME, NOTSUPPORTED @@ -288,7 +288,10 @@ public class Unary extends Lop case CUMMAX: return "ucummax"; - + + case CUMSUMPROD: + return "ucumk+*"; + case INVERSE: return "inverse"; @@ -330,6 +333,7 @@ public class Unary extends Lop || op==OperationTypes.CUMPROD || op==OperationTypes.CUMMIN || op==OperationTypes.CUMMAX + || op==OperationTypes.CUMSUMPROD || op==OperationTypes.EXP || op==OperationTypes.LOG || op==OperationTypes.SIGMOID; http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index fe67fb6..9f3a1e2 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -627,11 +627,14 @@ public class BuiltinFunctionExpression extends DataIdentifier case CUMSUM: case CUMPROD: + case CUMSUMPROD: case CUMMIN: case CUMMAX: // cumsum(X); checkNumParameters(1); checkMatrixParam(getFirstExpr()); + if( getOpCode() == BuiltinFunctionOp.CUMSUMPROD && id.getDim2() > 2 ) + raiseValidateError("Cumsumprod only supported over two-column matrices", conditional); output.setDataType(DataType.MATRIX); output.setDimensions(id.getDim1(), id.getDim2()); @@ -1910,6 +1913,8 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.CUMPROD; else if (functionName.equals("cumsum")) bifop = Expression.BuiltinFunctionOp.CUMSUM; + else if (functionName.equals("cumsumprod")) + bifop = Expression.BuiltinFunctionOp.CUMSUMPROD; //'castAsScalar' for backwards compatibility else if (functionName.equals("as.scalar") || functionName.equals("castAsScalar")) bifop = Expression.BuiltinFunctionOp.CAST_AS_SCALAR; http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 22d152d..bdfdf8f 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2597,6 +2597,7 @@ public class DMLTranslator case FLOOR: case CUMSUM: case CUMPROD: + case CUMSUMPROD: case CUMMIN: case CUMMAX: currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index b394372..8299fbe 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -86,6 +86,7 @@ public abstract class Expression implements ParseInfo CUMMIN, CUMPROD, CUMSUM, + CUMSUMPROD, DIAG, EIGEN, EVAL, http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java index 7ad0808..a4923ff 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java @@ -50,7 +50,7 @@ public class Builtin extends ValueFunction public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, - STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, EVAL, LIST } + STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST } public BuiltinCode bFunc; private static final boolean FASTMATH = true; @@ -91,6 +91,7 @@ public class Builtin extends ValueFunction String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR); String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM); String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD); + String2BuiltinCode.put( "ucumk+*", BuiltinCode.CUMSUMPROD); String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN); String2BuiltinCode.put( "ucummax", BuiltinCode.CUMMAX); String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE); @@ -103,7 +104,7 @@ public class Builtin extends ValueFunction private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null; private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj; private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null; - private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null; + private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null, cumsprodObj=null; private static Builtin stopObj = null, spropObj = null, sigmoidObj = null; private Builtin(BuiltinCode bf) { @@ -256,7 +257,11 @@ public class Builtin extends ValueFunction case CUMPROD: if ( cumprodObj == null ) cumprodObj = new Builtin(BuiltinCode.CUMPROD); - return cumprodObj; + return cumprodObj; + case CUMSUMPROD: + if ( cumsprodObj == null ) + cumsprodObj = new Builtin(BuiltinCode.CUMSUMPROD); + return cumsprodObj; case CUMMIN: if ( cumminObj == null ) cumminObj = new Builtin(BuiltinCode.CUMMIN); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/functionobjects/KahanPlus.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/KahanPlus.java b/src/main/java/org/apache/sysml/runtime/functionobjects/KahanPlus.java index 28b1b44..b5b59ea 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/KahanPlus.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/KahanPlus.java @@ -98,7 +98,7 @@ public class KahanPlus extends KahanFunction implements Serializable //default path for any other value double correction = in2 + in1._correction; double sum = in1._sum + correction; - in1.set(sum, correction-(sum-in1._sum)); //prevent eager JIT opt + in1.set(sum, correction-(sum-in1._sum)); //prevent eager JIT opt } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 4663671..9acb56e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -173,6 +173,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "floor" , CPType.Unary); String2CPInstructionType.put( "ucumk+", CPType.Unary); String2CPInstructionType.put( "ucum*" , CPType.Unary); + String2CPInstructionType.put( "ucumk+*" , CPType.Unary); String2CPInstructionType.put( "ucummin", CPType.Unary); String2CPInstructionType.put( "ucummax", CPType.Unary); String2CPInstructionType.put( "stop" , CPType.Unary); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java index 588b8e6..17c2632 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java @@ -453,10 +453,8 @@ public class InstructionUtils return agg; } - public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(UnaryOperator uop) - { + public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(UnaryOperator uop) { Builtin f = (Builtin)uop.fn; - if( f.getBuiltinCode()==BuiltinCode.CUMSUM ) return parseBasicAggregateUnaryOperator("uack+") ; else if( f.getBuiltinCode()==BuiltinCode.CUMPROD ) @@ -465,31 +463,24 @@ public class InstructionUtils return parseBasicAggregateUnaryOperator("uacmin") ; else if( f.getBuiltinCode()==BuiltinCode.CUMMAX ) return parseBasicAggregateUnaryOperator("uacmax" ) ; - + else if( f.getBuiltinCode()==BuiltinCode.CUMSUMPROD ) + return parseBasicAggregateUnaryOperator("uack+*" ) ; throw new RuntimeException("Unsupported cumulative aggregate unary operator: "+f.getBuiltinCode()); } - public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode) - { - AggregateUnaryOperator aggun = null; - if( "ucumack+".equals(opcode) ) { - AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW); - aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject()); - } - else if ( "ucumac*".equals(opcode) ) { - AggregateOperator agg = new AggregateOperator(0, Multiply.getMultiplyFnObject(), false, CorrectionLocationType.NONE); - aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject()); - } - else if ( "ucumacmin".equals(opcode) ) { - AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), false, CorrectionLocationType.NONE); - aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject()); - } - else if ( "ucumacmax".equals(opcode) ) { - AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), false, CorrectionLocationType.NONE); - aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject()); - } - - return aggun; + public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode) { + AggregateOperator agg = null; + if( "ucumack+".equals(opcode) ) + agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW); + else if ( "ucumac*".equals(opcode) ) + agg = new AggregateOperator(1, Multiply.getMultiplyFnObject(), false, CorrectionLocationType.NONE); + else if ( "ucumac+*".equals(opcode) ) + agg = new AggregateOperator(0, PlusMultiply.getFnObject(), false, CorrectionLocationType.NONE); + else if ( "ucumacmin".equals(opcode) ) + agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), false, CorrectionLocationType.NONE); + else if ( "ucumacmax".equals(opcode) ) + agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), false, CorrectionLocationType.NONE); + return new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject()); } public static UnaryOperator parseUnaryOperator(String opcode) { http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index efec463..dd32ce0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -297,10 +297,12 @@ public class SPInstructionParser extends InstructionParser //cumsum/cumprod/cummin/cummax String2SPInstructionType.put( "ucumack+" , SPType.CumsumAggregate); String2SPInstructionType.put( "ucumac*" , SPType.CumsumAggregate); + String2SPInstructionType.put( "ucumac+*" , SPType.CumsumAggregate); String2SPInstructionType.put( "ucumacmin" , SPType.CumsumAggregate); String2SPInstructionType.put( "ucumacmax" , SPType.CumsumAggregate); String2SPInstructionType.put( "bcumoffk+" , SPType.CumsumOffset); String2SPInstructionType.put( "bcumoff*" , SPType.CumsumOffset); + String2SPInstructionType.put( "bcumoff+*" , SPType.CumsumOffset); String2SPInstructionType.put( "bcumoffmin", SPType.CumsumOffset); String2SPInstructionType.put( "bcumoffmax", SPType.CumsumOffset); @@ -352,7 +354,7 @@ public class SPInstructionParser extends InstructionParser case TSMM: return TsmmSPInstruction.parseInstruction(str); case TSMM2: - return Tsmm2SPInstruction.parseInstruction(str); + return Tsmm2SPInstruction.parseInstruction(str); case PMM: return PmmSPInstruction.parseInstruction(str); case ZIPMM: http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryCPInstruction.java index 8e023ea..9f5d71e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryCPInstruction.java @@ -62,7 +62,7 @@ public abstract class UnaryCPInstruction extends ComputationCPInstruction { out.split(parts[2]); func = Builtin.getBuiltinFnObject(opcode); - if( Arrays.asList(new String[]{"ucumk+","ucum*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ) + if( Arrays.asList(new String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ) return new UnaryMatrixCPInstruction(new UnaryOperator(func,Integer.parseInt(parts[3])), in, out, opcode, str); else return new UnaryScalarCPInstruction(null, in, out, opcode, str); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java index bf99add..74390e1 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java @@ -27,6 +27,7 @@ import scala.Tuple2; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; @@ -45,13 +46,10 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio public static CumulativeAggregateSPInstruction parseInstruction( String str ) { String[] parts = InstructionUtils.getInstructionPartsWithValueType( str ); InstructionUtils.checkNumFields ( parts, 2 ); - String opcode = parts[0]; CPOperand in1 = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); - AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode); - return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str); } @@ -81,13 +79,12 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio { private static final long serialVersionUID = 11324676268945117L; - private AggregateUnaryOperator _op = null; + private AggregateUnaryOperator _op = null; private long _rlen = -1; private int _brlen = -1; private int _bclen = -1; - public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int brlen, int bclen ) - { + public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int brlen, int bclen ) { _op = op; _rlen = rlen; _brlen = brlen; @@ -97,7 +94,7 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio @Override public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) throws Exception - { + { MatrixIndexes ixIn = arg0._1(); MatrixBlock blkIn = arg0._2(); @@ -105,10 +102,20 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio MatrixBlock blkOut = new MatrixBlock(); //process instruction - OperationsOnMatrixValues.performAggregateUnary( ixIn, blkIn, ixOut, blkOut, - ((AggregateUnaryOperator)_op), _brlen, _bclen); - if( ((AggregateUnaryOperator)_op).aggOp.correctionExists ) - blkOut.dropLastRowsOrColumns(((AggregateUnaryOperator)_op).aggOp.correctionLocation); + AggregateUnaryOperator aop = (AggregateUnaryOperator)_op; + if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { //cumsumprod + aop.indexFn.execute(ixIn, ixOut); + MatrixBlock t1 = blkIn.slice(0, blkIn.getNumRows()-1, 0, 0, new MatrixBlock()); + MatrixBlock t2 = blkIn.slice(0, blkIn.getNumRows()-1, 1, 1, new MatrixBlock()); + blkOut.reset(1, 2); + blkOut.quickSetValue(0, 0, t1.sum()); + blkOut.quickSetValue(0, 1, t2.prod()); + } + else { //general case + OperationsOnMatrixValues.performAggregateUnary( ixIn, blkIn, ixOut, blkOut, aop, _brlen, _bclen); + if( aop.aggOp.correctionExists ) + blkOut.dropLastRowsOrColumns(aop.aggOp.correctionLocation); + } //cumsum expand partial aggregates long rlenOut = (long)Math.ceil((double)_rlen/_brlen); http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java index 6e151c6..36f02b1 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.java @@ -33,6 +33,7 @@ import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -53,13 +54,20 @@ public class CumulativeOffsetSPInstruction extends BinarySPInstruction { if ("bcumoffk+".equals(opcode)) { _bop = new BinaryOperator(Plus.getPlusFnObject()); _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+")); - } else if ("bcumoff*".equals(opcode)) { + } + else if ("bcumoff*".equals(opcode)) { _bop = new BinaryOperator(Multiply.getMultiplyFnObject()); _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*")); - } else if ("bcumoffmin".equals(opcode)) { + } + else if ("bcumoff+*".equals(opcode)) { + _bop = new BinaryOperator(PlusMultiply.getFnObject()); + _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*")); + } + else if ("bcumoffmin".equals(opcode)) { _bop = new BinaryOperator(Builtin.getBuiltinFnObject("min")); _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummin")); - } else if ("bcumoffmax".equals(opcode)) { + } + else if ("bcumoffmax".equals(opcode)) { _bop = new BinaryOperator(Builtin.getBuiltinFnObject("max")); _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummax")); } @@ -81,9 +89,10 @@ public class CumulativeOffsetSPInstruction extends BinarySPInstruction { @Override public void processInstruction(ExecutionContext ec) { SparkExecutionContext sec = (SparkExecutionContext)ec; - MatrixCharacteristics mc = sec.getMatrixCharacteristics(input2.getName()); - long rlen = mc.getRows(); - int brlen = mc.getRowsPerBlock(); + MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName()); + MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName()); + long rlen = mc2.getRows(); + int brlen = mc2.getRowsPerBlock(); //get inputs JavaPairRDD<MatrixIndexes,MatrixBlock> inData = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); @@ -93,12 +102,15 @@ public class CumulativeOffsetSPInstruction extends BinarySPInstruction { inAgg = inAgg.flatMapToPair(new RDDCumSplitFunction(_initValue, rlen, brlen)); //execute cumulative offset (apply cumulative op w/ offsets) - JavaPairRDD<MatrixIndexes,MatrixBlock> out = - inData.join( inAgg ) - .mapValues(new RDDCumOffsetFunction(_uop, _bop)); + JavaPairRDD<MatrixIndexes,MatrixBlock> out = inData + .join( inAgg ).mapValues(new RDDCumOffsetFunction(_uop, _bop)); - updateUnaryOutputMatrixCharacteristics(sec); //put output handle in symbol table + if( _bop.fn instanceof PlusMultiply ) + sec.getMatrixCharacteristics(output.getName()) + .set(mc1.getRows(), 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock()); + else //general case + updateUnaryOutputMatrixCharacteristics(sec); sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); sec.addLineageRDD(output.getName(), input2.getName()); @@ -149,7 +161,7 @@ public class CumulativeOffsetSPInstruction extends BinarySPInstruction { { MatrixIndexes tmpix = new MatrixIndexes(rixOffset+i+2, ixIn.getColumnIndex()); MatrixBlock tmpblk = new MatrixBlock(1, blkIn.getNumColumns(), blkIn.isInSparseFormat()); - blkIn.slice(i, i, 0, blkIn.getNumColumns()-1, tmpblk); + blkIn.slice(i, i, 0, blkIn.getNumColumns()-1, tmpblk); ret.add(new Tuple2<>(tmpix, tmpblk)); } @@ -164,27 +176,34 @@ public class CumulativeOffsetSPInstruction extends BinarySPInstruction { private UnaryOperator _uop = null; private BinaryOperator _bop = null; - public RDDCumOffsetFunction(UnaryOperator uop, BinaryOperator bop) - { + public RDDCumOffsetFunction(UnaryOperator uop, BinaryOperator bop) { _uop = uop; _bop = bop; } @Override - public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) - throws Exception - { + public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) throws Exception { //prepare inputs and outputs MatrixBlock dblkIn = arg0._1(); //original data MatrixBlock oblkIn = arg0._2(); //offset row vector - MatrixBlock blkOut = new MatrixBlock(dblkIn.getNumRows(), dblkIn.getNumColumns(), dblkIn.isInSparseFormat()); + MatrixBlock data2 = new MatrixBlock(dblkIn); //cp data + boolean cumsumprod = _bop.fn instanceof PlusMultiply; //blockwise offset aggregation and prefix sum computation - MatrixBlock data2 = new MatrixBlock(dblkIn); //cp data - MatrixBlock fdata2 = data2.slice(0, 0); - fdata2.binaryOperationsInPlace(_bop, oblkIn); //sum offset to first row - data2.copy(0, 0, 0, data2.getNumColumns()-1, fdata2, true); //0-based - data2.unaryOperations(_uop, blkOut); //compute columnwise prefix sums/prod/min/max + if( cumsumprod ) { + data2.quickSetValue(0, 0, data2.quickGetValue(0, 0) + + data2.quickGetValue(0, 1) * oblkIn.quickGetValue(0, 0)); + } + else { + MatrixBlock fdata2 = data2.slice(0, 0); + fdata2.binaryOperationsInPlace(_bop, oblkIn); //sum offset to first row + data2.copy(0, 0, 0, data2.getNumColumns()-1, fdata2, true); //0-based + } + + //compute columnwise prefix sums/prod/min/max + MatrixBlock blkOut = new MatrixBlock(dblkIn.getNumRows(), + cumsumprod ? 1 : dblkIn.getNumColumns(), dblkIn.isInSparseFormat()); + data2.unaryOperations(_uop, blkOut); return blkOut; } http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 174f2a5..5ae7eda 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -99,6 +99,7 @@ public class LibMatrixAgg CUM_MIN, CUM_MAX, CUM_PROD, + CUM_SUM_PROD, MIN, MAX, MEAN, @@ -625,23 +626,19 @@ public class LibMatrixAgg return AggType.INVALID; } - private static AggType getAggType( UnaryOperator op ) - { + private static AggType getAggType( UnaryOperator op ) { ValueFunction vfn = op.fn; - - //cumsum/cumprod/cummin/cummax if( vfn instanceof Builtin ) { BuiltinCode bfunc = ((Builtin) vfn).bFunc; - switch( bfunc ) - { - case CUMSUM: return AggType.CUM_KAHAN_SUM; - case CUMPROD: return AggType.CUM_PROD; - case CUMMIN: return AggType.CUM_MIN; - case CUMMAX: return AggType.CUM_MAX; - default: return AggType.INVALID; + switch( bfunc ) { + case CUMSUM: return AggType.CUM_KAHAN_SUM; + case CUMPROD: return AggType.CUM_PROD; + case CUMMIN: return AggType.CUM_MIN; + case CUMMAX: return AggType.CUM_MAX; + case CUMSUMPROD: return AggType.CUM_SUM_PROD; + default: return AggType.INVALID; } } - return AggType.INVALID; } @@ -1483,6 +1480,12 @@ public class LibMatrixAgg d_ucumkp(da, agg, dc, n, kbuff, kplus, rl, ru); break; } + case CUM_SUM_PROD: { //CUMSUMPROD + if( n != 2 ) + throw new DMLRuntimeException("Cumsumprod expects two-column input (n="+n+")."); + d_ucumkpp(da, agg, dc, rl, ru); + break; + } case CUM_PROD: { //CUMPROD d_ucumm(a, agg, c, n, rl, ru); break; @@ -1762,6 +1765,29 @@ public class LibMatrixAgg } /** + * CUMSUMPROD, opcode: ucumk+*, dense input. + * + * @param a ? + * @param agg ? + * @param c ? + * @param n ? + * @param kbuff ? + * @param kplus ? + * @param rl row lower index + * @param ru row upper index + */ + private static void d_ucumkpp( DenseBlock a, double[] agg, DenseBlock c, int rl, int ru ) { + //init current row sum/correction arrays w/ neutral 0 + double sum = (agg != null) ? agg[0] : 0; + //scan once and compute prefix sums + double[] avals = a.valuesAt(0); + double[] cvals = c.valuesAt(0); + for( int i=rl, ix=rl*2; i<ru; i++, ix+=2 ) { + sum = cvals[i] = avals[ix] + avals[ix+1] * sum; + } + } + + /** * CUMPROD, opcode: ucum*, dense input. * * @param a ? @@ -2918,7 +2944,7 @@ public class LibMatrixAgg corr[pos1+i] = kbuff._correction; } } - + private static void sumAgg(double[] a, DenseBlock c, int[] aix, int ai, final int len, final int n, KahanObject kbuff, KahanFunction kplus) { //note: output might span multiple physical blocks double[] sum = c.values(0); @@ -2933,7 +2959,7 @@ public class LibMatrixAgg corr[pos1+ix] = kbuff._correction; } } - + private static double product( double[] a, int ai, final int len ) { double val = 1; if( NAN_AWARENESS ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 5c50326..8f0821c 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -794,6 +794,18 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } /** + * Wrapper method for reduceall-product of a matrix. + * + * @return ? + */ + public double prod() { + MatrixBlock out = new MatrixBlock(1, 1, false); + LibMatrixAgg.aggregateUnaryMatrix(this, out, + InstructionUtils.parseBasicAggregateUnaryOperator("ua*", 1)); + return out.quickGetValue(0, 0); + } + + /** * Wrapper method for reduceall-mean of a matrix. * * @return ? @@ -2612,14 +2624,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab boolean sp = this.sparse && op.sparseSafe; //allocate output + int n = Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUMPROD) ? 1 : clen; if( ret == null ) - ret = new MatrixBlock(rlen, clen, sp, this.nonZeros); + ret = new MatrixBlock(rlen, n, sp, sp ? nonZeros : rlen*n); else - ret.reset(rlen, clen, sp); + ret.reset(rlen, n, sp); //core execute if( LibMatrixAgg.isSupportedUnaryOperator(op) ) { - //e.g., cumsum/cumprod/cummin/cumax + //e.g., cumsum/cumprod/cummin/cumax/cumsumprod if( op.getNumThreads() > 1 ) LibMatrixAgg.cumaggregateUnaryMatrix(this, ret, op, op.getNumThreads()); else http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java new file mode 100644 index 0000000..7f02055 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.unary.matrix; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; + +public class FullCumsumprodTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "Cumsumprod"; + private final static String TEST_DIR = "functions/unary/matrix/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FullCumsumprodTest.class.getSimpleName() + "/"; + + private final static int rows = 1201; + private final static double spDense = 1.0; + private final static double spSparse = 0.3; + + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); + } + + @Test + public void testCumsumprodForwardDenseCP() { + runCumsumprodTest(false, false, ExecType.CP); + } + + @Test + public void testCumsumprodForwardSparseCP() { + runCumsumprodTest(false, true, ExecType.CP); + } + + @Test + public void testCumsumprodBackwardDenseCP() { + runCumsumprodTest(true, false, ExecType.CP); + } + + @Test + public void testCumsumprodBackwardSparseCP() { + runCumsumprodTest(true, true, ExecType.CP); + } + + @Test + public void testCumsumprodForwardDenseSP() { + runCumsumprodTest(false, false, ExecType.SPARK); + } + + @Test + public void testCumsumprodForwardSparseSP() { + runCumsumprodTest(false, true, ExecType.SPARK); + } + + @Test + public void testCumsumprodBackwardDenseSP() { + runCumsumprodTest(true, false, ExecType.SPARK); + } + + @Test + public void testCumsumprodBackwardSparseSP() { + runCumsumprodTest(true, true, ExecType.SPARK); + } + + private void runCumsumprodTest(boolean reverse, boolean sparse, ExecType instType) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( instType ){ + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + try + { + double sparsity = sparse ? spSparse : spDense; + + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + // This is for running the junit test the new way, i.e., construct the arguments directly + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-args", input("A"), input("B"), + String.valueOf(reverse).toUpperCase(), output("C") }; + + double[][] A = getRandomMatrix(rows, 1, -10, 10, sparsity, 3); + double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.1, 7); + //FIXME double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7); + writeInputMatrixWithMTD("A", A, false); + writeInputMatrixWithMTD("B", B, false); + + runTest(true, false, null, -1); + + Assert.assertEquals(new Double(rows), + readDMLMatrixFromHDFS("C").get(new CellIndex(1,1))); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/e90af572/src/test/scripts/functions/unary/matrix/Cumsumprod.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/unary/matrix/Cumsumprod.dml b/src/test/scripts/functions/unary/matrix/Cumsumprod.dml new file mode 100644 index 0000000..3089451 --- /dev/null +++ b/src/test/scripts/functions/unary/matrix/Cumsumprod.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +cumSumProd = externalFunction (Matrix[double] X, Matrix[double] C, double start, boolean isReverse) + return (Matrix[double] Y) implemented in (classname = "org.apache.sysml.udf.lib.CumSumProd", exectype = "mem"); + + +A = read($1); +B = read($2); + +# old external function +C1 = cumSumProd(A, B, 0, $3); + +# new builtin function +AB = cbind(A,B); +C2 = ifelse($3, rev(cumsumprod(rev(AB))), cumsumprod(AB)); + +C = as.matrix(sum(abs(C1-C2)<=1e-8)); +write(C, $4);