Repository: systemml Updated Branches: refs/heads/master 847e5bcab -> c33e066ac
[SYSTEMML-2229] Add missing support for rowProds/colProds aggregates So far we only support full aggregates for products but not the row/column-wise aggregates as all other unary aggregates. This patch adds the missing compiler and runtime support for all backends. Similar as before these operations are not NaN aware by using early out as soon the temporary output contains a zero but the runtime is configurable in that regard. Furthermore, this also includes an improvement of the dense block abstraction to make the API more fluent for initialization. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c33e066a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c33e066a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c33e066a Branch: refs/heads/master Commit: c33e066ac1537c32cafdcb5275b7d36052b8e311 Parents: 847e5bc Author: Matthias Boehm <[email protected]> Authored: Mon Apr 2 23:44:01 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Apr 2 23:44:01 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/lops/PartialAggregate.java | 8 +- .../sysml/parser/BuiltinFunctionExpression.java | 6 + .../org/apache/sysml/parser/DMLTranslator.java | 12 +- .../org/apache/sysml/parser/Expression.java | 2 + .../instructions/CPInstructionParser.java | 2 + .../runtime/instructions/InstructionUtils.java | 18 +- .../instructions/MRInstructionParser.java | 2 + .../instructions/SPInstructionParser.java | 2 + .../sysml/runtime/matrix/data/DenseBlock.java | 26 ++- .../runtime/matrix/data/DenseBlockDRB.java | 18 +- .../runtime/matrix/data/DenseBlockLDRB.java | 18 +- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 82 ++++++++- .../runtime/matrix/data/LibMatrixMult.java | 18 ++ .../aggregate/RowColProdsAggregateTest.java | 165 +++++++++++++++++++ src/test/scripts/functions/aggregate/ColProds.R | 34 ++++ .../scripts/functions/aggregate/ColProds.dml | 24 +++ src/test/scripts/functions/aggregate/RowProds.R | 34 ++++ .../scripts/functions/aggregate/RowProds.dml | 24 +++ .../functions/aggregate/ZPackageSuite.java | 1 + 19 files changed, 465 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/lops/PartialAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java b/src/main/java/org/apache/sysml/lops/PartialAggregate.java index 5afa805..8358a1d 100644 --- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java @@ -370,9 +370,11 @@ public class PartialAggregate extends Lop } case Product: { - if( dir == DirectionTypes.RowCol ) - return "ua*"; - break; + switch( dir ) { + case RowCol: return "ua*"; + case Row: return "uar*"; + case Col: return "uac*"; + } } case Max: { http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index d7d3a4c..e1ac9da 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -389,6 +389,7 @@ public class BuiltinFunctionExpression extends DataIdentifier case COLMAX: case COLMIN: case COLMEAN: + case COLPROD: case COLSD: case COLVAR: // colSums(X); @@ -405,6 +406,7 @@ public class BuiltinFunctionExpression extends DataIdentifier case ROWMIN: case ROWINDEXMIN: case ROWMEAN: + case ROWPROD: case ROWSD: case ROWVAR: //rowSums(X); @@ -1698,6 +1700,10 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.ROWVAR; else if (functionName.equals("colVars")) bifop = Expression.BuiltinFunctionOp.COLVAR; + else if (functionName.equals("rowProds")) + bifop = Expression.BuiltinFunctionOp.ROWPROD; + else if (functionName.equals("colProds")) + bifop = Expression.BuiltinFunctionOp.COLPROD; else if (functionName.equals("cummax")) bifop = Expression.BuiltinFunctionOp.CUMMAX; else if (functionName.equals("cummin")) http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 3250883..9b5f02e 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2351,7 +2351,12 @@ public class DMLTranslator currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Col, expr); break; - + + case COLPROD: + currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, + Direction.Col, expr); + break; + case COLSD: // colStdDevs = sqrt(colVariances) currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), @@ -2395,6 +2400,11 @@ public class DMLTranslator Direction.Row, expr); break; + case ROWPROD: + currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, + Direction.Row, expr); + break; + case ROWSD: // rowStdDevs = sqrt(rowVariances) currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index e20a908..0381f2d 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -75,6 +75,7 @@ public abstract class Expression implements ParseInfo COLMAX, COLMEAN, COLMIN, + COLPROD, COLSD, COLSUM, COLVAR, @@ -121,6 +122,7 @@ public abstract class Expression implements ParseInfo ROWMAX, ROWMEAN, ROWMIN, + ROWPROD, ROWSD, ROWSUM, ROWVAR, http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 00ad286..9dbb24e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -98,6 +98,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "uar+" , CPType.AggregateUnary); String2CPInstructionType.put( "uac+" , CPType.AggregateUnary); String2CPInstructionType.put( "ua*" , CPType.AggregateUnary); + String2CPInstructionType.put( "uar*" , CPType.AggregateUnary); + String2CPInstructionType.put( "uac*" , CPType.AggregateUnary); String2CPInstructionType.put( "uatrace" , CPType.AggregateUnary); String2CPInstructionType.put( "uaktrace", CPType.AggregateUnary); String2CPInstructionType.put( "nrow" ,CPType.AggregateUnary); http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java index 294827d..3b531dd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java @@ -260,14 +260,12 @@ public class InstructionUtils if ( opcode.equalsIgnoreCase("uak+") ) { AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN); aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); - } - else if ( opcode.equalsIgnoreCase("uark+") ) { - // RowSums + } + else if ( opcode.equalsIgnoreCase("uark+") ) { // RowSums AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN); aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); } - else if ( opcode.equalsIgnoreCase("uack+") ) { - // ColSums + else if ( opcode.equalsIgnoreCase("uack+") ) { // ColSums AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW); aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); } @@ -339,6 +337,14 @@ public class InstructionUtils AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject()); aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); } + else if ( opcode.equalsIgnoreCase("uar*") ) { + AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject()); + aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + } + else if ( opcode.equalsIgnoreCase("uac*") ) { + AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject()); + aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); + } else if ( opcode.equalsIgnoreCase("uamax") ) { AggregateOperator agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max")); aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); @@ -808,7 +814,7 @@ public class InstructionUtils return "avar"; else if ( opcode.equalsIgnoreCase("ua+") || opcode.equalsIgnoreCase("uar+") || opcode.equalsIgnoreCase("uac+") ) return "a+"; - else if ( opcode.equalsIgnoreCase("ua*") ) + else if ( opcode.equalsIgnoreCase("ua*") || opcode.equalsIgnoreCase("uar*") || opcode.equalsIgnoreCase("uac*") ) return "a*"; else if ( opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") ) return "aktrace"; http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java index e587fc6..eb16464 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java @@ -122,6 +122,8 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "uarvar", MRType.AggregateUnary); String2MRInstructionType.put( "uacvar", MRType.AggregateUnary); String2MRInstructionType.put( "ua*" , MRType.AggregateUnary); + String2MRInstructionType.put( "uar*" , MRType.AggregateUnary); + String2MRInstructionType.put( "uac*" , MRType.AggregateUnary); String2MRInstructionType.put( "uamax" , MRType.AggregateUnary); String2MRInstructionType.put( "uamin" , MRType.AggregateUnary); String2MRInstructionType.put( "uatrace" , MRType.AggregateUnary); http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 799a77a..5a201cb 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -115,6 +115,8 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "uar+" , SPType.AggregateUnary); String2SPInstructionType.put( "uac+" , SPType.AggregateUnary); String2SPInstructionType.put( "ua*" , SPType.AggregateUnary); + String2SPInstructionType.put( "uar*" , SPType.AggregateUnary); + String2SPInstructionType.put( "uac*" , SPType.AggregateUnary); String2SPInstructionType.put( "uatrace" , SPType.AggregateUnary); String2SPInstructionType.put( "uaktrace", SPType.AggregateUnary); http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java index 50beb3c..2489f16 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java @@ -225,8 +225,9 @@ public abstract class DenseBlock implements Serializable * Set the given value for the entire dense block (fill). * * @param v value + * @return self */ - public abstract void set(double v); + public abstract DenseBlock set(double v); /** * Set the given value for an entire index range of the @@ -237,8 +238,9 @@ public abstract class DenseBlock implements Serializable * @param cl column lower index * @param cu column upper index (exclusive) * @param v value + * @return self */ - public abstract void set(int rl, int ru, int cl, int cu, double v); + public abstract DenseBlock set(int rl, int ru, int cl, int cu, double v); /** @@ -247,23 +249,26 @@ public abstract class DenseBlock implements Serializable * @param r row index * @param c column index * @param v value + * @return self */ - public abstract void set(int r, int c, double v); + public abstract DenseBlock set(int r, int c, double v); /** * Copy the given vector into the given row. * * @param r row index * @param v value vector + * @return self */ - public abstract void set(int r, double[] v); + public abstract DenseBlock set(int r, double[] v); /** * Copy the given dense block. * * @param db dense block + * @return self */ - public abstract void set(DenseBlock db); + public abstract DenseBlock set(DenseBlock db); /** * Copy the given dense block into the specified @@ -274,18 +279,21 @@ public abstract class DenseBlock implements Serializable * @param cl column lower index * @param cu column upper index (exclusive) * @param db dense block + * @return self */ - public abstract void set(int rl, int ru, int cl, int cu, DenseBlock db); + public abstract DenseBlock set(int rl, int ru, int cl, int cu, DenseBlock db); /** * Copy the given kahan object sum and correction. * * @param kbuff kahan object + * @return self */ - public void set(KahanObject kbuff) { + public DenseBlock set(KahanObject kbuff) { set(0, 0, kbuff._sum); set(0, 1, kbuff._correction); + return this; } /** @@ -294,10 +302,12 @@ public abstract class DenseBlock implements Serializable * * @param r row index * @param kbuff kahan object + * @return self */ - public void set(int r, KahanObject kbuff) { + public DenseBlock set(int r, KahanObject kbuff) { set(r, 0, kbuff._sum); set(r, 1, kbuff._correction); + return this; } /** http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java index 7f2ddb0..979f233 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java @@ -171,31 +171,35 @@ public class DenseBlockDRB extends DenseBlock } @Override - public void set(double v) { + public DenseBlock set(double v) { Arrays.fill(data, 0, rlen*clen, v); + return this; } @Override - public void set(int rl, int ru, int cl, int cu, double v) { + public DenseBlock set(int rl, int ru, int cl, int cu, double v) { if( cl==0 && cu == clen ) Arrays.fill(data, rl*clen, ru*clen, v); else for(int i=rl, ix=rl*clen; i<ru; i++, ix+=clen) Arrays.fill(data, ix+cl, ix+cu, v); + return this; } @Override - public void set(int r, int c, double v) { + public DenseBlock set(int r, int c, double v) { data[pos(r, c)] = v; + return this; } @Override - public void set(DenseBlock db) { + public DenseBlock set(DenseBlock db) { System.arraycopy(db.valuesAt(0), 0, data, 0, rlen*clen); + return this; } @Override - public void set(int rl, int ru, int cl, int cu, DenseBlock db) { + public DenseBlock set(int rl, int ru, int cl, int cu, DenseBlock db) { double[] a = db.valuesAt(0); if( cl == 0 && cu == clen) System.arraycopy(a, 0, data, rl*clen+cl, (int)db.size()); @@ -204,11 +208,13 @@ public class DenseBlockDRB extends DenseBlock for(int i=rl, ix1=0, ix2=rl*clen+cl; i<ru; i++, ix1+=len, ix2+=clen) System.arraycopy(a, ix1, data, ix2, len); } + return this; } @Override - public void set(int r, double[] v) { + public DenseBlock set(int r, double[] v) { System.arraycopy(v, 0, data, pos(r), clen); + return this; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java index e041f39..4662c69 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java @@ -200,13 +200,14 @@ public class DenseBlockLDRB extends DenseBlock } @Override - public void set(double v) { + public DenseBlock set(double v) { for(int i=0; i<numBlocks(); i++) Arrays.fill(data[i], v); + return this; } @Override - public void set(int rl, int ru, int cl, int cu, double v) { + public DenseBlock set(int rl, int ru, int cl, int cu, double v) { boolean rowBlock = (cl == 0 && cu == clen); final int bil = index(rl); final int biu = index(ru-1); @@ -219,30 +220,35 @@ public class DenseBlockLDRB extends DenseBlock for(int i=lpos; i<lpos+len; i+=clen) Arrays.fill(data[bi], i+cl, i+cu, v); } + return this; } @Override - public void set(int r, int c, double v) { + public DenseBlock set(int r, int c, double v) { data[index(r)][pos(r, c)] = v; + return this; } @Override - public void set(int r, double[] v) { + public DenseBlock set(int r, double[] v) { System.arraycopy(v, 0, data[index(r)], pos(r), clen); + return this; } @Override - public void set(DenseBlock db) { + public DenseBlock set(DenseBlock db) { for(int bi=0; bi<numBlocks(); bi++) System.arraycopy(db.valuesAt(bi), 0, data[bi], 0, size(bi)); + return this; } @Override - public void set(int rl, int ru, int cl, int cu, DenseBlock db) { + public DenseBlock set(int rl, int ru, int cl, int cu, DenseBlock db) { for(int i=rl; i<ru; i++) { System.arraycopy(db.values(i-rl), db.pos(i-rl), values(i), pos(i, cl), cu-cl); } + return this; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 2d81255..72b90fa 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -589,7 +589,8 @@ public class LibMatrixAgg } //prod - if( vfn instanceof Multiply && ifn instanceof ReduceAll ) + if( vfn instanceof Multiply + && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) { return AggType.PROD; } @@ -1350,6 +1351,10 @@ public class LibMatrixAgg case PROD: { //PROD if( ixFn instanceof ReduceAll ) // PROD d_uam(a, c, n, rl, ru ); + else if( ixFn instanceof ReduceCol ) //ROWPROD + d_uarm(a, c, n, rl, ru); + else if( ixFn instanceof ReduceRow ) //COLPROD + d_uacm(a, c, n, rl, ru); break; } @@ -1452,6 +1457,10 @@ public class LibMatrixAgg case PROD: { //PROD if( ixFn instanceof ReduceAll ) // PROD s_uam(a, c, n, rl, ru ); + else if( ixFn instanceof ReduceCol ) // ROWPROD + s_uarm(a, c, n, rl, ru ); + else if( ixFn instanceof ReduceRow ) // COLPROD + s_uacm(a, c, n, rl, ru ); break; } @@ -2084,7 +2093,36 @@ public class LibMatrixAgg } c.set(0, 0, tmp); } + + /** + * ROWPROD, opcode: uar*, dense input. + * + * @param a ? + * @param c ? + * @param n ? + * @param rl row lower index + * @param ru row upper index + */ + private static void d_uarm( DenseBlock a, DenseBlock c, int n, int rl, int ru ) { + double[] lc = c.valuesAt(0); + for( int i=rl; i<ru; i++ ) + lc[i] = product(a.values(i), a.pos(i), n); + } + /** + * COLPROD, opcode: uac*, dense input. + * + * @param a ? + * @param c ? + * @param n ? + * @param rl row lower index + * @param ru row upper index + */ + private static void d_uacm( DenseBlock a, DenseBlock c, int n, int rl, int ru ) { + double[] lc = c.set(1).valuesAt(0); //guaranteed single row + for( int i=rl; i<ru; i++ ) + LibMatrixMult.vectMultiplyWrite(a.values(i), lc, lc, a.pos(i), 0, 0, n); + } /** * SUM, opcode: uak+, sparse input. @@ -2808,6 +2846,48 @@ public class LibMatrixAgg c.set(0, 0, ret); } + /** + * ROWPROD, opcode: uar*, sparse input. + * + * @param a ? + * @param c ? + * @param n ? + * @param rl row lower index + * @param ru row upper index + */ + private static void s_uarm( SparseBlock a, DenseBlock c, int n, int rl, int ru ) { + double[] lc = c.valuesAt(0); + for( int i=rl; i<ru; i++ ) { + if( !a.isEmpty(i) ) { + int alen = a.size(i); + double tmp = product(a.values(i), 0, alen); + lc[i] = tmp * ((alen<n) ? 0 : 1); + } + } + } + + /** + * COLPROD, opcode: uac*, sparse input. + * + * @param a ? + * @param c ? + * @param n ? + * @param rl row lower index + * @param ru row upper index + */ + private static void s_uacm( SparseBlock a, DenseBlock c, int n, int rl, int ru ) { + double[] lc = c.set(1).valuesAt(0); + int[] cnt = new int[ n ]; + for( int i=rl; i<ru; i++ ) { + if( a.isEmpty(i) ) continue; + countAgg(a.values(i), cnt, a.indexes(i), a.pos(i), a.size(i)); + LibMatrixMult.vectMultiplyWrite(lc, a.values(i), lc, 0, a.pos(i), 0, a.size(i)); + } + for( int j=0; j<n; j++ ) + if( cnt[j] < ru-rl ) + lc[j] *= 0; + } + //////////////////////////////////////////// // performance-relevant utility functions // http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index ef273f6..adc955e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -3252,6 +3252,24 @@ public class LibMatrixMult c[ ci+7 ] = a[ ai+7 ] * b[ bi+7 ]; } } + + public static void vectMultiplyWrite( final double[] a, double[] b, double[] c, int[] bix, final int ai, final int bi, final int ci, final int len ) { + final int bn = len%8; + //rest, not aligned to 8-blocks + for( int j = bi; j < bi+bn; j++ ) + c[ ci+bix[j] ] = a[ ai+bix[j] ] * b[ j ]; + //unrolled 8-block (for better instruction-level parallelism) + for( int j = bi+bn; j < bi+len; j+=8 ) { + c[ ci+bix[j+0] ] = a[ ai+bix[j+0] ] * b[ j+0 ]; + c[ ci+bix[j+1] ] = a[ ai+bix[j+1] ] * b[ j+1 ]; + c[ ci+bix[j+2] ] = a[ ai+bix[j+2] ] * b[ j+2 ]; + c[ ci+bix[j+3] ] = a[ ai+bix[j+3] ] * b[ j+3 ]; + c[ ci+bix[j+4] ] = a[ ai+bix[j+4] ] * b[ j+4 ]; + c[ ci+bix[j+5] ] = a[ ai+bix[j+5] ] * b[ j+5 ]; + c[ ci+bix[j+6] ] = a[ ai+bix[j+6] ] * b[ j+6 ]; + c[ ci+bix[j+7] ] = a[ ai+bix[j+7] ] * b[ j+7 ]; + } + } private static void vectMultiply( double[] a, double[] c, int ai, int ci, final int len ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java new file mode 100644 index 0000000..97ae324 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.aggregate; + +import java.util.HashMap; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class RowColProdsAggregateTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "RowProds"; + private final static String TEST_NAME2 = "ColProds"; + + private final static String TEST_DIR = "functions/aggregate/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RowColProdsAggregateTest.class.getSimpleName() + "/"; + private final static double eps = 1e-10; + + private final static int dim1 = 1079; + private final static int dim2 = 15; + private final static double sparsity1 = 0.1; + private final static double sparsity2 = 1.0; //otherwise 0 output + + @Override + public void setUp() + { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"})); + if (TEST_CACHE_ENABLED) { + setOutAndExpectedDeletionDisabled(true); + } + } + + @BeforeClass + public static void init() { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + + @AfterClass + public static void cleanUp() { + if (TEST_CACHE_ENABLED) { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + } + + @Test + public void testRowProdsDenseMatrixCP() { + runProdsAggregateTest(TEST_NAME1, false, true, ExecType.CP); + } + + @Test + public void testRowProdsSparseMatrixCP() { + runProdsAggregateTest(TEST_NAME1, true, true, ExecType.CP); + } + + @Test + public void testRowProdsDenseMatrixSP() { + runProdsAggregateTest(TEST_NAME1, false, true, ExecType.SPARK); + } + + @Test + public void testRowProdsSparseMatrixSP() { + runProdsAggregateTest(TEST_NAME1, true, true, ExecType.SPARK); + } + + @Test + public void testColProdsDenseMatrixCP() { + runProdsAggregateTest(TEST_NAME2, false, true, ExecType.CP); + } + + @Test + public void testColProdsSparseMatrixCP() { + runProdsAggregateTest(TEST_NAME2, true, true, ExecType.CP); + } + + @Test + public void testColProdsDenseMatrixSP() { + runProdsAggregateTest(TEST_NAME2, false, true, ExecType.SPARK); + } + + @Test + public void testColProdsSparseMatrixSP() { + runProdsAggregateTest(TEST_NAME2, true, true, ExecType.SPARK); + } + + private void runProdsAggregateTest(String TEST_NAME, boolean sparse, boolean rewrites, ExecType instType) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( instType ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean oldRewritesFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + try + { + double sparsity = (sparse) ? sparsity1 : sparsity2; + TestConfiguration config = getTestConfiguration(TEST_NAME); + + String TEST_CACHE_DIR = ""; + if (TEST_CACHE_ENABLED) { + TEST_CACHE_DIR = TEST_NAME + "_" + sparsity + "/"; + } + + loadTestConfiguration(config, TEST_CACHE_DIR); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-explain", "-args", input("A"), output("B") }; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); + + int rows = TEST_NAME.equals(TEST_NAME1) ? dim1 : dim2; + int cols = TEST_NAME.equals(TEST_NAME1) ? dim2 : dim1; + double[][] A = getRandomMatrix(rows, cols, 0.9, 1, sparsity, 1234); + writeInputMatrixWithMTD("A", A, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("B"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("B"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewritesFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/ColProds.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/aggregate/ColProds.R b/src/test/scripts/functions/aggregate/ColProds.R new file mode 100644 index 0000000..2729498 --- /dev/null +++ b/src/test/scripts/functions/aggregate/ColProds.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +if(!("matrixStats" %in% rownames(installed.packages()))){ + install.packages("matrixStats") +} + +library("Matrix") +library("matrixStats") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B <- t(colProds(A)); + +writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/ColProds.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/aggregate/ColProds.dml b/src/test/scripts/functions/aggregate/ColProds.dml new file mode 100644 index 0000000..f2d3b3c --- /dev/null +++ b/src/test/scripts/functions/aggregate/ColProds.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = colProds(A); +write(B, $2); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/RowProds.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/aggregate/RowProds.R b/src/test/scripts/functions/aggregate/RowProds.R new file mode 100644 index 0000000..0cc6689 --- /dev/null +++ b/src/test/scripts/functions/aggregate/RowProds.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +if(!("matrixStats" %in% rownames(installed.packages()))){ + install.packages("matrixStats") +} + +library("Matrix") +library("matrixStats") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B <- rowProds(A); + +writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/RowProds.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/aggregate/RowProds.dml b/src/test/scripts/functions/aggregate/RowProds.dml new file mode 100644 index 0000000..bb9ceb1 --- /dev/null +++ b/src/test/scripts/functions/aggregate/RowProds.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = rowProds(A); +write(B, $2); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java index c99fe0a..4aedf4d 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java @@ -43,6 +43,7 @@ import org.junit.runners.Suite; NRowTest.class, ProdTest.class, PushdownSumBinaryTest.class, + RowColProdsAggregateTest.class, RowStdDevsTest.class, RowSumsSqTest.class, RowSumTest.class,
