Repository: systemml
Updated Branches:
  refs/heads/master 847e5bcab -> c33e066ac


[SYSTEMML-2229] Add missing support for rowProds/colProds aggregates

So far we only support full aggregates for products but not the
row/column-wise aggregates as all other unary aggregates. This patch
adds the missing compiler and runtime support for all backends. Similar
as before these operations are not NaN aware by using early out as soon
the temporary output contains a zero but the runtime is configurable in
that regard.

Furthermore, this also includes an improvement of the dense block
abstraction to make the API more fluent for initialization.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c33e066a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c33e066a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c33e066a

Branch: refs/heads/master
Commit: c33e066ac1537c32cafdcb5275b7d36052b8e311
Parents: 847e5bc
Author: Matthias Boehm <[email protected]>
Authored: Mon Apr 2 23:44:01 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Apr 2 23:44:01 2018 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/lops/PartialAggregate.java |   8 +-
 .../sysml/parser/BuiltinFunctionExpression.java |   6 +
 .../org/apache/sysml/parser/DMLTranslator.java  |  12 +-
 .../org/apache/sysml/parser/Expression.java     |   2 +
 .../instructions/CPInstructionParser.java       |   2 +
 .../runtime/instructions/InstructionUtils.java  |  18 +-
 .../instructions/MRInstructionParser.java       |   2 +
 .../instructions/SPInstructionParser.java       |   2 +
 .../sysml/runtime/matrix/data/DenseBlock.java   |  26 ++-
 .../runtime/matrix/data/DenseBlockDRB.java      |  18 +-
 .../runtime/matrix/data/DenseBlockLDRB.java     |  18 +-
 .../sysml/runtime/matrix/data/LibMatrixAgg.java |  82 ++++++++-
 .../runtime/matrix/data/LibMatrixMult.java      |  18 ++
 .../aggregate/RowColProdsAggregateTest.java     | 165 +++++++++++++++++++
 src/test/scripts/functions/aggregate/ColProds.R |  34 ++++
 .../scripts/functions/aggregate/ColProds.dml    |  24 +++
 src/test/scripts/functions/aggregate/RowProds.R |  34 ++++
 .../scripts/functions/aggregate/RowProds.dml    |  24 +++
 .../functions/aggregate/ZPackageSuite.java      |   1 +
 19 files changed, 465 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
index 5afa805..8358a1d 100644
--- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
@@ -370,9 +370,11 @@ public class PartialAggregate extends Lop
                        }
 
                        case Product: {
-                               if( dir == DirectionTypes.RowCol )
-                                       return "ua*";
-                               break;
+                               switch( dir ) {
+                                       case RowCol: return "ua*";
+                                       case Row:    return "uar*";
+                                       case Col:    return "uac*";
+                               }
                        }
                        
                        case Max: {

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index d7d3a4c..e1ac9da 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -389,6 +389,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case COLMAX:
                case COLMIN:
                case COLMEAN:
+               case COLPROD:
                case COLSD:
                case COLVAR:
                        // colSums(X);
@@ -405,6 +406,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case ROWMIN:
                case ROWINDEXMIN:
                case ROWMEAN:
+               case ROWPROD:
                case ROWSD:
                case ROWVAR:
                        //rowSums(X);
@@ -1698,6 +1700,10 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        bifop = Expression.BuiltinFunctionOp.ROWVAR;
                else if (functionName.equals("colVars"))
                        bifop = Expression.BuiltinFunctionOp.COLVAR;
+               else if (functionName.equals("rowProds"))
+                       bifop = Expression.BuiltinFunctionOp.ROWPROD;
+               else if (functionName.equals("colProds"))
+                       bifop = Expression.BuiltinFunctionOp.COLPROD;
                else if (functionName.equals("cummax"))
                         bifop = Expression.BuiltinFunctionOp.CUMMAX;
                else if (functionName.equals("cummin"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 3250883..9b5f02e 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2351,7 +2351,12 @@ public class DMLTranslator
                        currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.MEAN,
                                        Direction.Col, expr);
                        break;
-
+               
+               case COLPROD:
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.PROD,
+                                       Direction.Col, expr);
+                       break;
+               
                case COLSD:
                        // colStdDevs = sqrt(colVariances)
                        currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(),
@@ -2395,6 +2400,11 @@ public class DMLTranslator
                                        Direction.Row, expr);
                        break;
 
+               case ROWPROD:
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(), target.getValueType(), AggOp.PROD,
+                                       Direction.Row, expr);
+                       break;
+
                case ROWSD:
                        // rowStdDevs = sqrt(rowVariances)
                        currBuiltinOp = new AggUnaryOp(target.getName(), 
target.getDataType(),

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index e20a908..0381f2d 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -75,6 +75,7 @@ public abstract class Expression implements ParseInfo
                COLMAX,
                COLMEAN,
                COLMIN,
+               COLPROD,
                COLSD,
                COLSUM,
                COLVAR,
@@ -121,6 +122,7 @@ public abstract class Expression implements ParseInfo
                ROWMAX,
                ROWMEAN, 
                ROWMIN,
+               ROWPROD,
                ROWSD,
                ROWSUM,
                ROWVAR,

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 00ad286..9dbb24e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -98,6 +98,8 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "uar+"    , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uac+"    , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "ua*"     , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uar*"    , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uac*"    , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uatrace" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uaktrace", 
CPType.AggregateUnary);
                String2CPInstructionType.put( "nrow"    ,CPType.AggregateUnary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
index 294827d..3b531dd 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
@@ -260,14 +260,12 @@ public class InstructionUtils
                if ( opcode.equalsIgnoreCase("uak+") ) {
                        AggregateOperator agg = new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
                        aggun = new AggregateUnaryOperator(agg, 
ReduceAll.getReduceAllFnObject(), numThreads);
-               }               
-               else if ( opcode.equalsIgnoreCase("uark+") ) {
-                       // RowSums
+               }
+               else if ( opcode.equalsIgnoreCase("uark+") ) { // RowSums
                        AggregateOperator agg = new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
                        aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
                } 
-               else if ( opcode.equalsIgnoreCase("uack+") ) {
-                       // ColSums
+               else if ( opcode.equalsIgnoreCase("uack+") ) { // ColSums
                        AggregateOperator agg = new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
                        aggun = new AggregateUnaryOperator(agg, 
ReduceRow.getReduceRowFnObject(), numThreads);
                }
@@ -339,6 +337,14 @@ public class InstructionUtils
                        AggregateOperator agg = new AggregateOperator(1, 
Multiply.getMultiplyFnObject());
                        aggun = new AggregateUnaryOperator(agg, 
ReduceAll.getReduceAllFnObject(), numThreads);
                } 
+               else if ( opcode.equalsIgnoreCase("uar*") ) {
+                       AggregateOperator agg = new AggregateOperator(1, 
Multiply.getMultiplyFnObject());
+                       aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
+               } 
+               else if ( opcode.equalsIgnoreCase("uac*") ) {
+                       AggregateOperator agg = new AggregateOperator(1, 
Multiply.getMultiplyFnObject());
+                       aggun = new AggregateUnaryOperator(agg, 
ReduceRow.getReduceRowFnObject(), numThreads);
+               }
                else if ( opcode.equalsIgnoreCase("uamax") ) {
                        AggregateOperator agg = new 
AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
                        aggun = new AggregateUnaryOperator(agg, 
ReduceAll.getReduceAllFnObject(), numThreads);
@@ -808,7 +814,7 @@ public class InstructionUtils
                        return "avar";
                else if ( opcode.equalsIgnoreCase("ua+") || 
opcode.equalsIgnoreCase("uar+") || opcode.equalsIgnoreCase("uac+") )
                        return "a+";
-               else if ( opcode.equalsIgnoreCase("ua*") )
+               else if ( opcode.equalsIgnoreCase("ua*") || 
opcode.equalsIgnoreCase("uar*") || opcode.equalsIgnoreCase("uac*") )
                        return "a*";
                else if ( opcode.equalsIgnoreCase("uatrace") || 
opcode.equalsIgnoreCase("uaktrace") ) 
                        return "aktrace";

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index e587fc6..eb16464 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -122,6 +122,8 @@ public class MRInstructionParser extends InstructionParser
                String2MRInstructionType.put( "uarvar", MRType.AggregateUnary);
                String2MRInstructionType.put( "uacvar", MRType.AggregateUnary);
                String2MRInstructionType.put( "ua*"   , MRType.AggregateUnary);
+               String2MRInstructionType.put( "uar*"  , MRType.AggregateUnary);
+               String2MRInstructionType.put( "uac*"  , MRType.AggregateUnary);
                String2MRInstructionType.put( "uamax" , MRType.AggregateUnary);
                String2MRInstructionType.put( "uamin" , MRType.AggregateUnary);
                String2MRInstructionType.put( "uatrace" , 
MRType.AggregateUnary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 799a77a..5a201cb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -115,6 +115,8 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uar+"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uac+"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "ua*"     , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uar*"    , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uac*"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uatrace" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uaktrace", 
SPType.AggregateUnary);
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java
index 50beb3c..2489f16 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlock.java
@@ -225,8 +225,9 @@ public abstract class DenseBlock implements Serializable
         * Set the given value for the entire dense block (fill).
         * 
         * @param v value
+        * @return self
         */
-       public abstract void set(double v);
+       public abstract DenseBlock set(double v);
        
        /**
         * Set the given value for an entire index range of the 
@@ -237,8 +238,9 @@ public abstract class DenseBlock implements Serializable
         * @param cl column lower index 
         * @param cu column upper index (exclusive)
         * @param v value
+        * @return self
         */
-       public abstract void set(int rl, int ru, int cl, int cu, double v);
+       public abstract DenseBlock set(int rl, int ru, int cl, int cu, double 
v);
        
        
        /**
@@ -247,23 +249,26 @@ public abstract class DenseBlock implements Serializable
         * @param r row index
         * @param c column index
         * @param v value
+        * @return self
         */
-       public abstract void set(int r, int c, double v);
+       public abstract DenseBlock set(int r, int c, double v);
        
        /**
         * Copy the given vector into the given row.
         * 
         * @param r row index
         * @param v value vector
+        * @return self
         */
-       public abstract void set(int r, double[] v);
+       public abstract DenseBlock set(int r, double[] v);
        
        /**
         * Copy the given dense block.
         * 
         * @param db dense block
+        * @return self
         */
-       public abstract void set(DenseBlock db);
+       public abstract DenseBlock set(DenseBlock db);
        
        /**
         * Copy the given dense block into the specified
@@ -274,18 +279,21 @@ public abstract class DenseBlock implements Serializable
         * @param cl column lower index 
         * @param cu column upper index (exclusive)
         * @param db dense block
+        * @return self
         */
-       public abstract void set(int rl, int ru, int cl, int cu, DenseBlock db);
+       public abstract DenseBlock set(int rl, int ru, int cl, int cu, 
DenseBlock db);
        
        
        /**
         * Copy the given kahan object sum and correction.
         * 
         * @param kbuff kahan object
+        * @return self
         */
-       public void set(KahanObject kbuff) {
+       public DenseBlock set(KahanObject kbuff) {
                set(0, 0, kbuff._sum);
                set(0, 1, kbuff._correction);
+               return this;
        }
        
        /**
@@ -294,10 +302,12 @@ public abstract class DenseBlock implements Serializable
         * 
         * @param r row index
         * @param kbuff kahan object
+        * @return self
         */
-       public void set(int r, KahanObject kbuff) {
+       public DenseBlock set(int r, KahanObject kbuff) {
                set(r, 0, kbuff._sum);
                set(r, 1, kbuff._correction);
+               return this;
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java
index 7f2ddb0..979f233 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockDRB.java
@@ -171,31 +171,35 @@ public class DenseBlockDRB extends DenseBlock
        }
 
        @Override
-       public void set(double v) {
+       public DenseBlock set(double v) {
                Arrays.fill(data, 0, rlen*clen, v);
+               return this;
        }
        
        @Override
-       public void set(int rl, int ru, int cl, int cu, double v) {
+       public DenseBlock set(int rl, int ru, int cl, int cu, double v) {
                if( cl==0 && cu == clen )
                        Arrays.fill(data, rl*clen, ru*clen, v);
                else
                        for(int i=rl, ix=rl*clen; i<ru; i++, ix+=clen)
                                Arrays.fill(data, ix+cl, ix+cu, v);
+               return this;
        }
 
        @Override
-       public void set(int r, int c, double v) {
+       public DenseBlock set(int r, int c, double v) {
                data[pos(r, c)] = v;
+               return this;
        }
        
        @Override
-       public void set(DenseBlock db) {
+       public DenseBlock set(DenseBlock db) {
                System.arraycopy(db.valuesAt(0), 0, data, 0, rlen*clen);
+               return this;
        }
        
        @Override
-       public void set(int rl, int ru, int cl, int cu, DenseBlock db) {
+       public DenseBlock set(int rl, int ru, int cl, int cu, DenseBlock db) {
                double[] a = db.valuesAt(0);
                if( cl == 0 && cu == clen)
                        System.arraycopy(a, 0, data, rl*clen+cl, 
(int)db.size());
@@ -204,11 +208,13 @@ public class DenseBlockDRB extends DenseBlock
                        for(int i=rl, ix1=0, ix2=rl*clen+cl; i<ru; i++, 
ix1+=len, ix2+=clen)
                                System.arraycopy(a, ix1, data, ix2, len);
                }
+               return this;
        }
 
        @Override
-       public void set(int r, double[] v) {
+       public DenseBlock set(int r, double[] v) {
                System.arraycopy(v, 0, data, pos(r), clen);
+               return this;
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java
index e041f39..4662c69 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/DenseBlockLDRB.java
@@ -200,13 +200,14 @@ public class DenseBlockLDRB extends DenseBlock
        }
 
        @Override
-       public void set(double v) {
+       public DenseBlock set(double v) {
                for(int i=0; i<numBlocks(); i++)
                        Arrays.fill(data[i], v);
+               return this;
        }
        
        @Override
-       public void set(int rl, int ru, int cl, int cu, double v) {
+       public DenseBlock set(int rl, int ru, int cl, int cu, double v) {
                boolean rowBlock = (cl == 0 && cu == clen);
                final int bil = index(rl);
                final int biu = index(ru-1);
@@ -219,30 +220,35 @@ public class DenseBlockLDRB extends DenseBlock
                                for(int i=lpos; i<lpos+len; i+=clen)
                                        Arrays.fill(data[bi], i+cl, i+cu, v);
                }
+               return this;
        }
 
        @Override
-       public void set(int r, int c, double v) {
+       public DenseBlock set(int r, int c, double v) {
                data[index(r)][pos(r, c)] = v;
+               return this;
        }
 
        @Override
-       public void set(int r, double[] v) {
+       public DenseBlock set(int r, double[] v) {
                System.arraycopy(v, 0, data[index(r)], pos(r), clen);
+               return this;
        }
        
        @Override
-       public void set(DenseBlock db) {
+       public DenseBlock set(DenseBlock db) {
                for(int bi=0; bi<numBlocks(); bi++)
                        System.arraycopy(db.valuesAt(bi), 0, data[bi], 0, 
size(bi));
+               return this;
        }
        
        @Override
-       public void set(int rl, int ru, int cl, int cu, DenseBlock db) {
+       public DenseBlock set(int rl, int ru, int cl, int cu, DenseBlock db) {
                for(int i=rl; i<ru; i++) {
                        System.arraycopy(db.values(i-rl),
                                db.pos(i-rl), values(i), pos(i, cl), cu-cl);
                }
+               return this;
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index 2d81255..72b90fa 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -589,7 +589,8 @@ public class LibMatrixAgg
                }
 
                //prod
-               if( vfn instanceof Multiply && ifn instanceof ReduceAll )
+               if( vfn instanceof Multiply 
+                       && (ifn instanceof ReduceAll || ifn instanceof 
ReduceCol || ifn instanceof ReduceRow))
                {
                        return AggType.PROD;
                }
@@ -1350,6 +1351,10 @@ public class LibMatrixAgg
                        case PROD: { //PROD
                                if( ixFn instanceof ReduceAll ) // PROD
                                        d_uam(a, c, n, rl, ru );
+                               else if( ixFn instanceof ReduceCol ) //ROWPROD
+                                       d_uarm(a, c, n, rl, ru);
+                               else if( ixFn instanceof ReduceRow ) //COLPROD
+                                       d_uacm(a, c, n, rl, ru);
                                break;
                        }
                        
@@ -1452,6 +1457,10 @@ public class LibMatrixAgg
                        case PROD: { //PROD
                                if( ixFn instanceof ReduceAll ) // PROD
                                        s_uam(a, c, n, rl, ru );
+                               else if( ixFn instanceof ReduceCol ) // ROWPROD
+                                       s_uarm(a, c, n, rl, ru );
+                               else if( ixFn instanceof ReduceRow ) // COLPROD
+                                       s_uacm(a, c, n, rl, ru );
                                break;
                        }
 
@@ -2084,7 +2093,36 @@ public class LibMatrixAgg
                }
                c.set(0, 0, tmp);
        }
+
+       /**
+        * ROWPROD, opcode: uar*, dense input.
+        * 
+        * @param a ?
+        * @param c ?
+        * @param n ?
+        * @param rl row lower index
+        * @param ru row upper index
+        */
+       private static void d_uarm( DenseBlock a, DenseBlock c, int n, int rl, 
int ru ) {
+               double[] lc = c.valuesAt(0);
+               for( int i=rl; i<ru; i++ )
+                       lc[i] = product(a.values(i), a.pos(i), n);
+       }
        
+       /**
+        * COLPROD, opcode: uac*, dense input.
+        * 
+        * @param a ?
+        * @param c ?
+        * @param n ?
+        * @param rl row lower index
+        * @param ru row upper index
+        */
+       private static void d_uacm( DenseBlock a, DenseBlock c, int n, int rl, 
int ru ) {
+               double[] lc = c.set(1).valuesAt(0); //guaranteed single row
+               for( int i=rl; i<ru; i++ )
+                       LibMatrixMult.vectMultiplyWrite(a.values(i), lc, lc, 
a.pos(i), 0, 0, n);
+       }
        
        /**
         * SUM, opcode: uak+, sparse input. 
@@ -2808,6 +2846,48 @@ public class LibMatrixAgg
                c.set(0, 0, ret);
        }
        
+       /**
+        * ROWPROD, opcode: uar*, sparse input.
+        * 
+        * @param a ?
+        * @param c ?
+        * @param n ?
+        * @param rl row lower index
+        * @param ru row upper index
+        */
+       private static void s_uarm( SparseBlock a, DenseBlock c, int n, int rl, 
int ru ) {
+               double[] lc = c.valuesAt(0);
+               for( int i=rl; i<ru; i++ ) {
+                       if( !a.isEmpty(i) ) {
+                               int alen = a.size(i);
+                               double tmp = product(a.values(i), 0, alen);
+                               lc[i] = tmp * ((alen<n) ? 0 : 1);
+                       }
+               }
+       }
+       
+       /**
+        * COLPROD, opcode: uac*, sparse input.
+        * 
+        * @param a ?
+        * @param c ?
+        * @param n ?
+        * @param rl row lower index
+        * @param ru row upper index
+        */
+       private static void s_uacm( SparseBlock a, DenseBlock c, int n, int rl, 
int ru ) {
+               double[] lc = c.set(1).valuesAt(0);
+               int[] cnt = new int[ n ]; 
+               for( int i=rl; i<ru; i++ ) {
+                       if( a.isEmpty(i) ) continue;
+                       countAgg(a.values(i), cnt, a.indexes(i), a.pos(i), 
a.size(i));
+                       LibMatrixMult.vectMultiplyWrite(lc, a.values(i), lc, 0, 
a.pos(i), 0, a.size(i));
+               }
+               for( int j=0; j<n; j++ )
+                       if( cnt[j] < ru-rl )
+                               lc[j] *= 0;
+       }
+       
        
        ////////////////////////////////////////////
        // performance-relevant utility functions //

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index ef273f6..adc955e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -3252,6 +3252,24 @@ public class LibMatrixMult
                        c[ ci+7 ] = a[ ai+7 ] * b[ bi+7 ];
                }
        }
+       
+       public static void vectMultiplyWrite( final double[] a, double[] b, 
double[] c, int[] bix, final int ai, final int bi, final int ci, final int len 
) {
+               final int bn = len%8;
+               //rest, not aligned to 8-blocks
+               for( int j = bi; j < bi+bn; j++ )
+                       c[ ci+bix[j] ] = a[ ai+bix[j] ] * b[ j ];
+               //unrolled 8-block (for better instruction-level parallelism)
+               for( int j = bi+bn; j < bi+len; j+=8 ) {
+                       c[ ci+bix[j+0] ] = a[ ai+bix[j+0] ] * b[ j+0 ];
+                       c[ ci+bix[j+1] ] = a[ ai+bix[j+1] ] * b[ j+1 ];
+                       c[ ci+bix[j+2] ] = a[ ai+bix[j+2] ] * b[ j+2 ];
+                       c[ ci+bix[j+3] ] = a[ ai+bix[j+3] ] * b[ j+3 ];
+                       c[ ci+bix[j+4] ] = a[ ai+bix[j+4] ] * b[ j+4 ];
+                       c[ ci+bix[j+5] ] = a[ ai+bix[j+5] ] * b[ j+5 ];
+                       c[ ci+bix[j+6] ] = a[ ai+bix[j+6] ] * b[ j+6 ];
+                       c[ ci+bix[j+7] ] = a[ ai+bix[j+7] ] * b[ j+7 ];
+               }
+       }
 
        private static void vectMultiply( double[] a, double[] c, int ai, int 
ci, final int len )
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java
new file mode 100644
index 0000000..97ae324
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/RowColProdsAggregateTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RowColProdsAggregateTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "RowProds";
+       private final static String TEST_NAME2 = "ColProds";
+       
+       private final static String TEST_DIR = "functions/aggregate/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RowColProdsAggregateTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-10;
+       
+       private final static int dim1 = 1079;
+       private final static int dim2 = 15;
+       private final static double sparsity1 = 0.1;
+       private final static double sparsity2 = 1.0; //otherwise 0 output
+       
+       @Override
+       public void setUp() 
+       {
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"})); 
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"})); 
+               if (TEST_CACHE_ENABLED) {
+                       setOutAndExpectedDeletionDisabled(true);
+               }
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @AfterClass
+       public static void cleanUp() {
+               if (TEST_CACHE_ENABLED) {
+                       TestUtils.clearDirectory(TEST_DATA_DIR + 
TEST_CLASS_DIR);
+               }
+       }
+       
+       @Test
+       public void testRowProdsDenseMatrixCP() {
+               runProdsAggregateTest(TEST_NAME1, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testRowProdsSparseMatrixCP() {
+               runProdsAggregateTest(TEST_NAME1, true, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testRowProdsDenseMatrixSP() {
+               runProdsAggregateTest(TEST_NAME1, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testRowProdsSparseMatrixSP() {
+               runProdsAggregateTest(TEST_NAME1, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testColProdsDenseMatrixCP() {
+               runProdsAggregateTest(TEST_NAME2, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testColProdsSparseMatrixCP() {
+               runProdsAggregateTest(TEST_NAME2, true, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testColProdsDenseMatrixSP() {
+               runProdsAggregateTest(TEST_NAME2, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testColProdsSparseMatrixSP() {
+               runProdsAggregateTest(TEST_NAME2, true, true, ExecType.SPARK);
+       }
+
+       private void runProdsAggregateTest(String TEST_NAME, boolean sparse, 
boolean rewrites, ExecType instType)
+       {
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( instType ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
+       
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               boolean oldRewritesFlag = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+               
+               try
+               {
+                       double sparsity = (sparse) ? sparsity1 : sparsity2;
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       
+                       String TEST_CACHE_DIR = "";
+                       if (TEST_CACHE_ENABLED) {
+                               TEST_CACHE_DIR = TEST_NAME + "_" + sparsity + 
"/";
+                       }
+                       
+                       loadTestConfiguration(config, TEST_CACHE_DIR);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "-args", 
input("A"),  output("B") };
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
+                       
+                       int rows = TEST_NAME.equals(TEST_NAME1) ? dim1 : dim2;
+                       int cols = TEST_NAME.equals(TEST_NAME1) ? dim2 : dim1;
+                       double[][] A = getRandomMatrix(rows, cols, 0.9, 1, 
sparsity, 1234);
+                       writeInputMatrixWithMTD("A", A, true);
+                       
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+                       
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("B");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewritesFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/ColProds.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/ColProds.R 
b/src/test/scripts/functions/aggregate/ColProds.R
new file mode 100644
index 0000000..2729498
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/ColProds.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+if(!("matrixStats" %in% rownames(installed.packages()))){
+   install.packages("matrixStats")
+}
+
+library("Matrix")
+library("matrixStats") 
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- t(colProds(A));
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/ColProds.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/ColProds.dml 
b/src/test/scripts/functions/aggregate/ColProds.dml
new file mode 100644
index 0000000..f2d3b3c
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/ColProds.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = colProds(A);
+write(B, $2);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/RowProds.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/RowProds.R 
b/src/test/scripts/functions/aggregate/RowProds.R
new file mode 100644
index 0000000..0cc6689
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/RowProds.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+if(!("matrixStats" %in% rownames(installed.packages()))){
+   install.packages("matrixStats")
+}
+
+library("Matrix")
+library("matrixStats") 
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- rowProds(A);
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test/scripts/functions/aggregate/RowProds.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/RowProds.dml 
b/src/test/scripts/functions/aggregate/RowProds.dml
new file mode 100644
index 0000000..bb9ceb1
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/RowProds.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = rowProds(A);
+write(B, $2);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/c33e066a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
index c99fe0a..4aedf4d 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
@@ -43,6 +43,7 @@ import org.junit.runners.Suite;
        NRowTest.class,
        ProdTest.class,
        PushdownSumBinaryTest.class,
+       RowColProdsAggregateTest.class,
        RowStdDevsTest.class,
        RowSumsSqTest.class,
        RowSumTest.class,

Reply via email to